In [1]:
import os
import gym
import numpy as np
from tqdm import trange
import itertools
from tensorboardX import SummaryWriter

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.autograd import Variable
from torch.distributions import Categorical

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [3]:
downsample = 2
output_size = 160//downsample

def preprocess(frame):
    '''from karpathy.'''
    I = frame
    I = I[35:195] # crop
    I = I[::downsample,::downsample,0] # downsample by factor of 2
    I[I == 144] = 0 # erase background (background type 1)
    I[I == 109] = 0 # erase background (background type 2)
    I[I != 0] = 1 # everything else (paddles, ball) just set to 1
    tensor = torch.from_numpy(I).float()
    return tensor.unsqueeze(0).unsqueeze(0) #BCHW

if torch.cuda.is_available():
    def to_var(x, requires_grad=False, gpu=None):
        x = x.cuda(gpu)
        return Variable(x, requires_grad=requires_grad)
else:
    def to_var(x, requires_grad=False, vgpu=None):
        return Variable(x, requires_grad=requires_grad)


def clip_grads(net, low=-10, high=10):
    """Gradient clipping to the range [low, high]."""
    for p in net.parameters():
        if p.grad is not None:
            p.grad.data.clamp_(low, high)

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.kaiming_uniform(m.weight.data)
        nn.init.constant(m.bias.data,0)
        print("Initialized", m)
        
def total_weights(net):
    '''Count total weights size.'''
    ret = 0
    for p in net.parameters():
        ret+=p.data.cpu().numpy().size
    return ret

In [4]:
class REINFORCE:
    '''Implement REINFORCE algorithm.'''
    
    def __init__(self, model, gamma=0.99, learning_rate=1.e-3, batch_size=10):
        self.model = model
        self.gamma = gamma
        self.optimizer = Adam(model.parameters(), lr=learning_rate)
        self.optimizer.zero_grad() # need or not?
        self.batch_size=batch_size
        
        self.log_probs = []
        self.rewards = []
        
        self.history = []
        
    @property
    def episode(self):
        return len(self.history)
        
    def select_action(self, obs):
        self.model.train()
        state = to_var(obs)
        logits = self.model(state)
        probs = F.softmax(logits, dim=1)
        m = Categorical(probs)
        action = m.sample()
        log_prob = m.log_prob(action)
        return action, log_prob
    
    def keep_for_grad(self, log_prob, reward):
        self.log_probs.append(log_prob)
        self.rewards.append(reward)
    
    def accumulate_policy_grad(self):
        policy_loss = get_policy_loss(self.log_probs, self.rewards, self.gamma)
        
        self.history.append([sum(self.rewards), # total_reward
                             len(self.rewards), # n_round
                             policy_loss.data[0]]) # train_loss
        
        policy_loss.backward()
        del self.log_probs[:]
        del self.rewards[:]
        
    def train(self):
        clip_grads(self.model,-10,10)
        self.optimizer.step()
        self.optimizer.zero_grad()
        
    def step(self):
        self.accumulate_policy_grad()
        episode = self.episode
        if episode>0 and episode%self.batch_size==0:
            self.train()
    
    def play(self, obs):
        self.model.eval()
        state = to_var(obs)
        prob = self.model(state)
        _, action = prob.max(dim=1)
        return action.data[0]

def get_discounted_rewards(rewards, gamma):
    acc = []
    R = 0
    for r in reversed(rewards):
        R = r + gamma * R
        acc.append(R)
    ret = np.array(acc[::-1])
    return ret

def get_normalized_rewards(rewards, gamma):
    ret = get_discounted_rewards(rewards, gamma)
    return (ret-ret.mean()) / (ret.std()+np.finfo(np.float32).eps)

def get_policy_loss(log_probs,rewards, gamma):
    ret = 0
    normalized_rewards = get_normalized_rewards(rewards, gamma)
    for log_prob, reward in zip(log_probs, normalized_rewards):
        ret -= log_prob*reward # it's less memory consuming than dot product
    return ret

In [5]:
class EnhancedWriter(SummaryWriter):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.logdir = list(self.all_writers.keys())[0]
        
    def in_logdir(self, path):
        return os.path.join(self.logdir, path)
        
    def save(self, model, path):
        torch.save(model.state_dict(), self.in_logdir(path))
        
    def export_logs(self, filename='training.json'):
        self.export_scalars_to_json(self.in_logdir(filename))

In [6]:
class Net(nn.Module):
    '''very similar to Nature DQN.'''
    def __init__(self, action_n, input_shape=(1,80,80)):
        super().__init__()
        self.conv = nn.Sequential(nn.Conv2d(input_shape[0],32,kernel_size=8, stride=4),nn.ReLU(),
                                  nn.Conv2d(32,64,kernel_size=4, stride=2),nn.ReLU(),
                                  nn.Conv2d(64,64,kernel_size=3, stride=1),nn.ReLU(),)
        flatten_size = self._get_flatten_size(input_shape)
        self.fc = nn.Sequential(nn.Linear(flatten_size, 512),nn.ReLU(),
                               nn.Linear(512, action_n))
        self.apply(weights_init)
        print("Network size:", total_weights(self))
    
    def _get_flatten_size(self, shape):
        x = Variable(torch.rand(1, *shape))
        output_feat = self.conv(x)
        n_size = output_feat.view(-1).size(0)
        return n_size
        
    def forward(self, x):
        feat = self.conv(x)
        logit = self.fc(feat.view(feat.size(0),-1))
        return logit

In [7]:
env = gym.make("Pong-v0")

net = Net(env.action_space.n, input_shape=(1,output_size,output_size))
print(net)
if torch.cuda.is_available():
    net = net.cuda()

agent = REINFORCE(model=net, gamma=0.99, learning_rate=1.e-3, batch_size=10)
writer = EnhancedWriter()

[2018-01-09 19:14:37,917] Making new env: Pong-v0


Initialized Conv2d(1, 32, kernel_size=(8, 8), stride=(4, 4))
Initialized Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
Initialized Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
Initialized Linear(in_features=2304, out_features=512, bias=True)
Initialized Linear(in_features=512, out_features=6, bias=True)
Network size: 1255078
Net(
  (conv): Sequential(
    (0): Conv2d(1, 32, kernel_size=(8, 8), stride=(4, 4))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): ReLU()
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (5): ReLU()
  )
  (fc): Sequential(
    (0): Linear(in_features=2304, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=6, bias=True)
  )
)


In [8]:
# net.load_state_dict(torch.load('runs/Jan08_15-07-08_amax/episode9100.pth'))

In [9]:
running_reward = best_reward = -21

for episode in trange(100000):
    frame = env.reset()
    last_obs = preprocess(frame)
    curr_obs = preprocess(frame)
    for step in itertools.count(start=1, step=1):
        action, log_prob = agent.select_action(obs=curr_obs-last_obs)
        frame, reward, done, _ = env.step(action)
        agent.keep_for_grad(log_prob, reward)
        last_obs = curr_obs
        curr_obs = preprocess(frame)
        if step>=50000: # don't exceed
            print("Seems much but not enough")
            break
        if done:
             break
    agent.step() 
    
    total_reward, n_round, train_loss = agent.history[-1]
    writer.add_scalar("reward",total_reward,episode)
    writer.add_scalar("n_round",n_round,episode)
    writer.add_scalar("loss",train_loss,episode)
    
    if total_reward>best_reward:
        print("New record:", total_reward)
        best_reward=total_reward
        writer.save(net, "best.pth")
    
    count_gamma = 0.5
    running_reward = count_gamma*running_reward+(1-count_gamma)*total_reward
    if (episode+1)%100==0:
        print(episode, total_reward, running_reward)
        writer.save(net, "episode%s.pth"%episode)
    if running_reward>1:
        break

  0%|          | 1/100000 [00:05<144:01:54,  5.19s/it]

New record: -20.0


  0%|          | 10/100000 [00:30<82:56:20,  2.99s/it]

New record: -19.0


  0%|          | 25/100000 [01:13<86:29:33,  3.11s/it]

New record: -18.0


  0%|          | 96/100000 [04:51<87:26:29,  3.15s/it]

New record: -17.0


  0%|          | 100/100000 [05:03<82:37:00,  2.98s/it]

99 -20.0 -20.2976141739793


  0%|          | 200/100000 [10:25<93:44:16,  3.38s/it]

199 -21.0 -20.92431060192001


  0%|          | 284/100000 [15:18<98:39:26,  3.56s/it]

New record: -16.0


  0%|          | 296/100000 [16:00<100:46:03,  3.64s/it]

New record: -15.0


  0%|          | 300/100000 [16:17<114:33:33,  4.14s/it]

299 -18.0 -18.7224849146816


  0%|          | 400/100000 [23:22<139:09:44,  5.03s/it]

399 -20.0 -19.645631841004548


  0%|          | 500/100000 [31:45<139:28:50,  5.05s/it]

499 -19.0 -18.602839319296155


  1%|          | 600/100000 [40:15<175:16:59,  6.35s/it]

599 -19.0 -19.140439157231377


  1%|          | 621/100000 [42:25<178:18:32,  6.46s/it]

New record: -14.0


  1%|          | 700/100000 [51:44<213:37:58,  7.75s/it]

699 -16.0 -15.933442829963996


  1%|          | 721/100000 [54:30<221:51:13,  8.04s/it]

New record: -13.0


  1%|          | 748/100000 [57:53<209:39:11,  7.60s/it]

New record: -12.0


  1%|          | 778/100000 [1:01:31<213:59:01,  7.76s/it]

New record: -11.0


  1%|          | 800/100000 [1:04:26<218:42:30,  7.94s/it]

799 -15.0 -15.983631565836342


  1%|          | 900/100000 [1:18:37<242:35:16,  8.81s/it]

899 -19.0 -18.09564634228236


  1%|          | 924/100000 [1:22:16<284:07:06, 10.32s/it]

New record: -9.0


  1%|          | 1000/100000 [1:34:31<264:31:52,  9.62s/it]

999 -19.0 -17.287637563386372


  1%|          | 1100/100000 [1:50:34<289:29:02, 10.54s/it]

1099 -17.0 -17.312852031121487


  1%|          | 1200/100000 [2:08:19<279:09:06, 10.17s/it]

1199 -21.0 -19.913758694523374


  1%|▏         | 1300/100000 [2:27:57<325:23:04, 11.87s/it]

1299 -20.0 -17.8506343559239


  1%|▏         | 1400/100000 [2:45:42<311:33:40, 11.38s/it]

1399 -14.0 -15.427046832570436


  2%|▏         | 1500/100000 [3:05:05<289:53:29, 10.60s/it]

1499 -20.0 -18.435528587211174


  2%|▏         | 1600/100000 [3:22:55<300:25:53, 10.99s/it]

1599 -14.0 -16.2024594468359


  2%|▏         | 1700/100000 [3:43:23<334:29:29, 12.25s/it]

1699 -16.0 -16.580898013274286


  2%|▏         | 1800/100000 [4:03:14<338:53:27, 12.42s/it]

1799 -18.0 -16.7427373020196


  2%|▏         | 1900/100000 [4:23:48<327:48:09, 12.03s/it]

1899 -11.0 -13.74034576600801


  2%|▏         | 2000/100000 [4:42:49<291:35:53, 10.71s/it]

1999 -15.0 -16.515169725012466


  2%|▏         | 2100/100000 [5:00:12<298:22:26, 10.97s/it]

2099 -17.0 -16.814032393751276


  2%|▏         | 2200/100000 [5:18:16<285:25:50, 10.51s/it]

2199 -19.0 -17.9540573704627


  2%|▏         | 2300/100000 [5:36:55<281:07:10, 10.36s/it]

2299 -19.0 -17.54474467166677


  2%|▏         | 2400/100000 [5:55:34<298:25:16, 11.01s/it]

2399 -20.0 -17.346930604076817


  2%|▎         | 2500/100000 [6:14:17<290:19:57, 10.72s/it]

2499 -17.0 -17.01660969207216


  3%|▎         | 2600/100000 [6:32:46<261:09:09,  9.65s/it]

2599 -14.0 -16.292193327629143


  3%|▎         | 2700/100000 [6:52:38<338:04:47, 12.51s/it]

2699 -19.0 -16.55105589742969


  3%|▎         | 2800/100000 [7:13:26<322:13:35, 11.93s/it]

2799 -17.0 -17.015846882834712


  3%|▎         | 2900/100000 [7:35:17<332:33:14, 12.33s/it]

2899 -15.0 -15.291144680143102


  3%|▎         | 2916/100000 [7:38:49<398:02:39, 14.76s/it]

New record: -2.0


  3%|▎         | 3000/100000 [7:57:22<347:12:38, 12.89s/it]

2999 -17.0 -16.02327810686349


  3%|▎         | 3100/100000 [8:20:13<356:15:56, 13.24s/it]

3099 -15.0 -15.289964318518123


  3%|▎         | 3200/100000 [8:43:09<367:29:16, 13.67s/it]

3199 -18.0 -16.889794244201926


  3%|▎         | 3300/100000 [9:06:29<396:40:27, 14.77s/it]

3299 -16.0 -14.668661889610481


  3%|▎         | 3400/100000 [9:29:34<382:21:29, 14.25s/it]

3399 -13.0 -13.266441200111718


  4%|▎         | 3500/100000 [9:52:29<388:37:16, 14.50s/it]

3499 -13.0 -12.561154219306305


  4%|▎         | 3600/100000 [10:17:14<355:23:25, 13.27s/it]

3599 -17.0 -14.871837925984414


  4%|▎         | 3700/100000 [10:41:57<412:36:05, 15.42s/it]

3699 -13.0 -12.461700242692775


  4%|▍         | 3800/100000 [11:06:51<358:44:52, 13.43s/it]

3799 -19.0 -17.088145036940727


  4%|▍         | 3900/100000 [11:31:03<348:44:51, 13.06s/it]

3899 -17.0 -15.846184231914936


  4%|▍         | 4000/100000 [11:56:01<368:27:36, 13.82s/it]

3999 -19.0 -17.07635015850931


  4%|▍         | 4100/100000 [12:20:55<472:04:16, 17.72s/it]

4099 -11.0 -12.213866196010654


  4%|▍         | 4200/100000 [12:45:42<411:20:15, 15.46s/it]

4199 -13.0 -13.33822347881367


  4%|▍         | 4300/100000 [13:09:37<376:26:28, 14.16s/it]

4299 -13.0 -12.453809850357969


  4%|▍         | 4400/100000 [13:31:55<359:23:09, 13.53s/it]

4399 -13.0 -12.358585120054467


  4%|▍         | 4500/100000 [13:56:18<424:15:37, 15.99s/it]

4499 -11.0 -13.046204241265706


  5%|▍         | 4600/100000 [14:20:51<343:04:13, 12.95s/it]

4599 -14.0 -15.782924684691377


  5%|▍         | 4700/100000 [14:46:36<467:51:46, 17.67s/it]

4699 -14.0 -12.928140303841124


  5%|▍         | 4800/100000 [15:11:47<376:27:22, 14.24s/it]

4799 -17.0 -13.349824115962084


  5%|▍         | 4900/100000 [15:35:54<378:23:38, 14.32s/it]

4899 -17.0 -14.503791634820644


  5%|▌         | 5000/100000 [16:02:12<385:55:43, 14.62s/it]

4999 -11.0 -12.41323747329914


  5%|▌         | 5100/100000 [16:26:45<367:03:12, 13.92s/it]

5099 -15.0 -14.663402492003609


  5%|▌         | 5200/100000 [16:50:46<350:48:50, 13.32s/it]

5199 -12.0 -12.871940797181928


  5%|▌         | 5209/100000 [16:52:58<386:25:51, 14.68s/it]

New record: -1.0


  5%|▌         | 5279/100000 [17:09:18<381:43:41, 14.51s/it]

New record: 3.0


  5%|▌         | 5300/100000 [17:14:00<357:20:36, 13.58s/it]

5299 -8.0 -7.983172181405186


  5%|▌         | 5400/100000 [17:36:23<353:34:48, 13.46s/it]

5399 -11.0 -10.351101178820226


  5%|▌         | 5417/100000 [17:40:20<377:47:25, 14.38s/it]

New record: 7.0


  6%|▌         | 5500/100000 [17:59:04<333:59:11, 12.72s/it]

5499 -9.0 -8.197230587423086


  6%|▌         | 5505/100000 [18:00:15<373:38:40, 14.23s/it]

New record: 12.0


AttributeError: 'REINFORCE' object has no attribute 'running_reward'

In [12]:
episode

5505

In [19]:
running_reward = -10

for episode in trange(5506, 100000):
    frame = env.reset()
    last_obs = preprocess(frame)
    curr_obs = preprocess(frame)
    for step in itertools.count(start=1, step=1):
        action, log_prob = agent.select_action(obs=curr_obs-last_obs)
        frame, reward, done, _ = env.step(action)
        agent.keep_for_grad(log_prob, reward)
        last_obs = curr_obs
        curr_obs = preprocess(frame)
        if step>=50000: # don't exceed
            print("Seems much but not enough")
            break
        if done:
             break
    agent.step() 
    
    total_reward, n_round, train_loss = agent.history[-1]
    writer.add_scalar("reward",total_reward,episode)
    writer.add_scalar("n_round",n_round,episode)
    writer.add_scalar("loss",train_loss,episode)
    
    if total_reward>best_reward:
        print("New record:", total_reward)
        best_reward=total_reward
        writer.save(net, "best.pth")
    
    count_gamma = 0.95
    running_reward = count_gamma*running_reward+(1-count_gamma)*total_reward
    if (episode+1)%100==0:
        print(episode, total_reward, running_reward)
        writer.save(net, "episode%s.pth"%episode)
    if running_reward>1:
        break


  0%|          | 0/94494 [00:00<?, ?it/s][A
  0%|          | 94/94494 [21:28<398:41:46, 15.20s/it]

5599 -1.0 -2.461317697216149


  0%|          | 160/94494 [35:51<320:53:54, 12.25s/it]

New record: 13.0


  0%|          | 194/94494 [42:28<291:15:19, 11.12s/it]

5699 -10.0 -2.9017533880251607


  0%|          | 223/94494 [48:01<266:51:28, 10.19s/it]

New record: 15.0


  0%|          | 294/94494 [1:01:33<265:28:16, 10.15s/it]

5799 9.0 -0.08577522061825438


  0%|          | 299/94494 [1:02:34<307:34:39, 11.76s/it]

In [20]:
print("Finished: %s@%s" %(agent.history[-1],episode))
writer.export_logs()        
writer.save(net, "final.pth")

Finished: [11.0, 4205, -15.009461402893066]@5805


In [21]:
episode

5805

In [22]:
running_reward = -10

for episode in trange(5805, 100000):
    frame = env.reset()
    last_obs = preprocess(frame)
    curr_obs = preprocess(frame)
    for step in itertools.count(start=1, step=1):
        action, log_prob = agent.select_action(obs=curr_obs-last_obs)
        frame, reward, done, _ = env.step(action)
        agent.keep_for_grad(log_prob, reward)
        last_obs = curr_obs
        curr_obs = preprocess(frame)
        if step>=50000: # don't exceed
            print("Seems much but not enough")
            break
        if done:
             break
    agent.step() 
    
    total_reward, n_round, train_loss = agent.history[-1]
    writer.add_scalar("reward",total_reward,episode)
    writer.add_scalar("n_round",n_round,episode)
    writer.add_scalar("loss",train_loss,episode)
    
    if total_reward>best_reward:
        print("New record:", total_reward)
        best_reward=total_reward
        writer.save(net, "best.pth")
    
    count_gamma = 0.95
    running_reward = count_gamma*running_reward+(1-count_gamma)*total_reward
    if (episode+1)%100==0:
        print(episode, total_reward, running_reward)
        writer.save(net, "episode%s.pth"%episode)
    if running_reward>18:
        break


  0%|          | 0/94195 [00:00<?, ?it/s][A
  0%|          | 95/94195 [19:30<315:26:15, 12.07s/it]

5899 6.0 0.6901370883556701


  0%|          | 156/94195 [31:40<294:21:09, 11.27s/it]

New record: 18.0


  0%|          | 195/94195 [39:22<304:39:31, 11.67s/it]

5999 4.0 1.7828627538012518


  0%|          | 295/94195 [57:01<262:38:15, 10.07s/it]

6099 -6.0 1.514309572328634


  0%|          | 345/94195 [1:06:09<228:54:21,  8.78s/it]

New record: 19.0


  0%|          | 395/94195 [1:14:59<270:11:31, 10.37s/it]

6199 9.0 0.465709546885731


  1%|          | 495/94195 [1:33:58<329:50:31, 12.67s/it]

6299 -4.0 2.7435254170887347


  1%|          | 595/94195 [1:53:49<323:47:32, 12.45s/it]

6399 -5.0 2.6634100733174106


  1%|          | 695/94195 [2:12:57<285:17:40, 10.98s/it]

6499 11.0 3.9745351827514455


  1%|          | 795/94195 [2:32:28<318:32:03, 12.28s/it]

6599 7.0 4.009885103215289


  1%|          | 895/94195 [2:52:41<329:50:34, 12.73s/it]

6699 1.0 3.662860754299958


  1%|          | 995/94195 [3:11:56<285:43:45, 11.04s/it]

6799 11.0 7.364418368505788


  1%|          | 1095/94195 [3:30:59<349:16:27, 13.51s/it]

6899 1.0 5.223480829441586


  1%|▏         | 1195/94195 [3:50:55<309:58:02, 12.00s/it]

6999 -3.0 4.448711442302216


  1%|▏         | 1295/94195 [4:09:53<294:39:05, 11.42s/it]

7099 6.0 5.483560037147505


  1%|▏         | 1395/94195 [4:29:30<326:25:32, 12.66s/it]

7199 5.0 2.9043812373964784


  2%|▏         | 1469/94195 [4:43:47<307:40:36, 11.95s/it]

KeyboardInterrupt: 

In [25]:
# manually stopped
running_reward

3.673095292282504

In [26]:
print("Finished: %s@%s" %(agent.history[-1],episode))
writer.export_logs()        
writer.save(net, "final.pth")

Finished: [3.0, 6611, -68.62922668457031]@7274
