In [1]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.a2c import A2C

In [2]:
class Actor(nn.Module):
    def __init__(self, nstates: int, nactions: int):
        super(Actor, self).__init__()
                
        self.fc1 = nn.Linear(nstates, 256)
        self.fc2 = nn.Linear(256, nactions)

        self.init_weight()
    
    def init_weight(self):
        initrange = 0.1
        self.fc1.weight.data.uniform_(-initrange, initrange)
        self.fc2.weight.data.uniform_(-initrange, initrange)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=-1)
        return x

In [3]:
class Critic(nn.Module):
    def __init__(self, nstates: int):
        super(Critic, self).__init__()
                
        self.fc1 = nn.Linear(nstates, 256)
        self.fc2 = nn.Linear(256, 1)
        
        self.relu1 = nn.ReLU()

        self.init_weight()
    
    def init_weight(self):
        initrange = 0.1
        self.fc1.weight.data.uniform_(-initrange, initrange)
        self.fc2.weight.data.uniform_(-initrange, initrange)
                
    def forward(self, x):

        x = self.relu1(self.fc1(x))
        x = self.fc2(x)
        
        return x

In [4]:
lr = 0.0001

nactions = 2
nstates = 4

gamma = 0.98

n_epi = 5000
n_rollout = 10

In [5]:
act = Actor(nstates, nactions)
act_opt = torch.optim.Adam(act.parameters(), lr=lr)

cri = Critic(nstates)
cri_opt = torch.optim.Adam(cri.parameters(), lr=lr)

In [6]:
env = gym.make('CartPole-v1', render_mode='human')

In [7]:
model = A2C(env, nactions, act, act_opt, cri, cri_opt, gamma=gamma)
model.train(n_epi, n_rollout)

  if not isinstance(terminated, (bool, np.bool8)):


epi: 20, score: 20.6
epi: 40, score: 25.65
epi: 60, score: 25.05
epi: 80, score: 21.1
epi: 100, score: 21.55
epi: 120, score: 22.5
epi: 140, score: 22.9
epi: 160, score: 23.25
epi: 180, score: 19.75
epi: 200, score: 22.75
epi: 220, score: 25.05
epi: 240, score: 22.75
epi: 260, score: 21.95
epi: 280, score: 22.1
epi: 300, score: 25.55
epi: 320, score: 22.0
epi: 340, score: 21.8
epi: 360, score: 24.5
epi: 380, score: 23.8
epi: 400, score: 20.8
epi: 420, score: 22.75
epi: 440, score: 26.4
epi: 460, score: 20.05
epi: 480, score: 25.25
epi: 500, score: 20.4
epi: 520, score: 28.35
epi: 540, score: 25.6
epi: 560, score: 31.6
epi: 580, score: 29.45
epi: 600, score: 29.05
epi: 620, score: 33.45
epi: 640, score: 36.7
epi: 660, score: 33.35
epi: 680, score: 29.45
epi: 700, score: 29.7
epi: 720, score: 36.0
epi: 740, score: 33.5
epi: 760, score: 30.35
epi: 780, score: 36.95
epi: 800, score: 48.05
epi: 820, score: 38.8
epi: 840, score: 43.25
epi: 860, score: 35.8
epi: 880, score: 49.1
epi: 900, sco

KeyboardInterrupt: 

In [8]:
env.close()