In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
import pickle 
from stable_baselines3 import SAC
from stable_baselines3.ppo import MlpPolicy

import gym
from stable_baselines3.common.policies  import ActorCriticPolicy, ActorCriticCnnPolicy
from stable_baselines3.common.evaluation import evaluate_policy
import numpy as np

from stable_baselines3.common.env_util import make_vec_env 
import pickle

  from urllib3.contrib.pyopenssl import orig_util_SSLContext as SSLContext


### Using SAC Trained Model
* downloaded from https://huggingface.co/sb3

In [3]:
def collect_trajectory(env, policy_net, n_trajectory=20, epsilon=0.0):
    trajectories=[] 
    scores=[]
    for episode in range(n_trajectory):
        state,info = env.reset()
        score = 0 
        states=[]
        actions=[]
        rewards=[]
        while True:
            if epsilon==0:
                action, _states = policy_net.predict(state, deterministic=True)
            else:
                if np.random.random() > (1-epsilon):
                    action=env.action_space.sample() 
                else:
                    action, _states = policy_net.predict(state, deterministic=True)
            
            next_state, reward, done,s, _ = env.step(action)
            score+=reward
             
            states.append(state)
            actions.append(action)
            rewards.append(reward)
            state = next_state 
            if done or s:  
                scores.append(score)
                break
        
        states=np.vstack(states)
        actions=np.vstack(actions)
        rewards=np.vstack(rewards)
        trajectories.append((states, actions, rewards))  
        if episode % 1 == 0:
            print('{} episode score is {:.2f}'.format(episode, score))
    env.close()
    return scores, trajectories

### Ant

In [4]:
savepath= "trained_models/sb3_sac_Ant-v3.zip"
model=SAC.load(savepath, print_system_info=False)

In [5]:
env_name='Ant-v3' 
env_t = gym.make(env_name)

In [6]:
print('collecting trajectories ...')
n_trajectory=10
scores, trajectories = collect_trajectory(env_t, model, n_trajectory)
mean=np.mean(scores)
print('mean score:', mean)

collecting trajectories ...
0 episode score is 5246.79
1 episode score is 3756.22
2 episode score is 4987.11
3 episode score is 5392.41
4 episode score is 5097.51
5 episode score is 5185.53
6 episode score is 5309.34
7 episode score is 5258.45
8 episode score is 5468.29
9 episode score is 4914.24
mean score: 5061.589123630422


In [7]:
filename=f"expert_data/{env_name}_{n_trajectory}_{int(mean)}.pkl"
print('saving ',filename)

with open(filename, 'wb') as f:
    pickle.dump(trajectories, f)
print('trajectories saved.')

saving  expert_data/Ant-v3_10_5061.pkl
trajectories saved.


### Halfcheetah

In [8]:
env_name='HalfCheetah-v3' 
env_t = gym.make(env_name)

savepath= "trained_models/sb3_sac_HalfCheetah-v3.zip"
model=SAC.load(savepath, print_system_info=False)

In [9]:
print('collecting trajectories ...')
n_trajectory=10
scores, trajectories = collect_trajectory(env_t, model, n_trajectory)
mean=np.mean(scores)
print('mean score:', mean)

collecting trajectories ...
0 episode score is 9545.70
1 episode score is 9369.82
2 episode score is 9424.86
3 episode score is 9485.17
4 episode score is 9508.30
5 episode score is 9501.30
6 episode score is 9530.53
7 episode score is 9674.54
8 episode score is 9502.12
9 episode score is 9441.51
mean score: 9498.384559390988


In [10]:
filename=f"expert_data/{env_name}_{n_trajectory}_{int(mean)}.pkl"
print('saving ',filename)

with open(filename, 'wb') as f:
    pickle.dump(trajectories, f)
print('trajectories saved.')

saving  expert_data/HalfCheetah-v3_10_9498.pkl
trajectories saved.


### Humanoid

In [11]:
env_name='Humanoid-v3' 
env_t = gym.make(env_name)

savepath= "trained_models/sb3_sac_Humanoid-v3.zip"
model=SAC.load(savepath, print_system_info=False)

In [12]:
print('collecting trajectories ...')
n_trajectory=10
scores, trajectories = collect_trajectory(env_t, model, n_trajectory)
mean=np.mean(scores)
print('mean score:', mean)

collecting trajectories ...
0 episode score is 6263.36
1 episode score is 6255.51
2 episode score is 6288.52
3 episode score is 6249.82
4 episode score is 6265.81
5 episode score is 6258.31
6 episode score is 6238.61
7 episode score is 6086.96
8 episode score is 6184.13
9 episode score is 6264.66
mean score: 6235.567992548466


In [13]:
filename=f"expert_data/{env_name}_{n_trajectory}_{int(mean)}.pkl"
print('saving ',filename)

with open(filename, 'wb') as f:
    pickle.dump(trajectories, f)
print('trajectories saved.')

saving  expert_data/Humanoid-v3_10_6235.pkl
trajectories saved.
