In [2]:
import gym
import numpy as np

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from scipy.ndimage.measurements import center_of_mass

import random
from collections import namedtuple, deque

torch.manual_seed(0)

<torch._C.Generator at 0x7ff13819be30>

In [3]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'reward', 'next_state', 'done'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque(maxlen=capacity)

    def __len__(self):
        return len(self.memory)

    def push(self, *args):
        self.memory.append(Transition(*args))

    # Returns list of Transitions
    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def count(self):
        return len(self.memory)
    
    
class DQN(nn.Module):
    '''DQN'''

    def __init__(self, in_features=4, n_actions=6, device='cpu'):
        self.device = device

        super(DQN, self).__init__()
        self.gru = nn.GRU(in_features, 32, 2)
        self.linear = nn.Linear(32, n_actions)

    def forward(self, x):
        x,h = self.gru(x.float())
        x = F.leaky_relu(x)
        return self.linear(x)

In [6]:
def prepare_state(s, prev_s=None):
    if prev_s is None:
        prev_s = np.zeros((5,1,4))

    # Get rid of useless rows and the green & blue colour chanels
    reduced_rows = s[34:194, :, 0]

    # Background is 0, paddles & ball is 1 (R value of backround 144)
    masked = (reduced_rows != 144).astype(int)

    # Center of our paddle (dqn paddle), opponent paddle, and ball y and x coordinates
    dqn_y, _ = center_of_mass(masked[:, 140:144])
    opp_y, _ = center_of_mass(masked[:, 16:20])
    ball_y, ball_x = center_of_mass(masked[:, 20:140])

    dqn_y = 80 if np.isnan(dqn_y) else dqn_y
    opp_y = 80 if np.isnan(opp_y) else opp_y
    # x position of ball is offset by 21 px to the center of img
    ball_x = 80 if np.isnan(ball_x) else ball_x + 21
    ball_y = 80 if np.isnan(ball_y) else ball_y

    # Scale the positions to [0, 10] and leave velocities in [0,4]
    # Hypothesis: Before, we were scaling an int in [0,160] to a float in [0,1]
    # Maybe this range is too small?
    state_vec = np.array([[[opp_y, dqn_y, ball_x, ball_y]]]) # / np.array([160, 160, 160, 160, 4, 4])
    state_vec = np.concatenate((state_vec, prev_s))[:5]

    return state_vec

In [4]:
# DQN takes input of shape (seq_len, batch, input_size)
dqn = DQN()
x = np.random.random((5,1,4))
y = dqn.forward(torch.from_numpy(x))
print(y)

tensor([[[ 0.0448, -0.1416, -0.1056, -0.1174, -0.1570,  0.1197]],

        [[ 0.0537, -0.1509, -0.1044, -0.1204, -0.1790,  0.1349]],

        [[ 0.0656, -0.1598, -0.1017, -0.1234, -0.2061,  0.1391]],

        [[ 0.0784, -0.1652, -0.1004, -0.1255, -0.2264,  0.1413]],

        [[ 0.0885, -0.1736, -0.0959, -0.1286, -0.2411,  0.1456]]],
       grad_fn=<AddBackward0>)


In [5]:
BATCH_SIZE = 64
GAMMA = 0.99
EPSILON_START = 0.3
EPSILON_FINAL = 0.02
EPSILON_DECAY = 1e6
TARGET_UPDATE = 100
lr = 1e-5
INITIAL_MEMORY = 10000
MEMORY_SIZE = 10 * INITIAL_MEMORY

policy_net = DQN()
target_net = DQN()
memory = ReplayMemory(MEMORY_SIZE)
optimizer = optim.Adam(policy_net.parameters(), lr=lr)



