# Deep Q-Network: Atari Games

In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from tqdm.notebook import tqdm
from itertools import count
import os

from PIL import Image
import matplotlib.pyplot as plt

In [10]:
import gym
from gym import Wrapper, ObservationWrapper, RewardWrapper
from collections import deque
import numpy as np
import cv2


class MaxFrame(ObservationWrapper):
    def __init__(self, env):
        super(MaxFrame, self).__init__(env)
        self.frames = deque(maxlen=2)

    def reset(self):
        observation = self.env.reset()
        if self.env.unwrapped.get_action_meanings()[1] == 'FIRE':
            obs, _, _, _ = self.env.step(1)
        for _ in range(2):
            self.frames.append(np.zeros(observation.shape))
        return observation

    def observation(self, observation):
        self.frames.append(observation)
        max_frames_values = np.maximum(self.frames[0], self.frames[1])
        return max_frames_values


# repeat action
# frame skipping
class RepeatAction(Wrapper):
    def __init__(self, env, repeat=4):
        super(RepeatAction, self).__init__(env)
        self.repeat = repeat

        # to trigger done when you lose a life
        self.ale = env.unwrapped.ale
        self.lives = 0

    def step(self, action):
        sum_reward = 0
        for _ in range(self.repeat):
            observation, reward, done, info = self.env.step(action)
            sum_reward += reward

            # if you lose a life trigger done
            new_lives = self.ale.lives()
            done = done or new_lives < self.lives
            self.lives = new_lives

            if done:
                break
        return observation, sum_reward, done, info

    def reset(self):
        observation = self.env.reset()
        self.lives = self.ale.lives()
        return observation


# remove y channel
# rescale to size between 0 and 1
# and rescale to 84x84
class PreprocessImage(ObservationWrapper):
    def __init__(self, env, shape):
        super(PreprocessImage, self).__init__(env)

        self.observation_space = gym.spaces.Box(
            0.0, 1.0, shape=shape, dtype=np.float32)

    def observation(self, observation):
        observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)
        observation = cv2.resize(
            observation, self.observation_space.shape[1:], interpolation=cv2.INTER_AREA)
        observation = observation.reshape(
            self.observation_space.shape).astype('float32')
        observation = observation / self.observation_space.high
        return observation


# stack n frames (4 was used)
class StackFrames(ObservationWrapper):
    def __init__(self, env, maxlen=4):
        super(StackFrames, self).__init__(env)
        self.maxlen = maxlen
        self.frames = deque(maxlen=maxlen)
        low = self.env.observation_space.low.repeat(maxlen, axis=0)
        high = self.env.observation_space.high.repeat(maxlen, axis=0)
        dtype = self.env.observation_space.dtype
        self.observation_space = gym.spaces.Box(
            low=low, high=high, dtype=dtype)

    def reset(self):
        observation = self.env.reset()
        for _ in range(self.maxlen):
            self.frames.append(observation)

        observation = np.vstack(self.frames)
        return observation

    def observation(self, observation):
        self.frames.append(observation)
        observation = np.vstack(self.frames)
        return observation


def apply_wrappers(env):
    env = MaxFrame(env)
    env = RepeatAction(env)
    env = PreprocessImage(env, shape=(1, 88, 88))
    env = StackFrames(env)
    return env


In [3]:
#replay buffer for memory replay
class ReplayBuffer():
    '''
    Stores memories up to a maximum of mem_size. The memories can be batched to use in training.
    '''
    
    def __init__(self, input_size, mem_size=10000, batch_size=64):
        self.mem_size = mem_size
        self.index = 0
        self.batch_size=batch_size
        
        self.obs_memory = np.empty((self.mem_size, *input_size), dtype=np.float32)
        self.action_memory = np.empty((self.mem_size), dtype=np.int64)
        self.reward_memory = np.empty((self.mem_size), dtype=np.float32)
        self.next_obs_memory = np.empty((self.mem_size, *input_size), dtype=np.float32)
        self.terminal_memory = np.empty((self.mem_size), dtype=np.bool)
        
    def add_memory(self, obs, action, reward, next_obs, done):
        self.obs_memory[self.index] = obs
        self.action_memory[self.index] = action
        self.reward_memory[self.index] = reward
        self.next_obs_memory[self.index] = next_obs
        self.terminal_memory[self.index] = done
        self.index += 1
        self.index %= self.mem_size

    def get_memory_batch(self):
        idxs = np.random.choice(len(self), self.batch_size, replace=False)
        
        obss = self.obs_memory[idxs]
        actions = self.action_memory[idxs]
        rewards = self.reward_memory[idxs]
        next_obss = self.next_obs_memory[idxs]
        dones = self.terminal_memory[idxs]
        
        return obss, actions, rewards, next_obss, dones
    
    def __len__(self):
        return min(self.index, self.mem_size)

In [4]:
#function approximator of the Q-Function
class CNN_QNN(nn.Module):
    '''
    A PyTorch based neural network approximator of the Q-function.
    The network can include a flexible number of input, hidden and output nodes.
    The model assumes images (e.g. atari screen) as input
    '''
    
    def __init__(self, input_size, action_size, lr, save_dir, name):
        super(CNN_QNN, self).__init__()
        self.input_size = input_size
        self.output_file = os.path.join(save_dir, name)

        self.conv1 = nn.Conv2d(input_size[0], 32, 8, stride=2)
        self.conv2 = nn.Conv2d(32, 64, 4, 2)
        self.conv3 = nn.Conv2d(64, 64, 3, 1)
        
        fc_dims = self.num_flat_features(input_size)
        
        self.fc1 = nn.Linear(fc_dims, 512)
        self.fc2 = nn.Linear(512, action_size)
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        
        self.loss = nn.MSELoss()
        self.optimizer = optim.RMSprop(self.parameters(), lr=lr)
        self.to(self.device)
    
    def num_flat_features(self, input_size):
        x = torch.zeros(1, *input_size)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        shape = x.shape[1:]
        num_features = 1
        for i in shape:
            num_features *= i
        return num_features
    
    
    def forward(self, state):
        x = state
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x

    def save(self):
        torch.save(self.state_dict(), self.output_file)
    
    def load(self):
        self.load_state_dict(torch.load(self.output_file))

In [5]:
class DQN:
    '''
    DQN Agent combining a memory buffer and separate online and target neural networks
    '''
    
    def __init__(self, input_size, action_size, min_epsilon, max_epsilon, epsilon_decay, 
                 gamma, lr, mem_size, batch_size, save_dir, name):
        self.replay_buffer = ReplayBuffer(input_size, mem_size, batch_size)
        self.qnn_target = CNN_QNN(input_size, action_size, lr, save_dir, name=name+'_target.pt')
        self.qnn_online = CNN_QNN(input_size, action_size, lr, save_dir, name=name+'_online.pt')
        self.replace_target_network()
        
        self.epsilon = max_epsilon
        self.min_epsilon = min_epsilon
        self.epsilon_decay = epsilon_decay
        self.gamma = gamma
        self.action_size = action_size
        self.batch_size = batch_size
    
    def epsilon_greedy(self, obs):
        if np.random.random() > self.epsilon:
            action = self.greedy(obs)
        else:
            action = np.random.choice(self.action_size)
        return action
    
    def greedy(self, obs):
        with torch.no_grad():
            obs = np.expand_dims(obs, axis=0)
            obs = torch.from_numpy(obs).to(self.qnn_online.device).float()
            action = np.argmax(self.qnn_online.forward(obs).detach().cpu().numpy())
            return action
    
    def decrement_epsilon(self):
        if self.epsilon <= self.min_epsilon:
            return
        
        epsilon = self.epsilon - self.epsilon_decay
        self.epsilon = max(epsilon, self.min_epsilon)
    
    def add_memory(self, obs, action, reward, next_obs, done):
        self.replay_buffer.add_memory(obs, action, reward, next_obs, done)
    
    def get_memory_batch(self):
        obss, actions, rewards, next_obss, dones = self.replay_buffer.get_memory_batch()
        device = self.qnn_online.device
        obss = torch.from_numpy(obss).to(device)
        actions = torch.from_numpy(actions).to(device)
        rewards = torch.from_numpy(rewards).to(device)
        next_obss = torch.from_numpy(next_obss).to(device)
        dones = torch.from_numpy(dones).to(device)
        return obss, actions, rewards, next_obss, dones
        
    def learn(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        
        self.qnn_online.optimizer.zero_grad()
        obss, actions, rewards, next_obss, dones = self.get_memory_batch()
        with torch.no_grad():
            target = rewards + self.gamma * torch.max(self.qnn_target.forward(next_obss).detach(), dim=1)[0] \
                * torch.logical_not(dones)
        target = target.unsqueeze(1)
        online = self.qnn_online.forward(obss).gather(dim=1, index=actions.unsqueeze(1))

        loss = self.qnn_online.loss(online, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.qnn_online.parameters(), 1.0)
        
        self.qnn_online.optimizer.step()
        self.decrement_epsilon()
    
    def save(self):
        self.qnn_online.save()
    
    def load(self):
        self.qnn_online.load()
        
    def replace_target_network(self):
        self.qnn_target.load_state_dict(self.qnn_online.state_dict())


In [1]:
ENVS = [
    {
        'NAME': 'PongNoFrameskip-v4',
        'UNWRAPPED': False,
        'SOLVED_REWARD': 15
    }
]

In [13]:
EPISODES = 500
MIN_EPSILON=0.1
MAX_EPSILON=1
EPSILON_DECAY=1e-5
GAMMA = 0.99
LEARNING_RATE = 0.0001
MEMORY_SIZE = 50000
BATCH_SIZE = 32
REPLACE_TARGET = 1000
SAVE_DIR = './progress'
RESULTS_DIR = './results'

In [14]:
#Main loop
for env in ENVS:

    NAME = env['NAME']
    UNWRAPPED = env['UNWRAPPED']
    SOLVED_REWARD = env['SOLVED_REWARD']
    
    print(f'--- TRAINING ENVIRONMENT {NAME}---\n')
    
    if UNWRAPPED:  
        ENV = gym.make(NAME).unwrapped
    else: 
        ENV = gym.make(NAME)
    ENV = apply_wrappers(ENV)
    ACTION_SIZE = ENV.action_space.n
    INPUT_SIZE = ENV.observation_space.shape

    agent = DQN(INPUT_SIZE, ACTION_SIZE, 
                min_epsilon=MIN_EPSILON, max_epsilon=MAX_EPSILON, epsilon_decay=EPSILON_DECAY,
                gamma=GAMMA, lr=LEARNING_RATE, 
                mem_size=MEMORY_SIZE, batch_size=BATCH_SIZE, 
                save_dir=SAVE_DIR, name=NAME)


    reward_tracking = []
    best_mean = -1000
    reward_mean = -1000
    counter = 0
    for episode in tqdm(range(EPISODES)):
        obs, done = ENV.reset(), False
        reward_sum = 0
        while not done:
            action = agent.epsilon_greedy(obs)
            next_obs, reward, done, info = ENV.step(action)
#             is_truncated = 'TimeLimit.truncated' in info and info['TimeLimit.truncated']
            terminal = done and (not is_truncated)
            reward_sum += reward
            agent.add_memory(obs, action, reward, next_obs, terminal)
            obs = next_obs
            agent.learn()
            if(counter + 1) % REPLACE_TARGET == 0:
                agent.replace_target_network()
                counter = 0
            counter += 1
            ENV.render()
        reward_tracking.append(reward_sum)

        # OUTPUT INFO
        if (episode > 10):
            reward_mean = np.array(reward_tracking[-10:]).mean()
            if reward_mean > best_mean:
                best_mean = reward_mean
                agent.save()

            print(f'best_mean: {best_mean}, current_mean: {reward_mean}', end='\r')

            if best_mean >= SOLVED_REWARD:
                print('\n', flush=True)
                print(f'---GOAL REACHED AFTER {episode} EPISODES---')
                print('\n')
                break


--- TRAINING ENVIRONMENT BreakoutNoFrameskip-v4---



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))

5est_mean: 0.1, current_mean: 0.1


KeyboardInterrupt: 

In [15]:
ENV.close()

In [None]:
for env in ENVS:
#     frames = []
    
    NAME = env['NAME']
    
    ENV = gym.make(NAME)
    ENV = apply_wrappers(ENV)
    ACTION_SIZE = ENV.action_space.n
    INPUT_SIZE = ENV.observation_space.shape
    
    agent = DQN(INPUT_SIZE, ACTION_SIZE, 
                min_epsilon=MIN_EPSILON, max_epsilon=MAX_EPSILON, epsilon_decay=EPSILON_DECAY,
                gamma=GAMMA, lr=LEARNING_RATE, 
                mem_size=MEMORY_SIZE, batch_size=BATCH_SIZE, 
                save_dir=SAVE_DIR, name=NAME)
    agent.load()
    

    obs, done = ENV.reset(), False
    
#     frames.append(Image.fromarray(ENV.render(mode='rgb_array')))
    eval_score = 0
    count = 0
    while not done:
        action = agent.greedy(obs)
        next_obs, reward, done, _ = ENV.step(action)
        ENV.render()
        eval_score += reward
        obs = next_obs
        count += 1
        print(count, end='\r')
#         frames.append(Image.fromarray(ENV.render(mode='rgb_array')))
    ENV.close()
    
#     path = os.path.join(RESULTS_DIR, NAME+'.gif')
#     with open(path, 'wb') as f:
#         im = Image.new('RGB', frames[0].size)
#         im.save(f, save_all=True, append_images=frames, loop=0, duration=25) 

In [None]:
ENV.close()