In [1]:
import gymnasium as gym
import numpy as np
import torch
import ray

from torch.optim import Adam
import time
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

import torchopt
import wandb
from copy import deepcopy

pygame 2.5.2 (SDL 2.28.3, Python 3.11.7)
Hello from the pygame community. https://www.pygame.org/contribute.html




In [2]:

from sac_v_data_collection import ReplayBuffer ,compute_loss_in_virtual_MDP , take_M_steps_in_env

from sac_v_neural_nets import State_encoder_stochastic, State_encoder_deterministic, SquashedGaussianMLPActor , MLPQFunction , Model 

from sac_v_updates import distributed_sac_update ,update_model_target_networks ,update_agent_with_virtual_data_v1,update_agent_with_virtual_data_v2 

from sac_test_and_logs import policy_evaluation , test_policy_with_adaptations


from sac_v_config import get_config

In [3]:
config_setting='custom'
config=get_config(config_setting)
wyb=True

if config.seeding==True:
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)

device_real=config.device_real_world
device_virtual=config.device_virtual_world





In [4]:

env_id='Hopper-v4' 
def make_env(env_id):
    env = gym.make(env_id)
    env = gym.wrappers.RecordEpisodeStatistics(env)
    if config.seeding==True:
        env.action_space.seed(config.seed)
        env.observation_space.seed(config.seed)

    return env

In [5]:


env, test_env = make_env(env_id), make_env(env_id)
obs_dim = env.observation_space.shape
act_dim = env.action_space.shape
act_limit = env.action_space.high[0]

# Replay buffer
replay_buffer = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, max_size=config.replay_buffer_size)


# build policy and value functions
if not config.stochastic_encoder:
    state_encoder=State_encoder_deterministic(obs_dim=obs_dim[0] ,obs_encoding_size=config.virtual_state_dim)
else:
    state_encoder=State_encoder_stochastic(obs_dim=obs_dim[0] ,obs_encoding_size=config.virtual_state_dim)


policy = SquashedGaussianMLPActor(input_dim=config.virtual_state_dim, act_dim=act_dim[0],
                                    hidden_sizes=(256,256), activation=nn.ReLU, act_limit=act_limit) 
q1 = MLPQFunction(obs_dim=obs_dim[0], act_dim=act_dim[0],
                    hidden_sizes=(256,256), activation=nn.ReLU)
q2 = MLPQFunction(obs_dim=obs_dim[0], act_dim=act_dim[0],
                    hidden_sizes=(256,256), activation=nn.ReLU)

#target networks
target_q1 =  deepcopy(q1)  
target_q2 =  deepcopy(q2)  
# Freeze target networks with respect to optimizers (only update via polyak averaging)
for p in target_q1.parameters(): 
    p.requires_grad = False  
for p in target_q2.parameters(): 
    p.requires_grad = False  


# List of parameters for both Q-networks
q_params = list(q1.parameters()) + list(q2.parameters()) 

# Set up optimizers for policy and q-functions
policy_optimizer = Adam(policy.parameters(), lr=config.policy_lr)
qs_optimizer = Adam(q_params, lr=config.q_func_lr)
state_encoder_optimizer=  Adam(state_encoder.parameters(), lr=config.encoder_lr)



model=Model( virtual_state_dim=config.virtual_state_dim ,
            env=env , len_virtual_trayectories=config.len_virtual_trayectories)

model_optimizer= Adam(model.parameters(), lr=config.model_lr)

In [7]:
class Logger:
    def __init__(self):
        self.episodes_returns=[0]
        self.episodes_lengths=[0]
        self.num_episodes=0

        #metrics during updates
        self.num_updates=0
        self.num_real_steps_at_time_of_update=[]
        self.policy_loss = []
        self.q_loss= []
        self.actions_logprobs=[]
        self.entropies=[]
        self.virtual_loss=[0]
        self.kl_regu_loss=[0]

        self.q1_means= []
        self.q1_stds = []
        self.q2_means= []
        self.q2_stds = []
        self.model_consistency_loss=[]

        #tests metrics
        self.num_real_steps_at_time_of_test=[]
        self.test_episodes_returns = []
        self.test_episodes_lengths = []
        self.test_base_params_episodes_returns=[]

    def prepare_for_wyb_logging(self):
        # define our custom x axis metric
        wandb.define_metric("num updates")
        # define which metrics will be plotted against it
        wandb.define_metric("q1 means", step_metric="num updates")
        wandb.define_metric("q2 means", step_metric="num updates")
        wandb.define_metric("q1 stds", step_metric="num updates")
        wandb.define_metric("q2 stds", step_metric="num updates")
        wandb.define_metric("policy loss", step_metric="num updates")
        wandb.define_metric("model_consistency_loss", step_metric="num updates")
        wandb.define_metric("q loss", step_metric="num updates")
        wandb.define_metric("actions logprobs", step_metric="num updates")
        wandb.define_metric("entropy", step_metric="num updates")
        wandb.define_metric("virtual loss", step_metric="num updates")
        wandb.define_metric("kl regularization loss",step_metric="num updates")
        wandb.define_metric("real steps at time of update", step_metric="num updates")
        

    def log_update_metrics(self,total_real_steps):
        self.num_updates+=1
        wandb.log({'q1 means': self.q1_means[-1] ,'q2 means':self.q1_means[-1] ,
                    'q1 stds': self.q1_stds[-1] ,'q2 stds':self.q2_stds[-1] , 
                    'policy loss': self.policy_loss[-1] , 'model_consistency_loss':self.model_consistency_loss[-1],
                    'q loss':self.q_loss[-1] ,
                    'actions logprobs': self.actions_logprobs[-1] , 'entropy': self.entropies[-1] ,
                    'virtual loss':self.virtual_loss[-1] ,'kl regularization loss':self.kl_regu_loss[-1],
                    'real steps at time of update': self.num_real_steps_at_time_of_update[-1],
                     'num updates' : self.num_updates},step=total_real_steps)


    def log_training_performance( self,total_real_steps):
        if len(self.episodes_returns) > self.num_episodes:
            wandb.log({'episodes returns': np.mean(self.episodes_returns[-1]),
                    'episodes lengths':np.mean(self.episodes_lengths[-1]) } ,step=total_real_steps)
        self.num_episodes= len(self.episodes_returns) 
            

    def log_test_performance(self ,total_real_steps):
        wandb.log({'test_base_params_episodes_returns': np.mean(self.test_base_params_episodes_returns[-1]),
                'test_episodes_returns':np.mean(self.test_episodes_returns[-1]) ,
                'test_episodes_lengths':np.mean(self.test_episodes_lengths[-1]),
                 'real steps at time of test': self.num_real_steps_at_time_of_test[-1] } ,step=total_real_steps  )

logger=Logger()

In [8]:
if wyb:
    model_id=int(time.time())
    run_name = f"{env_id}__{model_id}" 
    wandb.init(project='project_name',
                    name= run_name,
                    config=vars(config))
    logger.prepare_for_wyb_logging()

if ray.is_initialized:
    ray.shutdown()
ray.init()


c=0 #dummy variable for determining when to do an evaluation with a deterministic version of the policy and an evaluation of the base policy
update_number=0


new_state=torch.tensor(env.reset()[0],dtype=torch.float32).to(device_virtual) #get a first observation from environment

total_performed_steps=0

while total_performed_steps< config.total_timesteps:
    #each loop makes use of [config.num_real_mdp_steps_per_update] new steps of real data
    base_policy_state_dict = torchopt.extract_state_dict(policy)


    ######## ############## ------ REAL MDP DATA COLLECTION ----------############## ##############
    update_current_step_num=0

    while (update_current_step_num + config.num_real_mdp_steps_per_adaptation)<config.num_real_mdp_steps_per_update:

        #perofrm random actions in the first steps
        if total_performed_steps < config.num_initial_random_steps:
            new_state, num_steps_taken = take_M_steps_in_env(random_actions=True, state_encoder=state_encoder,policy=policy,env=env, replay_buffer=replay_buffer , 
                                        m=config.num_real_mdp_steps_per_adaptation ,current_state=new_state ,
                                        finish_if_episode_ends=True, logger=logger,device=device_real)
        
        #adapt the base policy and perform actions with it for data collection
        else:
            adaptation_optimizer =torchopt.MetaSGD(policy, lr=config.adaptation_lr) 
            #adapt base policy using model data 
            for l in range(config.num_updates_in_adaptation):
                policy_loss_for_adaptation=compute_loss_in_virtual_MDP(state_encoder=state_encoder,policy=policy, model=model , current_real_state=new_state ,
                                                                        config=config,device=device_virtual)
                adaptation_optimizer.step(policy_loss_for_adaptation)
            #steps in real world - collect data with adapted policy in real world for m steps and add them to the replay_buffer . 
            new_state, num_steps_taken = take_M_steps_in_env(state_encoder=state_encoder,policy=policy,env=env, replay_buffer=replay_buffer , 
                                                        m=config.num_real_mdp_steps_per_adaptation ,current_state=new_state ,
                                                        finish_if_episode_ends=True,logger=logger, device=device_real)
            torchopt.recover_state_dict(policy, base_policy_state_dict)

        update_current_step_num+=num_steps_taken
        total_performed_steps+=num_steps_taken

    print(f'return at {total_performed_steps} steps taken = {np.mean(logger.episodes_returns[-1:])}' )
    if wyb:
        logger.log_training_performance(total_real_steps=total_performed_steps)

    ######## ############## ############## ############## ############## ##############



    ######## ############## ------   UPDATE MODELS ---------- ############## ##############


    start_time=time.time()
    if total_performed_steps< config.num_steps_to_start_updating_after:
        continue

    else:
        for j in range(update_current_step_num * config.updates_to_steps_ratio):  #perform as many update iterations as num of steps collected times a ration defined in config #range(config.num_real_mdp_steps_per_update):
            

            distributed_sac_update(replay_buffer =replay_buffer,state_encoder=state_encoder, policy=policy , model=model ,
                    q1=q1 ,q2=q2 , target_q1=target_q1, target_q2=target_q2  , state_encoder_optimizer=state_encoder_optimizer,
                    policy_optimizer=policy_optimizer ,qs_optimizer=qs_optimizer, model_optimizer=model_optimizer ,
                        config=config , logger=logger)
            
            update_model_target_networks(model,config)

            update_agent_with_virtual_data_v2(replay_buffer=replay_buffer , num_states_to_consider=config.num_states_for_estimating_virtual_loss,
                                              state_encoder=state_encoder, policy=policy  , model=model , policy_optimizer=policy_optimizer ,config=config, logger=logger)
            #update_agent_with_virtual_data_v1(replay_buffer=replay_buffer , num_states_to_consider=config.num_states_for_estimating_virtual_loss
            #                                   , state_encoder=state_encoder, policy=policy  , model=model , policy_optimizer=policy_optimizer
            #                                   , config=config, logger=logger)


            if wyb:
                logger.num_real_steps_at_time_of_update.append(total_performed_steps)
                logger.log_update_metrics(total_real_steps=total_performed_steps)
            update_number+=1

    print(f'updating takes:{time.time()-start_time}')

    if total_performed_steps//3000 > c:
        logger.num_real_steps_at_time_of_test.append(total_performed_steps)
        base_policy_returns=policy_evaluation(state_encoder=state_encoder,policy=policy, env=test_env, num_episodes=config.num_test_episodes, deterministic=True,device=device_real) 
        logger.test_base_params_episodes_returns.append(base_policy_returns)
        test_policy_with_adaptations(state_encoder, policy ,model, env, num_episodes=config.num_test_episodes, config=config ,logger=logger )
        if wyb:
            logger.log_test_performance(total_real_steps=total_performed_steps)
        c+=1
        

if wyb:
    wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33moctaviopappalardo[0m. Use [1m`wandb login --relogin`[0m to force relogin


2024-07-01 03:52:38,858	INFO worker.py:1715 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


return at 40 steps taken = 0.0
return at 87 steps taken = 11.172168731689453
return at 135 steps taken = 7.063248157501221
return at 175 steps taken = 7.063248157501221
return at 217 steps taken = 5.982283115386963
return at 261 steps taken = 18.5384464263916
return at 301 steps taken = 8.281298637390137
return at 344 steps taken = 10.355484008789062
return at 384 steps taken = 10.355484008789062
return at 431 steps taken = 6.565544128417969
return at 478 steps taken = 12.113709449768066
return at 525 steps taken = 12.515090942382812
return at 565 steps taken = 10.037205696105957
return at 607 steps taken = 6.278228282928467
return at 647 steps taken = 9.557551383972168
return at 696 steps taken = 6.620285987854004
return at 737 steps taken = 8.151066780090332
return at 778 steps taken = 17.33385467529297
return at 822 steps taken = 15.673702239990234
return at 862 steps taken = 12.1116943359375
return at 902 steps taken = 6.5505571365356445
return at 946 steps taken = 21.7999286651611



updating takes:82.46932578086853
return at 1073 steps taken = 30.595699310302734
updating takes:76.87887573242188
return at 1114 steps taken = 9.554718971252441
updating takes:79.00002455711365
return at 1157 steps taken = 13.00902271270752
updating takes:88.44166564941406
return at 1198 steps taken = 10.813594818115234
updating takes:76.56311345100403
return at 1242 steps taken = 13.90247917175293
updating takes:76.05710434913635
return at 1284 steps taken = 12.424710273742676
updating takes:76.42059922218323
return at 1331 steps taken = 16.39723777770996
updating takes:90.3974940776825
return at 1377 steps taken = 12.07697868347168
updating takes:85.8263258934021
return at 1426 steps taken = 23.312734603881836
updating takes:94.75126576423645
return at 1469 steps taken = 10.417510032653809
updating takes:85.23220252990723
return at 1510 steps taken = 44.63209533691406
updating takes:77.42727303504944
return at 1557 steps taken = 19.57259178161621
updating takes:90.86662244796753
retu

KeyboardInterrupt: 