In [1]:
import gym
import numpy as np
import collections

import ptan
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [3]:
class PGN(nn.Module):
    def __init__(self,input_size,n_actions):
        super(PGN,self).__init__()
        
        self.net = nn.Sequential(
            nn.Linear(input_size,128),
            nn.ReLU(),
            nn.Linear(128,n_actions)
        )
        
    def forward(self,x):
        logits = self.net(x)
        return F.softmax(logits) 

In [4]:
class MeanBuffer():
    def __init__(self,capacity):
        self.capacity = capacity
        self.deque = collections.deque(maxlen=capacity)
        self.sum = 0.0
        
    def add(self,val):
        if len(self.deque)==self.capacity:
            self.sum -= self.deque[0]
        self.deque.append(val)
        self.sum+=val
    
    def mean(self):
        if not self.deque:
            return 0.0
        return self.sum/len(self.deque)

In [5]:
TARGET_REWARD = 195
GAMMA = 0.99
LEARNING_RATE = 0.001
BATCH_SIZE = 32
ENTROPY_BETA = 0.01
BELLMAN_STEPS = 10
BASELINE_STEPS = 50000

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env = gym.make('CartPole-v0')
net = PGN(env.observation_space.shape[0], env.action_space.n).to(device)

agent = ptan.agent.PolicyAgent(net,preprocessor=ptan.agent.float32_preprocessor,device=device)

exp_source = ptan.experience.ExperienceSourceFirstLast(env,agent,gamma=GAMMA,
                                                       steps_count=BELLMAN_STEPS)

optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)

total_rewards = []
step_rewards = []
baseline_buf = MeanBuffer(BASELINE_STEPS)
step_idx = 0
done_episodes = 0

batch_states, batch_actions, batch_scales = [], [], []
for step_idx, exp in enumerate(exp_source):
    baseline_buf.add(exp.reward)
    baseline = baseline_buf.mean()
    batch_states.append(exp.state)
    batch_actions.append(exp.action)
    batch_scales.append(exp.reward - baseline)
  
    episode_rewards = exp_source.pop_total_rewards()
    if episode_rewards:
        done_episodes += 1
        reward = episode_rewards[0]
        total_rewards.append(reward)
        mean_rewards = float(np.mean(total_rewards[-100:]))
        print("%d: reward: %6.2f, mean_100: %6.2f, episodes: %d" % (
                step_idx, reward, mean_rewards, done_episodes))
        if mean_rewards > TARGET_REWARD:
            print("Solved in %d steps and %d episodes!" % (step_idx, done_episodes))
            break
      
    if len(batch_states) < BATCH_SIZE:
        continue
    
    #copy training data to the GPU
    states_v = torch.FloatTensor(batch_states).to(device)
    batch_actions_t = torch.LongTensor(batch_actions).to(device)
    batch_scale_v = torch.FloatTensor(batch_scales).to(device)

    #apply gradient descent
    optimizer.zero_grad()
    #softmax output
    
    prob_v = net(states_v)
    
    #apply logarithm
    log_prob_v = torch.log(prob_v)
    #scale the log probs according to (reward - baseline)
    log_prob_actions_v = batch_scale_v * log_prob_v[range(BATCH_SIZE), batch_actions_t]
    #take the mean cross-entropy across all batches
    loss_policy_v = -log_prob_actions_v.mean()

    # subtract the entropy bonus from the loss function
    entropy_v = -(prob_v * log_prob_v).sum(dim=1).mean()
    entropy_loss_v = -ENTROPY_BETA * entropy_v
    loss_v = loss_policy_v + entropy_loss_v

    loss_v.backward()
    optimizer.step()

    batch_states.clear()
    batch_actions.clear()
    batch_scales.clear()

  result = entry_point.load(False)
  del sys.path[0]


17: reward:  16.00, mean_100:  16.00, episodes: 1
30: reward:  12.00, mean_100:  14.00, episodes: 2
46: reward:  15.00, mean_100:  14.33, episodes: 3
86: reward:  39.00, mean_100:  20.50, episodes: 4
100: reward:  13.00, mean_100:  19.00, episodes: 5
122: reward:  21.00, mean_100:  19.33, episodes: 6
137: reward:  14.00, mean_100:  18.57, episodes: 7
146: reward:   9.00, mean_100:  17.38, episodes: 8
167: reward:  20.00, mean_100:  17.67, episodes: 9
201: reward:  33.00, mean_100:  19.20, episodes: 10
220: reward:  18.00, mean_100:  19.09, episodes: 11
241: reward:  20.00, mean_100:  19.17, episodes: 12
262: reward:  20.00, mean_100:  19.23, episodes: 13
289: reward:  26.00, mean_100:  19.71, episodes: 14
310: reward:  20.00, mean_100:  19.73, episodes: 15
323: reward:  12.00, mean_100:  19.25, episodes: 16
337: reward:  13.00, mean_100:  18.88, episodes: 17
362: reward:  24.00, mean_100:  19.17, episodes: 18
375: reward:  12.00, mean_100:  18.79, episodes: 19
405: reward:  29.00, mean

6003: reward:  51.00, mean_100:  44.66, episodes: 157
6065: reward:  61.00, mean_100:  45.02, episodes: 158
6114: reward:  48.00, mean_100:  45.21, episodes: 159
6152: reward:  37.00, mean_100:  44.72, episodes: 160
6195: reward:  42.00, mean_100:  44.99, episodes: 161
6293: reward:  97.00, mean_100:  45.60, episodes: 162
6317: reward:  23.00, mean_100:  45.43, episodes: 163
6383: reward:  65.00, mean_100:  45.40, episodes: 164
6435: reward:  51.00, mean_100:  45.47, episodes: 165
6497: reward:  61.00, mean_100:  45.26, episodes: 166
6576: reward:  78.00, mean_100:  45.83, episodes: 167
6629: reward:  52.00, mean_100:  46.14, episodes: 168
6795: reward: 165.00, mean_100:  47.51, episodes: 169
6842: reward:  46.00, mean_100:  47.52, episodes: 170
6911: reward:  68.00, mean_100:  47.83, episodes: 171
7001: reward:  89.00, mean_100:  48.26, episodes: 172
7050: reward:  48.00, mean_100:  48.58, episodes: 173
7087: reward:  36.00, mean_100:  48.81, episodes: 174
7121: reward:  33.00, mean_1

25021: reward: 200.00, mean_100: 148.88, episodes: 307
25141: reward: 119.00, mean_100: 149.82, episodes: 308
25342: reward: 200.00, mean_100: 150.28, episodes: 309
25536: reward: 193.00, mean_100: 150.96, episodes: 310
25721: reward: 184.00, mean_100: 151.49, episodes: 311
25922: reward: 200.00, mean_100: 152.40, episodes: 312
26073: reward: 150.00, mean_100: 152.27, episodes: 313
26274: reward: 200.00, mean_100: 152.27, episodes: 314
26433: reward: 158.00, mean_100: 153.23, episodes: 315
26634: reward: 200.00, mean_100: 154.33, episodes: 316
26713: reward:  78.00, mean_100: 153.11, episodes: 317
26851: reward: 137.00, mean_100: 154.04, episodes: 318
27052: reward: 200.00, mean_100: 154.22, episodes: 319
27167: reward: 114.00, mean_100: 153.36, episodes: 320
27368: reward: 200.00, mean_100: 153.79, episodes: 321
27569: reward: 200.00, mean_100: 154.26, episodes: 322
27770: reward: 200.00, mean_100: 155.00, episodes: 323
27949: reward: 178.00, mean_100: 154.78, episodes: 324
28126: rew

53181: reward: 200.00, mean_100: 191.36, episodes: 456
53382: reward: 200.00, mean_100: 191.36, episodes: 457
53498: reward: 115.00, mean_100: 190.57, episodes: 458
53666: reward: 167.00, mean_100: 190.24, episodes: 459
53806: reward: 139.00, mean_100: 189.63, episodes: 460
53939: reward: 132.00, mean_100: 189.39, episodes: 461
54140: reward: 200.00, mean_100: 189.92, episodes: 462
54313: reward: 172.00, mean_100: 189.64, episodes: 463
54452: reward: 138.00, mean_100: 190.22, episodes: 464
54647: reward: 194.00, mean_100: 190.16, episodes: 465
54848: reward: 200.00, mean_100: 190.16, episodes: 466
55049: reward: 200.00, mean_100: 190.99, episodes: 467
55166: reward: 116.00, mean_100: 190.15, episodes: 468
55367: reward: 200.00, mean_100: 190.15, episodes: 469
55568: reward: 200.00, mean_100: 190.15, episodes: 470
55769: reward: 200.00, mean_100: 190.55, episodes: 471
55940: reward: 170.00, mean_100: 190.25, episodes: 472
56140: reward: 199.00, mean_100: 190.24, episodes: 473
56311: rew

80779: reward: 169.00, mean_100: 185.30, episodes: 606
80926: reward: 146.00, mean_100: 184.76, episodes: 607
81065: reward: 138.00, mean_100: 184.14, episodes: 608
81198: reward: 132.00, mean_100: 183.46, episodes: 609
81338: reward: 139.00, mean_100: 182.85, episodes: 610
81474: reward: 135.00, mean_100: 182.20, episodes: 611
81604: reward: 129.00, mean_100: 181.49, episodes: 612
81735: reward: 130.00, mean_100: 180.79, episodes: 613
81885: reward: 149.00, mean_100: 180.28, episodes: 614
82034: reward: 148.00, mean_100: 179.88, episodes: 615
82163: reward: 128.00, mean_100: 179.16, episodes: 616
82296: reward: 132.00, mean_100: 178.48, episodes: 617
82465: reward: 168.00, mean_100: 178.16, episodes: 618
82666: reward: 200.00, mean_100: 178.16, episodes: 619
82867: reward: 200.00, mean_100: 178.16, episodes: 620
83007: reward: 139.00, mean_100: 177.55, episodes: 621
83184: reward: 176.00, mean_100: 177.31, episodes: 622
83365: reward: 180.00, mean_100: 177.11, episodes: 623
83506: rew

107258: reward: 200.00, mean_100: 176.47, episodes: 756
107459: reward: 200.00, mean_100: 176.61, episodes: 757
107660: reward: 200.00, mean_100: 176.91, episodes: 758
107861: reward: 200.00, mean_100: 177.53, episodes: 759
108062: reward: 200.00, mean_100: 178.29, episodes: 760
108263: reward: 200.00, mean_100: 179.24, episodes: 761
108464: reward: 200.00, mean_100: 179.77, episodes: 762
108665: reward: 200.00, mean_100: 181.47, episodes: 763
108866: reward: 200.00, mean_100: 182.22, episodes: 764
109067: reward: 200.00, mean_100: 182.93, episodes: 765
109268: reward: 200.00, mean_100: 183.67, episodes: 766
109469: reward: 200.00, mean_100: 184.11, episodes: 767
109670: reward: 200.00, mean_100: 184.99, episodes: 768
109871: reward: 200.00, mean_100: 185.75, episodes: 769
110072: reward: 200.00, mean_100: 186.29, episodes: 770
110273: reward: 200.00, mean_100: 186.94, episodes: 771
110474: reward: 200.00, mean_100: 187.67, episodes: 772
110675: reward: 200.00, mean_100: 187.71, episod