In [1]:
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Categorical

In [2]:
gamma = 0.99

In [3]:
class Pi(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(Pi, self).__init__()
        layers = [nn.Linear(in_dim, 64), nn.ReLU(), nn.Linear(64, out_dim), ]
        self.model = nn.Sequential(*layers)
        self.onpolicy_reset()
        
    def onpolicy_reset(self):
        self.log_probs = []
        self.rewards = []
        
    def forward(self, x):
        pdparam = self.model(x)
        return pdparam
    
    def act(self, state):
        x = torch.from_numpy(state.astype(np.float32))
        pdparam = self.forward(x)
        pd = Categorical(logits=pdparam)
        action = pd.sample()
        log_prob = pd.log_prob(action)
        self.log_probs.append(log_prob)
        return action.item()

In [4]:
def train(pi, optimizer):
    T = len(pi.rewards)
    rets = np.empty(T, dtype=np.float32)
    future_ret = 0.0
    
    for t in reversed(range(T)):
        future_ret = pi.rewards[t] + gamma*future_ret
        rets[t] = future_ret
        
    rets = torch.tensor(rets)
    log_probs = torch.stack(pi.log_probs)
    loss = - log_probs*rets
    loss = torch.sum(loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss

In [5]:
env = gym.make('CartPole-v0')

In [6]:
in_dim = env.observation_space.shape[0]

In [7]:
print(in_dim)

4


In [8]:
out_dim = env.action_space.n

In [9]:
print(out_dim)

2


In [10]:
pi = Pi(in_dim, out_dim)

In [11]:
optimizer = optim.Adam(pi.parameters(), lr=0.01)

In [12]:
for epi in range(240):
    state = env.reset()
    for t in range(300):
        action = pi.act(state)
        state, reward, done, _ = env.step(action)
        pi.rewards.append(reward)
        env.render()
        if done:
            break
    loss = train(pi, optimizer)
    total_reward = sum(pi.rewards)
    solved = total_reward > 195.0
    pi.onpolicy_reset()
    print(f'Episode {epi}, loss: {loss}, total_reward: {total_reward}, solved: {solved}')

Episode 0, loss: 50.234378814697266, total_reward: 12.0, solved: False
Episode 1, loss: 597.0743408203125, total_reward: 44.0, solved: False
Episode 2, loss: 64.98707580566406, total_reward: 14.0, solved: False
Episode 3, loss: 229.73049926757812, total_reward: 27.0, solved: False
Episode 4, loss: 66.93148040771484, total_reward: 13.0, solved: False
Episode 5, loss: 211.13804626464844, total_reward: 25.0, solved: False
Episode 6, loss: 55.438838958740234, total_reward: 12.0, solved: False
Episode 7, loss: 250.60020446777344, total_reward: 28.0, solved: False
Episode 8, loss: 549.1503295898438, total_reward: 43.0, solved: False
Episode 9, loss: 193.25347900390625, total_reward: 24.0, solved: False
Episode 10, loss: 299.2527770996094, total_reward: 31.0, solved: False
Episode 11, loss: 886.17138671875, total_reward: 56.0, solved: False
Episode 12, loss: 68.13785552978516, total_reward: 14.0, solved: False
Episode 13, loss: 188.5084686279297, total_reward: 24.0, solved: False
Episode 14, 

Episode 117, loss: 4904.7392578125, total_reward: 163.0, solved: False
Episode 118, loss: 5164.2919921875, total_reward: 170.0, solved: False
Episode 119, loss: 112.10862731933594, total_reward: 19.0, solved: False
Episode 120, loss: 2641.803466796875, total_reward: 113.0, solved: False
Episode 121, loss: 4383.11572265625, total_reward: 151.0, solved: False
Episode 122, loss: 2067.0009765625, total_reward: 99.0, solved: False
Episode 123, loss: 49.962493896484375, total_reward: 13.0, solved: False
Episode 124, loss: 4349.85302734375, total_reward: 147.0, solved: False
Episode 125, loss: 217.04383850097656, total_reward: 25.0, solved: False
Episode 126, loss: 2875.186767578125, total_reward: 123.0, solved: False
Episode 127, loss: 2493.4208984375, total_reward: 114.0, solved: False
Episode 128, loss: 3656.85400390625, total_reward: 139.0, solved: False
Episode 129, loss: 3792.22119140625, total_reward: 136.0, solved: False
Episode 130, loss: 3313.2880859375, total_reward: 133.0, solved:

Episode 232, loss: 2325.193603515625, total_reward: 117.0, solved: False
Episode 233, loss: 3587.015380859375, total_reward: 166.0, solved: False
Episode 234, loss: 5401.2529296875, total_reward: 200.0, solved: True
Episode 235, loss: 5555.1220703125, total_reward: 200.0, solved: True
Episode 236, loss: 5222.90478515625, total_reward: 200.0, solved: True
Episode 237, loss: 2280.17919921875, total_reward: 114.0, solved: False
Episode 238, loss: 2316.193603515625, total_reward: 124.0, solved: False
Episode 239, loss: 5001.67578125, total_reward: 200.0, solved: True
Episode 240, loss: 1895.8753662109375, total_reward: 115.0, solved: False
Episode 241, loss: 4255.50439453125, total_reward: 171.0, solved: False
Episode 242, loss: 4751.08056640625, total_reward: 189.0, solved: False
Episode 243, loss: 5367.08349609375, total_reward: 200.0, solved: True
Episode 244, loss: 2646.878662109375, total_reward: 141.0, solved: False
Episode 245, loss: 2762.875732421875, total_reward: 140.0, solved: F