In [None]:
import os
import shutil
import time

from offline_dataset.dataset_creater import GymParallelSampler

from envs.env_creator import ibgym_env_creator, env_creator, IBGymModelQ_creator
from state_quantization.transforms import quantize_transform_creator
from q_learning.algorithm import QLPolicy
from ppo.policy import LSTMPPOPolicy

In [None]:
episodes = 1000
steps_per_episode = 1000
workers = 8

writer_path = os.path.join("tmp", "ibqf-out")
policy_save_path = 'tmp/q_learning/mb_q_policy_best_model_aeq-16bits_203871.pkl'

quant_model = 'model_h_c-20bits3'
model_path = f'tmp/state_quantization/{quant_model}'
q_transform_kwargs = {'device': 'cpu', 'keys': ['obs', 'new_obs'], 'reshape': (steps_per_episode, -1, 6),
                      'model_path': model_path}
use_policy = False

In [None]:
if os.path.exists(writer_path) and os.path.isdir(writer_path):
    shutil.rmtree(writer_path)

In [None]:


start = time.time()
if use_policy:
    policy = QLPolicy.load(policy_save_path)
    env_kwargs = {'steps_per_episode': steps_per_episode,'model_path':model_path}
    parallel_sampler = GymParallelSampler(env_creator=env_creator, path=writer_path, episodes=episodes,
                                      workers=workers, env_kwargs=env_kwargs, reward_threshold=None,
                                      policy=policy)
else:
    env_kwargs = {'steps_per_episode': steps_per_episode}
    parallel_sampler = GymParallelSampler(env_creator=ibgym_env_creator, path=writer_path, episodes=episodes,
                                      workers=workers, env_kwargs=env_kwargs, reward_threshold=None,
                                      buffer_transform=quantize_transform_creator,
                                      buffer_transform_kwargs=q_transform_kwargs,
                                      policy=None)

In [None]:
parallel_sampler.sample()
end = time.time()
print(end - start)

In [None]:
save_path = os.path.join("tmp", "offline_rl_trajectories",quant_model,f"{episodes}","rl_dataset.npy")
parallel_sampler.create_merged_dataset(save_path=save_path)