In [13]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import MultivariateNormal



class ActorCritic(nn.Module):
    def __init__(self, obs_dim, action_dim, action_std_init):
        super(ActorCritic, self).__init__()
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.device = torch.device("cpu")
        
        # Create our variable for the matrix.
        # Note that I chose 0.2 for stdev arbitrarily.
        self.cov_var = torch.full((self.action_dim,), action_std_init)

        # Create the covariance matrix
        self.cov_mat = torch.diag(self.cov_var).unsqueeze(dim=0)

        # actor
        self.actor = nn.Sequential(
                        nn.Linear(self.obs_dim, 500),
                        nn.Tanh(),
                        nn.Linear(500, 300),
                        nn.Tanh(),
                        nn.Linear(300, 100),
                        nn.Tanh(),
                        nn.Linear(100, self.action_dim),
                        nn.Tanh()
                    )
        
        # critic
        self.critic = nn.Sequential(
                        nn.Linear(self.obs_dim, 500),
                        nn.Tanh(),
                        nn.Linear(500, 300),
                        nn.Tanh(),
                        nn.Linear(300, 100),
                        nn.Tanh(),
                        nn.Linear(100, 1)
                    )

    def forward(self, x):
        logits = self.actor(x)
        val = self.critic(x)
        return logits, val

In [14]:
model = ActorCritic(obs_dim= 256, action_dim=2, action_std_init=  0.2)
# Calculate total parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")

Total parameters: 618103


In [18]:
from utils import Untils

util = Untils(device="cuda")
util.action_space_real

[(0.0, 0.0),
 (0.0, 0.1),
 (0.0, 0.2),
 (0.0, 0.3),
 (0.1, 0.0),
 (0.1, 0.1),
 (0.1, 0.2),
 (0.1, 0.3),
 (0.2, 0.0),
 (0.2, 0.1),
 (0.2, 0.2),
 (0.2, 0.3),
 (0.3, 0.0),
 (0.3, 0.1),
 (0.3, 0.2),
 (0.3, 0.3)]