In [1]:
import gymnasium as gym
import gymnasium_robotics,torch,sys
gym.register_envs(gymnasium_robotics)
from gymnasium.spaces import Box,Dict
from gymnasium.wrappers import Autoreset
import numpy as np
from dataclasses import dataclass

In [2]:
@dataclass()
class Hypers:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    max_steps = int(1e6)+1 
    lr = 3e-4
    action_dim = 3
    obs_dim = 6
    warmup = 25e3
    gamma = 0.99
    tau = 5e-3
    batch_size = 256
    num_envs = 1
    horizon = 100

hypers = Hypers()

class FetchReachCustom(gym.Wrapper):
    def __init__(self,env : gym.Env):
        super().__init__(env)
        self.action_space = Box(-1,1,(3,),np.float32)
        self.observation_space = Dict(
            {
            "observation" : Box(-np.inf,np.inf,(3,),np.float64),
            "achieved_goal" : Box(-np.inf,np.inf,(3,),np.float64),
            "desired_goal" : Box(-np.inf,np.inf,(3,),np.float64)
            }
        )

    def process_obs(self,observation):
        observation["observation"] = observation["observation"][:3]
        return observation
            
    def step(self, action):
        action = np.append(action,0)
        observation, reward, done,truncated, info = self.env.step(action)
        return  self.process_obs(observation), reward, done,truncated, info

    def reset(self,seed=None, options=None):
        observation,info = self.env.reset(seed=seed,options=options)
        self.env.unwrapped.data.qpos[0] = 0.3
        self.env.unwrapped.data.qpos[1] = 0.5
        return self.process_obs(observation),info

def tranform_observation(observation_dict : Dict):  
    current_pos = observation_dict.get("achieved_goal")
    target = observation_dict.get("desired_goal")
    output = np.concatenate((current_pos,target),axis=-1)
    return torch.from_numpy(output).to(device=hypers.device,dtype=torch.float32)

def make_env():
    x = gym.make("FetchReachDense-v3",max_episode_steps=50)
    x = FetchReachCustom(x)
    x = Autoreset(x)
    return x

In [3]:
from torch import Tensor
import torch.nn as nn
from torch.distributions import Normal
from torch.optim import Adam
import torch.nn.functional as F

def weight_init(l):
    if isinstance(l,nn.Linear):
        torch.nn.init.orthogonal_(l.weight)
        torch.nn.init.constant_(l.bias,0.0)

class policy(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(hypers.obs_dim,256)
        self.l2 = nn.Linear(256,256)
        self.mean = nn.Linear(256,hypers.action_dim)
        self.std = nn.Linear(256,hypers.action_dim)
        self.apply(weight_init)
        self.optim = Adam(self.parameters(),lr=hypers.lr)

    def forward(self,obs:Tensor):
        x = F.relu(self.l1(obs))
        x = F.relu(self.l2(x))
        mean = self.mean(x)
        std = self.std(x).clamp(-20,2).exp()
        dist = Normal(mean,std)
        pretanh = dist.rsample()
        action = F.tanh(pretanh)
        log = dist.log_prob(pretanh)
        log -= torch.log(1-action.pow(2) + 1e-6) 
        log = log.sum(-1,True)
        return action,log,mean
        
class q_network(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(hypers.obs_dim+hypers.action_dim,256)
        self.l2 = nn.Linear(256,256)
        self.l3 = nn.Linear(256,1)
        self.apply(weight_init)
    
    def forward(self,obs:Tensor,action:Tensor):
        x = torch.cat((obs,action),dim=-1)
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = self.l3(x)
        return x 

In [4]:
import numpy as np
from torch.utils.tensorboard import SummaryWriter

class buffer: 
    def _init_storage(self,data_path=None,capacity=hypers.max_steps): 
        if data_path is not None:
            self.data = torch.load(data_path,map_location=hypers.device,weights_only=False)
        obs_dim = (hypers.num_envs,hypers.obs_dim)     
        act_dim = (hypers.num_envs,hypers.action_dim) 
        self.stor_curr_states = self.data["curr_states"] if data_path is not None else torch.zeros((capacity,*obs_dim))
        self.stor_nx_states = self.data["nx_states"] if data_path is not None else torch.zeros((capacity,*obs_dim))
        self.stor_rewards = self.data["rewards"] if data_path is not None else torch.zeros((capacity,hypers.num_envs,))
        self.stor_dones = self.data["dones"] if data_path is not None else torch.zeros((capacity,hypers.num_envs,))
        self.stor_actions = self.data["actions"] if data_path is not None else torch.zeros((capacity,*act_dim))
        self.pointer = self.data["pointer"] if data_path is not None else 0

    def __init__(self,env,policy):
        self._init_storage(data_path=None)
        self.env = env
        self.policy = policy
        self.obs = self.env.reset()[0]
        self.epi_reward = 0
        self.reward = 0
        self.writter = SummaryWriter("./")
        self.steps = 0
    
    def store(self,curr_state,nx_state,reward,done,action):
        self.stor_curr_states[self.pointer] = curr_state
        self.stor_nx_states[self.pointer] = nx_state
        self.stor_rewards[self.pointer] = reward
        self.stor_dones[self.pointer] = done
        self.stor_actions[self.pointer]= action

    @torch.no_grad()
    def step(self):
        self.steps+=1
        if self.pointer<hypers.warmup:
            action = self.env.action_space.sample()
        else:
            action,_,_ = self.policy(tranform_observation(self.obs))
            action = action.squeeze()
        nx_state,reward,done,_,_ = self.env.step(action.tolist())
        self.writter.add_scalar("Main/rewards",reward,self.steps,new_style=True)
         
        self.reward += reward
        if done:
            self.epi_reward = self.reward
            self.reward = 0

        saved_action = (torch.from_numpy(np.array(action)) if isinstance(action,np.ndarray) else action)
       
        self.store(
            tranform_observation(self.obs),
            tranform_observation(nx_state),
            torch.tensor(reward),
            torch.tensor(done),
            saved_action
        )
        self.obs = nx_state
        self.pointer+=1  
  
    def sample(self,batch):
        idx = torch.randint(0,self.pointer,(batch,))
        return (
            self.stor_curr_states[idx].float().to(device=hypers.device),
            self.stor_nx_states[idx].float().to(device=hypers.device),
            self.stor_rewards[idx].unsqueeze(-1).to(device=hypers.device),
            self.stor_dones[idx].float().unsqueeze(-1).to(device=hypers.device),
            self.stor_actions[idx].float().to(device=hypers.device)
        )
    
    def save(self):
        data = {
            "curr_states":self.stor_curr_states.half(),
            "nx_states":self.stor_nx_states.half(),
            "rewards":self.stor_rewards.half(),
            "dones":self.stor_dones.bool(),
            "actions":self.stor_actions.half(),
            "pointer":self.pointer
        }
        torch.save(data,"./data.pth")

In [None]:
from copy import deepcopy
from tqdm import tqdm
from itertools import chain

class main:
    def __init__(self):
        self.actor = policy().to(hypers.device)
        self.q1 = q_network().to(hypers.device)
        self.q2 = q_network().to(hypers.device) 
        self.q1_target = deepcopy(self.q1).to(hypers.device)
        self.q2_target = deepcopy(self.q2).to(hypers.device)
        self.critic_optim = Adam(chain(self.q1.parameters(),self.q2.parameters()),lr=hypers.lr)

        self.entropy_target = -hypers.action_dim
        self.log_alpha = torch.tensor([0.0],requires_grad=True,device=hypers.device)
        self.alpha_optim = Adam([self.log_alpha],lr=hypers.lr)

        self.env = make_env()
        self.buffer = buffer(self.env,self.actor)
    
    def save(self,step):
        check = {
            "actor state":self.actor.state_dict(),
            "actor optim" : self.actor.optim.state_dict(),
            "q1 state":self.q1.state_dict(),
            "q1 target":self.q1_target.state_dict(),
            "q2 state":self.q2.state_dict(),
            "q2 target":self.q2_target.state_dict(),
            "critic optim":self.critic_optim.state_dict(),
            "alpha optim":self.alpha_optim.state_dict(),
            "log_alpha":self.log_alpha
        }
        torch.save(check,f"./{step}.pth")
    
    def load(self,model_path = None,strict=True):
        if model_path is not None:
            print("--- Resuming training")
            check = torch.load(model_path,map_location=hypers.device)
            self.actor.load_state_dict(check["actor state"],strict)
            self.actor.optim.load_state_dict(check["actor optim"])
            self.q1.load_state_dict(check["q1 state"],strict)
            self.q1_target.load_state_dict(check["q1 target"],strict)
            self.q2.load_state_dict(check["q2 state"],strict)
            self.q2_target.load_state_dict(check["q2 target"],strict)
            self.critic_optim.load_state_dict(check["critic optim"])
            self.alpha_optim.load_state_dict(check["alpha optim"])
            self.log_alpha = check["log_alpha"]
    
    def train(self,start=False):
        if start:
            self.load() 
            n = 0 
            alpha = self.log_alpha.exp()
            for traj in tqdm(range(hypers.max_steps-1),total=hypers.max_steps-1):
                self.buffer.step()
                
                if traj%500==0:
                    self.buffer.writter.add_scalar("Main/episodes rewards",self.buffer.epi_reward,traj,new_style=True)
                if self.buffer.pointer > hypers.warmup:
                    states,nx_states,reward,dones,actions = self.buffer.sample(hypers.batch_size) 
    
                    q1 = self.q1(states,actions) 
                    q2 = self.q2(states,actions) 
                    with torch.no_grad():
                        nx_actions,log_nx_actions,_ = self.actor(nx_states)
                        q1_target = self.q1_target(nx_states,nx_actions) 
                        q2_target = self.q2_target(nx_states,nx_actions) 
                        min_q_target = torch.min(q1_target,q2_target) 
                        # bellman backup operator... reward(st|at) + gamma * Q(st,at) - alpha*log policy(at|st))
                        q_target = reward + hypers.gamma * (1-dones) * (min_q_target - alpha * log_nx_actions)
                    critic_loss = F.mse_loss(q1,q_target) + F.mse_loss(q2,q_target)
                    self.critic_optim.zero_grad()
                    critic_loss.backward()
                    self.critic_optim.step()

                    new_action,log_pi,_ = self.actor(states)
                    min_q = torch.min(self.q1(states,new_action),self.q2(states,new_action))
                    policy_loss = ((alpha * log_pi) -  min_q).mean() # alpla * log policy(at|st) - Q(st,at)
                    
                    self.actor.optim.zero_grad()
                    policy_loss.backward()
                    self.actor.optim.step()

                    alpha_loss = -(self.log_alpha*(log_pi+self.entropy_target).detach()).mean()
                    self.alpha_optim.zero_grad()
                    alpha_loss.backward()
                    self.alpha_optim.step()
                    alpha = self.log_alpha.exp()

                    for q1_pars,q1_target_pars in zip(self.q1.parameters(),self.q1_target.parameters()):
                        q1_target_pars.data.mul_(1.0 - hypers.tau).add_(q1_pars.data,alpha=hypers.tau)
                    for q2_pars,q2_target_pars in zip(self.q2.parameters(),self.q2_target.parameters()):
                        q2_target_pars.data.mul_(1.0 - hypers.tau).add_(q2_pars.data,alpha=hypers.tau)
                         
                    if traj != 0 and traj%int(1e5) == 0: 
                        n+=1
                        self.save(n)
                        self.buffer.save() 
                    
                    self.buffer.writter.add_scalar("Main/loss Policy",policy_loss,traj,new_style=True)
                    self.buffer.writter.add_scalar("Main/entropy loss",alpha_loss,traj,new_style=True)
                    self.buffer.writter.add_scalar("Main/alpha value",alpha,traj,new_style=True)
                    self.buffer.writter.add_scalar("Main/critic Loss",critic_loss,traj,new_style=True)
                    self.buffer.writter.add_scalar("Main/action variance",actions.var(),traj,new_style=True)
                    self.buffer.writter.add_scalar("Main/policy loss action variance",new_action.var(),traj,new_style=True)
                    self.buffer.writter.flush()
                    
            self.save("final")
                
main().train(False) # set True for training

In [None]:
class test:
    def __init__(self):
        self.env = self.test_env()
        self.state = self.env.reset()[0]
        self.model = policy()
        self.model.load_state_dict(torch.load("./data/reach/reach_model.pth")["actor state"])

    def test_env(self):
        x = gym.make("FetchReachDense-v3",max_episode_steps=50,render_mode="human")
        x = FetchReachCustom(x)
        x = Autoreset(x)
        return x
    
    def run(self,start=False):
        if start:
            for _ in range(1000):
                _,_,action = self.model(tranform_observation(self.state))
                state,reward,_,_,_ = self.env.step(action.detach())
                self.state = state
                self.env.render()

test().run(start=True)