In [1]:
from utils import SurrogatPyRepEnvironment
from basic_walk.utils import BaseAgent
import sys
import time

from multiprocessing import Pool
import pickle
import matplotlib.pyplot as plt
import numpy as np

from tqc import structures, DEVICE
from tqc.trainer import Trainer
from tqc.structures import Actor, Critic, RescaleAction
from tqc.functions import eval_policy
from tqdm import tqdm
import copy

import warnings
warnings.filterwarnings('ignore')

## Заполняем буффер реплеев

In [2]:
def generate_correct_replays(goal_len, episode_length, seed=42):
    np.random.seed(seed)
    agent = BaseAgent(random_mode=True, foot_only_mode=True)
    replay_buffer = []
    with SurrogatPyRepEnvironment('scenes/basic_scene.ttt', headless_mode=True, foot_only_mode=True) as env:
        env = RescaleAction(env, -1., 1.)
        state, done = env.reset(), False
        episode_timesteps = 0
        last_replay = []

        while len(replay_buffer) < goal_len:
            episode_timesteps += 1
            action = agent.act(state)
            action = np.array(action)
            action /= np.pi

            next_state, reward, done, _ = env.step(action)
            last_replay.append((state, action, next_state, reward, done))
            state = next_state

            if done or episode_timesteps >= episode_length:
                if not done:
                    replay_buffer += last_replay
                agent = BaseAgent(random_mode=True, foot_only_mode=True)
                state, done = env.reset(), False
                episode_timesteps = 0
                last_replay = []
    return replay_buffer

In [3]:
def generate_correct_replays_in_parallel(goal_len=1200, episode_length=600, n_proc=2, seed=42):
    np.random.seed(seed)
    rng = np.random.default_rng()
    seeds = rng.choice(100000, size=n_proc, replace=False)
    
    assert goal_len % (episode_length * n_proc) == 0
    sub_process_goal_len = goal_len / n_proc
    replay_buffer = []
    
    with Pool(processes=n_proc) as pool:
        multiple_results = []
        for seed in seeds:
            res = pool.apply_async(generate_correct_replays, (sub_process_goal_len, episode_length, seed))
            multiple_results.append(res)
        for i, res in enumerate(multiple_results):
            res.wait()
            print("finish", i)

        for res in multiple_results:
            subreplay = res.get()
            replay_buffer += subreplay
    return replay_buffer

In [4]:
%%time
goal_len = 96000
episode_length = 500
seed = 42

replay_buffer_arr = generate_correct_replays_in_parallel(
    goal_len=goal_len,
    episode_length=episode_length,
    seed=seed,
    n_proc=8
)
# file_name = f"replay_buffers/replay_buffer_array_{goal_len}_{episode_length}_.pickle"
# with open(file_name, 'wb') as f:
#     pickle.dump(replay_buffer_arr, f)
    
with SurrogatPyRepEnvironment('scenes/basic_scene.ttt', headless_mode=True, foot_only_mode=True) as env:
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    
    
    replay_buffer = structures.ReplayBuffer(state_dim, action_dim)
    
    for q in replay_buffer_arr:
        replay_buffer.add(*q)

    file_name = f"replay_buffers/replay_buffer_foot_only_{goal_len}_{episode_length}_.pickle"
    with open(file_name, 'wb') as f:
        pickle.dump(replay_buffer, f)

finish 0
finish 1
finish 2
finish 3
finish 4
finish 5
finish 6
finish 7
CPU times: user 2.54 s, sys: 871 ms, total: 3.41 s
Wall time: 13min 39s
