In [1]:
import gym

import torch
from torch import nn
import torch.nn.functional as F

from models.ppo import PPO

In [2]:
class Actor(nn.Module):
    def __init__(self, nstats: int, nactions: int, hid_dim:int):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(nstats, hid_dim)
        self.fc2 = nn.Linear(hid_dim, nactions)
        self.relu1 = nn.ReLU()
        
    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = F.softmax(self.fc2(x), -1)
        return x


In [3]:
class Critic(nn.Module):
    def __init__(self, nstats: int, hid_dim:int):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(nstats, hid_dim)
        self.fc2 = nn.Linear(hid_dim, 1)
        self.relu1 = nn.ReLU()
        
    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = self.fc2(x)
        return x


In [4]:
learning_rate = 0.0005
nstates = 4
nactions = 2
hid_dim = 128
gamma = 0.98
gae = 0.95

In [5]:
env = gym.make("CartPole-v1", render_mode='human')

In [6]:
act = Actor(nstates, nactions, hid_dim)
act_opt = torch.optim.Adam(act.parameters(), lr=learning_rate)

cri = Critic(nstates, hid_dim)
cri_opt = torch.optim.Adam(cri.parameters(), lr=learning_rate)

In [7]:
ppo = PPO(env, 2, act, act_opt, cri, cri_opt)

In [8]:
ppo.train(5000, 10)

  if not isinstance(terminated, (bool, np.bool8)):


epi: 20, score: 22.05
epi: 40, score: 20.6
epi: 60, score: 18.25
epi: 80, score: 24.45
epi: 100, score: 18.85
epi: 120, score: 19.1
epi: 140, score: 20.9
epi: 160, score: 18.25
epi: 180, score: 30.0
epi: 200, score: 34.9
epi: 220, score: 43.05
epi: 240, score: 45.55
epi: 260, score: 46.6
epi: 280, score: 61.55
epi: 300, score: 58.3
epi: 320, score: 49.2
epi: 340, score: 112.35


KeyboardInterrupt: 

In [9]:
env.close()