# Using Deep Q-learning with tf-agents to solve breakout atari game

![breakout_gif](https://i.imgur.com/rRxXF4H.gif "breakout gif")

In [None]:
import psutil
import base64
import imageio
import IPython
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import gym
import tensorflow as tf
from datetime import datetime
from tf_agents.agents.dqn import dqn_agent
from tf_agents.environments import suite_gym, gym_wrapper, tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import sequential
from tf_agents.policies import random_tf_policy, epsilon_greedy_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.specs import tensor_spec
from tf_agents.utils import common
from tf_agents.policies import PolicySaver

from rl_src.atari_games.environment_preprocessing import wrap_atari_deepmind, hard_reset, get_timelimit_env

np.random.seed(42)

print("Total memory: {}GB".format(psutil.virtual_memory().total / (1024.0 ** 3)))


def now():
    return datetime.now().strftime("%Y-%M-%d %H:%M:%S")

# Create and visualize environment

In [None]:
env_name = 'BreakoutNoFrameskip-v4'
train_gym_env = wrap_atari_deepmind(gym.make(env_name), frame_skip=4, frame_stack=True, scale=True, steps_limit=10000)
eval_gym_env = wrap_atari_deepmind(gym.make(env_name), frame_skip=4, frame_stack=True, scale=True, steps_limit=10000)

train_py_env = gym_wrapper.GymWrapper(train_gym_env)
eval_py_env = gym_wrapper.GymWrapper(eval_gym_env)

In [None]:
print('Observation Spec:\n', train_py_env.time_step_spec().observation)
print('Reward Spec:\n', train_py_env.time_step_spec().reward)
print('Action Spec:\n', train_py_env.action_spec())
print(' Action meanings:', train_py_env.unwrapped.get_action_meanings())

In [None]:
# visualize game
plt.imshow(train_py_env.render());

In [None]:
train_py_env.render().shape

In [None]:
# confirm that episodes run smoothly by running a number of environment steps with random actions

df = pd.DataFrame(columns={'step_type': [], 'lives': [], 'reward': []})  # keep track of episode progress
for i in range(1000):
    obs = eval_py_env.step(np.random.randint(2, 4))  # random step left or right
    new_point = {
        'step_type': int(obs.step_type), 
        'lives': eval_py_env.env.unwrapped.ale.lives(),
        'reward': float(obs.reward)
    }
    df = df.append(new_point, ignore_index=True)

hard_reset(eval_py_env);

In [None]:
fig, ax = plt.subplots(figsize=(20, 5))

df.plot(ax=ax, lw=1.5, alpha=0.75)
ax.set_xlabel('Step', fontsize=16)
ax.legend(fontsize=14);

In [None]:
# plot each frame in one observation
fig, axes = plt.subplots(1, 4, figsize=(20, 6))

for i in range(len(axes)):
    axes[i].imshow(obs.observation[:,:,i], cmap='gray');

In [None]:
# plot max pixel from each frame
plt.imshow(np.max(obs.observation, axis=-1), cmap='gray');

In [None]:
# convert numpy arrays to tensors within the environment
train_env = tf_py_environment.TFPyEnvironment(train_py_env)
eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)

# Create agent

In [None]:
tf.keras.backend.clear_session()

input_shape = train_py_env.observation_spec().shape
print("Input shape: {}".format(input_shape))
# network with one final Dense layer that use num_actions output nodes
network_layers = [
    tf.keras.layers.InputLayer(input_shape, dtype=tf.float32, name='input'),
    
    tf.keras.layers.Conv2D(filters=32, kernel_size=8, strides=4, activation='relu', name='conv2d_1', dtype=tf.float32),
    tf.keras.layers.Conv2D(filters=64, kernel_size=4, strides=2, activation='relu', name='conv2d_2', dtype=tf.float32),
    tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=1, activation='relu', name='conv2d_3', dtype=tf.float32),
    
    tf.keras.layers.Flatten(name='flatten'),
    
    tf.keras.layers.Dense(512, activation='relu', name='dense_1'),
    tf.keras.layers.Dense(train_py_env.action_spec().num_values, activation='linear', name='dense_2')
]

q_net = sequential.Sequential(network_layers)

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.00025, clipnorm=1.0)

train_step_counter = tf.Variable(0)

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    epsilon_greedy=1.1,
    boltzmann_temperature=None,
    target_update_period=10000,
    td_errors_loss_fn=common.element_wise_huber_loss,
    gamma=0.99,
    train_step_counter=train_step_counter)

agent.initialize()

agent._q_network.summary()

In [None]:
eval_policy = agent.policy  # greedy policy
collect_policy = agent.collect_policy  # epsilon-greedy policy

random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(), train_env.action_spec())  # random policy

In [None]:
def update_collect_policy_epsilon(agent, new_epsilon):
    """Utility function to update the collect_policies' epsilon.
    """
    agent._epsilon_greedy = new_epsilon
    agent._collect_policy = epsilon_greedy_policy.EpsilonGreedyPolicy(agent.policy, epsilon=agent._epsilon_greedy)


def compute_avg_return(environment, policy, num_games=5):
    """Make a hard reset on the environment and play num_games.
    """
    total_return = 0.0
    total_steps = 0.0
    get_info = environment.pyenv.get_info if isinstance(environment, tf_py_environment.TFPyEnvironment) else environment.get_info
    get_lives = environment.pyenv.envs[0].unwrapped.ale.lives if isinstance(environment, tf_py_environment.TFPyEnvironment) else environment.unwrapped.ale.lives
    time_limit_env = get_timelimit_env(environment)
    
    for _ in range(num_games):
        time_step = hard_reset(environment)
        game_return = 0.0
        truncated = False
        
        while get_lives() > 0 and not truncated:
            action_step = policy.action(time_step)
            time_step = environment.step(action_step.action)
            game_return += time_step.reward
            info = get_info()
            if isinstance(info, dict):
                truncated = 'TimeLimit.truncated' in info.keys()
            
        total_return += game_return
        total_steps += time_limit_env._elapsed_steps

    avg_return = total_return / num_games
    avg_steps = total_steps / num_games
    return avg_return.numpy()[0], avg_steps

In [None]:
# average return and number of steps under random policy
compute_avg_return(eval_env, random_policy, num_games=3)

# Create replay buffer

In [None]:
replay_buffer_max_length = 100000

replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,  # train_env.batch_size=1
    max_length=replay_buffer_max_length
)

In [None]:
def collect_step(environment, policy, buffer):
    time_step = environment.current_time_step()
    action_step = policy.action(time_step)
    next_time_step = environment.step(action_step.action)
    traj = trajectory.from_transition(time_step, action_step, next_time_step)

    # Add trajectory to the replay buffer
    buffer.add_batch(traj)
    
def collect_data(env, policy, buffer, steps):
    for _ in range(steps):
        collect_step(env, policy, buffer)


initial_collect_steps = 100
collect_data(train_env, random_policy, replay_buffer, initial_collect_steps)

In [None]:
batch_size = 32

# Dataset generates trajectories with shape [Bx2x...]
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3, 
    sample_batch_size=batch_size, 
    num_steps=2
).prefetch(3)

dataset

In [None]:
iterator = iter(dataset)
iterator

# Train agent

In [None]:
num_iterations = 10000000
num_epsilon_greedy_steps = 260000
num_eval_games = 5
collect_steps_per_iteration = 4  # update agent every collect_steps_per_iteration steps
log_interval = 10000

# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)

# Reset the train step
agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return, avg_steps = compute_avg_return(eval_env, agent.policy, num_eval_games)
returns = [avg_return]
steps = [avg_steps]

best_return = 25  # start storing policies after reaching this amount of return
print("[{}] Starting training...".format(now()))
for _ in range(num_iterations):

    # Collect a few steps using collect_policy and save to the replay buffer.
    collect_data(train_env, agent.collect_policy, replay_buffer, collect_steps_per_iteration)

    # Sample a batch of data from the buffer and update the agent's network.
    experience, unused_info = next(iterator)
    train_loss = agent.train(experience).loss

    step = agent.train_step_counter.numpy()
    
    new_epsilon = max(agent._epsilon_greedy - 1.0 / num_epsilon_greedy_steps, 0.1)
    update_collect_policy_epsilon(agent, new_epsilon)

    if step % log_interval == 0:
        avg_return, avg_steps = compute_avg_return(eval_env, agent.policy, num_eval_games)
        print("[{}]".format(now()) + f" step = {step}: loss = {train_loss:<17,.10f} avg return = {avg_return:<10,.2f} avg steps = {avg_steps:.2f}")
        returns.append(avg_return)
        steps.append(avg_steps)
        if avg_return > best_return:
            PolicySaver(eval_policy).save('breakout_agents/eval_policy_ret{:03d}_st{:04d}'.format(int(avg_return), (step // log_interval)))
            best_return = avg_return

In [None]:
fig, [ax, ax2] = plt.subplots(2, 1, figsize=(15, 10))

iterations = list(range(0, len(returns) * log_interval, log_interval))
ax.plot(iterations, returns, lw=2.5, alpha=0.8, label='returns')

window = 20
rol_mean = [np.mean(returns[i-window: i]) for i in range(window, len(returns))]
ax.plot(iterations[window:], rol_mean, lw=2.5, alpha=0.8, label='rolling mean returns')

ax.set_ylabel('Average Return', fontsize=14)
ax.set_xlabel('Gradient Steps', fontsize=14)
ax.set_xlim(left=0)
ax.set_ylim(bottom=0)
ax.hlines(ax.get_yticks()[1:-1], iterations[0], iterations[-1], lw=0.5, alpha=0.5, ls='--', color='black')
ax.legend(fontsize=13)

ax2.plot(iterations, steps, lw=1.5, alpha=0.5, color='black', label='game steps')
ax2.set_ylabel('Steps per game', fontsize=14)
ax2.set_xlabel('Gradient Steps', fontsize=14)
ax2.set_xlim(left=0)
ax2.set_ylim(bottom=0)
ax2.hlines(ax2.get_yticks()[1:-1], iterations[0], iterations[-1], lw=0.5, alpha=0.5, ls='--', color='black')
ax2.legend(fontsize=13);

# Show actions in video and see returns

In the <a href="https://web.stanford.edu/class/psych209/Readings/MnihEtAlHassibis15NatureControlDeepRL.pdf">paper</a> they mention the following performance: ![breakout_performance](paper-breakout-performance.png "Breakout Performance")

In [None]:
def embed_mp4(filename):
    """Embeds an mp4 file in the notebook.
    """
    video = open(filename,'rb').read()
    b64 = base64.b64encode(video)
    tag = '''
    <video width="640" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
    Your browser does not support the video tag.
    </video>'''.format(b64.decode())
    return IPython.display.HTML(tag)


def create_policy_eval_video(policy, filename, num_games=1, fps=30):
    """Uses eval_env and the provided policy to play games and make
    a video of the gameplay.
    """
    filename = filename + ".mp4"
    with imageio.get_writer(filename, fps=fps) as video:
        for _ in range(num_games):
            total_reward = 0
            get_info = eval_py_env.get_info
            get_lives = eval_py_env.unwrapped.ale.lives
            time_limit_env = get_timelimit_env(eval_py_env)
            
            time_step = hard_reset(eval_env)
            video.append_data(eval_py_env.render())
            truncated = False
            while get_lives() > 0 and not truncated:
                action_step = policy.action(time_step)
                time_step = eval_env.step(action_step.action)
                video.append_data(eval_py_env.render())
                total_reward += time_step.reward.numpy()[0]
                info = get_info()
                if isinstance(info, dict):
                    truncated = 'TimeLimit.truncated' in info.keys()
            
            print("{} steps with reward {}".format(time_limit_env._elapsed_steps, total_reward))
    return embed_mp4(filename)

In [None]:
create_policy_eval_video(agent.policy, "trained-agent")

In [None]:
create_policy_eval_video(random_policy, "random-agent")

In [None]:
# load policy
policy_path = 'breakout_agents/eval_policy_ret999_st9999/'
policy = tf.saved_model.load(policy_path)

create_policy_eval_video(policy, "trained-agent")