In [1]:
import gym

import torch
from torch import nn
import torch.nn.functional as F

from models.ppo import PPO

In [2]:
class Actor(nn.Module):
    def __init__(self, nstats: int, nactions: int, hid_dim:int):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(nstats, hid_dim)
        self.fc2 = nn.Linear(hid_dim, nactions)
        self.relu1 = nn.ReLU()
        
    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = F.softmax(self.fc2(x), -1)
        return x


In [3]:
class Critic(nn.Module):
    def __init__(self, nstats: int, hid_dim:int):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(nstats, hid_dim)
        self.fc2 = nn.Linear(hid_dim, 1)
        self.relu1 = nn.ReLU()
        
    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = self.fc2(x)
        return x


In [6]:
learning_rate = 0.0005
nstates = 4
nactions = 2
hid_dim = 128
gamma = 0.98
gae = 0.95

In [7]:
env = gym.make("CartPole-v1", render_mode='human')

In [8]:
act = Actor(nstates, nactions, hid_dim)
act_opt = torch.optim.Adam(act.parameters(), lr=learning_rate)

cri = Critic(nstates, hid_dim)
cri_opt = torch.optim.Adam(cri.parameters(), lr=learning_rate)

In [9]:
ppo = PPO(env, 2, act, act_opt, cri, cri_opt)

In [10]:
ppo.train(5000, 10)

  if not isinstance(terminated, (bool, np.bool8)):


epi: 20, score: 22.3
epi: 40, score: 26.9
epi: 60, score: 23.85
epi: 80, score: 21.3
epi: 100, score: 21.05
epi: 120, score: 21.6
epi: 140, score: 22.85
epi: 160, score: 25.2
epi: 180, score: 29.75
epi: 200, score: 31.75
epi: 220, score: 34.4
epi: 240, score: 29.5
epi: 260, score: 33.2
epi: 280, score: 34.85
epi: 300, score: 36.15
epi: 320, score: 46.3
epi: 340, score: 50.5
epi: 360, score: 48.85
epi: 380, score: 64.65
epi: 400, score: 79.5
epi: 420, score: 77.25
epi: 440, score: 103.1
epi: 460, score: 112.0


KeyboardInterrupt: 

In [11]:
env.close()