In [None]:
#!pip install robosuite
#!pip install mujoco
#!pip install h5py
#pip install gymansium==1.2.0
...

In [20]:
import robosuite as suite
from robosuite.wrappers import GymWrapper
from gymnasium.vector import SyncVectorEnv
from gymnasium.wrappers import Autoreset
import torch
from torch import Tensor
import torch.nn.functional as F
import torch.nn as nn
from IPython.display import clear_output
from dataclasses import dataclass
from copy import deepcopy
clear_output()

In [None]:
@dataclass
class Hypers:
    num_env = 1
    obs_dim = 214      # observation dim 
    action_dim = 0
    horizon = 1000
    lr = 3e-4
    lambda_ = 0
    tau = 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

hypers = Hypers()

env_configs = {
    "robots":["Panda"],
    "gripper_types":["JacoThreeFingerDexterousGripper"],
    "has_renderer":False,
    "use_camera_obs":False,
    "has_offscreen_renderer":False,
    "horizon":500, 
}

def vec_env():
    def make_env():
        x = suite.make(env_name ="PickPlace" ,**env_configs)
        x = GymWrapper(x,keys=list(x.observation_spec()))
        x.metadata = {"render_mode":[]}
        x = Autoreset(x)
        return x
    return SyncVectorEnv([make_env for _ in range(hypers.num_env)])

In [19]:
shared_net = nn.Sequential(
    nn.LazyLinear(512),
    nn.ReLU(),
    nn.Linear(512,512),
    nn.ReLU(),
    nn.Linear(512,512),
    nn.ReLU()
)

class Actor(nn.Module):
    def __init__(self,action_dim):
        super().__init__()
        self.shared_network = shared_net
        self.output = nn.Linear(512,action_dim)
        self.optim = torch.optim.Adam(self,hypers.lr)

    def forward(self,obs:Tensor):
        x = shared_net(obs)
        x = self.output(x)
        return F.tanh(x)

class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.shared_net = shared_net
        self.output = nn.Linear(512,1)

    def forward(self,obs:Tensor,action:Tensor):
        cat = torch.cat((obs,action),dim=-1)
        x = self.shared_net(cat)
        x = self.output(x)
        return x

In [None]:
class collector:
    def __init__(self):
        self.data = []

    @torch.no_grad()
    def rollout(self):
        pass

    def sample(self):
        pass