In [1]:
import gym
import numpy as np
import matplotlib.pyplot as plt
from tqdm import trange
from tensorboardX import SummaryWriter

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.distributions import Categorical

from torchvision import transforms

%matplotlib inline

In [2]:
env = gym.make("Pong-v0")

[2018-01-05 01:56:21,493] Making new env: Pong-v0


In [3]:
downsample = 2

def preprocess(frame):
    '''from karpathy.'''
    I = frame
    I = I[35:195] # crop
    I = I[::downsample,::downsample,0] # downsample by factor of 2
    I[I == 144] = 0 # erase background (background type 1)
    I[I == 109] = 0 # erase background (background type 2)
    I[I != 0] = 1 # everything else (paddles, ball) just set to 1
    tensor = torch.from_numpy(I).float()
    return tensor.unsqueeze(0).unsqueeze(0) #BCHW

def clip_grads(net, low=-10, high=10):
    """Gradient clipping to the range [low, high]."""
    parameters = [param for param in net.parameters()
                  if param.grad is not None]
    for p in parameters:
        p.grad.data.clamp_(low, high)
        
if torch.cuda.is_available():
    def to_var(x, requires_grad=False, gpu=None):
        x = x.cuda(gpu)
        return Variable(x, requires_grad=requires_grad)
else:
    def to_var(x, requires_grad=False, vgpu=None):
        return Variable(x, requires_grad=requires_grad)

In [4]:
class Net(nn.Module):
    def __init__(self, action_n):
        super().__init__()
        self.conv = nn.Sequential(nn.Conv2d(1,16,kernel_size=5),nn.ReLU(),nn.MaxPool2d(kernel_size=2),
#                                   nn.Conv2d(32,128,kernel_size=5),nn.ReLU(),nn.MaxPool2d(kernel_size=2),
                                  nn.Conv2d(16,64,kernel_size=5),nn.ReLU(),nn.AdaptiveMaxPool2d(output_size=1))
        self.fc = nn.Linear(64, action_n)
        
    def forward(self, x):
        feat = self.conv(x)
        logit = self.fc(feat.view(feat.size(0),-1))
        return logit

In [5]:
class PolicyGradient:
    
    def __init__(self, model, gamma=0.99, eps=1.e-6, running_gamma=0.99, running_start=0,
#                 episode2thresh=lambda i: 0.05+0.9*np.exp(-1. * i / 100) if i>150 else 0): # eploration will start after 150 episodes
                 episode2thresh=lambda i: 0): # without exploration
        self.model = model
        self.gamma = gamma
        self.eps = eps
        self.log_probs = []
        self.rewards = []
        self.total_rewards = []
        self.running_reward = running_start
        self.running_gamma = running_gamma
        self.episode2thresh = episode2thresh
        
    @property
    def episodes(self):
        return len(self.total_rewards)
        
    def select_action(self,obs):
        self.model.train()
        thresh=self.episode2thresh(self.episodes)
        action, log_prob = select_action(obs, self.model, thresh=thresh)
        self.log_probs.append(log_prob)
        return action
    
    def get_loss_and_clear(self):
        total_reward = sum(self.rewards)
        self.total_rewards.append(total_reward)
        self.running_reward = self.running_gamma*self.running_reward+(1-self.running_gamma)*total_reward
        policy_loss = get_policy_loss(self.log_probs, self.rewards, self.gamma, self.eps)
        del self.log_probs[:]
        del self.rewards[:]
        return policy_loss
    
    def take_action(self, action, env, render=False):
        obs, reward, done, info = env.step(action)
        self.rewards.append(reward)
        if render:
            env.render()
        return obs, reward, done, info
    
    def greedy_policy(self, obs):
        self.model.eval()
        state = to_var(obs)
        prob = self.model(state)
        _, action = prob.max(dim=1)
        return action.data[0]

def select_action(obs, model, thresh=0):
    state = to_var(obs)
    logits = model(state)
    probs = F.softmax(logits, dim=1)
    m = Categorical(probs)
    action = m.sample()
#     if np.random.random()>thresh:
# #         print(probs)
#         try:
#             action = m.sample()
#         except:
#             print(probs,m)
#             raise
#     else:
#         action_space = probs.size(1)
#         action = to_var(torch.from_numpy(np.random.randint(action_space,size=1)))
    return action.data[0],m.log_prob(action)
    
def get_normalized_rewards(rewards, gamma, eps):
    acc = []
    R = 0
    for r in reversed(rewards):
        R = r + gamma * R
        acc.append(R)
    ret = to_var(torch.Tensor(acc[::-1]),requires_grad=False)
    ret = (ret - ret.mean()) / (ret.std()+eps)
#     print(ret)
    return ret

def get_policy_loss(log_probs,rewards, gamma,eps):
    log_probs_v = torch.cat(log_probs)
    rewards_v = get_normalized_rewards(rewards, gamma, eps)
    return -log_probs_v.dot(rewards_v)

In [6]:
net = Net(env.action_space.n)
if torch.cuda.is_available():
    net = net.cuda()
optimizer = optim.Adam(net.parameters(), lr=1.e-4,weight_decay=0.001)
trainer = PolicyGradient(model=net,running_start=-21)
writer = SummaryWriter()

In [7]:
import os

writer_path = list(writer.all_writers.keys())[0]
weight_join = lambda p: os.path.join(writer_path, p)

In [8]:
for episode in trange(100000):
    frame = env.reset()
    last_obs = preprocess(frame)
    curr_obs = preprocess(frame)
    total_reward = 0
    for step in range(100000): # not exceed 10000 steps
        action = trainer.select_action(obs=curr_obs-last_obs)
        frame, reward, done, _ = trainer.take_action(action, env, render=False)
        last_obs = curr_obs
        curr_obs = preprocess(frame)
        total_reward+=reward
        if done:
             break
    if step==100000:
        print("not enough!!!!!!!!!!!!!!!")
    policy_loss = trainer.get_loss_and_clear()
    writer.add_scalar("loss",policy_loss.data[0],episode)
    writer.add_scalar("reward",total_reward,episode)
#     print(policy_loss)
    optimizer.zero_grad()
    policy_loss.backward()
    clip_grads(trainer.model,-5,5)
    optimizer.step()
    running_reward = trainer.running_reward
    if episode%100==0:
        print(episode, total_reward,running_reward)
        torch.save(net.state_dict(), weight_join("episode%s.pth"%episode))
    if running_reward>1:
        break
print("Finished: %s@%s" %(trainer.running_reward,episode))

  0%|          | 1/100000 [00:04<116:21:31,  4.19s/it]

0 -21.0 -21.0


  0%|          | 100/100000 [04:42<82:05:51,  2.96s/it]

100 -21.0 -20.427130787534544


  0%|          | 201/100000 [09:31<82:09:34,  2.96s/it]

200 -20.0 -20.250640724142155


  0%|          | 301/100000 [13:55<68:14:15,  2.46s/it]

300 -21.0 -20.2555824350713


  0%|          | 401/100000 [17:59<71:42:18,  2.59s/it]

400 -19.0 -20.2938765476465


  1%|          | 501/100000 [22:18<71:30:41,  2.59s/it]

500 -20.0 -20.31941319740614


  1%|          | 601/100000 [26:42<77:09:47,  2.79s/it]

600 -20.0 -20.323572110162306


  1%|          | 701/100000 [31:08<89:22:45,  3.24s/it]

700 -15.0 -20.203858915167345


  1%|          | 801/100000 [35:33<76:42:18,  2.78s/it]

800 -19.0 -20.190157013239343


  1%|          | 901/100000 [39:52<72:55:04,  2.65s/it]

900 -21.0 -20.3404920567844


  1%|          | 1001/100000 [44:05<70:14:53,  2.55s/it]

1000 -19.0 -20.397655294953395


  1%|          | 1101/100000 [48:12<71:44:30,  2.61s/it]

1100 -19.0 -20.43738764513271


  1%|          | 1201/100000 [52:31<79:57:14,  2.91s/it]

1200 -20.0 -20.330208908683755


  1%|▏         | 1301/100000 [56:50<72:05:48,  2.63s/it]

1300 -20.0 -20.27141169764255


  1%|▏         | 1401/100000 [1:00:44<61:22:08,  2.24s/it]

1400 -21.0 -20.261107125547408


  2%|▏         | 1501/100000 [1:04:32<61:42:29,  2.26s/it]

1500 -20.0 -20.298721724286853


  2%|▏         | 1601/100000 [1:08:31<71:06:18,  2.60s/it]

1600 -20.0 -20.371616895976935


  2%|▏         | 1701/100000 [1:12:56<69:21:35,  2.54s/it]

1700 -21.0 -20.27028576962773


  2%|▏         | 1801/100000 [1:17:25<72:10:35,  2.65s/it]

1800 -21.0 -20.173252579710308


  2%|▏         | 1901/100000 [1:21:54<68:39:35,  2.52s/it]

1900 -21.0 -20.07046338029367


  2%|▏         | 2001/100000 [1:26:19<79:39:38,  2.93s/it]

2000 -20.0 -20.126293713585838


  2%|▏         | 2101/100000 [1:30:43<78:34:39,  2.89s/it]

2100 -19.0 -20.1457436053611


  2%|▏         | 2201/100000 [1:35:12<68:39:34,  2.53s/it]

2200 -21.0 -20.235237272274773


  2%|▏         | 2301/100000 [1:39:43<76:32:02,  2.82s/it]

2300 -21.0 -20.169645468822594


  2%|▏         | 2401/100000 [1:44:11<73:50:46,  2.72s/it]

2400 -19.0 -20.216455877372333


  3%|▎         | 2501/100000 [1:48:26<68:25:15,  2.53s/it]

2500 -21.0 -20.08777390603611


  3%|▎         | 2601/100000 [1:52:32<66:58:07,  2.48s/it]

2600 -20.0 -20.037333626962823


  3%|▎         | 2701/100000 [1:56:46<71:10:16,  2.63s/it]

2700 -21.0 -20.13561945899216


  3%|▎         | 2801/100000 [2:01:19<73:55:00,  2.74s/it]

2800 -18.0 -20.137195545450094


  3%|▎         | 2901/100000 [2:05:47<68:51:54,  2.55s/it]

2900 -21.0 -20.15252547467973


  3%|▎         | 3001/100000 [2:10:17<76:01:09,  2.82s/it]

3000 -20.0 -20.122476719157717


  3%|▎         | 3101/100000 [2:14:59<72:09:22,  2.68s/it]

3100 -21.0 -20.10307634770331


  3%|▎         | 3201/100000 [2:19:31<75:38:59,  2.81s/it]

3200 -20.0 -20.106143102065637


  3%|▎         | 3301/100000 [2:24:05<72:48:12,  2.71s/it]

3300 -21.0 -20.04106875842242


  3%|▎         | 3401/100000 [2:28:36<73:01:41,  2.72s/it]

3400 -21.0 -20.190307456220587


  4%|▎         | 3501/100000 [2:33:19<83:29:39,  3.11s/it]

3500 -18.0 -20.001218770577864


  4%|▎         | 3601/100000 [2:38:01<75:49:21,  2.83s/it]

3600 -20.0 -20.041912587371193


  4%|▎         | 3701/100000 [2:42:42<71:15:38,  2.66s/it]

3700 -21.0 -20.067494474149985


  4%|▍         | 3801/100000 [2:47:21<77:48:04,  2.91s/it]

3800 -21.0 -19.99948933827048


  4%|▍         | 3901/100000 [2:52:04<70:12:20,  2.63s/it]

3900 -21.0 -19.991970643329832


  4%|▍         | 4001/100000 [2:56:49<79:02:16,  2.96s/it]

4000 -21.0 -19.953386243820237


  4%|▍         | 4101/100000 [3:01:34<73:26:04,  2.76s/it]

4100 -21.0 -19.90491568350774


  4%|▍         | 4201/100000 [3:06:20<74:58:17,  2.82s/it]

4200 -21.0 -19.97793232595239


  4%|▍         | 4301/100000 [3:11:04<73:48:56,  2.78s/it]

4300 -21.0 -20.00509157743619


  4%|▍         | 4401/100000 [3:15:51<71:42:25,  2.70s/it]

4400 -20.0 -19.996584554806297


  5%|▍         | 4501/100000 [3:20:36<82:47:24,  3.12s/it]

4500 -20.0 -19.92475992025258


  5%|▍         | 4601/100000 [3:25:24<71:19:45,  2.69s/it]

4600 -20.0 -19.992262949408683


  5%|▍         | 4701/100000 [3:30:12<73:45:55,  2.79s/it]

4700 -21.0 -20.044009132274958


  5%|▍         | 4801/100000 [3:35:07<79:55:32,  3.02s/it]

4800 -19.0 -19.82822104873077


  5%|▍         | 4901/100000 [3:40:00<72:36:55,  2.75s/it]

4900 -20.0 -19.97351495951153


  5%|▌         | 5001/100000 [3:44:58<80:54:59,  3.07s/it]

5000 -21.0 -20.00030226223695


  5%|▌         | 5101/100000 [3:50:02<82:16:42,  3.12s/it]

5100 -19.0 -19.912473278911925


  5%|▌         | 5201/100000 [3:55:03<78:21:06,  2.98s/it]

5200 -19.0 -19.90872563473473


  5%|▌         | 5301/100000 [4:00:14<84:36:57,  3.22s/it]

5300 -18.0 -19.702710008044328


  5%|▌         | 5401/100000 [4:05:16<76:15:25,  2.90s/it]

5400 -21.0 -19.829844395269298


  6%|▌         | 5501/100000 [4:09:59<73:17:43,  2.79s/it]

5500 -20.0 -19.907795709013335


  6%|▌         | 5601/100000 [4:14:28<74:50:45,  2.85s/it]

5600 -20.0 -19.90188154821085


  6%|▌         | 5701/100000 [4:19:11<75:08:53,  2.87s/it]

5700 -20.0 -19.86083263909687


  6%|▌         | 5801/100000 [4:24:15<74:57:18,  2.86s/it]

5800 -21.0 -19.917739466992586


  6%|▌         | 5901/100000 [4:29:22<72:55:36,  2.79s/it]

5900 -21.0 -19.86206886139366


  6%|▌         | 6001/100000 [4:34:27<82:38:08,  3.16s/it]

6000 -20.0 -19.84452358362551


  6%|▌         | 6101/100000 [4:39:32<80:49:09,  3.10s/it]

6100 -19.0 -19.84552950230182


  6%|▌         | 6201/100000 [4:44:42<83:01:57,  3.19s/it]

6200 -18.0 -19.80994456220307


  6%|▋         | 6301/100000 [4:49:54<82:32:37,  3.17s/it]

6300 -20.0 -19.733579965531188


  6%|▋         | 6401/100000 [4:55:03<73:18:25,  2.82s/it]

6400 -21.0 -19.808783112634618


  7%|▋         | 6501/100000 [5:00:17<90:18:03,  3.48s/it]

6500 -16.0 -19.900186472710267


  7%|▋         | 6601/100000 [5:05:32<79:57:29,  3.08s/it]

6600 -20.0 -19.70636305350933


  7%|▋         | 6701/100000 [5:10:49<77:44:54,  3.00s/it]

6700 -19.0 -19.772596614528503


  7%|▋         | 6801/100000 [5:15:40<67:21:02,  2.60s/it]

6800 -21.0 -19.8871867716992


  7%|▋         | 6901/100000 [5:20:40<90:30:11,  3.50s/it]

6900 -18.0 -19.75711093296784


  7%|▋         | 7001/100000 [5:25:46<62:56:56,  2.44s/it]

7000 -21.0 -19.741964861305267


  7%|▋         | 7101/100000 [5:31:00<75:32:03,  2.93s/it]

7100 -20.0 -19.76008259828115


  7%|▋         | 7201/100000 [5:36:20<82:43:21,  3.21s/it]

7200 -20.0 -19.838456691591087


  7%|▋         | 7301/100000 [5:41:40<84:41:30,  3.29s/it]

7300 -18.0 -19.7827008447155


  7%|▋         | 7401/100000 [5:47:02<81:31:39,  3.17s/it]

7400 -21.0 -19.707728800437884


  8%|▊         | 7501/100000 [5:52:33<91:45:43,  3.57s/it]

7500 -20.0 -19.59637840153245


  8%|▊         | 7601/100000 [5:58:08<86:33:18,  3.37s/it]

7600 -21.0 -19.653161390242115


  8%|▊         | 7701/100000 [6:03:34<90:58:30,  3.55s/it]

7700 -21.0 -19.712895607410708


  8%|▊         | 7801/100000 [6:09:12<92:18:33,  3.60s/it]

7800 -18.0 -19.59886315351107


  8%|▊         | 7901/100000 [6:14:36<80:18:41,  3.14s/it]

7900 -20.0 -19.669952226336722


  8%|▊         | 8001/100000 [6:20:21<91:39:43,  3.59s/it]

8000 -20.0 -19.57095549510572


  8%|▊         | 8101/100000 [6:25:56<88:34:26,  3.47s/it]

8100 -19.0 -19.774856860747956


  8%|▊         | 8201/100000 [6:31:38<89:26:30,  3.51s/it]

8200 -20.0 -19.688593075009333


  8%|▊         | 8301/100000 [6:37:24<78:48:08,  3.09s/it]

8300 -21.0 -19.6543674959591


  8%|▊         | 8401/100000 [6:43:10<83:23:49,  3.28s/it]

8400 -20.0 -19.704424492732358


  9%|▊         | 8501/100000 [6:49:01<85:05:04,  3.35s/it]

8500 -20.0 -19.69906453364814


  9%|▊         | 8601/100000 [6:54:53<98:18:45,  3.87s/it] 

8600 -20.0 -19.6203520275043


  9%|▊         | 8701/100000 [7:00:41<83:31:00,  3.29s/it]

8700 -16.0 -19.51616031312804


  9%|▉         | 8801/100000 [7:06:29<89:17:10,  3.52s/it]

8800 -21.0 -19.61926991899348


  9%|▉         | 8901/100000 [7:12:20<93:59:57,  3.71s/it]

8900 -20.0 -19.556741071580337


  9%|▉         | 9001/100000 [7:18:26<87:03:18,  3.44s/it]

9000 -20.0 -19.417986176418317


  9%|▉         | 9101/100000 [7:24:18<99:13:38,  3.93s/it]

9100 -18.0 -19.554079109708244


  9%|▉         | 9201/100000 [7:30:18<91:53:37,  3.64s/it]

9200 -20.0 -19.599171484108613


  9%|▉         | 9301/100000 [7:36:22<87:30:05,  3.47s/it]

9300 -20.0 -19.473839776443835


  9%|▉         | 9401/100000 [7:42:38<89:47:41,  3.57s/it]

9400 -18.0 -19.485781375172188


 10%|▉         | 9501/100000 [7:48:57<96:34:47,  3.84s/it]

9500 -19.0 -19.495210006807206


 10%|▉         | 9601/100000 [7:55:07<82:07:06,  3.27s/it]

9600 -19.0 -19.602041890870908


 10%|▉         | 9701/100000 [8:01:33<98:15:41,  3.92s/it] 

9700 -20.0 -19.53076997846288


 10%|▉         | 9801/100000 [8:07:32<84:51:07,  3.39s/it]

9800 -20.0 -19.468837733403177


 10%|▉         | 9901/100000 [8:13:09<81:12:41,  3.24s/it]

9900 -20.0 -19.54352291288031


 10%|█         | 10001/100000 [8:19:06<93:22:27,  3.74s/it]

10000 -19.0 -19.564347620879378


 10%|█         | 10101/100000 [8:25:30<95:50:39,  3.84s/it]

10100 -21.0 -19.378343339358562


 10%|█         | 10201/100000 [8:31:50<96:10:01,  3.86s/it]

10200 -18.0 -19.461453783758188


 10%|█         | 10301/100000 [8:38:18<100:40:17,  4.04s/it]

10300 -17.0 -19.3797803417963


 10%|█         | 10401/100000 [8:44:42<101:57:03,  4.10s/it]

10400 -18.0 -19.480229946037976


 11%|█         | 10501/100000 [8:51:13<96:36:52,  3.89s/it]

10500 -18.0 -19.276077438878538


 11%|█         | 10601/100000 [8:57:54<96:01:14,  3.87s/it]

10600 -17.0 -19.11386383513537


 11%|█         | 10701/100000 [9:04:33<93:59:04,  3.79s/it]

10700 -20.0 -19.28456853630293


 11%|█         | 10801/100000 [9:11:15<96:24:51,  3.89s/it]

10800 -18.0 -19.083584037026245


 11%|█         | 10901/100000 [9:18:07<108:45:56,  4.39s/it]

10900 -18.0 -19.09363636913442


 11%|█         | 11001/100000 [9:24:48<96:08:46,  3.89s/it]

11000 -21.0 -19.303339845435087


 11%|█         | 11101/100000 [9:31:29<99:24:07,  4.03s/it] 

11100 -20.0 -19.34516371336201


 11%|█         | 11201/100000 [9:38:15<110:38:22,  4.49s/it]

11200 -18.0 -19.24300462376038


 11%|█▏        | 11301/100000 [9:45:00<120:30:42,  4.89s/it]

11300 -17.0 -19.29612400270262


 11%|█▏        | 11401/100000 [9:51:56<96:16:24,  3.91s/it]

11400 -19.0 -19.17405672444303


 12%|█▏        | 11501/100000 [9:58:54<109:35:22,  4.46s/it]

11500 -18.0 -19.217946470665357


 12%|█▏        | 11601/100000 [10:05:51<91:59:14,  3.75s/it]

11600 -20.0 -19.162353621640715


 12%|█▏        | 11701/100000 [10:12:41<97:07:46,  3.96s/it]

11700 -18.0 -19.297187984252204


 12%|█▏        | 11801/100000 [10:19:43<105:45:09,  4.32s/it]

11800 -19.0 -18.952412204120705


 12%|█▏        | 11901/100000 [10:27:13<116:03:48,  4.74s/it]

11900 -17.0 -18.961640556610273


 12%|█▏        | 12001/100000 [10:34:41<116:50:23,  4.78s/it]

12000 -18.0 -18.88424501905022


 12%|█▏        | 12101/100000 [10:41:47<96:03:08,  3.93s/it] 

12100 -20.0 -19.015095008525826


 12%|█▏        | 12201/100000 [10:49:08<110:19:11,  4.52s/it]

12200 -16.0 -18.954355549051684


 12%|█▏        | 12301/100000 [10:56:16<104:23:53,  4.29s/it]

12300 -21.0 -19.07705080088033


 12%|█▏        | 12401/100000 [11:03:43<113:51:18,  4.68s/it]

12400 -18.0 -19.00770316609197


 13%|█▎        | 12501/100000 [11:11:19<112:03:20,  4.61s/it]

12500 -19.0 -19.108406797046094


 13%|█▎        | 12601/100000 [11:18:51<107:30:18,  4.43s/it]

12600 -18.0 -18.986223397075317


 13%|█▎        | 12701/100000 [11:26:16<100:59:27,  4.16s/it]

12700 -18.0 -18.971417352129425


 13%|█▎        | 12801/100000 [11:33:53<118:27:32,  4.89s/it]

12800 -19.0 -18.95009721735541


 13%|█▎        | 12901/100000 [11:41:33<112:10:53,  4.64s/it]

12900 -16.0 -18.846221220150444


 13%|█▎        | 13001/100000 [11:49:09<118:30:00,  4.90s/it]

13000 -19.0 -18.83733176616729


 13%|█▎        | 13101/100000 [11:56:45<111:10:06,  4.61s/it]

13100 -21.0 -18.789849365593938


 13%|█▎        | 13201/100000 [12:04:37<112:01:17,  4.65s/it]

13200 -19.0 -18.67963490956379


 13%|█▎        | 13301/100000 [12:12:36<121:48:13,  5.06s/it]

13300 -18.0 -18.658215365832167


 13%|█▎        | 13401/100000 [12:20:26<120:55:49,  5.03s/it]

13400 -19.0 -18.62763970705473


 14%|█▎        | 13501/100000 [12:28:25<108:19:48,  4.51s/it]

13500 -19.0 -18.682398026268135


 14%|█▎        | 13601/100000 [12:36:35<117:16:57,  4.89s/it]

13600 -21.0 -18.701567850395378


 14%|█▎        | 13701/100000 [12:44:27<115:56:11,  4.84s/it]

13700 -15.0 -18.618866085265868


 14%|█▍        | 13801/100000 [12:52:45<118:29:45,  4.95s/it]

13800 -19.0 -18.534094573147886


 14%|█▍        | 13901/100000 [13:00:50<109:27:00,  4.58s/it]

13900 -20.0 -18.72356416758137


 14%|█▍        | 14001/100000 [13:09:04<120:44:24,  5.05s/it]

14000 -15.0 -18.491482503674764


 14%|█▍        | 14101/100000 [13:17:13<116:18:28,  4.87s/it]

14100 -18.0 -18.556630481858402


 14%|█▍        | 14201/100000 [13:25:28<119:31:19,  5.01s/it]

14200 -15.0 -18.457019463489434


 14%|█▍        | 14301/100000 [13:33:55<119:10:32,  5.01s/it]

14300 -20.0 -18.407304788296674


 14%|█▍        | 14401/100000 [13:42:08<120:34:34,  5.07s/it]

14400 -17.0 -18.471365883552647


 15%|█▍        | 14501/100000 [13:50:35<123:38:16,  5.21s/it]

14500 -19.0 -18.50714699228112


 15%|█▍        | 14601/100000 [13:58:53<119:26:57,  5.04s/it]

14600 -18.0 -18.65223816470362


 15%|█▍        | 14701/100000 [14:07:07<105:49:09,  4.47s/it]

14700 -19.0 -18.53083624762402


 15%|█▍        | 14801/100000 [14:15:19<108:25:59,  4.58s/it]

14800 -20.0 -18.68864039426268


 15%|█▍        | 14901/100000 [14:23:50<122:14:58,  5.17s/it]

14900 -17.0 -18.22694875260322


 15%|█▌        | 15001/100000 [14:32:30<120:51:10,  5.12s/it]

15000 -20.0 -18.259607722491463


 15%|█▌        | 15101/100000 [14:41:11<119:46:49,  5.08s/it]

15100 -21.0 -18.17333940450969


 15%|█▌        | 15201/100000 [14:49:48<120:30:24,  5.12s/it]

15200 -16.0 -18.172880397190383


 15%|█▌        | 15301/100000 [14:58:26<122:15:41,  5.20s/it]

15300 -20.0 -18.381160465520633


 15%|█▌        | 15401/100000 [15:06:55<115:20:08,  4.91s/it]

15400 -21.0 -18.37279512915695


 16%|█▌        | 15501/100000 [15:15:28<107:49:04,  4.59s/it]

15500 -18.0 -18.469384568787575


 16%|█▌        | 15601/100000 [15:24:06<128:27:46,  5.48s/it]

15600 -19.0 -18.255604332130996


 16%|█▌        | 15701/100000 [15:32:54<120:54:24,  5.16s/it]

15700 -21.0 -18.353974773488037


 16%|█▌        | 15801/100000 [15:41:19<133:54:39,  5.73s/it]

15800 -16.0 -18.309910984659957


 16%|█▌        | 15901/100000 [15:50:08<126:21:41,  5.41s/it]

15900 -20.0 -18.317576223085943


 16%|█▌        | 16001/100000 [15:59:10<129:27:08,  5.55s/it]

16000 -19.0 -18.11409909458799


 16%|█▌        | 16101/100000 [16:08:25<130:23:52,  5.60s/it]

16100 -16.0 -17.902590026217172


 16%|█▌        | 16201/100000 [16:17:12<129:24:49,  5.56s/it]

16200 -20.0 -18.25310238256096


 16%|█▋        | 16301/100000 [16:26:18<123:23:00,  5.31s/it]

16300 -16.0 -18.1035564763608


 16%|█▋        | 16401/100000 [16:35:00<128:36:11,  5.54s/it]

16400 -18.0 -18.032619290934786


 17%|█▋        | 16501/100000 [16:43:59<125:31:17,  5.41s/it]

16500 -19.0 -18.061530033428117


 17%|█▋        | 16601/100000 [16:53:16<143:10:39,  6.18s/it]

16600 -18.0 -18.07320956863183


 17%|█▋        | 16701/100000 [17:02:22<129:40:12,  5.60s/it]

16700 -18.0 -18.13429465197164


 17%|█▋        | 16801/100000 [17:11:48<137:02:20,  5.93s/it]

16800 -16.0 -17.950328846831916


 17%|█▋        | 16901/100000 [17:21:06<121:10:42,  5.25s/it]

16900 -16.0 -18.036834559648213


 17%|█▋        | 17001/100000 [17:30:39<130:29:50,  5.66s/it]

17000 -20.0 -17.98256441228991


 17%|█▋        | 17101/100000 [17:40:07<141:35:39,  6.15s/it]

17100 -20.0 -17.835640140848238


 17%|█▋        | 17201/100000 [17:49:33<127:43:42,  5.55s/it]

17200 -18.0 -17.736191133382867


 17%|█▋        | 17301/100000 [17:59:09<142:51:51,  6.22s/it]

17300 -17.0 -17.81957392263824


 17%|█▋        | 17401/100000 [18:08:46<134:13:01,  5.85s/it]

17400 -19.0 -17.786226450091725


 18%|█▊        | 17501/100000 [18:18:24<120:47:29,  5.27s/it]

17500 -19.0 -17.905132003144384


 18%|█▊        | 17601/100000 [18:27:48<129:08:52,  5.64s/it]

17600 -19.0 -17.82279653411278


 18%|█▊        | 17701/100000 [18:37:15<131:18:46,  5.74s/it]

17700 -16.0 -17.81919869002885


 18%|█▊        | 17801/100000 [18:47:00<121:42:07,  5.33s/it]

17800 -17.0 -17.83437524270996


 18%|█▊        | 17901/100000 [18:56:50<133:15:32,  5.84s/it]

17900 -13.0 -17.954957704530727


 18%|█▊        | 18001/100000 [19:06:30<112:48:14,  4.95s/it]

18000 -16.0 -17.95254771815579


 18%|█▊        | 18101/100000 [19:16:11<154:16:24,  6.78s/it]

18100 -19.0 -18.03305470574161


 18%|█▊        | 18201/100000 [19:26:02<132:37:08,  5.84s/it]

18200 -15.0 -17.89126419990739


 18%|█▊        | 18301/100000 [19:35:31<123:18:27,  5.43s/it]

18300 -18.0 -17.714606670396975


 18%|█▊        | 18401/100000 [19:44:46<119:05:12,  5.25s/it]

18400 -18.0 -17.592888394741255


 19%|█▊        | 18501/100000 [19:54:07<132:50:04,  5.87s/it]

18500 -20.0 -17.5918300453632


 19%|█▊        | 18601/100000 [20:03:51<130:58:44,  5.79s/it]

18600 -19.0 -17.591067143433943


 19%|█▊        | 18701/100000 [20:14:01<151:49:13,  6.72s/it]

18700 -14.0 -17.345188488929505


 19%|█▉        | 18801/100000 [20:24:19<154:58:21,  6.87s/it]

18800 -14.0 -17.412021250107486


 19%|█▉        | 18901/100000 [20:34:55<140:34:09,  6.24s/it]

18900 -20.0 -17.278827936048526


 19%|█▉        | 19001/100000 [20:45:12<148:15:20,  6.59s/it]

19000 -15.0 -17.30261796752017


 19%|█▉        | 19101/100000 [20:56:08<167:37:33,  7.46s/it]

19100 -17.0 -17.352504879344185


 19%|█▉        | 19201/100000 [21:06:22<113:41:56,  5.07s/it]

19200 -21.0 -17.493304989587354


 19%|█▉        | 19301/100000 [21:16:22<122:05:26,  5.45s/it]

19300 -19.0 -17.530946778309488


 19%|█▉        | 19401/100000 [21:27:02<137:13:45,  6.13s/it]

19400 -18.0 -17.378557895954884


 20%|█▉        | 19501/100000 [21:37:50<140:15:47,  6.27s/it]

19500 -15.0 -17.34303230274575


 20%|█▉        | 19601/100000 [21:48:32<157:56:45,  7.07s/it]

19600 -15.0 -17.476562748894082


 20%|█▉        | 19701/100000 [21:59:28<141:17:43,  6.33s/it]

19700 -20.0 -17.197450510283968


 20%|█▉        | 19801/100000 [22:09:58<131:22:32,  5.90s/it]

19800 -21.0 -17.392861967236875


 20%|█▉        | 19901/100000 [22:20:22<121:41:51,  5.47s/it]

19900 -17.0 -17.436715158985134


 20%|██        | 20001/100000 [22:31:20<122:56:58,  5.53s/it]

20000 -19.0 -17.240440539383112


 20%|██        | 20101/100000 [22:42:22<143:42:25,  6.47s/it]

20100 -12.0 -16.999565117951178


 20%|██        | 20201/100000 [22:53:47<157:19:46,  7.10s/it]

20200 -18.0 -16.820289112811434


 20%|██        | 20301/100000 [23:04:50<141:34:48,  6.40s/it]

20300 -17.0 -17.02229659412314


 20%|██        | 20401/100000 [23:16:05<152:14:15,  6.89s/it]

20400 -17.0 -16.988957441990085


 21%|██        | 20501/100000 [23:27:14<174:56:27,  7.92s/it]

20500 -16.0 -16.705847243729043


 21%|██        | 20601/100000 [23:38:09<141:31:59,  6.42s/it]

20600 -15.0 -16.97086473214362


 21%|██        | 20701/100000 [23:49:24<136:11:32,  6.18s/it]

20700 -18.0 -16.828417257646528


 21%|██        | 20801/100000 [23:59:59<154:09:50,  7.01s/it]

20800 -11.0 -17.154996387089916


 21%|██        | 20901/100000 [24:11:33<166:21:08,  7.57s/it]

20900 -17.0 -16.829318899440953


 21%|██        | 21001/100000 [24:23:11<154:28:24,  7.04s/it]

21000 -14.0 -16.97272660565705


 21%|██        | 21100/100000 [24:34:24<151:07:07,  6.90s/it]

21100 -11.0 -16.76337492814893


OSError: [Errno 28] No space left on device

In [None]:
torch.save(net.state_dict(), weight_join("final.pth"))

In [None]:
plt.plot(trainer.total_rewards)

In [None]:
writer_path