In [4]:
import torch
import torch.nn as nn
import numpy as np

In [40]:
class MLPCritic(nn.Module):
    def __init__(self, obs_dim, act_dim, hidden_sizes = [256, 256]):
        super(MLPCritic, self).__init__()

        self.obs_dim = obs_dim
        self.act_dim = act_dim
        self.layer_sizes = [obs_dim+act_dim] + hidden_sizes + [1]

        self.layers = nn.ModuleList()
        # Hidden layers
        for h_i in range(len(self.layer_sizes)-2):
            self.layers += [nn.Linear(self.layer_sizes[h_i], self.layer_sizes[h_i+1]),
                            nn.Sigmoid()]
        # Output layer
        self.layers += [nn.Linear(self.layer_sizes[-2], self.layer_sizes[-1]),
                        nn.ReLU()]

    def forward(self, obs, act):
        x = torch.cat([obs, act], dim=-1)
        hidden_output = []
        hid_activation = []
        # Hidden layers
        for h_i in range(len(self.layers)-1):
            x = self.layers[h_i](x)
            # Store activation
            if h_i % 2 == 1:
                print(h_i)
                hid_activation.append(x)
        return x, hid_activation

    
class MLPActor(nn.Module):
    def __init__(self, obs_dim, act_dim, act_limit, hidden_sizes=[256, 256]):
        super(MLPActor, self).__init__()
        self.obs_dim = obs_dim
        self.act_dim = act_dim
        self.act_limit = act_limit
        self.layer_sizes = [obs_dim] + hidden_sizes + [act_dim]
        
        self.layers = nn.ModuleList()
        # Hidden layers
        for h_i in range(len(self.layer_sizes)-2):
            self.layers += [nn.Linear(self.layer_sizes[h_i], self.layer_sizes[h_i+1]),
                            nn.Sigmoid()]
        # Output layer
        self.layers += [nn.Linear(self.layer_sizes[-2], self.layer_sizes[-1]),
                        nn.Tanh()]

    def forward(self, obs):
        x = obs
        hidden_output = []
        hid_activation = []
        # Hidden layers
        for h_i in range(len(self.layers)-1):
            x = self.layers[h_i](x)
            # Store activation
            if h_i % 2 == 1:
                print(h_i)
                hid_activation.append(x)
        return self.act_limit * x, hid_activation
    
    
class MLPActorCritic(nn.Module):
    def __init__(self, obs_dim, act_dim, act_limit, critic_hidden_sizes=[256, 256], actor_hidden_sizes=[256, 256]):
        super(MLPActorCritic, self).__init__()
        self.q1 = MLPCritic(obs_dim, act_dim, critic_hidden_sizes)
        self.q2 = MLPCritic(obs_dim, act_dim, critic_hidden_sizes)
        self.pi = MLPActor(obs_dim, act_dim, act_limit, actor_hidden_sizes)
        
    def act(self, obs):
        with torch.no_grad():
            a, _ = self.pi(obs)
            return a.cpu().numpy()

In [41]:
obs_dim = 15
act_dim = 5
critic = MLPCritic(obs_dim, act_dim)
actor = MLPActor(obs_dim, act_dim)

In [43]:
m = 100
obs = torch.as_tensor(np.random.rand(m, obs_dim), dtype=torch.float32)
act = torch.as_tensor(np.random.rand(m, act_dim), dtype=torch.float32)
q, q_hid_activation = critic(obs, act)
a, a_hid_activation = actor(obs)

1
3
1
3


torch.Size([100, 256])

In [27]:
hid_a = hid_activation[0]
hid_a.mean(axis=0)

tensor([0.6067, 0.5047, 0.4744, 0.4898, 0.5349, 0.5409, 0.4726, 0.5370, 0.4718,
        0.3470, 0.3877, 0.5362, 0.3976, 0.5017, 0.3958, 0.5525, 0.5401, 0.3854,
        0.5290, 0.5130, 0.4945, 0.4786, 0.5506, 0.4881, 0.5183, 0.3692, 0.6020,
        0.3939, 0.5712, 0.5318, 0.6325, 0.5938, 0.4832, 0.5813, 0.3648, 0.4099,
        0.5243, 0.5920, 0.5059, 0.5398, 0.3487, 0.5012, 0.5444, 0.4800, 0.6469,
        0.4534, 0.4393, 0.5691, 0.4726, 0.5828, 0.4657, 0.5383, 0.4352, 0.3629,
        0.5125, 0.6026, 0.4704, 0.5292, 0.4727, 0.4528, 0.4770, 0.5008, 0.4778,
        0.5587, 0.5814, 0.5431, 0.4608, 0.6139, 0.6130, 0.4333, 0.4183, 0.3449,
        0.6047, 0.3676, 0.5674, 0.5013, 0.5519, 0.5632, 0.4392, 0.3591, 0.4591,
        0.5530, 0.6337, 0.3952, 0.4739, 0.4734, 0.4242, 0.4878, 0.5078, 0.6197,
        0.4293, 0.6240, 0.4365, 0.4395, 0.5399, 0.5173, 0.6532, 0.5311, 0.3310,
        0.5044, 0.5077, 0.3787, 0.4826, 0.5108, 0.5040, 0.4619, 0.4056, 0.5470,
        0.5123, 0.4193, 0.4274, 0.5585, 

In [None]:
rho = torch.as_tensor(0.05, dtype=torch.float32)
beta = 0.5
sparsity_penalty = torch.nn.functional.kl_div(torch.cat(hid_activation, dim=1).mean(axis=0), rho, reduction='sum')

tensor(-89.3844, grad_fn=<KlDivBackward>)