In [24]:
import torch 
from torch import nn 
from torch.nn.init import xavier_uniform_
import numpy as scinp

In [25]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using {} device".format(device))

Using cuda:0 device


In [71]:
class Critic(nn.Module):
    
    def __init__(self,data_dimension,hidden_dimensionality=[16,8]):
        super(Critic, self).__init__()
        self.flatten=nn.Flatten()
        self.layer_dimensions=[data_dimension]+hidden_dimensionality+[1]
        
        layers=[]
        
        for i in range(len(self.layer_dimensions)-1):
            in_dim=self.layer_dimensions[i]
            out_dim=self.layer_dimensions[i+1]
            linear_layer=nn.Linear(in_dim,out_dim)
            xavier_uniform_(linear_layer.weight)
            layers.append(linear_layer)
            
        self.network=nn.Sequential(*layers)
        
    def forward(self,x):
        x=self.flatten(x)
        critic_output=self.network(x)
        return critic_output

In [72]:
critic=Critic(4).to(device)
critic

Critic(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (network): Sequential(
    (0): Linear(in_features=4, out_features=16, bias=True)
    (1): Linear(in_features=16, out_features=8, bias=True)
    (2): Linear(in_features=8, out_features=1, bias=True)
  )
)

In [73]:
x=torch.Tensor(1,4).uniform_(-1,1).to(device)
x

tensor([[-0.0841, -0.2739,  0.6447, -0.8764]], device='cuda:0')

In [74]:
critic(x)

tensor([[-0.1063]], device='cuda:0', grad_fn=<AddmmBackward>)