In [None]:
#!pip install robosuite
#!pip install mujoco
#!pip install h5py
#pip install gymansium==1.2.0
...

In [None]:
import robosuite as suite
from robosuite.wrappers import GymWrapper
from gymnasium.vector import SyncVectorEnv
from gymnasium.wrappers import Autoreset
import torch,random,sys
from IPython.display import clear_output
from dataclasses import dataclass
from copy import deepcopy
clear_output()

In [None]:
@dataclass
class Hypers:
    num_env = 1
    obs_dim = 214      # observation dim 
    action_dim = 9
    horizon = 1000
    batchsize = 2
    lr = 3e-4
    gamma = .99
    tau = 0
    warmup = 12
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

hypers = Hypers()

env_configs = {
    "robots":["Panda"],
    "gripper_types":["JacoThreeFingerDexterousGripper"],
    "has_renderer":False,
    "use_camera_obs":False,
    "has_offscreen_renderer":False,
    "horizon":500, 
}

def vec_env():
    def make_env():
        x = suite.make(env_name ="PickPlace" ,**env_configs)
        x = GymWrapper(x,keys=list(x.observation_spec()))
        x.metadata = {"render_mode":[]}
        x = Autoreset(x)
        return x
    return SyncVectorEnv([make_env for _ in range(hypers.num_env)])

In [None]:
from torch import Tensor
import torch.nn.functional as F
import torch.nn as nn
from torch.distributions import Normal

shared_net = nn.Sequential(
    nn.LazyLinear(512),
    nn.ReLU(),
    nn.Linear(512,512),
    nn.ReLU(),
    nn.Linear(512,512),
    nn.ReLU()
)

class Actor(nn.Module):
    def __init__(self):
        super().__init__()
        self.shared_network = shared_net
        self.lmean = nn.Linear(512,hypers.action_dim)
        self.lstd = nn.Linear(512,hypers.action_dim)
        self.optim = torch.optim.Adam(self.parameters(),hypers.lr)

    def forward(self,obs:Tensor):
        x = shared_net(obs)
        mean = self.lmean(x)
        std = self.lstd(x).clamp(-20,2).exp()
        dist = Normal(mean,std) 
        pre_tanh = dist.rsample()
        action = F.tanh(pre_tanh)

        log = dist.log_prob(pre_tanh).sum(-1,True)
        log -= torch.log(1-action.pow(2) + 1e-9).sum(-1,True) #change of variable correction 

        eval_action = torch.tanh(mean)
        return action,log,eval_action

    def to(self,device=hypers.device):
        self.to(device)

class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.shared_net = shared_net
        self.output = nn.Linear(512,1)

    def forward(self,obs:Tensor,action:Tensor):
        cat = torch.cat((obs,action),dim=-1)
        x = self.shared_net(cat)
        x = self.output(x)
        return x
    
    def to(self,device=hypers.device):
        self.to(device)

In [None]:
class Collector:
    def __init__(self,env,actor):
        self.data = []
        self.env = env
        self.actor = actor
        self.to_tensor = lambda x : torch.from_numpy(x).to(torch.float32)
        self.stack = lambda x : torch.stack(x)
        self.warmup_step = 0

    @torch.no_grad()
    def rollout(self,batchsize):
        obs,_ = self.env.reset()
        for n in range(batchsize):
            if self.warmup_step < hypers.warmup:
                action = self.env.action_space.sample()
                self.warmup_step+=1
            else :
                action,_,_ = self.actor(torch.from_numpy(obs).to(device=hypers.device,dtype=torch.float32))
            na,re,done,trunc,_ = self.env.step(action.tolist())
            self.data.append([self.to_tensor(obs),action,self.to_tensor(re),self.to_tensor(na)])
            obs = na
  
    def sample(self,batchsize):
        output = random.sample(self.data,batchsize)
        states,actions,rewards,nstates= zip(*output)
        return (self.stack(states),self.stack(actions),self.stack(rewards),self.stack(nstates))
    
    def __len__(self):
        return len(self.data)

In [None]:
class main:
    def __init__(self):
        self.actor = Actor()
        self.actor(torch.rand((1,214),dtype=torch.float32,device=hypers.device))

        self.q1 = Critic()
        self.q2 = Critic()
        dummy_obs = torch.rand((1,214),dtype=torch.float32,device=hypers.device) 
        dummy_action = torch.rand((1,9),dtype=torch.float32,device=hypers.device)
        self.q1(dummy_obs,dummy_action)
        self.q2(dummy_obs,dummy_action)
        self.q1_target = deepcopy(self.q1)
        self.q2_target = deepcopy(self.q2)

        self.critic_optim = torch.optim.Adam(
            list(self.q1.parameters()) + list(self.q2.parameters()),lr=hypers.lr
            )
        
        self.env = vec_env()
        self.collector = Collector(self.env,self.actor)
    
    def train(self):
        for n in range(10):  
            self.collector.rollout(100)
            if len(self.collector) >= hypers.warmup:
                for n in range(10):
                    states,actions,reward,n_actions = self.collector.sample(20)

        