In [1]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F


from models.reinforce import Reinforce

In [2]:
class Policy(nn.Module):
    def __init__(self, n_sts: int, n_acts: int, dim: int):
        super(Policy, self).__init__()
                
        self.fc1 = nn.Linear(n_sts, dim)
        self.fc2 = nn.Linear(dim, n_acts)

        self.init_weight()
    
    def init_weight(self):
        initrange = 0.1
        self.fc1.weight.data.uniform_(-initrange, initrange)
        self.fc2.weight.data.uniform_(-initrange, initrange)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=-1)
        return x

In [5]:
lr = 0.0002

n_acts = 2
n_sts = 4

gamma = 0.98

dim = 256

n_epi = 5000

In [6]:
pi = Policy(n_sts, n_acts, dim)
pi_opt = torch.optim.Adam(pi.parameters(), lr=lr)


In [7]:
env = gym.make('CartPole-v1', render_mode='human')

In [8]:
model = Reinforce(env, n_acts, pi, pi_opt, gamma=gamma)
model.train(n_epi)

  if not isinstance(terminated, (bool, np.bool8)):


step: 430, score: 21.5
step: 928, score: 24.9
step: 1387, score: 22.95
step: 1820, score: 21.65
step: 2250, score: 21.5
step: 2648, score: 19.9
step: 3128, score: 24.0
step: 3550, score: 21.1
step: 4102, score: 27.6
step: 4644, score: 27.1
step: 5186, score: 27.1
step: 5685, score: 24.95
step: 6183, score: 24.9
step: 6895, score: 35.6
step: 7494, score: 29.95
step: 8127, score: 31.65
step: 8803, score: 33.8
step: 9263, score: 23.0
step: 9933, score: 33.5
step: 10620, score: 34.35
step: 11222, score: 30.1
step: 11851, score: 31.45
step: 12522, score: 33.55
step: 13098, score: 28.8
step: 13723, score: 31.25
step: 14475, score: 37.6
step: 15295, score: 41.0
step: 16020, score: 36.25
step: 16624, score: 30.2
step: 17327, score: 35.15
step: 18269, score: 47.1
step: 18937, score: 33.4
step: 19772, score: 41.75
step: 20597, score: 41.25
step: 21335, score: 36.9
step: 22268, score: 46.65
step: 23227, score: 47.95
step: 24101, score: 43.7
step: 25218, score: 55.85
step: 26142, score: 46.2
step:

KeyboardInterrupt: 

In [9]:
env.close()