In [14]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.a2c import A2C

In [15]:
class Actor(nn.Module):
    def __init__(self, n_sts: int, n_acts: int, dim: int):
        super(Actor, self).__init__()
                
        self.fc1 = nn.Linear(n_sts, dim)
        self.fc2 = nn.Linear(dim, n_acts)

        self.init_weight()
    
    def init_weight(self):
        initrange = 0.1
        self.fc1.weight.data.uniform_(-initrange, initrange)
        self.fc2.weight.data.uniform_(-initrange, initrange)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=-1)
        return x

In [16]:
class Critic(nn.Module):
    def __init__(self, n_sts: int, dim: int):
        super(Critic, self).__init__()
                
        self.fc1 = nn.Linear(n_sts, dim)
        self.fc2 = nn.Linear(dim, 1)
        
        self.relu1 = nn.ReLU()

        self.init_weight()
    
    def init_weight(self):
        initrange = 0.1
        self.fc1.weight.data.uniform_(-initrange, initrange)
        self.fc2.weight.data.uniform_(-initrange, initrange)
                
    def forward(self, x):

        x = self.relu1(self.fc1(x))
        x = self.fc2(x)
        
        return x

In [17]:
lr = 0.0001

n_acts = 2
n_sts = 4

gamma = 0.98

n_epi = 2000
n_rollout = 10
n_update = 3

dim = 256

In [18]:
act = Actor(n_sts, n_acts, dim)
act_opt = torch.optim.Adam(act.parameters(), lr=lr)

cri = Critic(n_sts, dim)
cri_opt = torch.optim.Adam(cri.parameters(), lr=lr)

In [19]:
env = gym.make('CartPole-v1', render_mode='human')

In [20]:
model = A2C(env, n_acts, act, act_opt, cri, cri_opt, gamma=gamma)
model.train(n_epi, n_rollout, 3)

step: 416, score: 20.8
step: 876, score: 23.0
step: 1366, score: 24.5
step: 1748, score: 19.1
step: 2160, score: 20.6
step: 2617, score: 22.85
step: 3080, score: 23.15
step: 3528, score: 22.4
step: 4096, score: 28.4
step: 4662, score: 28.3
step: 5159, score: 24.85
step: 5811, score: 32.6
step: 6331, score: 26.0
step: 7012, score: 34.05
step: 7926, score: 45.7
step: 8706, score: 39.0
step: 9551, score: 42.25
step: 10374, score: 41.15
step: 11280, score: 45.3
step: 12388, score: 55.4
step: 13321, score: 46.65
step: 14393, score: 53.6
step: 15500, score: 55.35
step: 16702, score: 60.1
step: 17846, score: 57.2
step: 19058, score: 60.6
step: 20427, score: 68.45
step: 21662, score: 61.75
step: 23446, score: 89.2
step: 25432, score: 99.3
step: 27975, score: 127.15
step: 30571, score: 129.8
step: 33441, score: 143.5
step: 37386, score: 197.25
step: 41488, score: 205.1
step: 46185, score: 234.85
step: 49772, score: 179.35
step: 53528, score: 187.8
step: 57118, score: 179.5
step: 60839, score: 1

KeyboardInterrupt: 

In [21]:
env.close()