In [1]:
#python 3.10
#!pip install gymnasium==1.1.1
#!pip install gymnasium-robotics==1.3.1 

In [2]:
import gymnasium as gym
from gymnasium.spaces import Box,Dict
from gymnasium.vector import SyncVectorEnv
import gymnasium_robotics
gym.register_envs(gymnasium_robotics)

import numpy as np
import torch,sys
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from dataclasses import dataclass
from copy import deepcopy
from IPython.display import clear_output
clear_output()

In [3]:
@dataclass
class Hypers:
    num_env : int = 2
    lr = 3e-4
    device = torch.device(
        "cuda" if torch.cuda.is_available()else "cpu"
    )

hypers = Hypers()

In [4]:
class FetchReachCustom(gym.Wrapper):
    def __init__(self,env : gym.Env):
        super().__init__(env)
        self.action_space = Box(-1,1,(3,),np.float32)
        self.observation_space = Dict(
            {
            "observation" : Box(-np.inf,np.inf,(3,),np.float64),
            "achieved_goal" : Box(-np.inf,np.inf,(3,),np.float64),
            "desired_goal" : Box(-np.inf,np.inf,(3,),np.float64)
            }
        )
    
    def process_obs(self,observation):
        observation["observation"] = observation["observation"][:3]
        return observation
         
    def step(self, action):
        action = np.append(action,0)
        observation, reward, done,truncated, info = self.env.step(action)
        return  self.process_obs(observation), reward, done,truncated, info
    
    def reset(self,seed=None, options=None):
        observation,info = self.env.reset(seed=seed,options=options)
        return self.process_obs(observation),info

def tranform_observation(observation_dict : Dict): # -> torch.Size([6])
    observation = observation_dict.get("observation")
    target = observation_dict.get("achieved_goal")
    assert observation.shape == target.shape
    output = np.concatenate((observation,target),axis=-1)
    return torch.from_numpy(output).to(device=hypers.device,dtype=torch.float32)

def sync_env():
    def make_env():
        x = gym.make("FetchReach-v3")
        x = FetchReachCustom(x)
        return x
    return SyncVectorEnv([make_env for _ in range(hypers.num_env)])

In [5]:
class Actor(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.LazyLinear(256)
        self.l2 = nn.LazyLinear(256)
        self.l3 = nn.LazyLinear(256)
        self.l4 = nn.LazyLinear(256)
        self.output = nn.LazyLinear(3)
        self.policy_optim = torch.optim.Adam(self.parameters(),hypers.lr)
    
    def forward(self,obs: Tensor):
        obs = F.relu(self.l1(obs))
        obs = F.relu(self.l2(obs))
        obs = F.relu(self.l3(obs))
        obs = F.relu(self.l4(obs))
        output = F.tanh(self.output(obs))
        return output

class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.LazyLinear(256)
        self.l2 = nn.LazyLinear(256)
        self.output = nn.LazyLinear(1)
        self.critic_optim = torch.optim.Adam(self.parameters(),hypers.lr)
    
    def forward(self,state: Tensor,action: Tensor):
        x = torch.cat((state,action),-1)
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = self.output(x)
        return x

def init_weights(w):
    if isinstance(w,nn.Linear):
        torch.nn.init.orthogonal_(w.weight)
        torch.nn.init.constant_(w.bias,0.0)

rand_obs = lambda : torch.rand((1,6),dtype = torch.float32,device=hypers.device)
rand_action = lambda : torch.rand((1,3),dtype = torch.float32,device=hypers.device)

def init_networks(a: Actor,q1: Critic,q2: Critic):
    a(rand_obs())
    a.apply(init_weights)

    q1(rand_obs(),rand_action())
    q1.apply(init_weights)

    q2(rand_obs(),rand_action())
    q2.apply(init_weights)

actor = Actor()
critic1 = Critic()
critic2 = Critic()

init_networks(actor,critic1,critic2)

actor_target = deepcopy(actor)
critic1_target = deepcopy(critic1)
critic2_target = deepcopy(critic2)

In [6]:
class replay_buffer:
    def __init__(self,env: SyncVectorEnv,actor: Actor):
        self.env = env
        self.actor = actor
        self.data = []
        self.pointer = 0
        
    @torch.no_grad()
    def rollout(self,batchsize):
        self.clear()
        epsilon = torch.normal(0.0,0.1,size=(2,3))
        obs,_ = self.env.reset()
        curr_state  = tranform_observation(obs)
        for _ in range(batchsize):
            action = (self.actor(curr_state) + epsilon).cpu().numpy()
            next_state,reward,_,_,_ = self.env.step(torch.rand((2,3)).numpy())
            next_state = tranform_observation(next_state)
            self.data.append(
                (
                    curr_state,
                    torch.from_numpy(action),
                    torch.from_numpy(reward).to(torch.float32),
                    next_state
                )
            )
            curr_state = next_state  
            
    def sample(self,sample):
        output = self.data[self.pointer:sample+self.pointer]
        self.pointer+=sample
        curr_state,action,reward,next_state = zip(*output)
        Stack = lambda x : torch.stack(x)
        return (
            Stack(curr_state),
            Stack(action),
            Stack(reward),
            Stack(next_state)
        )

    def clear(self):
        self.data = []
        self.pointer = 0