In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [2]:
from tqdm import trange
import itertools
import gym
import torch

In [3]:
import _init_paths

add code root path (with `rllib`).


In [4]:
from rllib.models import ConvNetPV
from rllib.misc import EnhancedWriter
from rllib.actor_critic import ActorCritic

In [5]:
downsample = 2
output_size = 160//downsample

def preprocess(frame):
    '''from karpathy.'''
    I = frame
    I = I[35:195] # crop
    I = I[::downsample,::downsample,0] # downsample by factor of 2
    I[I == 144] = 0 # erase background (background type 1)
    I[I == 109] = 0 # erase background (background type 2)
    I[I != 0] = 1 # everything else (paddles, ball) just set to 1
    tensor = torch.Tensor(I)
    return tensor.unsqueeze(0) #CHW

In [6]:
env = gym.make("Pong-v0")

net = ConvNetPV(input_shape=(1,output_size,output_size), action_n=env.action_space.n)
print(net)
if torch.cuda.is_available():
    net = net.cuda()

agent = ActorCritic(model=net, gamma=0.99, learning_rate=1.e-3, batch_size=10)
writer = EnhancedWriter()

[2018-01-10 02:28:06,932] Making new env: Pong-v0


Initialized Conv2d(1, 32, kernel_size=(8, 8), stride=(4, 4))
Initialized Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
Initialized Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
Initialized Linear(in_features=2304, out_features=512, bias=True)
Initialized Linear(in_features=512, out_features=6, bias=True)
Initialized Linear(in_features=512, out_features=1, bias=True)
Network size: 1255591
ConvNetPV(
  (conv): Sequential(
    (0): Conv2d(1, 32, kernel_size=(8, 8), stride=(4, 4))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): ReLU()
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (5): ReLU()
  )
  (fc): Sequential(
    (0): Linear(in_features=2304, out_features=512, bias=True)
    (1): ReLU()
  )
  (policy_head): Linear(in_features=512, out_features=6, bias=True)
  (value_head): Linear(in_features=512, out_features=1, bias=True)
)


In [7]:
# net.load_state_dict(torch.load('runs/Jan08_15-07-08_amax/episode9100.pth'))

In [8]:
running_reward = best_reward = -21

for episode in trange(100000):
    frame = env.reset()
    last_obs = preprocess(frame)
    curr_obs = preprocess(frame)
    for step in itertools.count(start=1, step=1):
        action, log_prob, state_value = agent.select_action(obs=curr_obs-last_obs)
        frame, reward, done, _ = env.step(action)
        agent.keep_for_grad(log_prob, state_value, reward)
        last_obs = curr_obs
        curr_obs = preprocess(frame)
        if step>=50000: # don't exceed
            print("Seems much but not enough")
            break
        if done:
             break
    agent.step() 
    
    total_reward, n_round, policy_loss, value_loss = agent.history[-1]
    writer.add_scalar("reward",total_reward,episode)
    writer.add_scalar("n_round",n_round,episode)
    writer.add_scalar("policy_loss",policy_loss,episode)
    writer.add_scalar("value_loss",value_loss,episode)
    
    if total_reward>best_reward:
        print("New record:", total_reward)
        best_reward=total_reward
        writer.save(net, "best.pth")
    
    count_gamma = 0.5
    running_reward = count_gamma*running_reward+(1-count_gamma)*total_reward
    if (episode+1)%100==0:
        print(episode, total_reward, running_reward)
        writer.save(net, "episode%s.pth"%episode)
    if running_reward>1:
        break
        
writer.save(net, "final.pth")
writer.export_logs()
print("Finished: %s@%s" %(agent.history[-1],episode))

  0%|          | 1/100000 [00:04<114:47:31,  4.13s/it]

New record: -20.0


  0%|          | 13/100000 [00:42<87:46:02,  3.16s/it]

New record: -18.0


  0%|          | 37/100000 [02:02<103:15:46,  3.72s/it]

New record: -17.0


  0%|          | 100/100000 [05:24<86:51:18,  3.13s/it]

99 -20.0 -20.353415588559955


  0%|          | 184/100000 [10:09<99:03:50,  3.57s/it]

New record: -16.0


  0%|          | 200/100000 [11:04<100:00:39,  3.61s/it]

199 -20.0 -19.97219906563897


  0%|          | 300/100000 [17:05<98:39:38,  3.56s/it] 

299 -20.0 -19.996437434868334


  0%|          | 400/100000 [23:03<96:27:38,  3.49s/it]

399 -18.0 -19.319092088057833


  0%|          | 500/100000 [29:25<101:13:32,  3.66s/it]

499 -21.0 -20.38189649796133


  1%|          | 511/100000 [30:07<111:15:32,  4.03s/it]

New record: -15.0


  1%|          | 600/100000 [36:02<114:51:00,  4.16s/it]

599 -19.0 -18.613086860447417


  1%|          | 680/100000 [41:32<114:31:23,  4.15s/it]

New record: -14.0


  1%|          | 700/100000 [42:59<121:21:52,  4.40s/it]

699 -16.0 -17.272099122479407


  1%|          | 800/100000 [50:22<121:24:08,  4.41s/it]

799 -18.0 -17.61537864224409


  1%|          | 900/100000 [58:03<127:01:22,  4.61s/it]

899 -18.0 -18.453171637465346


  1%|          | 1000/100000 [1:06:23<141:17:47,  5.14s/it]

999 -18.0 -18.411931823749804


  1%|          | 1029/100000 [1:09:00<159:58:42,  5.82s/it]

New record: -13.0


  1%|          | 1075/100000 [1:13:16<177:24:58,  6.46s/it]

New record: -12.0


  1%|          | 1080/100000 [1:13:43<155:33:31,  5.66s/it]

New record: -11.0


  1%|          | 1100/100000 [1:15:44<153:36:13,  5.59s/it]

1099 -18.0 -16.950163307410936


  1%|          | 1133/100000 [1:19:13<195:19:57,  7.11s/it]

New record: -10.0


  1%|          | 1167/100000 [1:22:47<186:05:27,  6.78s/it]

New record: -9.0


  1%|          | 1200/100000 [1:26:42<209:07:48,  7.62s/it]

1199 -16.0 -13.969490186072516


  1%|▏         | 1288/100000 [1:37:21<225:49:54,  8.24s/it]

New record: -8.0


  1%|▏         | 1298/100000 [1:38:32<227:07:00,  8.28s/it]

New record: -7.0


  1%|▏         | 1300/100000 [1:38:47<217:16:39,  7.93s/it]

1299 -15.0 -13.91553941111881


  1%|▏         | 1348/100000 [1:45:32<238:36:39,  8.71s/it]

New record: -6.0


  1%|▏         | 1400/100000 [1:52:57<240:30:40,  8.78s/it]

1399 -12.0 -11.789188212741138


  1%|▏         | 1498/100000 [2:07:33<296:58:06, 10.85s/it]

New record: -4.0


  2%|▏         | 1500/100000 [2:07:55<292:55:07, 10.71s/it]

1499 -13.0 -11.250539308652337


  2%|▏         | 1600/100000 [2:24:31<249:20:50,  9.12s/it]

1599 -19.0 -17.452836733548047


  2%|▏         | 1636/100000 [2:30:08<272:21:21,  9.97s/it]

New record: -3.0


  2%|▏         | 1700/100000 [2:40:59<267:40:12,  9.80s/it]

1699 -12.0 -12.265513548649942


  2%|▏         | 1713/100000 [2:43:13<302:32:54, 11.08s/it]

New record: -1.0


  2%|▏         | 1800/100000 [2:57:47<281:54:47, 10.33s/it]

1799 -7.0 -8.65919489382396


  2%|▏         | 1900/100000 [3:15:30<280:42:15, 10.30s/it]

1899 -6.0 -9.65861048973379


  2%|▏         | 2000/100000 [3:35:00<345:19:11, 12.69s/it]

1999 -14.0 -12.001715493913093


  2%|▏         | 2100/100000 [3:54:54<342:49:26, 12.61s/it]

2099 -10.0 -11.172092490195126


  2%|▏         | 2200/100000 [4:15:10<340:07:03, 12.52s/it]

2199 -9.0 -11.177373676870285


  2%|▏         | 2246/100000 [4:25:02<419:28:21, 15.45s/it]

New record: 3.0


  2%|▏         | 2300/100000 [4:37:20<383:39:23, 14.14s/it]

2299 -11.0 -9.74968344704965


  2%|▏         | 2400/100000 [4:59:22<341:38:08, 12.60s/it]

2399 -12.0 -10.822045662631517


  2%|▎         | 2500/100000 [5:22:36<410:43:37, 15.17s/it]

2499 -6.0 -6.606292393441941


  3%|▎         | 2600/100000 [5:46:33<414:55:07, 15.34s/it]

2599 -9.0 -7.359188908415954


  3%|▎         | 2700/100000 [6:11:18<409:17:08, 15.14s/it]

2699 -9.0 -9.455243362709528


  3%|▎         | 2800/100000 [6:36:13<439:34:57, 16.28s/it]

2799 -4.0 -7.996004723074184


  3%|▎         | 2821/100000 [6:41:36<424:20:53, 15.72s/it]

New record: 4.0


  3%|▎         | 2900/100000 [7:02:10<406:05:21, 15.06s/it]

2899 -6.0 -7.854151488245545


  3%|▎         | 2978/100000 [7:23:05<474:55:18, 17.62s/it]

New record: 7.0


  3%|▎         | 3000/100000 [7:29:15<439:26:24, 16.31s/it]

2999 -10.0 -9.146464312922108


  3%|▎         | 3100/100000 [7:57:10<409:24:12, 15.21s/it]

3099 -9.0 -10.0824142141144


  3%|▎         | 3200/100000 [8:24:20<470:56:51, 17.51s/it]

3199 -7.0 -5.652142101283301


  3%|▎         | 3300/100000 [8:51:37<450:24:44, 16.77s/it]

3299 -5.0 -7.459515084179211


  3%|▎         | 3400/100000 [9:19:18<474:45:47, 17.69s/it]

3399 4.0 -2.39579628390031


  4%|▎         | 3500/100000 [9:48:24<517:35:15, 19.31s/it]

3499 -5.0 -5.914647406649923


  4%|▎         | 3528/100000 [9:56:53<445:42:53, 16.63s/it]

New record: 9.0


  4%|▎         | 3600/100000 [10:17:12<392:24:24, 14.65s/it]

3599 -3.0 -9.121102987448092


  4%|▎         | 3700/100000 [10:47:04<478:55:20, 17.90s/it]

3699 -8.0 -7.864575841432248


  4%|▍         | 3800/100000 [11:15:53<482:25:27, 18.05s/it]

3799 -2.0 -5.362674818875475


  4%|▍         | 3900/100000 [11:46:45<525:23:26, 19.68s/it]

3899 -6.0 -3.5676233704441978


  4%|▍         | 3918/100000 [11:52:56<553:09:39, 20.73s/it]

Finished: [5.0, 8653, -112.47208404541016, 2408.460205078125]@3918


In [9]:
running_reward

2.7200344703228563

In [10]:
episode

3918

In [11]:
running_reward = -10

for episode in trange(3919, 100000):
    frame = env.reset()
    last_obs = preprocess(frame)
    curr_obs = preprocess(frame)
    for step in itertools.count(start=1, step=1):
        action, log_prob, state_value = agent.select_action(obs=curr_obs-last_obs)
        frame, reward, done, _ = env.step(action)
        agent.keep_for_grad(log_prob, state_value, reward)
        last_obs = curr_obs
        curr_obs = preprocess(frame)
        if step>=50000: # don't exceed
            print("Seems much but not enough")
            break
        if done:
             break
    agent.step() 
    
    total_reward, n_round, policy_loss, value_loss = agent.history[-1]
    writer.add_scalar("reward",total_reward,episode)
    writer.add_scalar("n_round",n_round,episode)
    writer.add_scalar("policy_loss",policy_loss,episode)
    writer.add_scalar("value_loss",value_loss,episode)
    
    if total_reward>best_reward:
        print("New record:", total_reward)
        best_reward=total_reward
        writer.save(net, "best.pth")
    
    count_gamma = 0.95
    running_reward = count_gamma*running_reward+(1-count_gamma)*total_reward
    if (episode+1)%100==0:
        print(episode, total_reward, running_reward)
        writer.save(net, "episode%s.pth"%episode)
    if running_reward>1:
        break


  0%|          | 0/96081 [00:00<?, ?it/s][A
  0%|          | 81/96081 [25:42<510:47:57, 19.15s/it]

3999 -4.0 -6.304692098715199


  0%|          | 139/96081 [42:54<464:15:59, 17.42s/it]

New record: 11.0


  0%|          | 181/96081 [55:42<542:58:21, 20.38s/it]

4099 -2.0 -6.556893718657285


  0%|          | 281/96081 [1:28:52<512:04:30, 19.24s/it]

4199 -7.0 -5.281192051148353


  0%|          | 381/96081 [2:02:44<565:13:56, 21.26s/it]

4299 -2.0 -3.3653954866249474


  0%|          | 447/96081 [2:24:34<424:40:20, 15.99s/it]

New record: 12.0


  1%|          | 481/96081 [2:35:28<463:14:48, 17.44s/it]

4399 -11.0 -6.7889682231141935


  1%|          | 581/96081 [3:08:56<501:01:32, 18.89s/it]

4499 7.0 -4.606927478033457


  1%|          | 681/96081 [3:42:44<548:54:23, 20.71s/it]

4599 -7.0 -4.789204420684547


  1%|          | 781/96081 [4:17:45<474:45:45, 17.93s/it]

4699 -11.0 -5.564791654593409


  1%|          | 881/96081 [4:51:59<568:02:07, 21.48s/it]

4799 1.0 -3.8562273657281896


  1%|          | 981/96081 [5:24:17<517:04:22, 19.57s/it]

4899 6.0 -2.9790923381070655


  1%|          | 1081/96081 [5:59:22<653:31:41, 24.77s/it]

4999 1.0 -3.6662262730140727


  1%|          | 1181/96081 [6:34:12<574:25:06, 21.79s/it]

5099 -3.0 -4.518388585091654


  1%|▏         | 1281/96081 [7:08:18<564:18:35, 21.43s/it]

5199 -3.0 -3.2641551856247704


  1%|▏         | 1381/96081 [7:41:38<572:08:04, 21.75s/it]

5299 -3.0 -2.8940433898956437


  2%|▏         | 1481/96081 [8:14:02<512:47:37, 19.51s/it]

5399 -4.0 -3.1532289105649878


  2%|▏         | 1581/96081 [8:46:32<526:22:26, 20.05s/it]

5499 -4.0 -4.755609536876327


  2%|▏         | 1681/96081 [9:21:28<545:10:27, 20.79s/it]

5599 -7.0 -4.402140040030038


  2%|▏         | 1781/96081 [9:56:54<563:42:27, 21.52s/it]

5699 -7.0 -3.806086285567739


  2%|▏         | 1881/96081 [10:32:27<539:29:41, 20.62s/it]

5799 6.0 -2.272635616060908


  2%|▏         | 1981/96081 [11:08:33<621:49:51, 23.79s/it]

5899 5.0 -0.930503998549929


  2%|▏         | 2081/96081 [11:44:56<630:45:29, 24.16s/it]

5999 2.0 -3.60130649227082


  2%|▏         | 2181/96081 [12:21:02<536:35:13, 20.57s/it]

6099 3.0 -2.701209347503945


  2%|▏         | 2281/96081 [12:57:57<590:29:18, 22.66s/it]

6199 -5.0 -2.0352405820376664


  2%|▏         | 2381/96081 [13:35:14<591:04:16, 22.71s/it]

6299 -4.0 -1.4770383806160807


  3%|▎         | 2481/96081 [14:11:41<540:33:47, 20.79s/it]

6399 8.0 0.15361726994744937


  3%|▎         | 2537/96081 [14:32:48<583:34:43, 22.46s/it]

New record: 13.0


  3%|▎         | 2581/96081 [14:49:00<543:49:54, 20.94s/it]

6499 -3.0 -2.1070353953950427


  3%|▎         | 2681/96081 [15:26:16<542:44:48, 20.92s/it]

6599 1.0 -2.49519451900251


  3%|▎         | 2724/96081 [15:42:04<548:01:30, 21.13s/it]

In [12]:
print("Finished: %s@%s" %(agent.history[-1],episode))
writer.export_logs()        
writer.save(net, "final.pth")

Finished: [10.0, 10154, 16.15831184387207, 2717.9013671875]@6643


In [13]:
running_reward

1.2750143365464455

In [14]:
episode

6643

In [15]:
running_reward = -10

for episode in trange(6643+1, 100000):
    frame = env.reset()
    last_obs = preprocess(frame)
    curr_obs = preprocess(frame)
    for step in itertools.count(start=1, step=1):
        action, log_prob, state_value = agent.select_action(obs=curr_obs-last_obs)
        frame, reward, done, _ = env.step(action)
        agent.keep_for_grad(log_prob, state_value, reward)
        last_obs = curr_obs
        curr_obs = preprocess(frame)
        if step>=50000: # don't exceed
            print("Seems much but not enough")
            break
        if done:
             break
    agent.step() 
    
    total_reward, n_round, policy_loss, value_loss = agent.history[-1]
    writer.add_scalar("reward",total_reward,episode)
    writer.add_scalar("n_round",n_round,episode)
    writer.add_scalar("policy_loss",policy_loss,episode)
    writer.add_scalar("value_loss",value_loss,episode)
    
    if total_reward>best_reward:
        print("New record:", total_reward)
        best_reward=total_reward
        writer.save(net, "best.pth")
    
    count_gamma = 0.99
    running_reward = count_gamma*running_reward+(1-count_gamma)*total_reward
    if (episode+1)%100==0:
        print(episode, total_reward, running_reward)
        writer.save(net, "episode%s.pth"%episode)
    if running_reward>5:
        break


  0%|          | 0/93356 [00:00<?, ?it/s][A
  0%|          | 56/93356 [21:04<575:52:04, 22.22s/it]

6699 7.0 -5.849091488371331


  0%|          | 156/93356 [58:52<588:12:09, 22.72s/it]

6799 -3.0 -3.382344437867607


  0%|          | 256/93356 [1:37:10<601:00:01, 23.24s/it]

6899 4.0 -1.5652927005216837


  0%|          | 356/93356 [2:12:19<559:10:29, 21.65s/it]

6999 -3.0 -1.4294474766629957


  0%|          | 456/93356 [2:49:35<591:44:11, 22.93s/it]

7099 5.0 -1.538605701515573


  1%|          | 556/93356 [3:27:17<645:55:48, 25.06s/it]

7199 -5.0 -1.2825904001118926


  1%|          | 656/93356 [4:05:24<534:23:07, 20.75s/it]

7299 -5.0 -1.0097033474291597


  1%|          | 756/93356 [4:42:15<588:35:45, 22.88s/it]

7399 -4.0 -0.03136116722193831


  1%|          | 856/93356 [5:21:20<561:57:00, 21.87s/it]

7499 3.0 0.1462778061186727


  1%|          | 896/93356 [5:37:16<539:25:18, 21.00s/it]

New record: 14.0


  1%|          | 956/93356 [6:00:43<558:50:38, 21.77s/it]

7599 -6.0 0.29849294361602824


  1%|          | 1056/93356 [6:39:50<651:52:38, 25.43s/it]

7699 2.0 0.9680810661554184


  1%|          | 1156/93356 [7:18:07<605:42:38, 23.65s/it]

7799 -5.0 0.10308640148001505


  1%|▏         | 1256/93356 [7:57:47<540:22:44, 21.12s/it]

7899 10.0 0.08408195804490814


  1%|▏         | 1332/93356 [8:27:02<551:11:22, 21.56s/it]

New record: 16.0


  1%|▏         | 1356/93356 [8:36:17<557:18:00, 21.81s/it]

7999 6.0 0.7552959374297694


  2%|▏         | 1456/93356 [9:16:35<588:11:25, 23.04s/it]

8099 10.0 1.2397442699373127


  2%|▏         | 1556/93356 [9:54:48<574:24:06, 22.53s/it]

8199 9.0 0.6718811518198331


  2%|▏         | 1655/93356 [10:33:48<545:57:00, 21.43s/it]

New record: 17.0


  2%|▏         | 1656/93356 [10:34:13<570:02:10, 22.38s/it]

8299 -11.0 1.7780999794900023


  2%|▏         | 1756/93356 [11:12:18<631:54:30, 24.83s/it]

8399 -1.0 1.5795190524359963


  2%|▏         | 1856/93356 [11:51:47<639:36:22, 25.16s/it]

8499 3.0 1.9842638450803738


  2%|▏         | 1956/93356 [12:31:10<624:41:41, 24.61s/it]

8599 -1.0 1.19849782385303


  2%|▏         | 2056/93356 [13:11:44<652:11:49, 25.72s/it]

8699 2.0 1.4246113243835503


  2%|▏         | 2156/93356 [13:51:20<608:41:55, 24.03s/it]

8799 -8.0 2.2592050116782865


  2%|▏         | 2256/93356 [14:32:05<599:18:25, 23.68s/it]

8899 4.0 2.9687476974015192


  3%|▎         | 2356/93356 [15:12:21<655:20:02, 25.93s/it]

8999 5.0 2.756019412298607


  3%|▎         | 2456/93356 [15:51:51<496:48:31, 19.68s/it]

9099 4.0 2.567475843976699


  3%|▎         | 2556/93356 [16:32:25<765:08:33, 30.34s/it]

9199 -2.0 2.4953401955976355


  3%|▎         | 2656/93356 [17:15:21<637:50:13, 25.32s/it]

9299 5.0 2.106735628859486


  3%|▎         | 2756/93356 [17:57:48<570:15:52, 22.66s/it]

9399 9.0 2.7511312145988827


  3%|▎         | 2856/93356 [18:39:17<630:06:01, 25.06s/it]

9499 9.0 3.363849409802315


  3%|▎         | 2956/93356 [19:20:55<666:39:40, 26.55s/it]

9599 1.0 3.0478336056277078


  3%|▎         | 3056/93356 [20:03:13<669:00:41, 26.67s/it]

9699 2.0 2.819328849286409


  3%|▎         | 3156/93356 [20:43:19<594:17:57, 23.72s/it]

9799 1.0 2.964009671459623


  3%|▎         | 3256/93356 [21:25:12<592:39:41, 23.68s/it]

9899 12.0 3.518997616102268


  4%|▎         | 3356/93356 [22:05:46<517:02:13, 20.68s/it]

9999 -3.0 3.5779659358849925


  4%|▎         | 3456/93356 [22:45:31<558:24:46, 22.36s/it]

10099 -10.0 3.629981065794688


  4%|▍         | 3556/93356 [23:28:33<725:20:14, 29.08s/it]

10199 4.0 3.6948076060096


  4%|▍         | 3656/93356 [24:09:18<547:12:23, 21.96s/it]

10299 -5.0 3.9129904317726236


  4%|▍         | 3688/93356 [24:22:22<593:32:15, 23.83s/it]

New record: 18.0


  4%|▍         | 3756/93356 [24:50:19<513:32:15, 20.63s/it]

10399 8.0 3.899731204935927


  4%|▍         | 3856/93356 [25:31:48<658:57:29, 26.51s/it]

10499 8.0 3.6788725348311964


  4%|▍         | 3956/93356 [26:13:06<598:50:32, 24.11s/it]

10599 9.0 4.143268478264385


  4%|▍         | 4056/93356 [26:55:18<683:46:18, 27.57s/it]

10699 3.0 4.414005820162901


  4%|▍         | 4156/93356 [27:36:02<641:34:54, 25.89s/it]

10799 3.0 4.38303863150212


  5%|▍         | 4256/93356 [28:17:07<580:50:56, 23.47s/it]

10899 9.0 4.767271184784362


  5%|▍         | 4267/93356 [28:21:09<485:00:51, 19.60s/it]

In [16]:
print("Finished: %s@%s" %(agent.history[-1],episode))
writer.export_logs()        
writer.save(net, "final.pth")

Finished: [18.0, 7553, -198.44850158691406, 1888.2464599609375]@10911


In [17]:
running_reward

5.1138222762730585