In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [20]:
class ActorCritic(nn.Module):
    def __init__(self, state_size, action_size, seed = 0):
        super(ActorCritic, self).__init__()
        self.hidden = 32
        
        self.actor = nn.Sequential(
                        nn.Linear(state_size, self.hidden),
                        nn.Tanh(),
                        nn.Linear(self.hidden, self.hidden),
                        nn.Tanh(),
                        nn.Linear(self.hidden, action_size),
                        nn.Tanh())
        
        self.critic = nn.Sequential(
                        nn.Linear(state_size, self.hidden),
                        nn.Tanh(),
                        nn.Linear(self.hidden, self.hidden),
                        nn.Tanh(),
                        nn.Linear(self.hidden, 1),
                        nn.Tanh())
        
        self.std = nn.Parameter(torch.zeros(action_size))  #returns an array of action_size for continous zero
    
    def forward(self, obs):
        mean = self.actor(obs)
        v = self.critic(obs)
        dist = torch.distributions.Normal(mean, F.softplus(self.std))
        return (v,dist)
        
actorcritic = ActorCritic(6,3).to(device)
print(actorcritic)

ActorCritic(
  (actor): Sequential(
    (0): Linear(in_features=6, out_features=32, bias=True)
    (1): Tanh()
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): Tanh()
    (4): Linear(in_features=32, out_features=3, bias=True)
    (5): Tanh()
  )
  (critic): Sequential(
    (0): Linear(in_features=6, out_features=32, bias=True)
    (1): Tanh()
    (2): Linear(in_features=32, out_features=32, bias=True)
    (3): Tanh()
    (4): Linear(in_features=32, out_features=1, bias=True)
    (5): Tanh()
  )
)


In [25]:
import gym
env = gym.make('Acrobot-v1')
env.seed(0)
print('State shape: ', env.observation_space.shape)
print('Number of actions: ', env.action_space.n)

State shape:  (6,)
Number of actions:  3


In [15]:
d[0]

tensor(0., grad_fn=<SelectBackward>)