In [1]:
import os
import sys
import time
import subprocess
from PIL import Image
import numpy as np
import torch
import yaml
from torchvision import transforms
from experiment import VAEXperiment
from models import *

import gymnasium as gym
from gymnasium import ObservationWrapper
from gymnasium.wrappers import PixelObservationWrapper, FrameStack
from gymnasium.spaces import Box, Discrete

from stable_baselines3 import SAC, PPO, A2C
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.utils import set_random_seed
from stable_baselines3.common.callbacks import CallbackList, CheckpointCallback, EvalCallback
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.evaluation import evaluate_policy


In [2]:
def get_vae(version='version_0',log_directory='logs/BCE_test_VAE_1/MSSIMVAE/',
            hparam_path = "configs/bces_vae.yaml"):
    #model_path= log_directory+'/'+version+'/hparams.yaml'
    ckpt_path=log_directory+'/'+version+'/checkpoints/last.ckpt'

    config = yaml.safe_load(open(hparam_path))
    model = vae_models[config['model_params']['name']](**config['model_params'])
    ckpt = torch.load(ckpt_path)
    experiment = VAEXperiment(model, config['exp_params'])
    experiment.load_state_dict(ckpt['state_dict'])      
    vae = experiment.model
    return vae

#Make a funciton to create environment, this allows to vectorize it
def make_env(env_id: str = "MountainCarContinuous-v0", rank: int = 0, seed: int = 42, 
            data_dir: str = "Data/MountainCar/test2", collect_frames: bool = True, env_iterator: int = 0,
            vae_version: int = 0, latent_dim: int = 1,
            vae_directory: str = 'logs/MountainCar/BCE_test_VAE_1/MSSIMVAE/',
            hparam_path: str = "configs/bces_no_pretrained.yaml"):
    def _init():
        save_path= data_dir+'/train_env_id_'+str(env_iterator)+'_nenv_'+str(rank)+'_'
        vae = get_vae(version='version_'+str(vae_version),
                      log_directory = vae_directory,
                      hparam_path = hparam_path)
        
        env = gym.make(env_id,
                    render_mode ='rgb_array')
        
        seed = 42
        env.reset(seed=seed + rank)
        env = PixelObservationWrapper(env)
        if collect_frames:
            env = frame_saver(env, save_path)
        env = VAE_ENC(env, vae, latent_dim)
        env = FrameStack(env, num_stack=2)
        env = Monitor(env)
        return env
    set_random_seed(seed)
    return _init

In [4]:
#save frame wrapper class for env
class frame_saver(ObservationWrapper):
    def __init__(self, env,
                 collect_frames_dir = None,
                 start_index = 0):
        super().__init__(env)
        
        self.collect_frames_dir = collect_frames_dir
        self.frame_idx = start_index
                
        
    def observation(self, obs):
        frame = obs['pixels']#.to('cuda')
        if self.collect_frames_dir != None:
            im = Image.fromarray(np.array(frame))
            im.save(self.collect_frames_dir+'_'+str(self.frame_idx)+'.jpeg')
            self.frame_idx += 1
        return obs

# to add Gaussian noise to the observations
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=0.1):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

#create VAE wrapper
class VAE_ENC(ObservationWrapper):
    def __init__(self, env, vae, latent_dim,
                 mean=0,std=0.1,
                 size=(64,64),
                 start_index = 0):
        super().__init__(env)
        #new obs space with std
        #self.observation_space = Box(shape=(2, latent_dim), low=-np.inf, high=np.inf)
        #just mean
        self.observation_space = Box(shape=(latent_dim,), low=-np.inf, high=np.inf)
        
        self.vae = vae
        #transforms
        self.mean = mean
        self.std = std
        self.size = size
        
        self.frame_idx = start_index
        
        
        
    def observation(self, obs):
        #get frame
        #print(obs)
        frame = obs['pixels']#.to('cuda')
        #transform for VAE
        val_transforms = transforms.Compose([transforms.ToTensor(),
        #transforms.RandomHorizontalFlip(),
        AddGaussianNoise(self.mean, self.std),
        transforms.Resize(self.size),
        #transforms.Grayscale(),
        #transforms.Normalize(self.mean, self.std),
        ])
        frame = val_transforms(frame) #(c,h,w)
        frame = torch.unsqueeze(frame, 0)#.to(self.device) #make it (1,c,h,w)
        enc = self.vae.encode(frame)    
        enc = np.array([tensor.detach().cpu().numpy() for tensor in enc])
        #with std
        #enc = np.array([enc[0][0], enc[1][0]]) ## mu, std #  give only mu?
        #just mean
        enc = np.array(enc[0][0])

        return enc

In [73]:
n_envs = 4
vae_version = 8
vae_name = "BCE_VAE_l1_test1_A2C"
vae_directory = 'logs/MountainCar/BCE_VAE_1_test1/MSSIMVAE/' # directory for versions of the vae
config_path = "configs/bces_no_pretrained.yaml"
data_name='A2C_l1_envTest'
save_path='Data/MountainCar/'+data_name+'/'

env = DummyVecEnv([make_env(env_id = "MountainCarContinuous-v0", rank=i, 
data_dir = save_path, collect_frames = False, env_iterator = 0,
vae_version = vae_version,
vae_directory = vae_directory,
hparam_path = config_path) for i in range(n_envs)])

In [76]:
agent_model_dir = "RLmodels/MountainCarContinuous-v0/Double_loop"#where to save the RL agents
agent_log_dir = agent_model_dir+"/logs"
n_steps = 5 
agent = A2C(
    device = 'cpu',
    env = env,
    n_steps= n_steps,           
    policy='MlpPolicy',
    ent_coef= 0.0,
    use_sde=True,
    sde_sample_freq = 16,
    policy_kwargs= dict(log_std_init=0.0, ortho_init=False),
    tensorboard_log=agent_log_dir,
    verbose=2)

Using cpu device


In [77]:
agent.learn(total_timesteps=10000, tb_log_name='eval_test')

Logging to RLmodels/MountainCarContinuous-v0/Double_loop/logs/eval_test_3
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 423      |
|    ep_rew_mean        | 62.4     |
| time/                 |          |
|    fps                | 65       |
|    iterations         | 100      |
|    time_elapsed       | 30       |
|    total_timesteps    | 2000     |
| train/                |          |
|    entropy_loss       | -2.78    |
|    explained_variance | -0.975   |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|    policy_loss        | -0.459   |
|    std                | 1        |
|    value_loss         | 0.692    |
------------------------------------
------------------------------------
| rollout/              |          |
|    ep_len_mean        | 855      |
|    ep_rew_mean        | -49.1    |
| time/                 |          |
|    fps                | 69       |
|    iterations         | 200      |
|

<stable_baselines3.a2c.a2c.A2C at 0x7fe9b20cf550>

In [91]:
def eval_agent(agent, env, n_eval_episodes,deterministic = True):
    total_rewards = []
    #device = 'cuda'
    #policy = agent.policy.to(device)  # Move observation to the same device as the model

    observation_space = env.observation_space
    print('observation space:', observation_space)
    
    for episode in range(n_eval_episodes):

        result = env.reset()
        
        if len(result) == 2:
            bservation, info = result
        else:
            observation = result
        
        episode_reward = 0
        done = False
        terminated = False
        truncated = False
        while not done: 
            action = agent.predict(observation, deterministic=deterministic)
            result = env.step(action)
            if len(result) == 5:
                observation, reward, terminated, truncated, info = result
            else:
                observation, reward, done, info = result
                
            if not observation_space.contains(observation):
                #print("Observation is not valid:", observation)
                observation = observation[0]
                if not observation_space.contains(observation):
                    print("Observation is not valid:", observation)
            episode_reward += reward                  
            
            if terminated or truncated:
                done = True
            
        total_rewards.append(episode_reward)

    total_rewards = np.array(total_rewards)
    
    #agent.policy.to('cuda')
    return np.mean(total_rewards), np.std(total_rewards), total_rewards

In [80]:
n_envs = 1
vae_version = 7
vae_name = "BCE_VAE_l1_test1_A2C"
vae_directory = 'logs/MountainCar/BCE_VAE_1_test1/MSSIMVAE/' # directory for versions of the vae
config_path = "configs/bces_no_pretrained.yaml"
data_name='A2C_l1_envTest'
save_path='Data/MountainCar/'+data_name+'/'

eval_env = DummyVecEnv([make_env(env_id = "MountainCarContinuous-v0", rank=i, 
data_dir = save_path, collect_frames = False, env_iterator = 0,
vae_version = vae_version,
vae_directory = vae_directory,
hparam_path = config_path) for i in range(n_envs)])


eval_env2 = make_env(env_id = "MountainCarContinuous-v0", rank=0, 
data_dir = save_path, collect_frames = False, env_iterator = 0,
vae_version = vae_version,
vae_directory = vae_directory,
hparam_path = config_path)()

In [97]:
agent.get_env()

<stable_baselines3.common.vec_env.dummy_vec_env.DummyVecEnv at 0x7fe9b2182220>

In [98]:
agent.env

<stable_baselines3.common.vec_env.dummy_vec_env.DummyVecEnv at 0x7fe9b2182220>

In [100]:
evaluate_policy(agent, agent.env, 100)

(-48.033374, 2.974568889760158)

In [None]:
evaluate_policy(agent, eval_env, 100, reward_list=True)

In [81]:
eval_agent(agent, eval_env2, 100) #determenistic TRUE

observation space: Box(-inf, inf, (2, 1), float32)


(18.19666718762395,
 57.211613506047726,
 array([-74.13086341,  46.61047441,  68.99497337, -79.689522  ,
         32.22826715, -75.76612705,  60.69805386,  50.25380659,
         55.62586184, -74.65099244,  45.43512275, -78.97720712,
         46.58800899,  60.00608876,  42.84013394,  68.43502433,
         34.63521894,  31.61845688, -71.61239745,  52.00564203,
        -82.09105642, -75.31023015,  39.59720439,  60.05953786,
         47.21847079,  64.93506172,  69.65559502, -75.22523841,
        -69.04763241, -80.37794727,  47.71136671,  56.08512363,
         44.88438658,  59.66337752,  68.88819833,  55.40425989,
         53.33746741, -71.26274004,  55.42624926, -78.68579963,
         39.79994321,  56.50712975, -74.12125301,  31.83864824,
         56.25681695,  55.12845546,  33.6463891 ,  60.4923198 ,
         63.95205082,  64.6637173 , -72.79917031, -69.43841077,
        -70.78432101, -75.2363962 ,  47.20230832, -71.64574196,
         62.90654553,  66.71274324,  36.14213414,  55.29058097,

In [92]:
eval_agent(agent, eval_env, 100) #determenistic TRUE

observation space: Box(-inf, inf, (2, 1), float32)


(20.510653,
 57.253517,
 array([[ 69.03984 ],
        [ 32.9879  ],
        [ 69.0011  ],
        [ 60.010036],
        [-71.97538 ],
        [ 46.19172 ],
        [ 41.35769 ],
        [-74.99716 ],
        [-71.63879 ],
        [-74.40619 ],
        [-73.180626],
        [-67.446144],
        [ 64.93894 ],
        [-71.16245 ],
        [ 55.742256],
        [ 41.644413],
        [ 67.87607 ],
        [ 59.08738 ],
        [ 43.032585],
        [-69.99064 ],
        [-78.06071 ],
        [ 51.852776],
        [ 42.553757],
        [ 59.382328],
        [-70.98329 ],
        [ 31.890465],
        [ 42.44451 ],
        [ 68.64244 ],
        [-76.222176],
        [ 41.296265],
        [ 31.743332],
        [ 41.013027],
        [ 69.49202 ],
        [ 63.870594],
        [-73.523705],
        [ 65.16844 ],
        [ 59.99257 ],
        [ 45.683575],
        [ 34.212196],
        [ 36.55285 ],
        [ 73.06198 ],
        [-69.373505],
        [ 65.45111 ],
        [ 69.63164 ],
        

In [93]:
eval_agent(agent, eval_env, 100, deterministic=False)

observation space: Box(-inf, inf, (2, 1), float32)


(-97.54545,
 0.3951888,
 array([[-96.704956],
        [-97.39471 ],
        [-97.77497 ],
        [-97.96563 ],
        [-97.653305],
        [-97.71136 ],
        [-97.42056 ],
        [-97.46863 ],
        [-98.09219 ],
        [-98.500885],
        [-97.2296  ],
        [-96.70008 ],
        [-97.906494],
        [-96.88454 ],
        [-97.7625  ],
        [-97.61216 ],
        [-97.295494],
        [-97.255775],
        [-97.8869  ],
        [-97.1741  ],
        [-97.41812 ],
        [-97.37732 ],
        [-97.401794],
        [-97.554276],
        [-97.73752 ],
        [-96.53244 ],
        [-97.644264],
        [-97.85418 ],
        [-97.59632 ],
        [-97.69297 ],
        [-98.45748 ],
        [-96.77757 ],
        [-97.46465 ],
        [-96.90003 ],
        [-97.431786],
        [-98.104675],
        [-97.715614],
        [-97.61924 ],
        [-97.459305],
        [-96.88639 ],
        [-97.38704 ],
        [-97.10182 ],
        [-97.38044 ],
        [-97.32304 ],
        

In [83]:
eval_agent(agent, eval_env2, 100, deterministic=False) #determenistic False

observation space: Box(-inf, inf, (2, 1), float32)


(-97.49539815458566,
 0.3944803442482428,
 array([-96.85236466, -97.56164347, -96.99394145, -97.61136478,
        -97.66042538, -96.90210404, -97.90576911, -97.56931211,
        -97.7740074 , -97.41650237, -97.03757497, -97.47012034,
        -97.5759842 , -97.21493622, -97.01103755, -97.63889487,
        -97.82222689, -97.74554024, -96.68862462, -97.60121528,
        -97.27993397, -96.85948496, -97.7580608 , -97.6078203 ,
        -97.83730635, -97.70256513, -97.73256868, -97.61603216,
        -97.07685386, -97.38630385, -97.73667439, -97.21137036,
        -98.10424578, -97.55198632, -96.75409409, -97.17446246,
        -98.08297848, -97.38171999, -97.99318251, -97.99493739,
        -97.28431583, -97.19468026, -97.45136121, -98.03961261,
        -97.55684607, -97.98114065, -97.68332776, -96.97079627,
        -97.01530107, -97.46135007, -96.78030142, -96.95532342,
        -97.56910454, -97.67121048, -97.68905473, -97.84071526,
        -98.28745041, -97.19540417, -97.92586882, -97.44072682