# without global pool
# batch with larger learning rate is much more stable

In [1]:
import gym
import numpy as np
import matplotlib.pyplot as plt
from tqdm import trange
from tensorboardX import SummaryWriter

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.distributions import Categorical

# from torchvision import transforms

%matplotlib inline

In [2]:
env = gym.make("Pong-v0")

[2018-01-06 23:26:45,736] Making new env: Pong-v0


In [3]:
downsample = 2
output_size = 160//downsample
batch_size = 10 # and learning rate become larger
learning_rate = 1.e-4 * batch_size

In [4]:
def preprocess(frame):
    '''from karpathy.'''
    I = frame
    I = I[35:195] # crop
    I = I[::downsample,::downsample,0] # downsample by factor of 2
    I[I == 144] = 0 # erase background (background type 1)
    I[I == 109] = 0 # erase background (background type 2)
    I[I != 0] = 1 # everything else (paddles, ball) just set to 1
    tensor = torch.from_numpy(I).float()
    return tensor.unsqueeze(0).unsqueeze(0) #BCHW

def clip_grads(net, low=-10, high=10):
    """Gradient clipping to the range [low, high]."""
    parameters = [param for param in net.parameters()
                  if param.grad is not None]
    for p in parameters:
        p.grad.data.clamp_(low, high)
        
if torch.cuda.is_available():
    def to_var(x, requires_grad=False, gpu=None):
        x = x.cuda(gpu)
        return Variable(x, requires_grad=requires_grad)
else:
    def to_var(x, requires_grad=False, vgpu=None):
        return Variable(x, requires_grad=requires_grad)

In [5]:
class Net(nn.Module):
    '''very similar to Nature DQN.'''
    def __init__(self, action_n, input_shape=(1,80,80)):
        super().__init__()
        self.conv = nn.Sequential(nn.Conv2d(input_shape[0],16,kernel_size=8, stride=2),nn.ReLU(),
                                  nn.Conv2d(16,32,kernel_size=4, stride=2),nn.ReLU())
        flatten_size = self._get_flatten_size(input_shape)
        self.fc = nn.Linear(flatten_size, action_n)
    
    def _get_flatten_size(self, shape):
        x = Variable(torch.rand(1, *shape))
        output_feat = self.conv(x)
        n_size = output_feat.view(-1).size(0)
        return n_size
        
    def forward(self, x):
        feat = self.conv(x)
        logit = self.fc(feat.view(feat.size(0),-1))
        return logit

In [6]:
class PolicyGradient:
    
    def __init__(self, model, gamma=0.99, eps=1.e-6, running_gamma=0.99, running_start=0,
#                 episode2thresh=lambda i: 0.05+0.9*np.exp(-1. * i / 100) if i>150 else 0): # eploration will start after 150 episodes
                 episode2thresh=lambda i: 0): # without exploration
        self.model = model
        self.gamma = gamma
        self.eps = eps
        self.log_probs = []
        self.rewards = []
        self.total_rewards = []
        self.running_reward = running_start
        self.running_gamma = running_gamma
        self.episode2thresh = episode2thresh
        
    @property
    def episodes(self):
        return len(self.total_rewards)
        
    def select_action(self,obs):
        self.model.train()
        thresh=self.episode2thresh(self.episodes)
        action, log_prob = select_action(obs, self.model, thresh=thresh)
        self.log_probs.append(log_prob)
        return action
    
    def get_loss_and_clear(self):
        total_reward = sum(self.rewards)
        self.total_rewards.append(total_reward)
        self.running_reward = self.running_gamma*self.running_reward+(1-self.running_gamma)*total_reward
        policy_loss = get_policy_loss(self.log_probs, self.rewards, self.gamma, self.eps)
        del self.log_probs[:]
        del self.rewards[:]
        return policy_loss
    
    def take_action(self, action, env, render=False):
        obs, reward, done, info = env.step(action)
        self.rewards.append(reward)
        if render:
            env.render()
        return obs, reward, done, info
    
    def greedy_policy(self, obs):
        self.model.eval()
        state = to_var(obs)
        prob = self.model(state)
        _, action = prob.max(dim=1)
        return action.data[0]

def select_action(obs, model, thresh=0):
    state = to_var(obs)
    logits = model(state)
    probs = F.softmax(logits, dim=1)
    m = Categorical(probs)
    action = m.sample()
#     if np.random.random()>thresh:
# #         print(probs)
#         try:
#             action = m.sample()
#         except:
#             print(probs,m)
#             raise
#     else:
#         action_space = probs.size(1)
#         action = to_var(torch.from_numpy(np.random.randint(action_space,size=1)))
    return action.data[0],m.log_prob(action)
    
def get_normalized_rewards(rewards, gamma, eps):
    acc = []
    R = 0
    for r in reversed(rewards):
        R = r + gamma * R
        acc.append(R)
    ret = to_var(torch.Tensor(acc[::-1]),requires_grad=False)
    ret = (ret - ret.mean()) / (ret.std()+eps)
#     print(ret)
    return ret

def get_policy_loss(log_probs,rewards, gamma,eps):
    log_probs_v = torch.cat(log_probs)
    rewards_v = get_normalized_rewards(rewards, gamma, eps)
    return -log_probs_v.dot(rewards_v)

In [7]:
net = Net(env.action_space.n, input_shape=(1,output_size,output_size))
if torch.cuda.is_available():
    net = net.cuda()
optimizer = optim.Adam(net.parameters(), lr=learning_rate, weight_decay=0.001)
trainer = PolicyGradient(model=net,running_start=-21)
writer = SummaryWriter()

In [8]:
import os

writer_path = list(writer.all_writers.keys())[0]
weight_join = lambda p: os.path.join(writer_path, p)

In [9]:
writer_path

'runs/Jan06_23-26-49_amax'

In [None]:
optimizer.zero_grad()

In [None]:
for episode in trange(100000):
    frame = env.reset()
    last_obs = preprocess(frame)
    curr_obs = preprocess(frame)
    total_reward = 0
    for step in range(100000): # not exceed 10000 steps
        action = trainer.select_action(obs=curr_obs-last_obs)
        frame, reward, done, _ = trainer.take_action(action, env, render=False)
        last_obs = curr_obs
        curr_obs = preprocess(frame)
        total_reward+=reward
        if done:
             break
    if step==100000:
        print("not enough!!!!!!!!!!!!!!!")
    policy_loss = trainer.get_loss_and_clear()
    writer.add_scalar("loss",policy_loss.data[0],episode)
    writer.add_scalar("reward",total_reward,episode)
    policy_loss.backward()
    if (episode+1)%batch_size==0:
        clip_grads(trainer.model,-10,10)
        optimizer.step()
        optimizer.zero_grad()
    running_reward = trainer.running_reward
    if episode%100==0:
        print(episode, total_reward,running_reward)
        torch.save(net.state_dict(), weight_join("episode%s.pth"%episode))
    if running_reward>1:
        break
print("Finished: %s@%s" %(trainer.running_reward,episode))

  0%|          | 1/100000 [00:04<124:39:29,  4.49s/it]

0 -19.0 -20.98


  0%|          | 101/100000 [04:09<65:01:43,  2.34s/it]

100 -21.0 -20.593320148522988


  0%|          | 201/100000 [08:12<65:37:39,  2.37s/it]

200 -20.0 -20.43658979275693


  0%|          | 301/100000 [12:20<66:03:26,  2.39s/it]

300 -21.0 -20.324941403698595


  0%|          | 401/100000 [16:35<74:33:18,  2.69s/it]

400 -20.0 -20.29747855383671


  1%|          | 501/100000 [20:58<72:35:06,  2.63s/it]

500 -21.0 -19.985940295657944


  1%|          | 601/100000 [25:33<81:31:55,  2.95s/it]

600 -19.0 -19.63815240524038


  1%|          | 701/100000 [30:24<78:49:49,  2.86s/it]

700 -20.0 -19.43256308065187


  1%|          | 801/100000 [35:31<71:33:18,  2.60s/it]

800 -19.0 -19.429698270801634


  1%|          | 901/100000 [40:40<84:59:07,  3.09s/it]

900 -21.0 -19.403330877153223


  1%|          | 1001/100000 [45:54<81:33:01,  2.97s/it]

1000 -20.0 -19.510369441712307


  1%|          | 1101/100000 [51:54<107:38:08,  3.92s/it]

1100 -18.0 -19.196990978756542


  1%|          | 1201/100000 [58:20<109:32:47,  3.99s/it]

1200 -19.0 -19.08992046960698


  1%|▏         | 1301/100000 [1:05:10<109:15:03,  3.98s/it]

1300 -18.0 -18.910432957351578


  1%|▏         | 1401/100000 [1:12:27<121:43:18,  4.44s/it]

1400 -15.0 -18.690822640443983


  2%|▏         | 1501/100000 [1:19:50<120:11:32,  4.39s/it]

1500 -20.0 -18.41175157240275


  2%|▏         | 1601/100000 [1:26:53<130:26:42,  4.77s/it]

1600 -16.0 -18.313298248248152


  2%|▏         | 1701/100000 [1:35:30<167:34:22,  6.14s/it]

1700 -15.0 -17.08492128729433


  2%|▏         | 1801/100000 [1:44:36<142:03:10,  5.21s/it]

1800 -20.0 -15.930765843272738


  2%|▏         | 1901/100000 [1:53:56<149:57:01,  5.50s/it]

1900 -18.0 -15.684703574373248


  2%|▏         | 2001/100000 [2:03:21<152:40:55,  5.61s/it]

2000 -13.0 -15.181284160638096


  2%|▏         | 2101/100000 [2:12:25<158:49:27,  5.84s/it]

2100 -17.0 -15.471478906484887


  2%|▏         | 2201/100000 [2:21:55<152:16:55,  5.61s/it]

2200 -12.0 -15.35677497731562


  2%|▏         | 2301/100000 [2:31:40<150:43:19,  5.55s/it]

2300 -17.0 -14.772674257142585


  2%|▏         | 2401/100000 [2:41:01<154:07:37,  5.69s/it]

2400 -11.0 -14.547133676329882


  3%|▎         | 2501/100000 [2:50:50<163:07:46,  6.02s/it]

2500 -19.0 -14.67345973165935


  3%|▎         | 2601/100000 [3:00:36<156:43:40,  5.79s/it]

2600 -13.0 -14.472608814421442


  3%|▎         | 2701/100000 [3:10:19<194:51:36,  7.21s/it]

2700 -8.0 -14.244923854253141


  3%|▎         | 2801/100000 [3:20:39<163:17:43,  6.05s/it]

2800 -15.0 -14.67451234398946


  3%|▎         | 2901/100000 [3:30:44<162:41:11,  6.03s/it]

2900 -17.0 -14.775291189329023


  3%|▎         | 3001/100000 [3:40:57<159:17:09,  5.91s/it]

3000 -13.0 -13.961154196469257


  3%|▎         | 3101/100000 [3:51:24<164:05:52,  6.10s/it]

3100 -12.0 -13.862416826279734


  3%|▎         | 3201/100000 [4:01:42<179:29:08,  6.68s/it]

3200 -5.0 -13.895729960675311


  3%|▎         | 3301/100000 [4:11:45<158:09:04,  5.89s/it]

3300 -16.0 -13.991666882155872


  3%|▎         | 3401/100000 [4:21:48<162:00:01,  6.04s/it]

3400 -14.0 -13.885584085160495


  4%|▎         | 3501/100000 [4:31:35<155:06:19,  5.79s/it]

3500 -17.0 -14.121357273136075


  4%|▎         | 3601/100000 [4:41:47<155:53:36,  5.82s/it]

3600 -18.0 -13.363968575759191


  4%|▎         | 3701/100000 [4:51:57<167:44:32,  6.27s/it]

3700 -7.0 -12.662717775419932


  4%|▍         | 3801/100000 [5:02:15<160:26:35,  6.00s/it]

3800 -15.0 -12.858893385195596


  4%|▍         | 3901/100000 [5:12:39<174:26:46,  6.53s/it]

3900 -15.0 -13.003787154628078


  4%|▍         | 4001/100000 [5:23:19<185:04:18,  6.94s/it]

4000 -10.0 -12.634505515755512


  4%|▍         | 4101/100000 [5:33:41<170:31:52,  6.40s/it]

4100 -10.0 -13.127838713299772


  4%|▍         | 4201/100000 [5:44:38<176:53:10,  6.65s/it]

4200 -15.0 -13.48986651935621


  4%|▍         | 4301/100000 [5:55:59<187:35:38,  7.06s/it]

4300 -11.0 -13.487427534257229


  4%|▍         | 4401/100000 [6:06:55<181:06:42,  6.82s/it]

4400 -10.0 -13.663321629431676


  5%|▍         | 4501/100000 [6:18:08<167:41:19,  6.32s/it]

4500 -16.0 -13.350207933751266


  5%|▍         | 4601/100000 [6:28:33<162:15:03,  6.12s/it]

4600 -11.0 -12.906172001119847


  5%|▍         | 4701/100000 [6:39:02<176:36:34,  6.67s/it]

4700 -14.0 -11.935700618149731


  5%|▍         | 4801/100000 [6:49:44<165:27:34,  6.26s/it]

4800 -13.0 -11.41613064468382


  5%|▍         | 4901/100000 [7:00:38<184:14:50,  6.97s/it]

4900 -1.0 -11.083711567458279


  5%|▌         | 5001/100000 [7:11:48<186:25:34,  7.06s/it]

5000 -10.0 -11.249601903759705


  5%|▌         | 5101/100000 [7:22:52<169:18:24,  6.42s/it]

5100 -16.0 -11.395774087249164


  5%|▌         | 5201/100000 [7:33:43<179:05:25,  6.80s/it]

5200 -6.0 -11.327536985261128


  5%|▌         | 5301/100000 [7:44:57<178:19:03,  6.78s/it]

5300 -15.0 -11.20690508944407


  5%|▌         | 5401/100000 [7:56:02<180:08:25,  6.86s/it]

5400 -13.0 -12.162041031536546


  6%|▌         | 5501/100000 [8:07:37<179:12:08,  6.83s/it]

5500 -11.0 -12.254860958229312


  6%|▌         | 5601/100000 [8:18:59<170:27:11,  6.50s/it]

5600 -16.0 -12.252959032953385


  6%|▌         | 5701/100000 [8:31:01<193:10:30,  7.37s/it]

5700 -13.0 -11.391078257909577


  6%|▌         | 5801/100000 [8:42:38<174:34:29,  6.67s/it]

5800 -11.0 -11.33357068751937


  6%|▌         | 5901/100000 [8:54:11<192:54:17,  7.38s/it]

5900 -12.0 -10.721497673121274


  6%|▌         | 6001/100000 [9:06:00<187:50:46,  7.19s/it]

6000 -14.0 -10.328076399413405


  6%|▌         | 6101/100000 [9:17:55<197:20:12,  7.57s/it]

6100 -9.0 -9.84302884141116


  6%|▌         | 6201/100000 [9:29:40<214:09:45,  8.22s/it]

6200 -6.0 -9.970381075404667


  6%|▋         | 6301/100000 [9:41:49<195:06:23,  7.50s/it]

6300 -7.0 -10.190336442764858


  6%|▋         | 6401/100000 [9:53:48<182:57:19,  7.04s/it]

6400 -4.0 -9.813581951936442


  7%|▋         | 6501/100000 [10:05:46<201:57:54,  7.78s/it]

6500 -10.0 -10.171650145377555


  7%|▋         | 6601/100000 [10:18:02<186:25:00,  7.19s/it]

6600 -13.0 -10.56593598909642


  7%|▋         | 6701/100000 [10:30:22<201:01:33,  7.76s/it]

6700 -7.0 -9.912434095673584


  7%|▋         | 6801/100000 [10:42:34<195:17:10,  7.54s/it]

6800 -7.0 -9.172270494084943


  7%|▋         | 6901/100000 [10:54:36<200:25:08,  7.75s/it]

6900 -1.0 -10.000017572976706


  7%|▋         | 7001/100000 [11:06:50<187:54:43,  7.27s/it]

7000 -12.0 -9.880019378435236


  7%|▋         | 7101/100000 [11:19:10<188:44:33,  7.31s/it]

7100 -9.0 -9.617432853679505


  7%|▋         | 7201/100000 [11:31:17<192:57:20,  7.49s/it]

7200 -6.0 -9.660655458904692


  7%|▋         | 7301/100000 [11:43:20<185:22:22,  7.20s/it]

7300 -7.0 -9.334606066714093


  7%|▋         | 7401/100000 [11:55:03<168:59:37,  6.57s/it]

7400 -14.0 -9.308347387523444


  8%|▊         | 7501/100000 [12:06:46<177:58:23,  6.93s/it]

7500 -9.0 -9.545301325785449


  8%|▊         | 7601/100000 [12:18:21<168:51:31,  6.58s/it]

7600 -15.0 -10.184963596063746


  8%|▊         | 7701/100000 [12:30:33<191:42:12,  7.48s/it]

7700 -7.0 -9.797249310551583


  8%|▊         | 7801/100000 [12:42:45<181:07:24,  7.07s/it]

7800 -8.0 -8.91685602991032


  8%|▊         | 7901/100000 [12:55:13<194:52:37,  7.62s/it]

7900 -4.0 -8.370305714461463


  8%|▊         | 8001/100000 [13:07:13<193:50:12,  7.58s/it]

8000 -3.0 -8.569021439973797


  8%|▊         | 8101/100000 [13:19:20<194:30:29,  7.62s/it]

8100 -8.0 -8.859530118655128


  8%|▊         | 8201/100000 [13:31:46<212:40:35,  8.34s/it]

8200 -1.0 -8.85860588821657


  8%|▊         | 8301/100000 [13:43:43<189:34:52,  7.44s/it]

8300 -3.0 -8.955536400465117


  8%|▊         | 8401/100000 [13:55:19<171:31:44,  6.74s/it]

8400 -13.0 -8.863229835535737


  9%|▊         | 8501/100000 [14:06:20<162:05:14,  6.38s/it]

8500 -17.0 -8.896738643605618


  9%|▊         | 8601/100000 [14:17:10<160:00:45,  6.30s/it]

8600 -12.0 -9.25728981172437


  9%|▊         | 8701/100000 [14:28:28<168:25:57,  6.64s/it]

8700 -15.0 -9.799944559648646


  9%|▉         | 8801/100000 [14:40:02<180:56:57,  7.14s/it]

8800 -9.0 -9.659031270303734


  9%|▉         | 8901/100000 [14:51:53<178:48:54,  7.07s/it]

8900 -12.0 -9.173200027539036


  9%|▉         | 9001/100000 [15:03:56<196:57:14,  7.79s/it]

9000 -5.0 -9.309666386138971


  9%|▉         | 9101/100000 [15:15:47<182:45:58,  7.24s/it]

9100 -9.0 -9.075059031032165


  9%|▉         | 9201/100000 [15:27:46<205:41:31,  8.16s/it]

9200 -1.0 -8.558874569157654


  9%|▉         | 9301/100000 [15:39:58<184:46:38,  7.33s/it]

9300 -9.0 -8.826073478170013


  9%|▉         | 9401/100000 [15:52:05<174:16:14,  6.92s/it]

9400 -9.0 -9.308668443422393


 10%|▉         | 9501/100000 [16:03:59<181:09:47,  7.21s/it]

9500 -6.0 -9.179622104840039


 10%|▉         | 9601/100000 [16:16:21<182:35:53,  7.27s/it]

9600 -7.0 -8.424582774713441


 10%|▉         | 9701/100000 [16:28:08<182:53:59,  7.29s/it]

9700 -7.0 -8.899729968503987


 10%|▉         | 9747/100000 [16:33:36<184:34:21,  7.36s/it]

In [None]:
torch.save(net.state_dict(), weight_join("final.pth"))

In [None]:
plt.plot(trainer.total_rewards)

In [None]:
writer_path