In [None]:
import torch
import torch.nn as nn
from torch.distributions import TanhTransform
import torch.nn.functional as F

def initialize_weights(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        nn.init.kaiming_uniform_(m.weight.data, nonlinearity="relu")
        nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0)

def build_network(input_size, hidden_size, num_layers, activation, output_size):
    assert num_layers >= 2, "num_layers must be at least 2"
    activation = getattr(nn, activation)()
    layers = []
    layers.append(nn.Linear(input_size, hidden_size))
    layers.append(activation)

    for i in range(num_layers - 2):
        layers.append(nn.Linear(hidden_size, hidden_size))
        layers.append(activation)

    layers.append(nn.Linear(hidden_size, output_size))

    network = nn.Sequential(*layers)
    network.apply(initialize_weights)
    return network

def create_normal_dist(
    x,
    std=None,
    mean_scale=1,
    init_std=0,
    min_std=0.1,
    activation=None,
    event_shape=None,
):
    if std == None:
        mean, std = torch.chunk(x, 2, -1)
        mean = mean / mean_scale
        if activation:
            mean = activation(mean)
        mean = mean_scale * mean
        std = F.softplus(std + init_std) + min_std
    else:
        mean = x
    dist = torch.distributions.Normal(mean, std)
    if event_shape:
        dist = torch.distributions.Independent(dist, event_shape)
    return dist

class Actor(nn.Module):
    def __init__(self, inputSize, actionSize):
        super().__init__()

        actionSize *= 2
        self.network = sequentialModel1D(inputSize, [256, 256], actionSize)

    def forward(self, posterior, deterministic):
        x = torch.cat((posterior, deterministic), -1)
        x = self.network(x)
        dist = create_normal_dist(
            x,
            mean_scale=5,
            init_std=5,
            min_std=0.0001,
            activation=torch.tanh,
        )
        entropy = dist.entropy()
        dist = torch.distributions.TransformedDistribution(dist, TanhTransform())
        action = torch.distributions.Independent(dist, 1).rsample()
        return action, dist.log_prob(action).sum(-1), entropy.sum(-1)

In [12]:
def attrdict_monkeypatch_fix():
    import collections
    import collections.abc
    for type_name in collections.abc.__all__:
            setattr(collections, type_name, getattr(collections.abc, type_name))
attrdict_monkeypatch_fix()

import torch
from attrdict import AttrDict
from torch.distributions import TanhTransform

# Dummy Configurations
config = AttrDict({
    "parameters": AttrDict({
        "dreamer": AttrDict({
            "agent": AttrDict({
                "actor": AttrDict({
                    "hidden_size": 128,
                    "num_layers": 3,
                    "activation": "ReLU",
                    "mean_scale": 5.0,
                    "init_std": 0.0,
                    "min_std": 0.1,
                })
            }),
            "stochastic_size": 30,
            "deterministic_size": 200,
        })
    })
})

# Actor Initialization
action_size = 5  # Number of actions
discrete_action_bool = False  # Test for continuous actions
actor = Actor(discrete_action_bool, action_size, config)

# Inputs to Actor
batch_size = 10
posterior = torch.randn(batch_size, config.parameters.dreamer.stochastic_size)
deterministic = torch.randn(batch_size, config.parameters.dreamer.deterministic_size)

# Forward Pass
actions, logprobs, entropy = actor(posterior, deterministic)

# Display Results
print("Action Shape:", actions.shape)
print("Sample Action:", actions[0])
print(f"logprobs: {logprobs} of shape {logprobs.shape}")
print(f"entropy: {entropy} of shape {entropy.shape}")


Action Shape: torch.Size([10, 5])
Sample Action: tensor([-0.2095,  0.8946, -0.7467,  0.3338,  0.3915],
       grad_fn=<SelectBackward0>)
logprobs: tensor([ 1.1874, 15.0395, -1.8972,  5.8088,  1.0603,  5.1731, -1.8401, -0.7238,
         1.7265,  0.7875], grad_fn=<SumBackward1>) of shape torch.Size([10])
entropy: tensor([3.1522, 5.0611, 4.2803, 3.6965, 7.3899, 4.6184, 6.2805, 4.3084, 4.7270,
        4.4729], grad_fn=<SumBackward1>) of shape torch.Size([10])
