In [None]:
# written on python 3.10.11
!pip install gymnasium==1.1.1
!pip install gymnasium-robotics==1.3.1 

In [None]:
import gymnasium as gym
from gymnasium.spaces import Box,Dict
from gymnasium.wrappers import Autoreset
import gymnasium_robotics
gym.register_envs(gymnasium_robotics)
import numpy as np
import torch,sys
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from dataclasses import dataclass
from copy import deepcopy
from IPython.display import clear_output
clear_output()

In [None]:
@dataclass
class Hypers:
    device = torch.device(
        "cuda" if torch.cuda.is_available()else "cpu"
    )
    xpos = 0.3 # position of the base of the robot
    ypos = 0.5

    max_env_steps = 50
    num_episode = 101
    buffer_max_size = 1e6
    buffer_update_size = 1e4
    warmups = 25e3
    batchsize = 256
    
    lr = 3e-4
    gamma = .99
    tau = 0.005

hypers = Hypers()

In [None]:
class FetchReachCustom(gym.Wrapper):
    def __init__(self,env : gym.Env):
        super().__init__(env)
        self.action_space = Box(-1,1,(3,),np.float32)
        self.observation_space = Dict(
            {
            "observation" : Box(-np.inf,np.inf,(3,),np.float64),
            "achieved_goal" : Box(-np.inf,np.inf,(3,),np.float64),
            "desired_goal" : Box(-np.inf,np.inf,(3,),np.float64)
            }
        )
    
    def process_obs(self,observation):
        observation["observation"] = observation["observation"][:3]
        return observation
         
    def step(self, action):
        action = np.append(action,0)
        observation, reward, done,truncated, info = self.env.step(action)
        return  self.process_obs(observation), reward, done,truncated, info
    
    def reset(self,seed=None, options=None):
        observation,info = self.env.reset(seed=seed,options=options)
        self.env.unwrapped.data.qpos[0] = hypers.xpos
        self.env.unwrapped.data.qpos[1] = hypers.ypos
        return self.process_obs(observation),info

def tranform_observation(observation_dict : Dict): # -> torch.Size([6])
    #observation = observation_dict.get("observation")
    current_pos = observation_dict.get("achieved_goal")
    target = observation_dict.get("desired_goal")
    #assert observation.shape == target.shape, f"{observation.shape},{target.shape}"
    output = np.concatenate((current_pos,target),axis=-1)
    return torch.from_numpy(output).to(device=hypers.device,dtype=torch.float32)

def make_env():
    x = gym.make("FetchReachDense-v3",max_episode_steps=hypers.max_env_steps)
    x = FetchReachCustom(x)
    x = Autoreset(x)
    return x

In [5]:
class Actor(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.LazyLinear(256)
        self.l2 = nn.LazyLinear(256)
        self.l3 = nn.LazyLinear(256)
        self.output = nn.LazyLinear(3)
        self.optim = torch.optim.Adam(self.parameters(),hypers.lr)
    
    def forward(self,obs: Tensor):
        obs = F.relu(self.l1(obs))
        obs = F.relu(self.l2(obs))
        obs = F.relu(self.l3(obs))
        output = F.tanh(self.output(obs))
        return output

class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.LazyLinear(256)
        self.l2 = nn.LazyLinear(256)
        self.output = nn.LazyLinear(1)
    
    def forward(self,state: Tensor,action: Tensor):
        x = torch.cat((state,action),-1)
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = self.output(x)
        return x

def init_weights(w):
    if isinstance(w,nn.Linear):
        torch.nn.init.orthogonal_(w.weight)
        torch.nn.init.constant_(w.bias,0.0)

rand_obs = lambda : torch.rand((1,6),dtype = torch.float32,device=hypers.device)
rand_action = lambda : torch.rand((1,3),dtype = torch.float32,device=hypers.device)

def init_networks(a: Actor,q1: Critic,q2: Critic):
    a(rand_obs())
    a.apply(init_weights)

    q1(rand_obs(),rand_action())
    q1.apply(init_weights)

    q2(rand_obs(),rand_action())
    q2.apply(init_weights)

actor = Actor()
critic1 = Critic()
critic2 = Critic()

init_networks(actor,critic1,critic2)

In [None]:
import random

class replay_buffer:
    def __init__(self,env,actor: Actor):
        self.env = env
        self.actor = actor
        self.data = []
        self.rew = []
        self.episode_reward = np.zeros(1, dtype=np.float32)
        self.counter = 0
        self.random_action_num = 0

    @torch.no_grad()
    def rollout(self,batchsize):
        obs,_ = self.env.reset()
        curr_state  = tranform_observation(obs)
        for _ in range(batchsize):
            self.counter+=1
            if self.counter <= hypers.warmups:
                self.random_action_num+=1
                action = self.env.action_space.sample()
            else:
                noise_rollout = torch.normal(0.0,0.1,size=(self.env.action_space.shape))
                action = (self.actor(curr_state) + noise_rollout).clamp(-1,1).numpy()
    
            next_state,reward,done,trunc,_ = self.env.step(action)
            next_state = tranform_observation(next_state)
            self.episode_reward+=reward
            if trunc:
                self.rew.append(self.episode_reward)
                self.episode_reward = 0
     
            self.data.append(
                (
                    curr_state,
                    torch.from_numpy(action),
                    torch.from_numpy(np.array(reward)).to(torch.float32),
                    next_state,
                    torch.tensor(trunc).to(torch.float32)
                )
            )
            curr_state = next_state  
            
    def sample(self,sample):
        output = random.sample(self.data,sample)
        curr_state,action,reward,next_state,trunc = zip(*output)
        Stack = lambda x : torch.stack(x)
        return (
            Stack(curr_state),
            Stack(action),
            Stack(reward).unsqueeze(-1),
            Stack(next_state),
            Stack(trunc).unsqueeze(-1)
        )

    def util(self):
        return self.rew,self.random_action_num
    
    def __len__(self):
        return len(self.data)

In [None]:
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

class Training:
    def __init__(self):
        self.actor = actor
        self.actor_target = deepcopy(self.actor)
        
        self.q1 = critic1
        self.q1_target = deepcopy(self.q1)
        self.q2 = critic2
        self.q2_target = deepcopy(self.q2)
        self.critic_optim = torch.optim.Adam(
            list(self.q1.parameters()) + list(self.q2.parameters()),lr=hypers.lr
        )
        self.env = make_env() 
        self.replay_buffer = replay_buffer(self.env,self.actor)
        self.writter = SummaryWriter("data/")
        self.total_it = 0

    def save(self,num):
        checkpoint = {
            "actor state" : self.actor.state_dict(),
            "actor optim": self.actor.optim.state_dict(),
            "actor target" : self.actor_target.state_dict(),

            "q1 state" : self.q1.state_dict(),
            "q2 state" : self.q2.state_dict(),
            "critic optim state" : self.critic_optim.state_dict(),
            "q1 target state" : self.q1_target.state_dict(),
            "q2 target state":self.q2_target.state_dict()  
        }
        torch.save(checkpoint,f"data/model_{num}.pth")
    
    def load(self,path):
        checkpoint = torch.load(path,map_location=hypers.device)

        self.actor.load_state_dict(checkpoint["actor state"],strict=True)
        self.actor.optim.load_state_dict(checkpoint["actor optim"])
        self.actor_target.load_state_dict(checkpoint["actor target"])
        
        self.critic_optim.load_state_dict(checkpoint["critic optim state"])
        self.q1.load_state_dict(checkpoint["q1 state"],strict=True)
        self.q2.load_state_dict(checkpoint["q2 state"],strict=True)
        self.q1_target.load_state_dict(checkpoint["q1 target state"])
        self.q2_target.load_state_dict(checkpoint["q2 target state"])
        
    def train(self):
        for traj in tqdm(range(int(hypers.num_episode)),total=hypers.num_episode):
            if not len(self.replay_buffer) == hypers.buffer_max_size:
                self.replay_buffer.rollout(int(hypers.buffer_update_size))

            if self.replay_buffer.random_action_num >= hypers.buffer_max_size : 
                assert len(self.replay_buffer) >= hypers.warmups
                
                for n in range(int(hypers.buffer_update_size)):
                    self.total_it +=1
                    s,a,r,ns,trunc = self.replay_buffer.sample(hypers.batchsize) # s : state, a : action, r : reward, ns : next state, trunc : truncated
                    with torch.no_grad():
                        noise_train = torch.normal(0.0,0.2,size=self.env.action_space.shape).clamp(-0.5,0.5)
                        na = (self.actor_target(ns) + noise_train).clamp(-1,1) # na : next action
                        q1_target = self.q1_target(ns,na)
                        q2_target = self.q2_target(ns,na)
                        assert r.shape == trunc.shape == q1_target.shape == q2_target.shape 
                        q_target = r + (1-trunc) * hypers.gamma * torch.min(q1_target,q2_target) 
                    
                    q1 = self.q1(s,a) 
                    q2 = self.q2(s,a) 
                    critic_loss = F.mse_loss(q1,q_target) + F.mse_loss(q2,q_target)
        
                    self.critic_optim.zero_grad()
                    critic_loss.backward()
                    self.critic_optim.step()

                    if self.total_it % 2 == 0: 
                        actor_loss = -self.q1(s,self.actor(s)).mean()
                        self.writter.add_scalar("main/Policy loss",actor_loss.mean(),traj)

                        self.actor.optim.zero_grad()
                        actor_loss.backward()
                        self.actor.optim.step()

                        # polyak averaging
                        for actor_param,actor_target_param in zip(self.actor.parameters(),self.actor_target.parameters()):
                            actor_target_param.data.copy_(
                                (hypers.tau * actor_param) + (1-hypers.tau) * actor_target_param
                            )
                        for q1_param,q1_target_param in zip(self.q1.parameters(),self.q1_target.parameters()):
                            q1_target_param.data.copy_(
                                hypers.tau*q1_param + (1-hypers.tau) * q1_target_param
                            )
                        for q2_param,q2_target_param in zip(self.q2.parameters(),self.q2_target.parameters()):
                            q2_target_param.data.copy_(
                                hypers.tau*q2_param + (1-hypers.tau) * q2_target_param
                            )
                    
                    self.writter.add_scalar("main/Action variance",a.var(),n)
                    self.writter.add_scalar("main/Next action variance",na.var(),n)
                    self.writter.add_scalar("main/Critic Loss",critic_loss.mean(),n)
                    self.writter.add_scalar("main/Q target",q_target.mean(),n)
                    self.writter.add_scalar("main/Episode reward",torch.tensor(self.replay_buffer.util()[0]).mean(),n)
                    self.writter.add_scalar("main/Warmup batch size",self.replay_buffer.util()[1],n)
                
            if traj!=0 and traj%10 == 0:
                self.save(traj)

        self.save(traj)

Training().train()

100%|██████████| 101/101 [5:17:08<00:00, 188.40s/it] 
