In [59]:
import torch
import torch.nn as nn
from torch.distributions import Normal

def mlp(input_size, hidden_sizes=(64, 64), activation="tanh"):
    if activation == "tanh":
        activation = nn.Tanh()
    elif activation == "relu":
        activation = nn.ReLU()
    elif activation == "sigmoid":
        activation = nn.Sigmoid()
    else:
        raise NotImplementedError(f"Activation {activation} is not supported")
    
    layers = []
    sizes = (input_size, ) + hidden_sizes
    for i in range(len(hidden_sizes)):
        layers += [nn.Linear(sizes[i], sizes[i+1]), activation]
    return nn.Sequential(*layers)

class GaussianPolicy(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_sizes=(64, 64), activation="tanh", device="cpu"):
        super().__init__()
        
        self.obs_dim = obs_dim
        self.act_dim = act_dim
        self.device = device
        
        self.mlp_net = mlp(obs_dim, hidden_sizes, activation)
        self.mean_layer = nn.Linear(hidden_sizes[-1], act_dim)
        self.logstd_layer = nn.Linear(hidden_sizes[-1], act_dim)
        
        self.mean_layer.weight.data.mul_(0.1)
        self.mean_layer.bias.data.mul_(0.0)
        self.to(device)
    
    def forward(self, obs):
        out = self.mlp_net(obs)
        mean = self.mean_layer(out)
        if len(mean.size()) == 1:
            mean = mean.view(1, -1)
        logstd = self.logstd_layer(out)
        std = torch.exp(logstd)
        return mean, logstd, std
    
    def get_act(self, obs, deterministic=False):
        mean, _, std = self.forward(obs)
        if deterministic:
            return mean
        else:
            return torch.normal(mean, std)
    
    # Agent interface
    def select_action(self, obs):
        x = torch.FloatTensor(obs).to(self.device)
        out = self.get_act(x)
        action = out.cpu().detach().numpy()
        return action[0]
    
    def logprob(self, obs, act):
        mean, _, std = self.forward(obs)
        normal = Normal(mean, std)
        return normal.log_prob(act).sum(-1, keepdim=True), mean, std

In [73]:
import gym
import safety_gym
import numpy as np

env_name = 'Safexp-PointGoal1-v0'
env = gym.make(env_name)
max_ep_length = 1000

obs_shape = env.observation_space.shape
act_shape = env.action_space.shape



In [79]:
policy = GaussianPolicy(obs_shape[0], act_shape[0])
obs = env.reset()
obs = torch.FloatTensor(obs)
act = policy.get_act(obs)
lp, _, _ = policy.logprob(obs, act)
print(act, lp)

tensor([[1.2676, 0.9684]], grad_fn=<NormalBackward3>) tensor([[-3.3132]], grad_fn=<SumBackward1>)
