In [None]:
#!pip install gymnasium-robotics==1.3.1 
#!pip install gymnasium

In [None]:
import gymnasium as gym
import gymnasium_robotics
gym.register_envs(gymnasium_robotics)
from gymnasium.vector import SyncVectorEnv
from gymnasium.wrappers import Autoreset
from gymnasium.spaces import Box,Dict
import numpy as np
from dataclasses import dataclass
import torch,random,sys

class custom(gym.Wrapper):
    def __init__(self,env):
        super().__init__(env)
        self.observation_space = Dict(
            {
            "observation" : Box(-np.inf,np.inf,(9,),np.float64),
            "achieved_goal" : Box(-np.inf,np.inf,(3,),np.float64),
            "desired_goal" : Box(-np.inf,np.inf,(3,),np.float64)
            }
        )
    
    def reset(self,**kwargs):
        obs,info = super().reset(**kwargs)
        target = random.choice([True,False,False])
        self.env.unwrapped.unwrapped.target_in_the_air = target
        obs["observation"] = obs["observation"][:9]
        self.env.unwrapped.data.qpos[0] = .3 # robot base x pos
        self.env.unwrapped.data.qpos[1] = .5 # robot base y pos
        # self.env.unwrapped.data.qpos[15]   # block's x pos
        # self.env.unwrapped.data.qpos[16]   # block's y pos
        self.env.unwrapped.data.qpos[17] = .4
        return obs,info

    def step(self,action):
        state,reward,done,trunc,info = super().step(action)
        state["observation"] = state["observation"][:9]
        return state,reward,done,trunc,info

def vec_env():
    def make_env():
        x = gym.make("FetchPickAndPlace-v3",max_episode_steps=50)
        x = custom(x)
        x = Autoreset(x)
        return x
    return SyncVectorEnv([make_env for _ in range(hypers.num_envs)])

@dataclass()
class Hypers:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    max_steps = int(2e6)+1 
    lr = 3e-4
    action_dim = 4
    obs_dim = 15
    alpha = 0.2
    warmup = 2_000  
    gamma = 0.99
    tau = 5e-3
    batch_size=128
    num_envs = 5
    horizon = 50

hypers = Hypers()

In [None]:
import torch
from torch import Tensor
import torch.nn as nn
from torch.distributions import Normal
from torch.optim import Adam
import torch.nn.functional as F

class policy(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(hypers.obs_dim,64)
        self.l2 = nn.Linear(64,64)
        self.mean = nn.Linear(64,hypers.action_dim)
        self.std = nn.Linear(64,hypers.action_dim)
        self.optim = Adam(self.parameters(),lr=hypers.lr)

    def forward(self,obs:Tensor):
        x = F.relu(self.l1(obs))
        x = F.relu(self.l2(x))
        mean = self.mean(x)
        std = self.std(x).clamp(-20,2).exp()
        dist = Normal(mean,std)
        pretanh = dist.rsample()
        action = F.tanh(pretanh)
        log = dist.log_prob(pretanh)
        log -= torch.log(1-action.pow(2) + 1e-6) 
        log = log.sum(-1,True)
        return action,log,mean
        
class q_network(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(hypers.obs_dim+hypers.action_dim,64)
        self.l2 = nn.Linear(64,64)
        self.l3 = nn.Linear(64,1)
        self.optim = Adam(self.parameters(),lr=hypers.lr)
    
    def forward(self,obs:Tensor,action:Tensor):
        x = torch.cat((obs,action),dim=-1)
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = F.relu(self.l3(x))
        return x 

In [None]:
from torch import linalg as LA

def process_obs(obs:dict):
    observation = obs["observation"]     # (n env,9)
    achieved_goal = obs["achieved_goal"] # (n env,3)
    desired_goal = obs["desired_goal"]   # (n env,3)
    output = torch.from_numpy(
        np.concatenate([observation,achieved_goal,desired_goal],axis=-1) 
    )
    assert output.shape == torch.Size([hypers.num_envs,15])
    return output.to(device=hypers.device,dtype=torch.float32)  

def process_her_states(observation,achieved_goal,desired_goal):
    output = torch.from_numpy(np.concatenate([observation,achieved_goal,desired_goal],axis=-1)).to(device=hypers.device)
    assert output.shape == torch.Size([hypers.num_envs,15])
    return output

def her_reward(goal_a,goal_b):
    goal_a = torch.from_numpy(goal_a)
    goal_b = torch.from_numpy(goal_b)
    distance_threshold = 0.05
    output = LA.norm(goal_a - goal_b,dim=-1)
    return -(output > distance_threshold).to(device=hypers.device,dtype=torch.float32)

In [None]:
from collections import deque

class hindsight_buffer: 
    def _init_storage(self,capacity=hypers.max_steps): 
        act_dim = (hypers.num_envs,hypers.action_dim) # action shape
        self.stor_rewards = torch.zeros((capacity,hypers.num_envs,),dtype=torch.float16) 
        self.stor_truncs = torch.zeros((capacity,hypers.num_envs,),dtype=torch.bool)
        self.stor_actions = torch.zeros((capacity,*act_dim),dtype=torch.float16)
        self.pointer = 0

    def __init__(self,env,policy):
        self.curr_state = [] # current states storage
        self.nx_states = []  # next states storage
        self._init_storage()
        self.env = env
        self.policy = policy
        self.obs = self.env.reset()[0]
        self.epi_reward = deque(maxlen=hypers.num_envs)
        self.reward = np.zeros(hypers.num_envs,np.float16)
        self.her_storage = deque(maxlen=hypers.horizon)

    def store(self,reward,trunc,action):  
        self.stor_rewards[self.pointer] = reward
        self.stor_truncs[self.pointer] = trunc
        self.stor_actions[self.pointer]= action

    @torch.no_grad()
    def add(self):
        if len(self)<hypers.warmup:
            action = self.env.action_space.sample()
        else:
            action,_,_ = self.policy(process_obs(self.obs))
            action = action.squeeze()
        nx_state,reward,done,trunc,info = self.env.step(action.tolist())

        for i in range(hypers.num_envs):
            self.reward[i]+=reward[i]
            if trunc[i]:
                self.epi_reward.append(self.reward[i])
                self.reward[i] = 0
 
        saved_action = (torch.from_numpy(np.array(action)) if isinstance(action,np.ndarray) else action)

        self.curr_state.append(self.obs)
        self.nx_states.append(nx_state)

        self.store(
            torch.from_numpy(reward).to(device=hypers.device),
            torch.tensor(trunc).to(device=hypers.device),
            saved_action.to(device=hypers.device)
        )
        self.obs = nx_state
    
    def save(self):
        data = {
            "curr_states":self.curr_state,
            "nx_states":self.nx_states,
            "rewards":self.stor_rewards[:self.pointer].half(),
            "truncs":self.stor_truncs[:self.pointer].bool(),
            "actions":self.stor_actions[:self.pointer].half()
        }
        torch.save(data,"./data.pth")
    
    def util(self): 
        return torch.tensor([self.epi_reward]).mean()
    
    def __len__(self):
        return len(self.curr_state)

In [None]:
def her_sample(
    # target ratio 4:1, strategy : future
    batch_size,
    k, 
    curr_states,
    nx_states,
    rewards,
    truncs,
    actions         
    ): 
    
    num_episodes = len(curr_states)//50
    epi_idx = np.random.randint(num_episodes)
    epi_start_idx = epi_idx*50
    batch = curr_states[epi_start_idx:epi_start_idx+50]
    nx_batch = nx_states[epi_start_idx:epi_start_idx+50]

    _her_curr = []
    _her_nx = []
    _her_rewards = []
    _her_truncs = []
    _her_actions = []

    for _ in range(hypers.horizon):
        idx = random.randint(0,len(batch)-2)
        for _ in range(k):
            future_idx = random.randint(idx+1,len(batch)-1)
            curr = batch[idx] # t
            nx = batch[idx+1] # t+1
            future = batch[future_idx] # t'  
            her_rewards = her_reward(curr["achieved_goal"],future["achieved_goal"])
            curr_her_transition = process_her_states(curr["observation"],curr["achieved_goal"],future["achieved_goal"])
            nx_her_transition = process_her_states(nx["observation"],nx["achieved_goal"],future["achieved_goal"])
           
            _her_curr.append(curr_her_transition)
            _her_nx.append(nx_her_transition)
            _her_rewards.append(her_rewards)
            _her_truncs.append(truncs[idx])  
            _her_actions.append(actions[idx])  
        
    assert (len(_her_actions)==len(_her_curr)==len(_her_nx)==len(_her_rewards)==len(_her_truncs)==50*k)

    c = torch.stack([process_obs(n) for n in batch])     # normal transitions
    nx = torch.stack([process_obs(m) for m in nx_batch])
    r = rewards[epi_start_idx:epi_start_idx+hypers.horizon] 
    tr = truncs[epi_start_idx:epi_start_idx+hypers.horizon] 
    a = actions[epi_start_idx:epi_start_idx+hypers.horizon] 

    s_c = torch.cat([c,torch.stack(_her_curr)])  # normal transitions + HER transitons
    s_nx = torch.cat([nx,torch.stack(_her_nx)])
    s_r = torch.cat([r,torch.stack(_her_rewards)])
    s_tr = torch.cat([tr,torch.stack(_her_truncs)])
    s_a = torch.cat([a,torch.stack(_her_actions)])

    sample_idx = torch.randperm(s_c.size(0))[:batch_size]
    
    return (
        s_c[sample_idx].float(),
        s_nx[sample_idx].float(),
        s_r[sample_idx].unsqueeze(-1).float(),
        s_tr[sample_idx].unsqueeze(-1).float(),
        s_a[sample_idx].float(),
    )

def her_worker(queue,buffer:hindsight_buffer):  # multithread worker
    while True:
        if len(buffer)>50:
            states,nx_state,reward,trunc,action = her_sample(
                hypers.batch_size,
                4,
                buffer.curr_state,
                buffer.nx_states,
                buffer.stor_rewards,
                buffer.stor_truncs,
                buffer.stor_actions
            )
            queue.put((states,nx_state,reward,trunc,action))

In [None]:
from tqdm import tqdm
import threading,queue,itertools,copy
from torch.utils.tensorboard import SummaryWriter

class main:
    def __init__(self):
        self.policy = policy().to(device=hypers.device)
        self.q1 = q_network().to(device=hypers.device)
        self.q2 = q_network().to(device=hypers.device)
        self.q1_target = copy.deepcopy(self.q1).to(device=hypers.device)
        self.q2_target = copy.deepcopy(self.q2).to(device=hypers.device)
        self.q_optim = Adam(itertools.chain(self.q1.parameters(),self.q2.parameters()),lr=hypers.lr)
        self.writter = SummaryWriter("./")

        self.env = vec_env()
        self.buffer = hindsight_buffer(self.env,self.policy)

        self.queue = queue.Queue(maxsize=40)
        self.thread = threading.Thread(target=her_worker,args=(self.queue,self.buffer),daemon=True)
        self.thread.start()

    def train(self,start=False):
        if start:
            t = 0
            for n in tqdm(range(int(2e6)),total=int(2e6)):
                self.buffer.add()
                if n%50 == 0:
                    self.writter.add_scalar("Main/epi reward",self.buffer.util(),n,new_style=True)

                if len(self.buffer) >= hypers.warmup: 
                    states,nx_state,reward,trunc,action = self.queue.get()
            
                    with torch.no_grad():
                        target_action,log_target_action,_ = self.policy(states)
                        q1_target = self.q1_target(nx_state,target_action)
                        q2_target = self.q2_target(nx_state,target_action)
                        q_target = reward + (1-trunc) * hypers.gamma * (torch.min(q1_target,q2_target) - hypers.alpha * log_target_action)
                    q1 = self.q1(states,action) 
                    q2 = self.q2(states,action)
                    q_loss = F.mse_loss(q1,q_target) + F.mse_loss(q2,q_target)
                    self.q_optim.zero_grad()
                    q_loss.backward()
                    self.q_optim.step()

                    p_action,log_p_action,_ = self.policy(states)
                    min_q = torch.min(self.q1(states,p_action),self.q2(states,p_action))
                    policy_loss = ((hypers.alpha*log_p_action) - min_q).mean()
                    self.policy.optim.zero_grad()
                    policy_loss.backward()
                    self.policy.optim.step()

                    for q1_params,q1_target_parms in zip(self.q1.parameters(),self.q1_target.parameters()):
                        q1_target_parms.data.mul_(1.0-hypers.tau).add_(q1_params.data,alpha=hypers.tau)

                    for q2_params,q2_target_params in zip(self.q2.parameters(),self.q2_target.parameters()):
                        q2_target_params.data.mul_(1.0-hypers.tau).add_(q2_params.data,alpha=hypers.tau)
                    
                    self.writter.add_scalar("Main/action variance",action.var(),n,new_style=True)
                    self.writter.add_scalar("Main/policy loss action variance",p_action.var(),n,new_style=True)
                    self.writter.add_scalar("Main/policy loss",policy_loss,n,new_style=True)
                    self.writter.add_scalar("Main/q loss",q_loss,n,new_style=True)
                    self.writter.flush()
            
                if n%int(1e5)==0:
                    t+=1
                    torch.save(self.policy.state_dict(),f"./model-{t}.pth")

            torch.save(self.policy.state_dict(),f"./model-final.pth")

if __name__ == "__main__":
    main().train(True)