In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data import sampler
from torch.optim.lr_scheduler import StepLR

import torchvision
import torchvision.datasets as dset
import torchvision.transforms as T

import os
import numpy as np
import timeit
import time
import platform
import random
import pickle as pickle
import matplotlib.pyplot as plt

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

print("Python version: ", platform.python_version())

Python version:  3.6.4


## Actor-Critic Model with LSTM

In [35]:
class Policy(torch.nn.Module):

    def __init__(self, input_channels, num_actions):
        super(Policy, self).__init__()
        self.temperature = 1.0
        self.input_channels = input_channels
        self.num_actions = num_actions
        self.features = self._init_features()
        self.lstm = self._init_lstm()
        self.action_head = self._init_action_head()
        self.value_head = self._init_value_head()
        
        self.saved_actions = []
        self.rewards = []

    def _init_features(self):
        layers = []
        # 80 x 80 x in_channels initial dimensions 3D array
        layers.append(torch.nn.Conv2d(self.input_channels,
                                      16, kernel_size=8, stride=4, padding=2))
        layers.append(torch.nn.BatchNorm2d(16))
        layers.append(torch.nn.ReLU(inplace=True))
        # 20 x 20 x 16 feature maps
        layers.append(torch.nn.Conv2d(16,
                                      32, kernel_size=4, stride=2, padding=1))
        layers.append(torch.nn.BatchNorm2d(32))
        layers.append(torch.nn.ReLU(inplace=True))
        # 10 x 10 x 32 feature maps
        layers.append(torch.nn.Conv2d(32,
                                      32, kernel_size=4, stride=2, padding=0))
        layers.append(torch.nn.BatchNorm2d(32))
        layers.append(torch.nn.ReLU(inplace=True))
        # 4 x 4 x 32 feature maps
        return torch.nn.Sequential(*layers)
    
    def _init_lstm(self):
        return torch.nn.LSTMCell(32*4*4, 256)

    def _init_action_head(self):
        return torch.nn.Linear(256, self.num_actions)

    def _init_value_head(self):
        return torch.nn.Linear(256, 1)

    def forward(self, inputs):
        x, (hx, cx) = inputs
        x = self.features(x)
        x = x.view(x.size(0), -1)  # 1 x 512(4x4x32)
        
        hx, cx = self.lstm(x, (hx, cx))
        x = hx
        
        action = torch.nn.functional.softmax(self.action_head(x) /
                                             self.temperature, dim=-1)
        value = self.value_head(x)
        return action, value, (hx, cx)


In [36]:
model = Policy(4,6)
print (model)
    
x = torch.randn(1,4,80,80)
cx = Variable(torch.zeros(1, 256))
hx = Variable(torch.zeros(1, 256))

action, value, (hx, cx) = model((Variable(x), (hx, cx)))        # Feed it through the model! 

print (ans.shape)
print (value.shape)
print (hx.shape)
print (cx.shape)


Policy(
  (features): Sequential(
    (0): Conv2d (4, 16, kernel_size=(8, 8), stride=(4, 4), padding=(2, 2))
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU(inplace)
    (3): Conv2d (16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True)
    (5): ReLU(inplace)
    (6): Conv2d (32, 32, kernel_size=(4, 4), stride=(2, 2))
    (7): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True)
    (8): ReLU(inplace)
  )
  (lstm): LSTMCell(512, 256)
  (action_head): Linear(in_features=256, out_features=6)
  (value_head): Linear(in_features=256, out_features=1)
)
torch.Size([1, 512])
torch.Size([1, 1])
torch.Size([1, 256])
torch.Size([1, 256])


In [14]:
from utils import preprocess_state
import gym

env = gym.make('Pong-v0')
num_frames = 4

state = env.reset()
state = preprocess_state(state)
state = np.stack([state]*num_frames)

num_frames, height, width = state.shape
state = torch.FloatTensor(state.reshape(-1, num_frames, height, width))
print (state.shape)
env.close()

torch.Size([1, 4, 80, 80])


torch.Size([1, 4, 80, 80])
