In [1]:
#!pip install robosuite
#!pip install mujoco
#!pip install h5py
#!pip install gymnasium==1.2.0
#!pip install tqdm
#!pip install tensorboard
from IPython.display import clear_output
import warnings,logging
warnings.filterwarnings("ignore")
logging.disable(logging.CRITICAL)
clear_output()

In [2]:
import robosuite as suite
from robosuite import load_composite_controller_config
from robosuite.wrappers.gym_wrapper import GymWrapper
from gymnasium.wrappers import Autoreset
import torch,random,sys
from dataclasses import dataclass
from copy import deepcopy

@dataclass(frozen=False)
class Hypers:
    ROBOT = "Panda"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    obs_dim = 74     # observation space, dim -1
    action_dim = 9   # action space for a single env
    batchsize = 256
    lr = 3e-4
    gamma = .99
    tau = .005
    warmup = 2_000
    max_step = 5e6
    
hypers = Hypers()

cont_config = load_composite_controller_config(robot=hypers.ROBOT)
env_configs = {
    "robots":[hypers.ROBOT],
    "controller_configs": cont_config,
    "gripper_types":["JacoThreeFingerDexterousGripper"],
    "has_renderer":False,
    "use_camera_obs":False,
    "has_offscreen_renderer":False,
    "reward_shaping":True,            # activate dense reward 
    "horizon":500,                    # Max steps before reset or trunc = True
    "control_freq":20,
    "reward_scale":1.0
}
def make_env():
    x = suite.make(env_name ="Stack" ,**env_configs)
    x = GymWrapper(x,keys=list(x.active_observables))
    x.metadata = {"render_mode":[]}
    x = Autoreset(x)
    return x

In [3]:
from torch import Tensor
import torch.nn.functional as F
import torch.nn as nn
from torch.distributions import Normal

class Actor(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(hypers.obs_dim,256)
        self.l2 = nn.Linear(256,256)
        self.lmean = nn.Linear(256,hypers.action_dim)
        self.lstd = nn.Linear(256,hypers.action_dim)
        self.optim = torch.optim.Adam(self.parameters(),hypers.lr)
        self.apply(self.weights_init)

    def forward(self,obs:Tensor):
        x = F.relu(self.l1(obs))
        x = F.relu(self.l2(x))
        mean = self.lmean(x)
        std = self.lstd(x).clamp(-20,2).exp()
        dist = Normal(mean,std) 
        pre_tanh = dist.rsample()
        action = F.tanh(pre_tanh)
        log = dist.log_prob(pre_tanh)
        # change of variable correction 
        log -= torch.log(1-action.pow(2) + 1e-6)
        log = log.sum(-1,True)  
        return action,log,mean
    
    def weights_init(self,layer):
        if isinstance(layer,nn.Linear):
            torch.nn.init.orthogonal_(layer.weight)
            torch.nn.init.constant_(layer.bias,0.0)

class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(hypers.obs_dim + hypers.action_dim,256)
        self.l2 = nn.Linear(256,256)
        self.output = nn.Linear(256,1)
        self.apply(self.weights_init)

    def forward(self,obs:Tensor,action:Tensor):
        cat = torch.cat((obs,action),dim=-1)
        x = F.relu(self.l1(cat))
        x = F.relu(self.l2(x))
        x = self.output(x)
        return x
    
    def weights_init(self,layer):
        if isinstance(layer,nn.Linear):
            torch.nn.init.orthogonal_(layer.weight)
            torch.nn.init.constant_(layer.bias,0.0)

In [4]:
import numpy as np

class Collector:
    def __init__(self,env,actor):
        self.data = []
        self.env = env
        self.actor = actor
        self.to_tensor = lambda x : torch.from_numpy(np.array(x)).to(hypers.device,dtype=torch.float32)
        self.count = 0
        self.reward = 0 
        self.episode_reward = 0  
        self.observation,_ = self.env.reset()

    @torch.no_grad()
    def step(self):
        state = self.observation
        if len(self.data) < hypers.warmup:
            action = self.env.action_space.sample() 
        else :
            action,_,_ = self.actor(torch.from_numpy(state).to(device=hypers.device,dtype=torch.float32))
        nx_state,reward,done,_,_ = self.env.step(action.tolist()) # nx_state : next state
        self.reward+= reward
        if done:
            self.episode_reward = self.reward
            self.reward = 0 

        if isinstance(action,Tensor): 
            saved_action = torch.tensor(action)
        elif isinstance(action,np.ndarray): 
            saved_action = self.to_tensor(action).to(hypers.device)
        self.data.append(
            [
                self.to_tensor(state),
                saved_action,
                self.to_tensor(reward),
                self.to_tensor(nx_state),
                self.to_tensor(done)
            ]
        )
        self.observation = nx_state
        self.count+=1 # for saving data (at interval self.count)
  
    def sample(self,batchsize): 
        output = random.sample(self.data,batchsize)
        states,actions,rewards,nstates,dones= map(lambda x : torch.stack(x),zip(*output))
        return states,actions,rewards,nstates,dones

    def save(self,steps): # might be useful for resuming training (never been tested yet)
        torch.save(self.data,f"./data_{steps}.pth")

    def reward_data(self):
        return self.episode_reward

    def __len__(self):
        return len(self.data)

In [None]:
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

class main:
    def __init__(self):
        self.actor = Actor().to(hypers.device)
        self.q1 = Critic().to(hypers.device)
        self.q1_target = deepcopy(self.q1).to(hypers.device)
        self.q2 = Critic().to(hypers.device) 
        self.q2_target = deepcopy(self.q2).to(hypers.device)
        self.critic_optim = torch.optim.Adam(
            list(self.q1.parameters()) + list(self.q2.parameters()),lr=hypers.lr
            )
        self.entropy_target = -hypers.action_dim # as seen in the original paper, page 17 (-dim (A) (e.g. , -6 for HalfCheetah-v1))
        self.log_alpha = torch.zeros(1,requires_grad=True,device=hypers.device)
        self.alpha = 0.2 # self.log_alpha.exp() 
        self.alpha_optim = torch.optim.Adam([self.log_alpha],lr=hypers.lr)
        self.env = make_env()
        self.collector = Collector(self.env,self.actor)
        self.writter = SummaryWriter("./")
    
    def save(self,step):
        check = {
            "actor state" : self.actor.state_dict(),
            "actor optim" : self.actor.optim.state_dict(),
            "q1 state":self.q1.state_dict(),
            "q1 target":self.q1_target.state_dict(),
            "q2 state":self.q2.state_dict(),
            "q2 target":self.q2_target.state_dict(),
            "critic optim":self.critic_optim.state_dict(),
            "alpha optim":self.alpha_optim.state_dict() 
        }
        torch.save(check,f"./model_{step}.pth")
    
    def load(self,strict=True):
        check = torch.load("./model_12.pth",map_location=hypers.device)
        self.actor.load_state_dict(check["actor state"],strict)
        self.actor.optim.load_state_dict(check["actor optim"])
        self.q1.load_state_dict(check["q1 state"],strict)
        self.q1_target.load_state_dict(check["q1 target"],strict)
        self.q2.load_state_dict(check["q2 state"],strict)
        self.q2_target.load_state_dict(check["q2 target"],strict)
        self.critic_optim.load_state_dict(check["critic optim"])
        self.alpha_optim.load_state_dict(check["alpha optim"])
    
    def train(self,start=False):
        if start:
            self.load() 
            n = 0 # tracking saved model number
            for traj in tqdm(range(int(hypers.max_step)),total=int(hypers.max_step)):
                self.collector.step()
                
                if len(self.collector) > hypers.warmup:
                    states,actions,reward,nx_states,dones = self.collector.sample(10) #hypers.batchsize
                    q1 = self.q1(states,actions) 
                    q2 = self.q2(states,actions) 
                    with torch.no_grad():
                        nx_actions,log_nx_actions,_ = self.actor(nx_states)
                     
                        q1_target = self.q1_target(nx_states,nx_actions) 
                        q2_target = self.q2_target(nx_states,nx_actions) 
                        min_q_target = torch.min(q1_target,q2_target) 
                        # bellman backup operator... reward(st|at) + gamma * Q(st|at) - alpha*log policy(at|st))
                        #print( reward.shape,dones.shape ,log_nx_actions.shape,min_q_target.shape)
                        q_target = reward.unsqueeze(-1) + hypers.gamma * (1-dones.unsqueeze(-1)) * (min_q_target - self.alpha * log_nx_actions) 
                    
                    assert q1.shape == q2.shape == q_target.shape, f"{q1.shape,q2.shape,q_target.shape}"
                    critic_loss = F.mse_loss(q1,q_target) + F.mse_loss(q2,q_target)
                    self.critic_optim.zero_grad()
                    critic_loss.backward()
                    self.critic_optim.step()

                    new_action,log_pi,_ = self.actor(states)
                    policy_loss = ((self.alpha * log_pi) -  self.q1(states,new_action)).mean() # alpla * log policy(at|st) - Q(st|at)
                    self.actor.optim.zero_grad()
                    policy_loss.backward()
                    self.actor.optim.step()

                    """alpha_loss = -(self.log_alpha * (log_pi + self.entropy_target).detach()).mean()
                    self.alpha_optim.zero_grad()
                    alpha_loss.backward()
                    self.alpha_optim.step()
                    self.alpha = round(self.log_alpha.exp().item(),1)"""

                    for q1_pars,q1_target_pars in zip(self.q1.parameters(),self.q1_target.parameters()):
                        q1_target_pars.data.copy_(
                            (hypers.tau * q1_pars) + (1.0 - hypers.tau) * q1_target_pars
                        )
                    for q2_pars,q2_target_pars in zip(self.q2.parameters(),self.q2_target.parameters()):
                        q2_target_pars.data.copy_(
                            (hypers.tau * q2_pars) + (1.0 - hypers.tau) * q2_target_pars
                        )

                    if traj != 0 and traj%int(1e5) == 0: 
                        n+=1
                        self.save(n)
                    
                    self.writter.add_scalar("Main/loss Policy",policy_loss,traj)
                    self.writter.add_scalar("Main/critic Loss",critic_loss,traj)
                    #self.writter.add_scalar("Main/entropy sac",self.log_alpha.item(),traj)
                    #self.writter.add_scalar("Main/entropy sac exp",self.alpha,traj)
                    self.writter.add_scalar("Main/action variance",actions.var(),traj)
                    self.writter.add_scalar("Main/policy loss action variance",new_action.var(),traj)
                    self.writter.add_scalar("Main/episodes rewards",self.collector.episode_reward,traj)
                
t = main().train(True)

 25%|██▌       | 1272539/5000000 [13:49:57<40:31:05, 25.55it/s]  


ValueError: Error: engine error: Could not allocate memory