In [1]:
#python 3.10
#!pip install gymnasium==1.1.1
#!pip install gymnasium-robotics==1.3.1 

In [2]:
import gymnasium as gym
from gymnasium.spaces import Box,Dict
from gymnasium.vector import SyncVectorEnv
import gymnasium_robotics
gym.register_envs(gymnasium_robotics)

import numpy as np
import torch,sys
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from dataclasses import dataclass
from copy import deepcopy
from IPython.display import clear_output
clear_output()

In [None]:
@dataclass
class Hypers:
    num_env : int = 2
    lr = 3e-4
    device = torch.device(
        "cuda" if torch.cuda.is_available()else "cpu"
    )
    batchsize = 10
    minibatch = 5
    num_episode = 5

    policy_noise = 0.2
    noise_clip = 0.5
    gamma = 0.4
    tau = 1.5

hypers = Hypers()

In [4]:
class FetchReachCustom(gym.Wrapper):
    def __init__(self,env : gym.Env):
        super().__init__(env)
        self.action_space = Box(-1,1,(3,),np.float32)
        self.observation_space = Dict(
            {
            "observation" : Box(-np.inf,np.inf,(3,),np.float64),
            "achieved_goal" : Box(-np.inf,np.inf,(3,),np.float64),
            "desired_goal" : Box(-np.inf,np.inf,(3,),np.float64)
            }
        )
    
    def process_obs(self,observation):
        observation["observation"] = observation["observation"][:3]
        return observation
         
    def step(self, action):
        action = np.append(action,0)
        observation, reward, done,truncated, info = self.env.step(action)
        return  self.process_obs(observation), reward, done,truncated, info
    
    def reset(self,seed=None, options=None):
        observation,info = self.env.reset(seed=seed,options=options)
        return self.process_obs(observation),info

def tranform_observation(observation_dict : Dict): # -> torch.Size([6])
    observation = observation_dict.get("observation")
    target = observation_dict.get("achieved_goal")
    assert observation.shape == target.shape
    output = np.concatenate((observation,target),axis=-1)
    return torch.from_numpy(output).to(device=hypers.device,dtype=torch.float32)

def sync_env():
    def make_env():
        x = gym.make("FetchReach-v3")
        x = FetchReachCustom(x)
        return x
    return SyncVectorEnv([make_env for _ in range(hypers.num_env)])

In [13]:
class Actor(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.LazyLinear(256)
        self.l2 = nn.LazyLinear(256)
        self.l3 = nn.LazyLinear(256)
        self.l4 = nn.LazyLinear(256)
        self.output = nn.LazyLinear(3)
        self.optim = torch.optim.Adam(self.parameters(),hypers.lr)
    
    def forward(self,obs: Tensor):
        obs = F.relu(self.l1(obs))
        obs = F.relu(self.l2(obs))
        obs = F.relu(self.l3(obs))
        obs = F.relu(self.l4(obs))
        output = F.tanh(self.output(obs))
        return output

class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.LazyLinear(256)
        self.l2 = nn.LazyLinear(256)
        self.output = nn.LazyLinear(1)
        self.optim = torch.optim.Adam(self.parameters(),hypers.lr)
    
    def forward(self,state: Tensor,action: Tensor):
        x = torch.cat((state,action),-1)
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = self.output(x)
        return x

def init_weights(w):
    if isinstance(w,nn.Linear):
        torch.nn.init.orthogonal_(w.weight)
        torch.nn.init.constant_(w.bias,0.0)

rand_obs = lambda : torch.rand((1,6),dtype = torch.float32,device=hypers.device)
rand_action = lambda : torch.rand((1,3),dtype = torch.float32,device=hypers.device)

def init_networks(a: Actor,q1: Critic,q2: Critic):
    a(rand_obs())
    a.apply(init_weights)

    q1(rand_obs(),rand_action())
    q1.apply(init_weights)

    q2(rand_obs(),rand_action())
    q2.apply(init_weights)

actor = Actor()
critic1 = Critic()
critic2 = Critic()

init_networks(actor,critic1,critic2)

In [None]:
class replay_buffer:
    def __init__(self,env: SyncVectorEnv,actor: Actor):
        self.env = env
        self.actor = actor
        self.data = []
        self.pointer = 0
        
    @torch.no_grad()
    def rollout(self,batchsize):
        self.clear()
        noise_rollout = torch.normal(0.0,0.1,size=(self.env.action_space.shape))
        obs,_ = self.env.reset()
        curr_state  = tranform_observation(obs)
        for _ in range(batchsize):
            action = (self.actor(curr_state) + noise_rollout).cpu().numpy()
            next_state,reward,done,_,_ = self.env.step(torch.rand((2,3)).numpy())
            next_state = tranform_observation(next_state)
            self.data.append(
                (
                    curr_state,
                    torch.from_numpy(action),
                    torch.from_numpy(reward).to(torch.float32),
                    next_state,
                    torch.from_numpy(done).to(torch.float32)
                )
            )
            curr_state = next_state  
            
    def sample(self,sample):
        output = self.data[self.pointer:sample+self.pointer]
        self.pointer+=sample
        curr_state,action,reward,next_state,dones = zip(*output)
        Stack = lambda x : torch.stack(x)
        return (
            Stack(curr_state),
            Stack(action),
            Stack(reward),
            Stack(next_state),
            Stack(dones)
        )

    def clear(self):
        self.data = []
        self.pointer = 0

In [None]:
class Training:
    def __init__(self):
        self.actor = actor
        self.actor_target = deepcopy(self.actor)
        
        self.q1 = critic1
        self.q1_target = deepcopy(self.q1)

        self.q2 = critic2
        self.q2_target = deepcopy(self.q2)

        self.env = sync_env()
        self.replay_buffer = replay_buffer(self.env,self.actor)
    
    def save(self):
        checkpoint = {
            "actor state" : self.actor.state_dict(),
            "actor optim": self.actor.optim.state_dict(),
            "actor target" : self.actor_target.state_dict(),

            "q1 state" : self.q1.state_dict(),
            "q1 optim": self.q1.optim.state_dict(),
            "q1 target state" : self.q1_target.state_dict(),

            "q2 state" : self.q2.state_dict(),
            "q2 optim" : self.q2.optim.state_dict(),
            "q2 target state":self.q2_target.state_dict()  
        }
        torch.save(checkpoint,"./td3.pth")
    
    def load(self,path):
        checkpoint = torch.load(path,map_location=hypers.device)

        self.actor.load_state_dict(checkpoint["actor state"],strict=True)
        self.actor.optim.load_state_dict(checkpoint["actor optim"])
        self.actor_target.load_state_dict(checkpoint["actor target"])

        self.q1.load_state_dict(checkpoint["q1 state"],strict=True)
        self.q1.optim.load_state_dict(checkpoint["q1 optim"])
        self.q1_target.load_state_dict(checkpoint["q1 target state"])

        self.q2.load_state_dict(checkpoint["q2 state"],strict=True)
        self.q2.optim.load_state_dict(checkpoint["q2 optim"])
        self.q2_target.load_state_dict(checkpoint["q2 target state"])
        

    def train(self,num_episode,batchsize,minibatch):
        for traj in range(num_episode):
            self.replay_buffer.rollout(batchsize)
            for _ in range(batchsize//minibatch):
                curr_state,action,reward,next_state,done = self.replay_buffer.sample(minibatch)

                with torch.no_grad():
                    noise_train = torch.normal(0.0,0.2,size=(minibatch,3)).clamp(-0.5,0.5)
                    next_action = (self.actor_target(next_state) + noise_train).clamp(-1,1)

                    q1_target = self.q1_target(next_state,next_action)
                    q2_target = self.q2_target(next_state,next_action)
                    q_target = reward + (1-done) * hypers.gamma * torch.min(q1_target,q2_target)
                 
                q1 = self.q1(curr_state,action)
                q2 = self.q2(curr_state,action)
                critic_loss = F.mse_loss(q1,q_target) + F.mse_loss(q2,q_target)

                self.q1.optim.zero_grad()
                self.q2.optim.zero_grad()
                critic_loss.backward()
                self.q1.optim.step()
                self.q2.optim.step()

                if traj % 2 == 0:
                    x = self.actor(curr_state)
                    actor_loss = -self.q1(curr_state,x).mean()
                    self.actor.optim.zero_grad()
                    actor_loss.backward()
                    self.actor.optim.step()

                    # polyak averaging
                    for actor_param,actor_target_param in zip(self.actor.parameters(),self.actor_target.parameters()):
                        actor_target_param.data.copy_(
                            (hypers.tau * actor_param) + (1-hypers.tau) * actor_target_param
                        )
                    
                    for q1_param,q1_target_param in zip(self.q1.parameters(),self.q1_target.parameters()):
                        q1_target_param.data.copy_(
                            hypers.tau*q1_param + (1-hypers.tau) * q1_target_param
                        )
                    
                    for q2_param,q2_target_param in zip(self.q2.parameters(),self.q2_target.parameters()):
                        q1_target_param.data.copy_(
                            hypers.tau*q2_param + (1-hypers.tau) * q2_target_param
                        )
