In [1]:
import gym
import matplotlib.pyplot as plt
import torch
from wrappers import wrap_deepmind

In [2]:
from collections import namedtuple
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward','ended'))

class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)
    
memory = ReplayMemory(5000)
# plt.imshow(memory.memory[0].state[:3,].permute(1,2,0))

In [3]:
import random
def random_agent(state, th = None):
    return random.randint(a=0,b=env.action_space.n-1)

In [4]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

def n2t(vec):
    return torch.from_numpy(vec).to(device)
    
def t2n(tensor):
    return tensor.cpu().numpy()

In [5]:
def play_game(env = wrap_deepmind(gym.make("Pong-v0"), frame_stack = True), agent = None, th = 0, skipframe=4, maxstep = 5000, render = False):
    cum_reward = 0.0
    state = env.reset()
    

    for i in range(maxstep):
        # take action:
        action = agent(state, th = th)
        for _ in range(skipframe):
            next_state, reward, ended, info = env.step(action)
        
        
        cum_reward += float(reward)
        
        # push to replay buffer:
        memory.push(state, action, next_state, reward, ended)
        state = next_state
        
        
        if render:
            env.render()
        if ended == 1:
            return cum_reward, i
    return cum_reward, i

In [11]:
env = wrap_deepmind(gym.make("Breakout-v0"), frame_stack = True)
play_game(env, agent = random_agent, skipframe=4, render = False)

(0.0, 45)

## Train model

In [30]:
from torch import nn
import torch.nn.functional as F
from importlib import reload 
import model
from torch import optim
import numpy as np
reload(model)
    
model = model.DQN().to(device)
optimizer = optim.Adam(model.parameters(), lr = 0.001) # , weight_decay = 0.001

In [31]:
param = {'batch_size' : 8,
        'GAMMA' : 0.9}

In [78]:
def default_states_preprocessor(states):
    """
    Convert list of states into the form suitable for model. By default we assume Variable
    :param states: list of numpy arrays with states
    :return: Variable
    
    Obtained from https://github.com/Shmuma/ptan/blob/master/ptan/agent.py
    """
    if len(states) == 1:
        np_states = np.expand_dims(states[0], 0)
    else:
        np_states = np.array([np.array(s, copy=False) for s in states], copy=False)
    return torch.tensor(np_states).permute(0,3,1,2).float().to(device)


def train_batch(param):
    if len(memory) < param['batch_size']:
        return 0
    batch = memory.sample(param['batch_size'])
    batch_states = default_states_preprocessor([m.state for m in batch])
    batch_next_states = default_states_preprocessor([m.next_state for m in batch])
    batch_ended = torch.tensor([m.ended for m in batch])
    batch_rewards = torch.tensor([m.reward for m in batch])
    batch_actions = torch.tensor([m.action for m in batch])

    ## Calculate expected reward:
    GAMMA = 0.99
    with torch.set_grad_enabled(False):
        not_ended_batch = 1 -torch.ByteTensor(batch_ended)
        next_states_non_final = batch_next_states[not_ended_batch]
        next_state_values = torch.zeros(param['batch_size']).to(device)
        reward_hat = model(next_states_non_final)
        next_state_values[not_ended_batch] = reward_hat.max(1)[0]
        expected_state_action_values = next_state_values*param['GAMMA'] + batch_rewards

    # Predict value function:
    yhat = model(batch_states)
    state_action_values = yhat.gather(1, batch_actions.unsqueeze(1)).squeeze()

    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values)
    optimizer.zero_grad()
    loss.backward()
    for param in model.parameters():
        param.data.clamp_(-1, 1)
    optimizer.step()
    return loss.data

In [77]:
while True:
    play_game(env, agent = random_agent, skipframe=4, render = False)
    train_batch(param)

KeyboardInterrupt: 