## Imports

In [6]:
%matplotlib inline

import time

import gym
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple
from itertools import count
from copy import deepcopy
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as T



## Initialization

In [7]:
use_gpu = torch.cuda.is_available()
print('Use GPU: {}'.format(use_gpu))

Use GPU: True


## Replay memory

In [8]:
from collections import deque
import random

class ReplayMemory():
    '''
    Replay memory to store states, actions, rewards, dones for batch sampling
    '''
    def __init__(self, capacity):
        '''
        :param capacity: replay memory capacity
        '''
        self.buffer = deque(maxlen=capacity)
        
    def add(self, state, action, reward, done, next_state):
        '''
        :param state: current state, atari_wrappers.LazyFrames object
        :param action: action
        :param reward: reward for the action
        :param done: "done" flag is True when the episode finished
        :param next_state: next state, atari_wrappers.LazyFrames object
        '''
        experience = (state, action, reward, done, next_state)
        self.buffer.append(experience)
        
    def sample(self, batch_size):
        '''
        Samples the data from the buffer of a desired size
        
        :param batch_size: sample batch size
        :return: batch of (states, actions, rewards, dones, next states).
                 all are numpy arrays. states and next states have shape of 
                 (batch_size, frames, width, height), where frames = 4.
                 actions, rewards and dones have shape of (batch_size,)
        '''
        if self.count() < batch_size:
            batch = random.sample(self.buffer, self.count())
        else:
            batch = random.sample(self.buffer, batch_size)
            
        state_batch = np.array([np.array(experience[0]) for experience in batch])
        action_batch = np.array([experience[1] for experience in batch])
        reward_batch = np.array([experience[2] for experience in batch])
        done_batch = np.array([experience[3] for experience in batch])
        next_state_batch = np.array([np.array(experience[4]) for experience in batch])
        
        return state_batch, action_batch, reward_batch, done_batch, next_state_batch
    
    def count(self):
        return len(self.buffer)


## Q-Network

In [9]:
class DQN(nn.Module):
    '''
    Deep Q-Network
    '''
    def __init__(self, num_actions):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(4, 32, kernel_size=8, stride=4, padding=0)
        # self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0)
        # self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0)
        # self.bn3 = nn.BatchNorm2d(64)
        
        self.fc1 = nn.Linear(7 * 7 * 64, 512)
        self.fc2 = nn.Linear(512, num_actions)
        
    def forward(self, inputs):
        '''
        Forward propogation
        
        :param inputs: images. expected sshape is (batch_size, frames, width, height)
        '''
        out = F.relu(self.conv1(inputs))
        out = F.relu(self.conv2(out))
        out = F.relu(self.conv3(out))
        
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = self.fc2(out)
        
        return out

## Agent

In [10]:
import os
from atari_wrappers import wrap_dqn
import datetime

class PongAgent:
    '''
    Pong agent. Implements training and testing methods
    '''
    def __init__(self):
        self.env = wrap_dqn(gym.make('PongDeterministic-v0'))
        self.num_actions = self.env.action_space.n
        
        self.dqn = DQN(self.num_actions)
        self.target_dqn = DQN(self.num_actions)
        
        if use_gpu:
            self.dqn.cuda()
            self.target_dqn.cuda()        
        
        self.buffer = ReplayMemory(1000000)
        
        self.gamma = 0.99
        
        self.mse_loss = nn.MSELoss()
        self.optim = optim.RMSprop(self.dqn.parameters(), lr=0.0001)
        
        self.out_dir = './model'
        
        if not os.path.exists(self.out_dir):
            os.makedirs(self.out_dir)

        
    def to_var(self, x):
        '''
        Converts x to Variable
        
        :param x: torch Tensor
        :return: torch Variable
        '''
        x_var = Variable(x)
        if use_gpu:
            x_var = x_var.cuda()
        return x_var

        
    def predict_q_values(self, states):
        '''
        Compute Q values bypassing states through estimation network
        
        :param states: states, numpy array, the shape is (batch_size, frames, width, height)
        :return: actions, Variable, the shape is (batch_size, num_actions)
        '''
        states = self.to_var(torch.from_numpy(states).float())
        actions = self.dqn(states)
        return actions

    

    
    def select_action(self, state, epsilon):
        '''
        Select action according to epsilon greedy policy. We will sometimes use 
        our model for choosing the action, and sometimes we will just sample one 
        uniformly.
        
        :param state: state, atari_wrappers.LazyFrames object - list of 4 frames,
                      each is a shape of (1, width, height)
        :param epsilon: epsilon for making choice between random and generated by dqn action
        
        :return: action index
        '''
        choice = np.random.choice([0, 1], p=(epsilon, (1 - epsilon)))
        
        if choice == 0:
            return np.random.choice(range(self.num_actions))
        else:
            state = np.expand_dims(state, 0)
            actions = self.predict_q_values(state)
            return np.argmax(actions.data.cpu().numpy())

    def sync_target_network(self):
        '''
        Copies weights from estimation to target network
        '''
        primary_params = list(self.dqn.parameters())
        target_params = list(self.target_dqn.parameters())
        for i in range(0, len(primary_params)):
            target_params[i].data[:] = primary_params[i].data[:]

    

        
        
    def load_model(self, filename):
        '''
        Loads model from the disk
        
        :param filename: model filename
        '''
        self.dqn.load_state_dict(torch.load(filename))
        self.sync_target_network()
        
        
    def play(self, episodes):
        '''
        Plays the game and renders it
        
        :param episodes: number of episodes to play
        '''
        for i in range(1, episodes + 1):
            done = False
            state = self.env.reset()
            while not done:
#                 action = np.random.choice(range(self.num_actions)) #playing randomly
                action = self.select_action(state, 0) # force to choose an action from the network
                state, reward, done, _ = self.env.step(action)
                self.env.render()
                
                
    def close_env(self):
        '''
        Closes the environment. Should be called to clean-up
        '''
        self.env.close()
        

In [11]:
agent = PongAgent()

  result = entry_point.load(False)


## Load the model and play

In [12]:
agent.load_model('./model_d_v0_ada/current_model_2000.pth')
agent.play(5)

In [13]:
agent.close_env()