In [None]:
#!pip install gymnasium-robotics==1.3.1 
#!pip install gymnasium

In [None]:
import gymnasium as gym
import gymnasium_robotics
gym.register_envs(gymnasium_robotics)
import numpy as np
from dataclasses import dataclass

class custom(gym.Wrapper):
    def __init__(self,env):
        super().__init__(env)
    
    def reset(self,**kwargs):
        obs,info = super().reset(**kwargs)
        self.env.unwrapped.data.qpos[0] = .3 # robot base x pos
        self.env.unwrapped.data.qpos[1] = .5 # robot base y pos
        # self.env.unwrapped.data.qpos[15]   # block's x pos
        # self.env.unwrapped.data.qpos[16]   # block's y pos
        self.env.unwrapped.data.qpos[17] = .4
        return obs,info

    def step(self,action):
        return super().step(action)

def process_obs(obs:dict):
    observation = obs.get("observation")[:9]
    achieved_goal = obs.get("achieved_goal")
    desired_goal = obs.get("desired_goal")
    return np.append(observation,(achieved_goal,desired_goal))

@dataclass()
class Hypers:
    lr = 3e-4
    action_dim = 4
    obs_dim = 15

hypers = Hypers()

In [None]:
import torch
from torch import Tensor
import torch.nn as nn
from torch.distributions import Normal
from torch.optim import Adam
import torch.nn.functional as F

class policy(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(hypers.obs_dim,256)
        self.l2 = nn.Linear(256,256)
        self.mean = nn.Linear(256,hypers.action_dim)
        self.std = nn.Linear(256,hypers.action_dim)
        self.optim = Adam(self.parameters(),lr=hypers.lr)
        self.apply(self.weights_init)

    def forward(self):
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        mean = self.mean(x)
        std = self.std(x).clamp(-20,2).exp()
        dist = Normal(mean,std)
        pretanh = dist.rsample()
        action = F.tanh(pretanh)

        log = dist.log_prob(pretanh)
        log -= torch.log(1-action.pow(2) + 1e-6)
        log = log.sum(1,True)
        return action,log,mean

    def weights_init(self,layer):
        if isinstance(layer,nn.Linear):
            torch.nn.init.orthogonal_(layer.weight)
            torch.nn.init.constant_(layer.bias,0.0)
        
class Qnetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Linear(hypers.obs_dim+hypers.action_dim,256)
        self.l2 = nn.Linear(256,256)
        self.l3 = nn.Linear(256,1)
        self.optim = Adam(self.parameters(),lr=hypers.lr)
        self.apply(self.weights_init)
    
    def forward(self,obs:Tensor,action:Tensor):
        x = torch.cat(obs,action,dim=-1)
        x = F.relu(self.l1(x))
        x = F.relu(self.l2(x))
        x = F.relu(self.l3(x))
        return x 
    
    def weights_init(self,layer):
        if isinstance(layer,nn.Linear):
            torch.nn.init.orthogonal_(layer.weight)
            torch.nn.init.constant_(layer.bias,0.0)