In [None]:
#!pip install gymnasium-robotics==1.3.1 
#!pip install gymnasium

In [None]:
import gymnasium as gym
import gymnasium_robotics
gym.register_envs(gymnasium_robotics)
from gymnasium.vector import SyncVectorEnv
from gymnasium.wrappers import Autoreset
from gymnasium.spaces import Box,Dict
import numpy as np
from dataclasses import dataclass
import torch,random

class custom(gym.Wrapper):
    def __init__(self,env):
        super().__init__(env)
        self.observation_space = Dict(
            {
            "observation" : Box(-np.inf,np.inf,(9,),np.float64),
            "achieved_goal" : Box(-np.inf,np.inf,(3,),np.float64),
            "desired_goal" : Box(-np.inf,np.inf,(3,),np.float64)
            }
        )
    
    def reset(self,**kwargs):
        obs,info = super().reset(**kwargs)
        target = random.choice([True,False,False])
        self.env.unwrapped.unwrapped.target_in_the_air = target
        obs["observation"] = obs["observation"][:9]
        self.env.unwrapped.data.qpos[0] = .3 # robot base x pos
        self.env.unwrapped.data.qpos[1] = .5 # robot base y pos
        # self.env.unwrapped.data.qpos[15]   # block's x pos
        # self.env.unwrapped.data.qpos[16]   # block's y pos
        self.env.unwrapped.data.qpos[17] = .4
        return obs,info

    def step(self,action):
        state,reward,done,trunc,info = super().step(action)
        state["observation"] = state["observation"][:9]
        return state,reward,done,trunc,info

def vec_env():
    def make_env():
        x = gym.make("FetchPickAndPlaceDense-v3",max_episode_steps=50)
        x = custom(x)
        x = Autoreset(x)
        return x
    return SyncVectorEnv([make_env for _ in range(hypers.num_envs)])

@dataclass()
class Hypers:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    max_steps = int(4e6)+1
    lr = 3e-4
    action_dim = 4
    obs_dim = 15
    alpha = 0.2
    warmup = 2_000 # * num_envs = 10_000
    gamma = 0.99
    tau = 5e-3
    batch_size=128
    num_envs = 5

hypers = Hypers()

In [None]:
import torch
from torch import Tensor
import torch.nn as nn
from torch.distributions import Normal
from torch.optim import Adam
import torch.nn.functional as F

class policy(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(hypers.obs_dim,64)
        self.l2 = nn.Linear(64,64)
        self.mean = nn.Linear(64,hypers.action_dim)
        self.std = nn.Linear(64,hypers.action_dim)
        self.optim = Adam(self.parameters(),lr=hypers.lr)

    def forward(self,obs:Tensor):
        x = F.relu(self.l1(obs))
        x = F.relu(self.l2(x))
        mean = self.mean(x)
        std = self.std(x).clamp(-10,1).exp()
        #print(std.min(),std.max())
        
        dist = Normal(mean,std)
        pretanh = dist.rsample()
        action = F.tanh(pretanh)
        log = dist.log_prob(pretanh)
        log -= torch.log(1-action.pow(2) + 1e-6) 
        log = log.sum(-1,True)
        return action,log,mean
        
class q_network(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(hypers.obs_dim+hypers.action_dim,64)
        self.l2 = nn.Linear(64,64)
        self.l3 = nn.Linear(64,1)
        self.optim = Adam(self.parameters(),lr=hypers.lr)
    
    def forward(self,obs:Tensor,action:Tensor):
        x = torch.cat((obs,action),dim=-1)
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = F.relu(self.l3(x))
        return x 

In [None]:
from collections import deque

class hindsight_buffer: 
    def _init_storage(self,capacity=hypers.max_steps): 
        obs_dim = (hypers.num_envs,hypers.obs_dim)    # observation shape
        act_dim = (hypers.num_envs,hypers.action_dim) # action shape
        self.stor_curr_states = torch.zeros((capacity,*obs_dim))
        self.stor_nx_states = torch.zeros((capacity,*obs_dim))
        self.stor_rewards = torch.zeros((capacity,hypers.num_envs,))
        self.stor_truncs = torch.zeros((capacity,hypers.num_envs,))
        self.stor_actions = torch.zeros((capacity,*act_dim))
        self.pointer = 0

    def __init__(self,env,policy):
        self._init_storage()
        self.env = env
        self.policy = policy
        self.obs = self.env.reset()[0]
        self.epi_reward = deque(maxlen=hypers.num_envs)
        self.reward = np.zeros(hypers.num_envs,np.float16)
    
    def store(self,curr_state,nx_state,reward,trunc,action):
        self.stor_curr_states[self.pointer] = curr_state
        self.stor_nx_states[self.pointer] = nx_state
        self.stor_rewards[self.pointer] = reward
        self.stor_truncs[self.pointer] = trunc
        self.stor_actions[self.pointer]= action

    @torch.no_grad()
    def add(self):
        if self.pointer<hypers.warmup:
            action = self.env.action_space.sample()
        else:
            action,_,_ = self.policy(self.process_obs(self.obs))
            action = action.squeeze()
        nx_state,reward,done,trunc,info = self.env.step(action.tolist())

        for i in range(hypers.num_envs):
            self.reward[i]+=reward[i]
            if trunc[i]:
                self.epi_reward.append(self.reward[i])
                self.reward[i] = 0
 
        saved_action = (
            torch.from_numpy(np.array(action)).to(torch.float32) if isinstance(action,np.ndarray) else action
        )
        self.store(
            self.process_obs(self.obs),
            self.process_obs(nx_state),
            self.process_reward(reward),
            torch.tensor(trunc),
            saved_action
        )
        self.pointer+=1 # step pointer one time 
        """# HER transition
        her_reward = np.zeros(hypers.num_envs,dtype=np.float16)
        for n in range(hypers.num_envs):
            _her_nx_reward = self.env.envs[n].unwrapped.unwrapped.compute_reward(
                self.obs.get("achieved_goal")[n],nx_state.get("desired_goal")[n],info={}
            ) # r' := r(st,at,g') as seen in the Hindsight Experience Replay on page 5
            her_reward[n] = _her_nx_reward
       
        her_curr_state = self.obs.copy()
        her_curr_state["desired_goal"] = nx_state.get("desired_goal")
        her_nx_state = nx_state.copy()
        self.store(
            self.process_obs(her_curr_state),
            self.process_obs(her_nx_state),
            self.process_reward(her_reward),
            torch.tensor(trunc),
            saved_action
        )
        self.pointer+=1 # step pointer second time """
    
    def process_obs(self,obs:dict):
        observation = obs.get("observation")
        achieved_goal = obs.get("achieved_goal")
        desired_goal = obs.get("desired_goal")
        output = torch.from_numpy(
            np.concatenate([observation,achieved_goal,desired_goal],axis=-1)).to(dtype=torch.float32
        )
        assert output.shape == torch.Size([hypers.num_envs,15])
        return output.unsqueeze(0).to(device=hypers.device)

    def process_reward(self,reward):
        return torch.from_numpy(np.array(reward)).to(torch.float32)
    
    def sample(self,batch):
        idx = torch.randint(0,self.pointer,(batch,))
        return (
            self.stor_curr_states[idx].float().to(device=hypers.device),
            self.stor_nx_states[idx].float().to(device=hypers.device),
            self.stor_rewards[idx].unsqueeze(-1).to(device=hypers.device),
            self.stor_truncs[idx].float().unsqueeze(-1).to(device=hypers.device),
            self.stor_actions[idx].float().to(device=hypers.device)
        )
    
    def save(self):
        data = {
            "curr_states":self.stor_curr_states[:self.pointer].half(),
            "nx_states":self.stor_nx_states[:self.pointer].half(),
            "rewards":self.stor_rewards[:self.pointer].half(),
            "truncs":self.stor_truncs[:self.pointer].bool(),
            "actions":self.stor_actions[:self.pointer].half()
        }
        torch.save(data,"./data.pth")
    
    def util(self): 
        return torch.tensor([self.epi_reward]).mean()

In [None]:
from copy import deepcopy
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

class main:
    def __init__(self):
        self.policy = policy().to(device=hypers.device)
        self.q1 = q_network().to(device=hypers.device)
        self.q2 = q_network().to(device=hypers.device)
        self.q1_target = deepcopy(self.q1).to(device=hypers.device)
        self.q2_target = deepcopy(self.q2).to(device=hypers.device)
        self.q_optim = Adam(list(self.q1.parameters())+list(self.q2.parameters()),lr=hypers.lr)
        self.env = vec_env()
        self.buffer = hindsight_buffer(self.env,self.policy)
        self.writter = SummaryWriter("./")

    def train(self,start=False):
        if start:
            t = 0
            for n in tqdm(range(int(2e6)-1),total=int(2e6)-1):
                self.buffer.add()
                if n%100 == 0:
                    self.writter.add_scalar("Main/epi reward",self.buffer.util(),n,new_style=True)
                if self.buffer.pointer >= hypers.warmup: 
                    states,nx_state,reward,trunc,action = self.buffer.sample(hypers.batch_size)
                    with torch.no_grad():
                        target_action,log_target_action,_ = self.policy(states)
                        q1_target = self.q1_target(nx_state,target_action)
                        q2_target = self.q2_target(nx_state,target_action)
                        q_target = reward + (1-trunc) * hypers.gamma * (torch.min(q1_target,q2_target) - hypers.alpha * log_target_action)
                    q1 = self.q1(states,action) 
                    q2 = self.q2(states,action)
                    q_loss = F.mse_loss(q1,q_target) + F.mse_loss(q2,q_target)
                    self.q_optim.zero_grad()
                    q_loss.backward()
                    self.q_optim.step()

                    p_action,log_p_action,_ = self.policy(states)
                    min_q = torch.min(self.q1(states,p_action),self.q2(states,p_action))
                    policy_loss = ((hypers.alpha*log_p_action) - min_q).mean()
                    self.policy.optim.zero_grad()
                    policy_loss.backward()
                    self.policy.optim.step()

                    for q1_params,q1_target_parms in zip(self.q1.parameters(),self.q1_target.parameters()):
                        q1_target_parms.data.mul_(1.0-hypers.tau).add_(q1_params.data,alpha=hypers.tau)
                         
                    for q2_params,q2_target_params in zip(self.q2.parameters(),self.q2_target.parameters()):
                        q2_target_params.data.mul_(1.0-hypers.tau).add_(q2_params.data,alpha=hypers.tau)
                    
                    self.writter.add_scalar("Main/action variance",action.var(),n,new_style=True)
                    self.writter.add_scalar("Main/policy loss action variance",p_action.var(),n,new_style=True)
                    self.writter.add_scalar("Main/policy loss",policy_loss,n,new_style=True)
                    self.writter.add_scalar("Main/q loss",q_loss,n,new_style=True)
                    self.writter.flush()
            
                if n%int(1e5)==0:
                    t+=1
                    torch.save(self.policy.state_dict(),f"./model-{t}.pth")

            torch.save(self.policy.state_dict(),f"./model-final.pth")

main().train(True)