In [None]:
#python 3.10
#!pip install gymnasium==1.1.1
#!pip install gymnasium-robotics==1.3.1 

In [None]:
import gymnasium as gym
from gymnasium.spaces import Box,Dict
from gymnasium.vector import AsyncVectorEnv,SyncVectorEnv
import gymnasium_robotics
gym.register_envs(gymnasium_robotics)
from IPython.display import clear_output

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
from dataclasses import dataclass
clear_output()

In [None]:
@dataclass
class Hypers:
    num_env : int = 2
    lr = 3e-4
    device = torch.device(
        "cuda" if torch.cuda.is_available()else "cpu"
    )

hypers = Hypers()

In [None]:
class FetchReachCustom(gym.Wrapper):
    def __init__(self,env : gym.Env):
        super().__init__(env)
        self.action_space = Box(-1,1,(3,),np.float32)
        self.observation_space = Dict(
            {
            "observation" : Box(-np.inf,np.inf,(3,),np.float64),
            "achieved_goal" : Box(-np.inf,np.inf,(3,),np.float64),
            "desired_goal" : Box(-np.inf,np.inf,(3,),np.float64)
            }
        )
    
    def process_obs(self,observation):
        observation["observation"] = observation["observation"][:3]
        return observation
         
    def step(self, action):
        observation, reward, done,truncated, info = self.env.step(action)
        return  self.process_obs(observation), reward, done,truncated, info
    
    def reset(self,seed=None, options=None):
        observation,info = self.env.reset(seed=seed,options=options)
        return self.process_obs(observation),info


def tranform_observation(observation_dict : gym.spaces.Dict): # -> torch.Size([6])
    observation = observation_dict.get("observation")
    target = observation_dict.get("achieved_goal")
    assert observation.shape == target.shape
    output = np.concatenate((observation,target),axis=-1)
    return torch.from_numpy(output).to(torch.float32)

def sync_env():
    def make_env():
        x = gym.make("FetchReach-v3")
        x = FetchReachCustom(x)
        return x
    return SyncVectorEnv([make_env for _ in range(hypers.num_env)])

In [48]:
class Actor(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.LazyLinear(256)
        self.l2 = nn.LazyLinear(256)
        self.l3 = nn.LazyLinear(256)
        self.l4 = nn.LazyLinear(256)
        self.output = nn.LazyLinear(3)
        self.policy_optim = torch.optim.Adam(self.parameters(),hypers.lr)
    
    def forward(self,obs):
        obs = F.relu(self.l1(obs))
        obs = F.relu(self.l2(obs))
        obs = F.relu(self.l3(obs))
        obs = F.relu(self.l4(obs))
        output = F.tanh(self.output(obs))
        return output

class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.LazyLinear(256)
        self.l2 = nn.LazyLinear(256)
        self.l3 = nn.LazyLinear(256)
        self.output = nn.LazyLinear(1)
        self.critic_optim = torch.optim.Adam(self.parameters(),hypers.lr)
    
    def forward(self,state,action):
        cat = torch.cat((state,action),-1)
        x = F.relu(self.l1(cat))
        x = F.relu(self.l2(x))
        x = F.relu(self.l3(x))
        x = self.output(x)
        return x

actor = Actor()
critic = Critic()

def init_networks(n1,n2):
    n1(torch.rand((1,6),dtype=torch.float32,device=hypers.device))
    n2(
        torch.rand((1,6),dtype=torch.float32,device=hypers.device),
        torch.rand((1,3),dtype=torch.float32,device=hypers.device)
    )

init_networks(actor,critic)
