In [1]:
import gym

import torch
from torch import nn
import torch.nn.functional as F

from models.ppo import PPO

In [2]:
class Actor(nn.Module):
    def __init__(self, nstats: int, nactions: int, hid_dim:int):
        super(Actor, self).__init__()
        self.fc1 = nn.Linear(nstats, hid_dim)
        self.fc2 = nn.Linear(hid_dim, nactions)
        self.relu1 = nn.ReLU()
        
    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = F.softmax(self.fc2(x), -1)
        return x


In [3]:
class Critic(nn.Module):
    def __init__(self, nstats: int, hid_dim:int):
        super(Critic, self).__init__()
        self.fc1 = nn.Linear(nstats, hid_dim)
        self.fc2 = nn.Linear(hid_dim, 1)
        self.relu1 = nn.ReLU()
        
    def forward(self, x):
        x = self.relu1(self.fc1(x))
        x = self.fc2(x)
        return x


In [4]:
learning_rate = 0.0005
nstates = 4
nactions = 2
hid_dim = 128
gamma = 0.98
gae = 0.95

In [5]:
env = gym.make("CartPole-v1", render_mode='human')

In [6]:
act = Actor(nstates, nactions, hid_dim)
act_opt = torch.optim.Adam(act.parameters(), lr=learning_rate)

cri = Critic(nstates, hid_dim)
cri_opt = torch.optim.Adam(cri.parameters(), lr=learning_rate)

In [7]:
ppo = PPO(env, 2, act, act_opt, cri, cri_opt)

In [8]:
ppo.train(5000, 10, 3)

  if not isinstance(terminated, (bool, np.bool8)):


step: 406, score: 20.3
step: 822, score: 20.8
step: 1347, score: 26.25
step: 2057, score: 35.5
step: 2878, score: 41.05
step: 3753, score: 43.75
step: 4959, score: 60.3
step: 7762, score: 140.15
step: 10501, score: 136.95
step: 13691, score: 159.5
step: 17846, score: 207.75


KeyboardInterrupt: 

In [None]:
env.close()