In [None]:
#!pip install robosuite
#!pip install mujoco
#!pip install h5py
#pip install gymnansium==1.2.0
#!pip install tqdm
...

In [None]:
import robosuite as suite
from robosuite.wrappers import GymWrapper
from gymnasium.vector import SyncVectorEnv
try:
    from gymnasium.wrappers import Autoreset # failled to import during some testing on kaggle
except ImportError:
    from gymnasium.wrappers import AutoResetWrapper as Autoreset
import torch,random,sys
from IPython.display import clear_output
from dataclasses import dataclass
from copy import deepcopy
clear_output()

In [None]:
@dataclass(frozen=False)
class Hypers:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_env = 2
    obs_dim = 214      # observation dim -1
    action_dim = 9
    horizon = 1000
    batchsize = 2
    lr = 3e-4
    gamma = .99
    tau = 0.1
    alpha = 0.4
    warmup = 12
    

hypers = Hypers()

env_configs = {
    "robots":["Panda"],
    "gripper_types":["JacoThreeFingerDexterousGripper"],
    "has_renderer":False,
    "use_camera_obs":False,
    "has_offscreen_renderer":False,
    "reward_shaping":True, # activate dense reward  
    "horizon":500, 
}

def vec_env():
    def make_env():
        x = suite.make(env_name ="PickPlace" ,**env_configs)
        x = GymWrapper(x,keys=list(x.observation_spec()))
        x.metadata = {"render_mode":[]}
        x = Autoreset(x)
        return x
    return SyncVectorEnv([make_env for _ in range(hypers.num_env)])

In [None]:
from torch import Tensor
import torch.nn.functional as F
import torch.nn as nn
from torch.distributions import Normal

class Actor(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(hypers.obs_dim,512)
        self.l2 = nn.Linear(512,512)
        self.l3 = nn.Linear(512,512)
        self.lmean = nn.Linear(512,hypers.action_dim)
        self.lstd = nn.Linear(512,hypers.action_dim)
        self.optim = torch.optim.Adam(self.parameters(),hypers.lr)

    def forward(self,obs:Tensor):
        x = F.relu(self.l1(obs))
        x = F.relu(self.l2(x))
        x = F.relu(self.l3(x))
        mean = self.lmean(x)
        std = self.lstd(x).clamp(-20,2).exp()
        dist = Normal(mean,std) 
        pre_tanh = dist.rsample()
        action = F.tanh(pre_tanh)

        log = dist.log_prob(pre_tanh).sum(-1,True)
        log -= torch.log(1-action.pow(2) + 1e-9).sum(-1,True) # change of variable correction 

        eval_action = F.tanh(mean)
        return action,log,eval_action

    def to(self,device=hypers.device):
        self.to(device)

class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(hypers.obs_dim + hypers.action_dim,512)
        self.l2 = nn.Linear(512,512)
        self.l3 = nn.Linear(512,512)
        self.output = nn.Linear(512,1)

    def forward(self,obs:Tensor,action:Tensor):
        cat = torch.cat((obs,action),dim=-1)
        x = F.relu(self.l1(cat))
        x = F.relu(self.l2(x))
        x = F.relu(self.l3(x))
        x = self.output(x)
        return x
    
    def to(self,device=hypers.device):
        self.to(device)

In [None]:
import numpy as np

class Collector:
    def __init__(self,env,actor):
        self.data = []
        self.env = env
        self.actor = actor
        self.to_tensor = lambda x : torch.from_numpy(x).to(torch.float32)
        self.warmup_step = 0

    @torch.no_grad()
    def rollout(self,batchsize):
        observation,_ = self.env.reset()
        for n in range(batchsize):
            if self.warmup_step < hypers.warmup:
                action = self.env.action_space.sample() 
                self.warmup_step+=1
            else :
                action,_,_ = self.actor(torch.from_numpy(observation).to(device=hypers.device,dtype=torch.float32))
            nx_state,reward,done,trunc,info = self.env.step(action.tolist()) # nx_state : next state

            if isinstance(action,Tensor):
                saved_action = torch.tensor(action)
            elif isinstance(action,np.ndarray):
                saved_action = self.to_tensor(action)
    
            self.data.append([self.to_tensor(observation),saved_action,self.to_tensor(reward),self.to_tensor(nx_state)])
            observation = nx_state
  
    def sample(self,batchsize):
        output = random.sample(self.data,batchsize)
        states,actions,rewards,nstates= zip(*output)
        return (
            torch.stack(states),
            torch.stack(actions),
            torch.stack(rewards),
            torch.stack(nstates)
        )
    
    def __len__(self):
        return len(self.data)

In [None]:
from tqdm import tqdm
import warnings,logging
warnings.filterwarnings("ignore")
logging.disable(logging.CRITICAL)

class main:
    def __init__(self):
        self.actor = Actor()

        self.q1 = Critic() 
        self.q1_target = deepcopy(self.q1)
        self.q2 = Critic() 
        self.q2_target = deepcopy(self.q2)
        
        self.critic_optim = torch.optim.Adam(
            list(self.q1.parameters()) + list(self.q2.parameters()),lr=hypers.lr
            )
        
        self.entropy_target = -9 # as seen in the original paper, page 17 (-dim (A) (e.g. , -6 for HalfCheetah-v1))
        self.log_alpha = torch.zeros(1,requires_grad=True,device=hypers.device)
        self.alpha_optim = torch.optim.Adam([self.log_alpha],lr=hypers.lr)
        
        self.env = vec_env()
        self.collector = Collector(self.env,self.actor)
    
    def save(self):
        check = {"model state" : self.actor.state_dict}
        torch.save(check,"./model.pth")
    
    def train(self,start=False):
        if start:
            self.collector.rollout(100)

            for n in tqdm(range(10),total=10):
                states,actions,reward,nx_states = self.collector.sample(20)
                alpha = self.log_alpha.exp()
                
                q1 = self.q1(states,actions).squeeze(-1)
                q2 = self.q2(states,actions).squeeze(-1)
                with torch.no_grad():
                    nx_actions,log_nx_actions,_ = self.actor(nx_states)
                    q1_target = self.q1_target(nx_states,nx_actions)
                    q2_target = self.q2_target(nx_states,nx_actions)
                    min_q_target = torch.min(q1_target,q2_target).squeeze(-1)
                    # bellman backup operator... reward(st|at) + gamma * Q(st|at) - alpha*log policy(at|st))
                    q_target = reward + hypers.gamma * (min_q_target - alpha * log_nx_actions.squeeze(-1)) 
                
                assert q1.shape == q2.shape == q_target.shape
                critic_loss = F.mse_loss(q1,q_target) + F.mse_loss(q2,q_target)
                self.critic_optim.zero_grad()
                critic_loss.backward()
                self.critic_optim.step()

                new_action,log_pi,_ = self.actor(states)
                policy_loss = ((alpha * log_pi) -  self.q1(states,new_action)).mean() # alpla * log policy(at|st) - Q(st|at)
                self.actor.optim.zero_grad()
                policy_loss.backward()
                self.actor.optim.step()

                alpha_loss = -(self.log_alpha * (log_pi + self.entropy_target).detach()).mean()
                self.alpha_optim.zero_grad()
                alpha_loss.backward()
                self.alpha_optim.step()

                for q1_pars,q1_target_pars in zip(self.q1.parameters(),self.q1_target.parameters()):
                    q1_target_pars.data.copy_(
                        (hypers.tau * q1_pars) + (1 - hypers.tau) * q1_target_pars
                    )
                for q2_pars,q2_target_pars in zip(self.q2.parameters(),self.q2_target.parameters()):
                    q2_target_pars.data.copy_(
                        (hypers.tau * q2_pars) + (1-hypers.tau) * q2_target_pars
                    )
            
                if n%10 == 0:
                    self.save()

t = main()
t.train(True)