# Dueling Deep Q-Network (Dueling-DQN)
---
In this notebook, you will implement a Dueling-DQN agent with OpenAI Gym's Breakout environment.
The Agent uses game pixel frames as input state and predicts the action q-values.

### 1. Import the Necessary Packages

In [None]:
import gym
# !pip3 install box2d
import random
import torch
import numpy as np
from collections import deque
import cv2
import os
import matplotlib.pyplot as plt
%matplotlib inline

# !python -m pip install pyvirtualdisplay
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

is_ipython = 'inline' in plt.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

In [None]:
from wrappers import *

In [None]:
if not os.path.exists('./models'):
    os.makedirs('./models')

In [None]:
%load_ext autoreload
%autoreload 2

### 2. Instantiate the Environment and Agent

Initialize the environment in the code cell below.

In [None]:
# BreakoutDeterministic-v4 -> Selects action every 4th frame and 
# repeats the action for skipped frames
env = gym.make('BreakoutDeterministic-v4')
env = make_env(env)
env.seed(0)
print('State shape: ', env.observation_space.shape)
print('Number of actions: ', env.action_space.n)

In [None]:
def pre_process(state, crop_h=(31, 16), crop_w=(6,6)):
    """
    Given the environment state (pixel image) as input. Do the following
    1. Convert to grayscale
    2. Crops the image as specified by the params
    3. Resize image to size 84 x 84
    
    Params
    ========================
    state   (h, w, c): a 3 channel image representing the pixel state of the environment
    crop_h (int, int): the number of pixels to be cropped from the (top, bottom) of the image
    crop_w (int, int): the number of pixels to be cropped from the (left, right) of the image
    
    Returns
    ========================
    the processed state of shape (1, 84, 84)
    """
    
    state = np.mean(state, axis=2).astype(np.uint8)                          # grayscale
    state = state[crop_h[0]:-crop_h[1],crop_w[0]:-crop_w[1]]                   # crop frame to remove useless pixels
    state = cv2.resize(state, dsize=(84, 84), interpolation=cv2.INTER_CUBIC) # resize
    
    return state.reshape((1, 84, 84))

In [None]:
# Number of frames stacked together to form the context for the Agent
FRAME_HISTORY = 4

In [None]:
from dueling_dqn_agent import Agent

agent = Agent(action_size=4, frame_history=FRAME_HISTORY, seed=0)

In [None]:
frames = []
# watch an untrained agent
state = env.reset()
state = pre_process(state)

stacked_state = deque(maxlen=FRAME_HISTORY)
for _ in range(FRAME_HISTORY-1):
    stacked_state.append(np.zeros_like(state))
stacked_state.append(state)

t = env.render(mode='rgb_array')
frames.append(t)
img = plt.imshow(t)

for j in range(200):
    action = agent.act(np.concatenate(stacked_state, axis=0)) 
    t = env.render(mode='rgb_array')
    frames.append(t)
    img.set_data(t)
    plt.axis('off')
    display.display(plt.gcf())
    display.clear_output(wait=True)
    state, reward, done, _ = env.step(action)
    state = pre_process(state)
    stacked_state.append(state)
    if done:
        break 
        
env.close()

### 3. Train the Agent with DQN

Run the code cell below to train the agent from scratch.  You are welcome to amend the supplied values of the parameters in the function, to try to see if you can get better performance!

In [None]:
def dqn(n_episodes=50000, max_t=20000, eps_start=1.0, eps_end=0.1):
    """Deep Q-Learning.
    
    Params
    ======
        n_episodes (int): maximum number of training episodes
        max_t (int): maximum number of timesteps per episode
        eps_start (float): starting value of epsilon, for epsilon-greedy action selection
        eps_end (float): minimum value of epsilon
        eps_decay (float): multiplicative factor (per episode) for decreasing epsilon
    """
    scores = []                        # list containing scores from each episode
    scores_window = deque(maxlen=100)  # last 100 scores
    moving_avgs = []                   # list containing scores window averages
    eps = eps_start                    # initialize epsilon
    eps_decay = (eps_start - eps_end) / 1e5
    
    for i_episode in range(1, n_episodes+1):
        state = env.reset()
        state = pre_process(state)
        
        stacked_state = deque(maxlen=FRAME_HISTORY)
        for _ in range(FRAME_HISTORY-1):
            stacked_state.append(np.zeros_like(state))
        stacked_state.append(state)

        score = 0
        for t in range(max_t):
            action = agent.act(np.concatenate(stacked_state, axis=0), eps)
            next_state, reward, done, _ = env.step(action)
            next_state = pre_process(next_state)
            agent.step(state, action, reward, next_state, done)
            state = next_state
            stacked_state.append(state)
            score += reward
            if done:
                break 
        scores_window.append(score)       # save most recent score
        scores.append(score)              # save most recent score
        moving_avgs.append(np.mean(scores_window)) # save the moving average
        eps = max(eps_end, eps-eps_decay) # decrease epsilon
        print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)), end="")
        if i_episode % 100 == 0:
            print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)))
        if np.mean(scores_window)>=200.0:
            print('\nEnvironment solved in {:d} episodes!\tAverage Score: {:.2f}'.format(i_episode-100, np.mean(scores_window)))
            torch.save(agent.qnetwork_local.state_dict(), 'checkpoint_best.pth')
            break
        if i_episode % 1000 == 0:
            torch.save(agent.qnetwork_local.state_dict(), f'./models/checkpoint_episode_{i_episode}.pth')
    return scores, moving_avgs

scores, moving_avgs = dqn()

In [None]:
np.save('scores.np', scores)
np.save('moving_avgs.np', moving_avgs)

In [None]:
fig = plt.figure(figsize=(10,6))
ax = fig.add_subplot(111)
plt.plot(np.arange(len(scores)), scores, c='k', label='DQN agent score', alpha=0.3)
plt.plot(np.arange(len(scores)), moving_avgs, c='r', label='Moving average score')
# plt.plot(np.arange(len(scores)), [200.0]*len(scores), c='b', label='baseline', alpha=0.5)
plt.ylabel('Score')
plt.xlabel('Episode #')
plt.legend();
plt.grid()
# plt.savefig('scores_plot.png')
plt.show()

### 4. Watch a Smart Agent!

In the next code cell, you will load the trained weights from file to watch a smart agent!

In [None]:
# load the weights from file
if torch.cuda.is_available():
    map_location=lambda storage, loc: storage.cuda()
else:
    map_location='cpu'

agent.qnetwork_local.load_state_dict(torch.load('./models/checkpoint_episode_30000.pth', map_location=map_location))

frames = []
for i in range(3):
    state = env.reset()
    state = pre_process(state)

    stacked_state = deque(maxlen=FRAME_HISTORY)
    for _ in range(FRAME_HISTORY-1):
        stacked_state.append(np.zeros_like(state))
    stacked_state.append(state)

    t = env.render(mode='rgb_array')
    frames.append(t)
    img = plt.imshow(t)

    for j in range(200):
        action = agent.act(np.concatenate(stacked_state, axis=0)) 
        t = env.render(mode='rgb_array')
        frames.append(t)
        img.set_data(t)
        plt.axis('off')
        display.display(plt.gcf())
        display.clear_output(wait=True)
        state, reward, done, _ = env.step(action)
        state = pre_process(state)
        stacked_state.append(state)
        if done:
            break 
            
env.close()

In [None]:
from matplotlib import animation

# NOTE: for imagemagick to work, you may need to run `conda install imagemagick`
# and then re run the trained agent

def display_frames_as_gif(frames):
    patch = plt.imshow(frames[0])
    plt.axis('off')
    def animate(i):
        patch.set_data(frames[i])
        
    anim = animation.FuncAnimation(plt.gcf(), animate, frames = len(frames), interval=5)
    anim.save('./breakout_result.gif', writer='imagemagick', fps=30)

In [None]:
display_frames_as_gif(frames)

![](./breakout_result.gif)