In [None]:
#!pip install gymnasium-robotics==1.3.1 
#!pip install gymnasium

In [None]:
import gymnasium as gym
import gymnasium_robotics
gym.register_envs(gymnasium_robotics)
import numpy as np
import torch,sys
from dataclasses import dataclass

class custom(gym.Wrapper):
    def __init__(self,env):
        super().__init__(env)
    
    def reset(self,**kwargs):
        obs,info = super().reset(**kwargs)
        self.env.unwrapped.data.qpos[0] = .3 # robot base x pos
        self.env.unwrapped.data.qpos[1] = .5 # robot base y pos
        # self.env.unwrapped.data.qpos[15]   # block's x pos
        # self.env.unwrapped.data.qpos[16]   # block's y pos
        self.env.unwrapped.data.qpos[17] = .4
        return obs,info

    def step(self,action):
        return super().step(action)

def make_env():
    x = gym.make("FetchPickAndPlace-v3",max_episode_steps=100)
    x = custom(x)
    return x

@dataclass()
class Hypers:
    lr = 3e-4
    action_dim = 4
    obs_dim = 15
    alpha = 0.2
    warmup = 2_000
    gamma = 0.99
    tau = 5e-3

hypers = Hypers()

In [None]:
from torch import Tensor
import torch.nn as nn
from torch.distributions import Normal
from torch.optim import Adam
import torch.nn.functional as F

class policy(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(hypers.obs_dim,128)
        self.l2 = nn.Linear(128,128)
        self.mean = nn.Linear(128,hypers.action_dim)
        self.std = nn.Linear(128,hypers.action_dim)
        self.optim = Adam(self.parameters(),lr=hypers.lr)

    def forward(self,obs:Tensor):
        x = F.relu(self.l1(obs))
        x = F.relu(self.l2(x))
        mean = self.mean(x)
        std = self.std(x).clamp(-20,2).exp()
        dist = Normal(mean,std)
        pretanh = dist.rsample()
        action = F.tanh(pretanh)

        log = dist.log_prob(pretanh)
        log -= torch.log(1-action.pow(2) + 1e-6)
        log = log.sum(-1,True)
        return action,log,mean
        
class q_network(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(hypers.obs_dim+hypers.action_dim,128)
        self.l2 = nn.Linear(128,128)
        self.l3 = nn.Linear(128,1)
        self.optim = Adam(self.parameters(),lr=hypers.lr)
    
    def forward(self,obs:Tensor,action:Tensor):
        x = torch.cat((obs,action),dim=-1)
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = F.relu(self.l3(x))
        return x 

In [None]:
import random

class hindsight_buffer:
    def __init__(self,env,policy):
        self.data = []
        self.env = env
        self.policy = policy
        self.obs = self.env.reset()[0]
        self.epi_reward = 0
        self.reward = 0

    @torch.no_grad()
    def add(self):
        if len(self.data)< hypers.warmup:
            action = self.env.action_space.sample()
        else:
            action,_,_ = self.policy(self.process_obs(self.obs))
        nx_state,reward,done,trunc,info = self.env.step(action.tolist())
        self.reward+=reward
        if trunc:
            self.epi_reward = self.reward
            self.obs = self.env.reset()[0]
            self.reward = 0
        else:
            self.obs = nx_state
 
        saved_action = (
            torch.from_numpy(np.array(action)).to(torch.float32) if isinstance(action,np.ndarray) else action
        )
        self.data.append(
            (
                self.process_obs(self.obs),
                self.process_obs(nx_state),
                torch.from_numpy(np.array(reward)).to(torch.float32),
                torch.tensor(trunc,dtype=torch.float32),
                saved_action
            )
        )
        # HER transition
        her_nx_reward = self.env.unwrapped.unwrapped.compute_reward(
            self.obs.get("achieved_goal"),nx_state.get("desired_goal"),info={}
        ) # r' := r(st,at,g') as seen in the paper of HER, page 5
        her_curr_state = self.obs.copy()
        her_curr_state["desired_goal"] = nx_state.get("desired_goal")
        her_nx_state = nx_state.copy()
        self.data.append(
            (
                self.process_obs(her_curr_state),
                self.process_obs(her_nx_state),
                torch.from_numpy(np.array(her_nx_reward)).to(torch.float32),
                torch.tensor(trunc,dtype=torch.float32),
                saved_action
            )
        )
    
    def process_obs(self,obs:dict):
        observation = obs.get("observation")[:9]
        achieved_goal = obs.get("achieved_goal")
        desired_goal = obs.get("desired_goal")
        return torch.from_numpy(np.append(observation,(achieved_goal,desired_goal))).to(dtype=torch.float32)
    
    def sample(self,batch):
        output = random.sample(self.data,batch)
        state,nx_state,reward,trunc,action = zip(*output)
        state,nx_state,reward,trunc,action = map(torch.stack,zip(*output))
        return state,nx_state,reward.unsqueeze(0),trunc.unsqueeze(0),action

    def util(self): 
        return self.epi_reward
    
    def __len__(self):
        return len(self.data)

In [None]:
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

class main:
    def __init__(self):
        self.policy = policy()
        self.q1 = q_network()
        self.q2 = q_network()
        self.q1_target = deepcopy(self.q1)
        self.q2_target = deepcopy(self.q2)
        self.q_optim = Adam(list(self.q1.parameters())+list(self.q2.parameters()),lr=hypers.lr)
        self.env = make_env()
        self.buffer = hindsight_buffer(self.env,self.policy)
        self.writter = SummaryWriter("./")

    def train(self,start=False):
        if start:
            for n in tqdm(range(int(1e6)),total=int(1e6)):
                self.buffer.add()
                if len(self.buffer)>= 2_000: 
                    states,nx_state,reward,trunc,action = self.buffer.sample(256)

                    with torch.no_grad():
                        target_action,log_target_action,_ = self.policy(states)
                        q1_target = self.q1_target(nx_state,target_action)
                        q2_target = self.q2_target(nx_state,target_action)
                        q_target = reward + (1-trunc) * hypers.gamma * (torch.min(q1_target,q2_target) - hypers.alpha * log_target_action)
                    q1 = self.q1(states,action) 
                    q2 = self.q2(states,action)
                    q_loss = F.mse_loss(q1,q_target) + F.mse_loss(q2,q_target)
                    self.q_optim.zero_grad()
                    q_loss.backward()
                    self.q_optim.step()

                    p_action,log_p_action,_ = self.policy(states)
                    policy_loss = ((hypers.alpha*log_p_action) - self.q1(states,p_action)).mean()
                    self.policy.optim.zero_grad()
                    policy_loss.backward()
                    self.policy.optim.step()

                    for q1_params,q1_target_parms in zip(self.q1.parameters(),self.q1_target.parameters()):
                        q1_target_parms.data.copy_((q1_params*hypers.tau)+(1.0-hypers.tau)*q1_target_parms)
                    for q2_params,q2_target_params in zip(self.q2.parameters(),self.q2_target.parameters()):
                        q2_target_params.data.copy_((q2_params*hypers.tau)+(1.0-hypers.tau)*q2_target_params)
                    
                    if n%100==0:
                        self.writter.add_scalar("Main/epi reward",self.buffer.util(),n,new_style=True)
                    self.writter.add_scalar("Main/action variance",action.var(),n,new_style=True)
                    self.writter.add_scalar("Main/policy loss action variance",p_action.var(),n,new_style=True)
                    self.writter.add_scalar("Main/policy loss",policy_loss,n,new_style=True)
                    self.writter.add_scalar("Main/q loss",q_loss,n,new_style=True)
                    self.writter.flush()
            
            torch.save(self.policy.state_dict(),"./model.pth")

main().train(True)