## Watch the Trained Agent

### 1. Start Environment 

In [1]:
import sys
sys.path.append('D:\Gym\gym')
import gym

%matplotlib inline

import pong_utils

# PongDeterministic does not contain random frameskip
# so is faster to train than the vanilla Pong-v4 environment
env = gym.make('PongDeterministic-v4')
print("List of available actions: ", env.unwrapped.get_action_meanings())

List of available actions:  ['NOOP', 'FIRE', 'RIGHT', 'LEFT', 'RIGHTFIRE', 'LEFTFIRE']


### 2. Start Model

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 

class Policy(nn.Module):

    def __init__(self):
        super(Policy, self).__init__()
        # 80x80x2 to 38x38x4
        # 2 channel from the stacked frame
        # new_size = (size - kernel_size)/stride + 1, i.e. (80 - 6)/2 + 1 = 38
        self.conv1 = nn.Conv2d(2, 4, kernel_size=6, stride=2, bias=False)
        # 38x38x4 to 9x9x32
        # new_size = (size - kernel_size)/stride + 1, i.e. (38 - 6)/4 + 1 = 9
        self.conv2 = nn.Conv2d(4, 16, kernel_size=6, stride=4)
        self.size=9*9*16
        
        self.fc1 = nn.Linear(self.size, 512)
        self.fc2 = nn.Linear(512, 1)
        self.sig = nn.Sigmoid()
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))

        x = x.view(-1,self.size)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return self.sig(x)

model = Policy().to(device)

### 3. Load Weights and Play, score = 1.77

Try and test out the solution.    
(the REINFORCE version can win more often than not)!

In [3]:
model.load_state_dict(torch.load('REINFORCE_5A_2300ep_score_1-77.policy'))
#import pong_utils
pong_utils.play(env, model, time=2000) 

### 4. Load Weights and Play, score = 2.46
Try and test out the solution.
(the REINFORCE version can win more often than not)!

In [4]:
model.load_state_dict(torch.load('REINFORCE_10_2000ep_score_2-46.policy'))
pong_utils.play(env, model, time=2000) 

### 5. Load Weights and Play, score = 4.82

In [5]:
model.load_state_dict(torch.load('REINFORCE_11_2300ep_score_4-82.policy'))
pong_utils.play(env, model, time=2000) 