In [1]:
from utils import SurrogatPyRepEnvironment
from basic_walk.utils import BaseAgent
import sys
import time

from multiprocessing import Pool
import pickle
import matplotlib.pyplot as plt
import numpy as np

from tqc import structures, DEVICE
from tqc.trainer import Trainer
from tqc.structures import Actor, Critic, RescaleAction
from tqc.functions import eval_policy
from tqdm import tqdm
import copy

import warnings
warnings.filterwarnings('ignore')

## Заполняем буффер реплеев

In [3]:
def generate_correct_replays(name, goal_len, episode_length, seed=42):
    print(f"{name} started with goal_len: {goal_len}")
    np.random.seed(seed)
    agent = BaseAgent(random_mode=True, foot_only_mode=True)
    replay_buffer = []
    with SurrogatPyRepEnvironment('scenes/basic_scene.ttt', headless_mode=True, foot_only_mode=True) as env:
        env = RescaleAction(env, -1., 1.)
        state, done = env.reset(), False
        episode_timesteps = 0
        last_replay = []
        percentage_to_achive = 0.05

        while len(replay_buffer) < goal_len:
            episode_timesteps += 1
            action = agent.act(state)
            action = np.array(action)
            action /= np.pi

            next_state, reward, done, _ = env.step(action)
            last_replay.append((state, action, next_state, reward, done))
            state = next_state
            
            
            if done or episode_timesteps >= episode_length:
                if not done:
                    replay_buffer += last_replay
                    print(f"{name}: {len(replay_buffer)} ({float(len(replay_buffer)) / goal_len})")
                    if float(len(replay_buffer)) / goal_len >= percentage_to_achive:
                        print(f"{name}: {int(percentage_to_achive*100)}%")
                        percentage_to_achive += 0.05
                    
                agent = BaseAgent(random_mode=True, foot_only_mode=True)
                state, done = env.reset(), False
                episode_timesteps = 0
                last_replay = []
            

    return replay_buffer

In [4]:
def generate_correct_replays_in_parallel(goal_len=1200, episode_length=600, n_proc=2, seed=42):
    np.random.seed(seed)
    rng = np.random.default_rng()
    seeds = rng.choice(100000, size=n_proc, replace=False)
    
    assert goal_len % (episode_length * n_proc) == 0
    sub_process_goal_len = goal_len / n_proc
    replay_buffer = []
    
    with Pool(processes=n_proc) as pool:
        multiple_results = []
        for i, seed in enumerate(seeds):
            res = pool.apply_async(generate_correct_replays, (i, sub_process_goal_len, episode_length, seed))
            multiple_results.append(res)
        for i, res in enumerate(multiple_results):
            res.wait()
            print("finish", i)

        for res in multiple_results:
            subreplay = res.get()
            replay_buffer += subreplay
    return replay_buffer

In [5]:
%%time
goal_len = 160000
# goal_len = 40000
episode_length = 500
seed = 42

replay_buffer_arr = generate_correct_replays_in_parallel(
    goal_len=goal_len,
    episode_length=episode_length,
    seed=seed,
    n_proc=8
)
buffer_name = "replay_buffer_leg_only"

file_name = f"data/replay_buffers/{buffer_name}_array_{goal_len}_{episode_length}.pickle"
with open(file_name, 'wb') as f:
    pickle.dump(replay_buffer_arr, f)

with SurrogatPyRepEnvironment('scenes/basic_scene.ttt', headless_mode=True, foot_only_mode=True) as env:
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    
    replay_buffer = structures.ReplayBuffer(state_dim, action_dim)
    
    for q in replay_buffer_arr:
        replay_buffer.add(*q)

    file_name = f"data/replay_buffers/{buffer_name}_{goal_len}_{episode_length}.pickle"
    with open(file_name, 'wb') as f:
        pickle.dump(replay_buffer, f)

1 started with goal_len: 20000.02 started with goal_len: 20000.0

0 started with goal_len: 20000.03 started with goal_len: 20000.04 started with goal_len: 20000.07 started with goal_len: 20000.0
5 started with goal_len: 20000.0


6 started with goal_len: 20000.0

7: 500 (0.025)
6: 500 (0.025)
1: 500 (0.025)
3: 500 (0.025)
4: 500 (0.025)
0: 500 (0.025)
1: 1000 (0.05)
1: 5%
3: 1000 (0.05)
3: 5%
6: 1000 (0.05)
6: 5%
7: 1000 (0.05)
7: 5%
4: 1000 (0.05)
4: 5%
5: 500 (0.025)
2: 500 (0.025)
0: 1000 (0.05)
0: 5%
4: 1500 (0.075)
5: 1000 (0.05)
5: 5%
7: 1500 (0.075)
1: 1500 (0.075)
2: 1000 (0.05)
2: 5%
7: 2000 (0.1)
7: 10%
0: 1500 (0.075)
3: 1500 (0.075)
2: 1500 (0.075)
6: 1500 (0.075)
4: 2000 (0.1)
4: 10%
5: 1500 (0.075)
3: 2000 (0.1)
3: 10%
0: 2000 (0.1)
0: 10%
6: 2000 (0.1)
6: 10%
4: 2500 (0.125)
5: 2000 (0.1)
5: 10%
7: 2500 (0.125)
2: 2000 (0.1)
2: 10%
1: 2000 (0.1)
1: 10%
6: 2500 (0.125)
3: 2500 (0.125)
0: 2500 (0.125)
2: 2500 (0.125)
5: 2500 (0.125)
3: 3000 (0.15)
6: 3000 (0.15)
4: 3000 (0

In [6]:
replay_buffer.size

160000