In [1]:
import gym
import numpy as np

from tqdm import trange

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.distributions import Categorical

In [2]:
env = gym.make("Pong-v0")

In [3]:
downsample = 2
output_size = 160//downsample

def preprocess(frame):
    '''from karpathy.'''
    I = frame
    I = I[35:195] # crop
    I = I[::downsample,::downsample,0] # downsample by factor of 2
    I[I == 144] = 0 # erase background (background type 1)
    I[I == 109] = 0 # erase background (background type 2)
    I[I != 0] = 1 # everything else (paddles, ball) just set to 1
    tensor = torch.from_numpy(I).float()
    return tensor.unsqueeze(0).unsqueeze(0) #BCHW

# def clip_grads(net, low=-10, high=10):
#     """Gradient clipping to the range [low, high]."""
#     parameters = [param for param in net.parameters()
#                   if param.grad is not None]
#     for p in parameters:
#         p.grad.data.clamp_(low, high)
        
if torch.cuda.is_available():
    def to_var(x, requires_grad=False, gpu=None):
        x = x.cuda(gpu)
        return Variable(x, requires_grad=requires_grad)
else:
    def to_var(x, requires_grad=False, vgpu=None):
        return Variable(x, requires_grad=requires_grad)

In [4]:
class Net(nn.Module):
    '''very similar to Nature DQN.'''
    def __init__(self, action_n, input_shape=(1,80,80)):
        super().__init__()
        self.conv = nn.Sequential(nn.Conv2d(input_shape[0],16,kernel_size=8, stride=2),nn.ReLU(),
                                  nn.Conv2d(16,32,kernel_size=4, stride=2),nn.ReLU())
        flatten_size = self._get_flatten_size(input_shape)
        self.fc = nn.Linear(flatten_size, action_n)
    
    def _get_flatten_size(self, shape):
        x = Variable(torch.rand(1, *shape))
        output_feat = self.conv(x)
        n_size = output_feat.view(-1).size(0)
        return n_size
        
    def forward(self, x):
        feat = self.conv(x)
        logit = self.fc(feat.view(feat.size(0),-1))
        return logit

In [5]:
net = Net(env.action_space.n, input_shape=(1,output_size,output_size))

weights_path = "episode12100.pth"
if torch.cuda.is_available():
    net = net.cuda()
    weights = torch.load(weights_path)
else:
    weights = torch.load(weights_path, map_location={'cuda:0': 'cpu'})
net.load_state_dict(weights)

In [6]:
for episode in trange(10):
    frame = env.reset()
    last_obs = preprocess(frame)
    curr_obs = preprocess(frame)
    total_reward = 0
    for step in range(100000): # not exceed 10000 steps
        with torch.no_grad():
            prob = net(to_var(curr_obs-last_obs))
            _, action_ = prob.max(dim=1)
            action = action_.data[0]
        frame, reward, done, _ = env.step(action)
        env.render()
        last_obs = curr_obs
        curr_obs = preprocess(frame)
        total_reward+=reward
        if done:
             break
    if step==100000:
        print("not enough!!!!!!!!!!!!!!!")
    print(episode, total_reward)

 10%|█         | 1/10 [01:23<12:29, 83.27s/it]

0 -1.0


 20%|██        | 2/10 [02:28<10:21, 77.75s/it]

1 -10.0


 30%|███       | 3/10 [03:33<08:38, 74.12s/it]

2 -7.0


 40%|████      | 4/10 [04:53<07:34, 75.81s/it]

3 -4.0


 50%|█████     | 5/10 [06:00<06:06, 73.28s/it]

4 -12.0


 60%|██████    | 6/10 [07:15<04:54, 73.72s/it]

5 -6.0


 70%|███████   | 7/10 [08:15<03:28, 69.40s/it]

6 -16.0


 80%|████████  | 8/10 [09:23<02:18, 69.17s/it]

7 -11.0


 90%|█████████ | 9/10 [10:35<01:09, 69.98s/it]

8 -8.0


100%|██████████| 10/10 [11:46<00:00, 70.22s/it]

9 -10.0



