# without global pool
# batch with larger learning rate is much more stable

In [1]:
import gym
import numpy as np
import matplotlib.pyplot as plt
from tqdm import trange
from tensorboardX import SummaryWriter

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.distributions import Categorical

# from torchvision import transforms

%matplotlib inline

In [2]:
env = gym.make("Pong-v0")

[2018-01-07 21:47:50,315] Making new env: Pong-v0


In [3]:
downsample = 2
output_size = 160//downsample
batch_size = 10 # and learning rate become larger
learning_rate = 3.e-5 * batch_size

In [4]:
def preprocess(frame):
    '''from karpathy.'''
    I = frame
    I = I[35:195] # crop
    I = I[::downsample,::downsample,0] # downsample by factor of 2
    I[I == 144] = 0 # erase background (background type 1)
    I[I == 109] = 0 # erase background (background type 2)
    I[I != 0] = 1 # everything else (paddles, ball) just set to 1
    tensor = torch.from_numpy(I).float()
    return tensor.unsqueeze(0).unsqueeze(0) #BCHW

def clip_grads(net, low=-10, high=10):
    """Gradient clipping to the range [low, high]."""
    parameters = [param for param in net.parameters()
                  if param.grad is not None]
    for p in parameters:
        p.grad.data.clamp_(low, high)
        
if torch.cuda.is_available():
    def to_var(x, requires_grad=False, gpu=None):
        x = x.cuda(gpu)
        return Variable(x, requires_grad=requires_grad)
else:
    def to_var(x, requires_grad=False, vgpu=None):
        return Variable(x, requires_grad=requires_grad)

In [5]:
class Net(nn.Module):
    '''very similar to Nature DQN.'''
    def __init__(self, action_n, input_shape=(1,80,80)):
        super().__init__()
        self.conv = nn.Sequential(nn.Conv2d(input_shape[0],16,kernel_size=8, stride=2),nn.ReLU(),
                                  nn.Conv2d(16,32,kernel_size=4, stride=2),nn.ReLU())
        flatten_size = self._get_flatten_size(input_shape)
        self.fc = nn.Linear(flatten_size, action_n)
    
    def _get_flatten_size(self, shape):
        x = Variable(torch.rand(1, *shape))
        output_feat = self.conv(x)
        n_size = output_feat.view(-1).size(0)
        return n_size
        
    def forward(self, x):
        feat = self.conv(x)
        logit = self.fc(feat.view(feat.size(0),-1))
        return logit

In [6]:
class PolicyGradient:
    
    def __init__(self, model, gamma=0.99, eps=1.e-6, running_gamma=0.99, running_start=0,
#                 episode2thresh=lambda i: 0.05+0.9*np.exp(-1. * i / 100) if i>150 else 0): # eploration will start after 150 episodes
                 episode2thresh=lambda i: 0): # without exploration
        self.model = model
        self.gamma = gamma
        self.eps = eps
        self.log_probs = []
        self.rewards = []
        self.total_rewards = []
        self.running_reward = running_start
        self.running_gamma = running_gamma
        self.episode2thresh = episode2thresh
        
    @property
    def episodes(self):
        return len(self.total_rewards)
        
    def select_action(self,obs):
        self.model.train()
        thresh=self.episode2thresh(self.episodes)
        action, log_prob = select_action(obs, self.model, thresh=thresh)
        self.log_probs.append(log_prob)
        return action
    
    def get_loss_and_clear(self):
        total_reward = sum(self.rewards)
        self.total_rewards.append(total_reward)
        self.running_reward = self.running_gamma*self.running_reward+(1-self.running_gamma)*total_reward
        policy_loss = get_policy_loss(self.log_probs, self.rewards, self.gamma, self.eps)
        del self.log_probs[:]
        del self.rewards[:]
        return policy_loss
    
    def take_action(self, action, env, render=False):
        obs, reward, done, info = env.step(action)
        self.rewards.append(reward)
        if render:
            env.render()
        return obs, reward, done, info
    
    def greedy_policy(self, obs):
        self.model.eval()
        state = to_var(obs)
        prob = self.model(state)
        _, action = prob.max(dim=1)
        return action.data[0]

def select_action(obs, model, thresh=0):
    state = to_var(obs)
    logits = model(state)
    probs = F.softmax(logits, dim=1)
    m = Categorical(probs)
    action = m.sample()
#     if np.random.random()>thresh:
# #         print(probs)
#         try:
#             action = m.sample()
#         except:
#             print(probs,m)
#             raise
#     else:
#         action_space = probs.size(1)
#         action = to_var(torch.from_numpy(np.random.randint(action_space,size=1)))
    return action.data[0],m.log_prob(action)
    
def get_normalized_rewards(rewards, gamma, eps):
    acc = []
    R = 0
    for r in reversed(rewards):
        R = r + gamma * R
        acc.append(R)
    ret = to_var(torch.Tensor(acc[::-1]),requires_grad=False)
    ret = (ret - ret.mean()) / (ret.std()+eps)
#     print(ret)
    return ret

def get_policy_loss(log_probs,rewards, gamma,eps):
    log_probs_v = torch.cat(log_probs)
    rewards_v = get_normalized_rewards(rewards, gamma, eps)
    return -log_probs_v.dot(rewards_v)

In [7]:
net = Net(env.action_space.n, input_shape=(1,output_size,output_size))
if torch.cuda.is_available():
    net = net.cuda()
optimizer = optim.Adam(net.parameters(), lr=learning_rate)
trainer = PolicyGradient(model=net,running_start=-21)
writer = SummaryWriter()

In [9]:
net.load_state_dict(torch.load('tmp/Jan06_23-26-49_amax/episode12500.pth'))

In [10]:
import os

writer_path = list(writer.all_writers.keys())[0]
weight_join = lambda p: os.path.join(writer_path, p)

In [11]:
writer_path

'runs/Jan07_21-47-53_amax'

In [12]:
optimizer.zero_grad()

In [None]:
for episode in trange(100000):
    frame = env.reset()
    last_obs = preprocess(frame)
    curr_obs = preprocess(frame)
    total_reward = 0
    for step in range(100000): # not exceed 10000 steps
        action = trainer.select_action(obs=curr_obs-last_obs)
        frame, reward, done, _ = trainer.take_action(action, env, render=False)
        last_obs = curr_obs
        curr_obs = preprocess(frame)
        total_reward+=reward
        if done:
             break
    if step==100000:
        print("not enough!!!!!!!!!!!!!!!")
    policy_loss = trainer.get_loss_and_clear()
    writer.add_scalar("loss",policy_loss.data[0],episode)
    writer.add_scalar("reward",total_reward,episode)
    policy_loss.backward()
    if (episode+1)%batch_size==0:
        clip_grads(trainer.model,-10,10)
        optimizer.step()
        optimizer.zero_grad()
    running_reward = trainer.running_reward
    if episode%100==0:
        print(episode, total_reward,running_reward)
        torch.save(net.state_dict(), weight_join("episode%s.pth"%episode))
    if running_reward>1:
        break
print("Finished: %s@%s" %(trainer.running_reward,episode))

  0%|          | 1/100000 [00:10<298:33:41, 10.75s/it]

0 -5.0 -20.84


  0%|          | 101/100000 [12:02<202:14:52,  7.29s/it]

100 -8.0 -13.032413136282548


  0%|          | 201/100000 [23:40<191:58:49,  6.93s/it]

200 -4.0 -10.001809221867138


  0%|          | 301/100000 [35:27<194:08:48,  7.01s/it]

300 -10.0 -8.63411327820182


  0%|          | 401/100000 [46:56<194:07:33,  7.02s/it]

400 -10.0 -8.187331489494778


  1%|          | 501/100000 [58:22<206:48:12,  7.48s/it]

500 -8.0 -8.178441019120616


  1%|          | 601/100000 [1:10:18<197:47:41,  7.16s/it]

600 -5.0 -7.9682416582412845


  1%|          | 701/100000 [1:21:32<187:18:26,  6.79s/it]

700 -4.0 -8.423093397550858


  1%|          | 801/100000 [1:33:33<200:07:31,  7.26s/it]

800 -6.0 -8.311304543819134


  1%|          | 901/100000 [1:45:10<221:09:05,  8.03s/it]

900 -4.0 -8.433084789157736


  1%|          | 1001/100000 [1:57:21<209:32:49,  7.62s/it]

1000 -6.0 -8.52754632325572


  1%|          | 1101/100000 [2:09:47<213:57:55,  7.79s/it]

1100 -8.0 -8.325610711996891


  1%|          | 1201/100000 [2:21:40<182:02:03,  6.63s/it]

1200 -4.0 -8.356632539502222


  1%|▏         | 1301/100000 [2:33:48<189:25:18,  6.91s/it]

1300 -12.0 -8.582246544207852


  1%|▏         | 1401/100000 [2:45:58<221:33:09,  8.09s/it]

1400 -8.0 -8.530923027533195


  2%|▏         | 1501/100000 [2:58:10<194:46:35,  7.12s/it]

1500 -6.0 -8.197709990711026


  2%|▏         | 1601/100000 [3:10:42<204:06:03,  7.47s/it]

1600 -10.0 -7.998479547337398


  2%|▏         | 1701/100000 [3:22:36<190:58:38,  6.99s/it]

1700 -7.0 -8.61561329383776


  2%|▏         | 1801/100000 [3:34:51<221:04:04,  8.10s/it]

1800 -1.0 -8.329599594236049


  2%|▏         | 1901/100000 [3:47:03<199:09:50,  7.31s/it]

1900 -9.0 -8.05278513808355


  2%|▏         | 2001/100000 [3:59:05<207:03:36,  7.61s/it]

2000 -1.0 -8.109887921479979


  2%|▏         | 2101/100000 [4:10:56<192:53:03,  7.09s/it]

2100 -9.0 -8.231429999408778


  2%|▏         | 2201/100000 [4:23:23<199:46:20,  7.35s/it]

2200 -7.0 -8.002038181008649


  2%|▏         | 2301/100000 [4:35:33<205:34:51,  7.58s/it]

2300 -1.0 -8.250746936276403


  2%|▏         | 2401/100000 [4:47:51<196:29:22,  7.25s/it]

2400 -7.0 -7.906684185603353


  3%|▎         | 2501/100000 [4:59:54<201:52:52,  7.45s/it]

2500 -3.0 -7.660693073017663


  3%|▎         | 2601/100000 [5:12:04<196:53:14,  7.28s/it]

2600 -9.0 -7.6223856250875945


  3%|▎         | 2701/100000 [5:24:18<201:06:35,  7.44s/it]

2700 -8.0 -7.795413870185883


  3%|▎         | 2801/100000 [5:36:34<194:11:01,  7.19s/it]

2800 -13.0 -7.706416537789798


  3%|▎         | 2901/100000 [5:48:26<193:17:07,  7.17s/it]

2900 -14.0 -7.532067106972514


  3%|▎         | 3001/100000 [6:00:19<192:07:11,  7.13s/it]

3000 -7.0 -7.747952960384762


  3%|▎         | 3101/100000 [6:11:26<187:40:56,  6.97s/it]

3100 -1.0 -7.557195085086698


  3%|▎         | 3201/100000 [6:22:59<183:51:27,  6.84s/it]

3200 -4.0 -7.385133313796213


  3%|▎         | 3301/100000 [6:35:06<197:05:55,  7.34s/it]

3300 -6.0 -7.797089261243785


  3%|▎         | 3401/100000 [6:46:43<193:11:09,  7.20s/it]

3400 -5.0 -7.6859621009065755


  4%|▎         | 3501/100000 [6:58:25<180:44:13,  6.74s/it]

3500 -7.0 -8.231394406242444


  4%|▎         | 3601/100000 [7:10:08<182:40:48,  6.82s/it]

3600 -6.0 -7.822256402553039


  4%|▎         | 3701/100000 [7:21:45<175:59:41,  6.58s/it]

3700 -10.0 -8.194987563980584


  4%|▍         | 3801/100000 [7:33:48<191:32:45,  7.17s/it]

3800 -5.0 -7.169042541294032


  4%|▍         | 3901/100000 [7:45:27<177:28:53,  6.65s/it]

3900 -5.0 -7.258490827642756


  4%|▍         | 4001/100000 [7:57:36<204:38:58,  7.67s/it]

4000 -8.0 -7.831888726158844


  4%|▍         | 4101/100000 [8:09:24<186:35:12,  7.00s/it]

4100 -8.0 -7.670367059530903


  4%|▍         | 4201/100000 [8:20:59<186:20:11,  7.00s/it]

4200 -10.0 -8.326598088136025


  4%|▍         | 4301/100000 [8:32:37<194:20:15,  7.31s/it]

4300 -3.0 -7.6117808856456115


  4%|▍         | 4401/100000 [8:44:39<197:36:00,  7.44s/it]

4400 -7.0 -7.456187583153089


  5%|▍         | 4501/100000 [8:57:12<202:48:39,  7.65s/it]

4500 -6.0 -7.163230348747176


  5%|▍         | 4601/100000 [9:08:54<183:51:14,  6.94s/it]

4600 -11.0 -7.399183454122499


  5%|▍         | 4701/100000 [9:20:36<177:58:25,  6.72s/it]

4700 -10.0 -7.572968738927308


  5%|▍         | 4801/100000 [9:32:28<190:57:28,  7.22s/it]

4800 -13.0 -7.613938380669218


  5%|▍         | 4901/100000 [9:44:09<191:35:06,  7.25s/it]

4900 -11.0 -7.816096808537004


  5%|▌         | 5001/100000 [9:56:13<173:48:26,  6.59s/it]

5000 -7.0 -7.626768933798054


  5%|▌         | 5101/100000 [10:07:16<177:20:37,  6.73s/it]

5100 -6.0 -7.68127496815365


  5%|▌         | 5201/100000 [10:18:55<184:10:15,  6.99s/it]

5200 -11.0 -7.5452996818447255


  5%|▌         | 5301/100000 [10:30:55<189:59:13,  7.22s/it]

5300 -7.0 -7.336478575936303


  5%|▌         | 5401/100000 [10:42:39<191:21:10,  7.28s/it]

5400 -10.0 -7.617077584128175


  6%|▌         | 5501/100000 [10:54:31<186:45:10,  7.11s/it]

5500 -5.0 -7.553453672841033


  6%|▌         | 5601/100000 [11:06:08<192:47:55,  7.35s/it]

5600 2.0 -7.906317176038815


  6%|▌         | 5701/100000 [11:17:22<183:17:44,  7.00s/it]

5700 -6.0 -8.030440908205776


  6%|▌         | 5801/100000 [11:29:01<184:50:15,  7.06s/it]

5800 -9.0 -7.949370637088702


  6%|▌         | 5901/100000 [11:40:50<185:13:54,  7.09s/it]

5900 -6.0 -8.045163910292183


  6%|▌         | 6001/100000 [11:52:28<195:24:19,  7.48s/it]

6000 -1.0 -8.0733257829385


  6%|▌         | 6101/100000 [12:04:07<169:19:25,  6.49s/it]

6100 -13.0 -8.17616577094297


  6%|▌         | 6201/100000 [12:15:38<191:25:11,  7.35s/it]

6200 -2.0 -7.991312355706404


  6%|▋         | 6301/100000 [12:27:23<185:50:51,  7.14s/it]

6300 -12.0 -8.150626034123281


  6%|▋         | 6401/100000 [12:39:21<195:38:45,  7.52s/it]

6400 -4.0 -8.000532765243163


  7%|▋         | 6501/100000 [12:51:16<190:48:12,  7.35s/it]

6500 -5.0 -7.908928253212426


  7%|▋         | 6601/100000 [13:03:12<180:34:43,  6.96s/it]

6600 -9.0 -7.8415335801352315


  7%|▋         | 6701/100000 [13:14:53<190:32:55,  7.35s/it]

6700 -9.0 -7.5081917668391185


  7%|▋         | 6801/100000 [13:26:32<182:08:11,  7.04s/it]

6800 -5.0 -7.411552068144121


  7%|▋         | 6901/100000 [13:38:43<191:23:56,  7.40s/it]

6900 -7.0 -7.386996599168962


  7%|▋         | 7001/100000 [13:50:46<173:42:34,  6.72s/it]

7000 -10.0 -7.472343834113175


  7%|▋         | 7101/100000 [14:03:06<188:57:17,  7.32s/it]

7100 -10.0 -7.607443934811345


  7%|▋         | 7201/100000 [14:15:14<181:24:10,  7.04s/it]

7200 -12.0 -7.702738900261485


  7%|▋         | 7301/100000 [14:27:13<187:13:32,  7.27s/it]

7300 -3.0 -8.070333371412953


  7%|▋         | 7401/100000 [14:39:06<181:36:28,  7.06s/it]

7400 -9.0 -7.822933956063248


  8%|▊         | 7501/100000 [14:51:01<180:37:35,  7.03s/it]

7500 -5.0 -7.7601952369629625


  8%|▊         | 7601/100000 [15:03:31<175:17:31,  6.83s/it]

7600 -13.0 -7.848927104974513


  8%|▊         | 7701/100000 [15:15:22<183:15:49,  7.15s/it]

7700 -5.0 -7.889508087270774


  8%|▊         | 7801/100000 [15:27:08<189:25:57,  7.40s/it]

7800 -13.0 -8.029591508551002


  8%|▊         | 7854/100000 [15:33:50<196:19:25,  7.67s/it]

In [None]:
torch.save(net.state_dict(), weight_join("final.pth"))

In [None]:
plt.plot(trainer.total_rewards)

In [None]:
writer_path