In [None]:
"""
!pip install robosuite
!pip install mujoco
!pip install h5py
!pip install gymnasium==1.2.0
!pip install tqdm
!pip install tensorboard
"""

In [None]:
import robosuite as suite
from robosuite import load_composite_controller_config
from robosuite.wrappers.gym_wrapper import GymWrapper
from gymnasium.wrappers import Autoreset
import torch,sys
from dataclasses import dataclass
from IPython.display import clear_output
import warnings,logging
warnings.filterwarnings("ignore")
logging.disable(logging.CRITICAL)
clear_output()

env_name = "Stack"

@dataclass(frozen=False)
class Hypers:
    ROBOT = "Panda"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    obs_dim = 148   # observation space, dim -1  
    action_dim = 9  # action space for a single env
    batchsize = 256
    lr = 3e-4
    gamma = .99
    alpha = 0.2
    tau = .005
    warmup = 10_000
    max_steps = int(10e6)
    
hypers = Hypers()

cont_config = load_composite_controller_config(robot=hypers.ROBOT)
env_configs = {
    "robots":[hypers.ROBOT],
    "controller_configs": cont_config,
    "gripper_types":["JacoThreeFingerDexterousGripper"],
    "has_renderer":False,
    "use_camera_obs":False,
    "has_offscreen_renderer":False,
    "reward_shaping":True,            # activate dense reward 
    "horizon":500,                    # Max steps before reset or trunc = True
    "control_freq":20,
    "reward_scale":1.0
}
def make_env():
    x = suite.make(env_name =env_name ,**env_configs)
    x = GymWrapper(x,list(x.observation_spec()))
    x.metadata = {"render_mode":[]}
    x = Autoreset(x)
    return x

if env_name == "Lift":
    hypers.obs_dim = 122

In [None]:
from torch import Tensor
import torch.nn.functional as F
import torch.nn as nn
from torch.distributions import Normal
from torch.optim import Adam

class Actor(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(hypers.obs_dim,256)
        self.l2 = nn.Linear(256,256)
        self.lmean = nn.Linear(256,hypers.action_dim)
        self.lstd = nn.Linear(256,hypers.action_dim)
        self.optim = Adam(self.parameters(),hypers.lr)

    def forward(self,obs:Tensor):
        x = F.relu(self.l1(obs))
        x = F.relu(self.l2(x))
        mean = self.lmean(x)
        std = self.lstd(x).clamp(-20,2).exp()
        dist = Normal(mean,std) 
        pre_tanh = dist.rsample()
        action = F.tanh(pre_tanh)
        log = dist.log_prob(pre_tanh)
        # change of variable correction 
        log -= torch.log(1-action.pow(2) + 1e-6)
        log = log.sum(-1,True)  
        return action,log,mean
    
class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(hypers.obs_dim + hypers.action_dim,256)
        self.l2 = nn.Linear(256,256)
        self.output = nn.Linear(256,1)

    def forward(self,obs:Tensor,action:Tensor):
        cat = torch.cat((obs,action),dim=-1)
        x = F.relu(self.l1(cat))
        x = F.relu(self.l2(x))
        x = self.output(x)
        return x

In [None]:
import numpy as np

class buffer: 
    def _init_storage(self,data_path=None,capacity=hypers.max_steps): 
        if data_path is not None:
            self.data = torch.load(data_path,map_location=hypers.device)
        obs_dim = (1,hypers.obs_dim)     
        act_dim = (1,hypers.action_dim) 
        self.stor_curr_states = self.data["curr_states"] if data_path is not None else torch.zeros((capacity,*obs_dim))
        self.stor_nx_states = self.data["nx_states"] if data_path is not None else torch.zeros((capacity,*obs_dim))
        self.stor_rewards = self.data["rewards"] if data_path is not None else torch.zeros((capacity,1,))
        self.stor_dones = self.data["dones"] if data_path is not None else torch.zeros((capacity,1,))
        self.stor_actions = self.data["actions"] if data_path is not None else torch.zeros((capacity,*act_dim))
        self.pointer = self.data["pointer"] if data_path is not None else 0

    def __init__(self,env,policy):
        self._init_storage(data_path=None)
        self.env = env
        self.policy = policy
        self.obs = self.env.reset()[0]
        self.epi_reward = 0
        self.reward = 0
        self.to_tensor = lambda x : torch.from_numpy(np.array(x)).to(hypers.device,dtype=torch.float32)
    
    def store(self,curr_state,nx_state,reward,done,action):
        self.stor_curr_states[self.pointer] = curr_state
        self.stor_nx_states[self.pointer] = nx_state
        self.stor_rewards[self.pointer] = reward
        self.stor_dones[self.pointer] = done
        self.stor_actions[self.pointer]= action

    @torch.no_grad()
    def step(self):
        if self.pointer<hypers.warmup:
            action = self.env.action_space.sample()
        else:
            action,_,_ = self.policy(self.to_tensor(self.obs))
            action = action.squeeze()
        nx_state,reward,done,_,_ = self.env.step(action.tolist())
        self.reward+=reward
        if done:
            self.epi_reward=self.reward
            self.reward=0

        saved_action = (
            torch.from_numpy(np.array(action)).to(torch.float32) if isinstance(action,np.ndarray) else action
        )

        self.store(
            self.to_tensor(self.obs),
            self.to_tensor(nx_state),
            self.to_tensor(reward),
            self.to_tensor(done),
            saved_action
        )
        self.obs = nx_state
        self.pointer+=1  
  
    def sample(self,batch):
        idx = torch.randint(0,self.pointer,(batch,))
        return (
            self.stor_curr_states[idx].float().to(device=hypers.device),
            self.stor_nx_states[idx].float().to(device=hypers.device),
            self.stor_rewards[idx].unsqueeze(-1).to(device=hypers.device),
            self.stor_dones[idx].float().unsqueeze(-1).to(device=hypers.device),
            self.stor_actions[idx].float().to(device=hypers.device)
        )
    
    def save(self):
        data = {
            "curr_states":self.stor_curr_states.half(),
            "nx_states":self.stor_nx_states.half(),
            "rewards":self.stor_rewards.half(),
            "dones":self.stor_dones.bool(),
            "actions":self.stor_actions.half(),
            "pointer":self.pointer
        }
        torch.save(data,"./data.pth")

In [None]:
from copy import deepcopy
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from itertools import chain

class main:
    def __init__(self):
        self.actor = Actor().to(hypers.device)
        self.q1 = Critic().to(hypers.device)
        self.q2 = Critic().to(hypers.device) 
        self.q1_target = deepcopy(self.q1).to(hypers.device)
        self.q2_target = deepcopy(self.q2).to(hypers.device)
        self.critic_optim = Adam(chain(self.q1.parameters(),self.q2.parameters()),lr=hypers.lr)
        self.env = make_env()
        self.buffer = buffer(self.env,self.actor)
        self.writter = SummaryWriter("./")
    
    def save(self,step):
        check = {
            "actor state":self.actor.state_dict(),
            "actor optim" : self.actor.optim.state_dict(),
            "q1 state":self.q1.state_dict(),
            "q1 target":self.q1_target.state_dict(),
            "q2 state":self.q2.state_dict(),
            "q2 target":self.q2_target.state_dict(),
            "critic optim":self.critic_optim.state_dict(),
        }
        torch.save(check,f"./{env_name}_model_{step}.pth")
    
    def load(self,model_path = None,strict=True):
        if model_path is not None:
            check = torch.load(model_path,map_location=hypers.device)
            self.actor.load_state_dict(check["actor state"],strict)
            self.actor.optim.load_state_dict(check["actor optim"])
            self.q1.load_state_dict(check["q1 state"],strict)
            self.q1_target.load_state_dict(check["q1 target"],strict)
            self.q2.load_state_dict(check["q2 state"],strict)
            self.q2_target.load_state_dict(check["q2 target"],strict)
            self.critic_optim.load_state_dict(check["critic optim"])
    
    def train(self,start=False):
        if start:
            self.load(model_path=None) 
            n = 0 # tracking saved model number
            for traj in tqdm(range(hypers.max_steps-1),total=hypers.max_steps-1):
                self.buffer.step()
                if traj%500==0:
                    self.writter.add_scalar("Main/episodes rewards",self.buffer.epi_reward,traj,new_style=True)
                if self.buffer.pointer > hypers.warmup:
                    states,nx_states,reward,dones,actions = self.buffer.sample(hypers.batchsize) 
    
                    q1 = self.q1(states,actions) 
                    q2 = self.q2(states,actions) 
                    with torch.no_grad():
                        nx_actions,log_nx_actions,_ = self.actor(nx_states)
                        q1_target = self.q1_target(nx_states,nx_actions) 
                        q2_target = self.q2_target(nx_states,nx_actions) 
                        min_q_target = torch.min(q1_target,q2_target) 
                        # bellman backup operator... reward(st|at) + gamma * Q(st|at) - alpha*log policy(at|st))
                        q_target = reward + hypers.gamma * (1-dones) * (min_q_target - hypers.alpha * log_nx_actions)
                    critic_loss = F.mse_loss(q1,q_target) + F.mse_loss(q2,q_target)
                    self.critic_optim.zero_grad()
                    critic_loss.backward()
                    self.critic_optim.step()

                    new_action,log_pi,_ = self.actor(states)
                    min_q = torch.min(self.q1(states,new_action),self.q2(states,new_action))
                    policy_loss = ((hypers.alpha * log_pi) -  min_q).mean() # alpla * log policy(at|st) - Q(st|at)
                    self.actor.optim.zero_grad()
                    policy_loss.backward()
                    self.actor.optim.step()

                    for q1_pars,q1_target_pars in zip(self.q1.parameters(),self.q1_target.parameters()):
                        q1_target_pars.data.mul_(1.0 - hypers.tau).add_(q1_pars.data,alpha=hypers.tau)
                        
                    for q2_pars,q2_target_pars in zip(self.q2.parameters(),self.q2_target.parameters()):
                        q2_target_pars.data.mul_(1.0 - hypers.tau).add_(q2_pars.data,alpha=hypers.tau)
                         
                    if traj != 0 and traj%int(1e5) == 0: 
                        n+=1
                        self.save(n)
                        self.buffer.save()
                    
                    self.writter.add_scalar("Main/loss Policy",policy_loss,traj,new_style=True)
                    self.writter.add_scalar("Main/critic Loss",critic_loss,traj,new_style=True)
                    self.writter.add_scalar("Main/action variance",actions.var(),traj,new_style=True)
                    self.writter.add_scalar("Main/policy loss action variance",new_action.var(),traj,new_style=True)
                    self.writter.flush()
                
main().train(True)