In [1]:
import os
import shutil
import time

from offline_dataset.dataset_creater import GymParallelSampler

from envs.env_creator import ibgym_env_creator, env_creator, IBGymModelQ_creator
from state_quantization.transforms import quantize_transform_creator
from q_learning.algorithm import QLPolicy
from ppo.policy import LSTMPPOPolicy

In [2]:
episodes = 1000
steps_per_episode = 1000
workers = 8
#env_kwargs = {'steps_per_episode': steps_per_episode, 'device':'cpu'}

writer_path = os.path.join("tmp", "ibqf-out")
policy_save_path = 'tmp/q_learning/mb_q_policy_best_model_aeq-16bits_203871.pkl'
policy = QLPolicy.load(policy_save_path)
quant_model = 'model_aeq-16bits'
model_path = f'tmp/state_quantization/{quant_model}'
q_transform_kwargs = {'device': 'cpu', 'keys': ['obs', 'new_obs'], 'reshape': (steps_per_episode, -1, 6),
                      'model_path': model_path}
use_policy = False

In [3]:
if os.path.exists(writer_path) and os.path.isdir(writer_path):
    shutil.rmtree(writer_path)

In [4]:


start = time.time()
if use_policy:

    env_kwargs = {'steps_per_episode': steps_per_episode,'model_path':model_path}
    parallel_sampler = GymParallelSampler(env_creator=env_creator, path=writer_path, episodes=episodes,
                                      workers=workers, env_kwargs=env_kwargs, reward_threshold=None,
                                      policy=policy)
else:
    env_kwargs = {'steps_per_episode': steps_per_episode}
    parallel_sampler = GymParallelSampler(env_creator=ibgym_env_creator, path=writer_path, episodes=episodes,
                                      workers=workers, env_kwargs=env_kwargs, reward_threshold=None,
                                      buffer_transform=quantize_transform_creator,
                                      buffer_transform_kwargs=q_transform_kwargs,
                                      policy=None)



In [5]:
parallel_sampler.sample()
end = time.time()
print(end - start)

Episodes Sampled: 37
Episodes Sampled: 38
Episodes Sampled: 40
Episodes Sampled: 44
Episodes Sampled: 46
Episodes Sampled: 47
Episodes Sampled: 52
Episodes Sampled: 54
Episodes Sampled: 57
Episodes Sampled: 60
Episodes Sampled: 63
Episodes Sampled: 66
Episodes Sampled: 68
Episodes Sampled: 71
Episodes Sampled: 75
Episodes Sampled: 76
Episodes Sampled: 79
Episodes Sampled: 83
Episodes Sampled: 84
Episodes Sampled: 86
Episodes Sampled: 91
Episodes Sampled: 92
Episodes Sampled: 94
Episodes Sampled: 99
Episodes Sampled: 100
Episodes Sampled: 102
Episodes Sampled: 107
Episodes Sampled: 108
Episodes Sampled: 110
Episodes Sampled: 115
Episodes Sampled: 116
Episodes Sampled: 118
Episodes Sampled: 123
Episodes Sampled: 123
Episodes Sampled: 124
Episodes Sampled: 129
Episodes Sampled: 131
Episodes Sampled: 132
Episodes Sampled: 137
Episodes Sampled: 139
Episodes Sampled: 140
Episodes Sampled: 145
Episodes Sampled: 147
Episodes Sampled: 147
Episodes Sampled: 153
Episodes Sampled: 155
Episodes Sam

  logger.warn(
Process GymEnvSamplerProcess-5:
Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/hamza/PycharmProjects/StateCompression/offline_dataset/dataset_creater.py", line 61, in run
    new_obs, rew, done, info = self.env.step(action)
  File "/home/hamza/PycharmProjects/StateCompression/envs/IBGym_mod_envs.py", line 340, in step
    discrete_obs = self.lstm_quantize(self.last_observation)[0]
  File "/home/hamza/PycharmProjects/StateCompression/state_quantization/transforms.py", line 70, in __call__
    x = self.normalize_transformer.transform(x)
KeyboardInterrupt
  logger.warn(
Process GymEnvSamplerProcess-8:
Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/home/hamza/PycharmProjects/StateCompression/offline_dataset/dataset_creater.py", line 61, in run
    new_obs, rew, done, info = self.env.step(a

KeyboardInterrupt: 

In [6]:
save_path = os.path.join("tmp", "offline_rl_trajectories",f"trajectory_ep{episodes}_{quant_model}.npy")
parallel_sampler.create_merged_dataset(save_path=save_path)

tmp/offline_rl_trajectories/trajectory_ep1000_model_aeq-16bits.npy


{'obs': array([37597., 49877., 58069., ..., 62109., 62109., 62141.]),
 'actions': array([ 1,  9, 15, ...,  3,  3, 15]),
 'rewards': array([-208.12060547, -226.02709961, -248.12010193, ..., -229.80134583,
        -228.11502075, -229.41885376]),
 'dones': array([False, False, False, ..., False, False,  True]),
 'new_obs': array([49877., 58069., 49877., ..., 62109., 62141., 53949.]),
 'unroll_id': array([ 0,  0,  0, ..., 66, 66, 66])}