In [2]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.a2c import A2C

In [3]:
class Actor(nn.Module):
    def __init__(self, n_sts: int, n_acts: int, dim: int):
        super(Actor, self).__init__()
                
        self.fc1 = nn.Linear(n_sts, dim)
        self.fc2 = nn.Linear(dim, n_acts)

        self.init_weight()
    
    def init_weight(self):
        initrange = 0.1
        self.fc1.weight.data.uniform_(-initrange, initrange)
        self.fc2.weight.data.uniform_(-initrange, initrange)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=-1)
        return x

In [5]:
class Critic(nn.Module):
    def __init__(self, n_sts: int, dim: int):
        super(Critic, self).__init__()
                
        self.fc1 = nn.Linear(n_sts, dim)
        self.fc2 = nn.Linear(dim, 1)
        
        self.relu1 = nn.ReLU()

        self.init_weight()
    
    def init_weight(self):
        initrange = 0.1
        self.fc1.weight.data.uniform_(-initrange, initrange)
        self.fc2.weight.data.uniform_(-initrange, initrange)
                
    def forward(self, x):

        x = self.relu1(self.fc1(x))
        x = self.fc2(x)
        
        return x

In [6]:
lr = 0.0001

n_acts = 2
n_sts = 4

gamma = 0.98

n_epi = 5000
n_rollout = 10
n_update = 3

dim = 256

In [7]:
act = Actor(n_sts, n_acts, dim)
act_opt = torch.optim.Adam(act.parameters(), lr=lr)

cri = Critic(n_sts, dim)
cri_opt = torch.optim.Adam(cri.parameters(), lr=lr)

In [8]:
env = gym.make('CartPole-v1', render_mode='human')

In [9]:
model = A2C(env, n_acts, act, act_opt, cri, cri_opt, gamma=gamma)
model.train(n_epi, n_rollout, 3)

  if not isinstance(terminated, (bool, np.bool8)):


step: 427, score: 21.35
step: 801, score: 18.7
step: 1175, score: 18.7


In [None]:
env.close()