In [1]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.a2c import A2C

In [2]:
class Actor(nn.Module):
    def __init__(self, nstates: int, nactions: int):
        super(Actor, self).__init__()
                
        self.fc1 = nn.Linear(nstates, 256)
        self.fc2 = nn.Linear(256, nactions)

        self.init_weight()
    
    def init_weight(self):
        initrange = 0.1
        self.fc1.weight.data.uniform_(-initrange, initrange)
        self.fc2.weight.data.uniform_(-initrange, initrange)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=-1)
        return x

In [3]:
class Critic(nn.Module):
    def __init__(self, nstates: int):
        super(Critic, self).__init__()
                
        self.fc1 = nn.Linear(nstates, 256)
        self.fc2 = nn.Linear(256, 1)
        
        self.relu1 = nn.ReLU()

        self.init_weight()
    
    def init_weight(self):
        initrange = 0.1
        self.fc1.weight.data.uniform_(-initrange, initrange)
        self.fc2.weight.data.uniform_(-initrange, initrange)
                
    def forward(self, x):

        x = self.relu1(self.fc1(x))
        x = self.fc2(x)
        
        return x

In [4]:
lr = 0.0001

nactions = 2
nstates = 4

gamma = 0.98

n_epi = 5000
n_rollout = 10
n_update = 3

In [5]:
act = Actor(nstates, nactions)
act_opt = torch.optim.Adam(act.parameters(), lr=lr)

cri = Critic(nstates)
cri_opt = torch.optim.Adam(cri.parameters(), lr=lr)

In [6]:
env = gym.make('CartPole-v1', render_mode='human')

In [7]:
model = A2C(env, nactions, act, act_opt, cri, cri_opt, gamma=gamma)
model.train(n_epi, n_rollout, 3)

  if not isinstance(terminated, (bool, np.bool8)):


step: 436, score: 21.8
step: 849, score: 20.65
step: 1320, score: 23.55
step: 1699, score: 18.95
step: 2221, score: 26.1
step: 2716, score: 24.75
step: 3180, score: 23.2
step: 3639, score: 22.95
step: 4174, score: 26.75
step: 4753, score: 28.95
step: 5542, score: 39.45
step: 6100, score: 27.9
step: 6895, score: 39.75
step: 7551, score: 32.8
step: 8429, score: 43.9
step: 9451, score: 51.1
step: 10486, score: 51.75
step: 11517, score: 51.55
step: 12505, score: 49.4
step: 13465, score: 48.0
step: 14613, score: 57.4
step: 15828, score: 60.75
step: 16987, score: 57.95
step: 18240, score: 62.65
step: 19524, score: 64.2
step: 20971, score: 72.35
step: 22653, score: 84.1
step: 24844, score: 109.55


KeyboardInterrupt: 

In [None]:
env.close()