In [6]:
# Python version
import sys
assert sys.version_info >= (3, 5)

# In case colab is being used
gColab = "google.colab" in sys.modules

if gColab:
    !apt update && apt install -y libpq-dev libsdl2-dev swig xorg-dev xvfb
    %pip install -U tf-agents pyvirtualdisplay
    %pip install -U gym>=0.21.0
    %pip install -U gym[box2d,atari,accept-rom-license]

# Scikit-Learn version
import sklearn
assert sklearn.__version__ >= "0.20"

# TensorFlow version
import tensorflow as tf
from tensorflow import keras
assert tf.__version__ >= "2.0"

if not tf.config.list_physical_devices('GPU'):
    print("No GPU was detected. Training will be slow.")
    if gColab:
        print("Change runtime to GPU to accelerate performance.")

[33m0% [Working][0m            Hit:1 https://cloud.r-project.org/bin/linux/ubuntu bionic-cran40/ InRelease
Hit:2 http://security.ubuntu.com/ubuntu bionic-security InRelease
Ign:3 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64  InRelease
Hit:4 http://archive.ubuntu.com/ubuntu bionic InRelease
Hit:5 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu bionic InRelease
Ign:6 https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64  InRelease
Hit:7 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64  Release
Hit:8 https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64  Release
Get:9 http://archive.ubuntu.com/ubuntu bionic-updates InRelease [88.7 kB]
Hit:10 http://ppa.launchpad.net/cran/libgit2/ubuntu bionic InRelease
Hit:11 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu bionic InRelease
Get:12 http://archive.ubuntu.com/ubuntu bionic-backports InRelease [74.6 kB]
Hit:13 http

In [18]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tf_agents.environments import suite_gym
from tf_agents.environments import suite_atari
from tf_agents.environments.atari_preprocessing import AtariPreprocessing
from tf_agents.environments.atari_wrappers import FrameStack4
from tf_agents.environments.tf_py_environment import TFPyEnvironment
from tf_agents.networks.q_network import QNetwork
from tensorflow import keras
from tf_agents.agents.dqn.dqn_agent import DqnAgent
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.metrics import tf_metrics
from tf_agents.drivers.dynamic_step_driver import DynamicStepDriver
from tf_agents.eval.metric_utils import log_metrics
import logging
from tf_agents.policies.random_tf_policy import RandomTFPolicy
from tf_agents.trajectories.trajectory import to_transition
from tf_agents.utils.common import function

%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc('axes', labelsize=14)
mpl.rc('xtick', labelsize=12)
mpl.rc('ytick', labelsize=12)

# animations
import matplotlib.animation as animation
mpl.rc('animation', html='jshtml')

class GameWithAutofire(AtariPreprocessing):
    def reset(self, **kwargs):
        obs = super().reset(**kwargs)
        super().step(1)  # FIRE to start
        return obs

    def step(self, action):
        lives_before_action = self.ale.lives()
        obs, rewards, done, info = super().step(action)
        if self.ale.lives() < lives_before_action and not done:
            super().step(1)  # FIRE to start after life lost
        return obs, rewards, done, info

tf.random.set_seed(42)
np.random.seed(42)
environment = suite_gym.load("Breakout-v4")
max_episode_steps = 27000  # <=> 108k ALE frames since 1 step = 4 frames
environment_name = "BreakoutNoFrameskip-v4"

# Creating the environment
environment = suite_atari.load(
    environment_name,
    max_episode_steps=max_episode_steps,
    gym_env_wrappers=[GameWithAutofire, FrameStack4]) # The game with autofire wrapper is used to restart the game when a life is lost
    # Frame stack 4 outputs the observations composed of multiple frames stacked on top of each other along the channels

# Wrapping the environment in a TFPyEnvironments wrapper to support tensorflow and python code
TensorflowEnv = TFPyEnvironment(environment)

# Setting the seed
environment.seed(42)
environment.reset()

def viewUpdate(num, frames, patch):
    patch.set_data(frames[num])
    return patch,

def displayAnimation(frames, repeat=False, interval=40):
    fig = plt.figure()
    patch = plt.imshow(frames[0])
    plt.axis('off')
    anim = animation.FuncAnimation(
        fig, viewUpdate, fargs=(frames, patch),
        frames=len(frames), repeat=repeat, interval=interval)
    plt.close()
    return anim

def PlotObsFrames(obs):
    # Since there are only 3 color channels, you cannot display 4 frames
    # with one primary color per frame. So this code computes the delta between
    # the current frame and the mean of the other frames, and it adds this delta
    # to the red and blue channels to get a pink color for the current frame.
    obs = obs.astype(np.float32)
    img = obs[..., :3]
    current_frame_delta = np.maximum(obs[..., 3] - obs[..., :3].mean(axis=-1), 0.)
    img[..., 0] += current_frame_delta
    img[..., 2] += current_frame_delta
    img = np.clip(img / 150, 0, 1)
    plt.imshow(img)
    plt.axis("off")


def GetNewScene(num, frames, patch):
    patch.set_data(frames[num])
    return patch,


'''
Creating the Deep Q-Network using the tf_agents library
'''
# Preprocessing layer to normalize the inputs
PreprocessLayer = keras.layers.Lambda(lambda obs: tf.cast(obs, np.float32) / 255.)
# 3 convolutional layers
# Layer1 : 32 filters, filter size = (8,8), stride = 4
# Layer2 : 64 filters, filter size = (4,4), stride = 2
# Layer1 : 64 filters, filter size = (3,3), stride =
ConvLayerHyperparams = [(32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3), 1)]
# Fully connected layer
DenseLayerParams = [550]

# Online model
DQNetworkModel = QNetwork(
    TensorflowEnv.observation_spec(),
    TensorflowEnv.action_spec(),
    preprocessing_layers=PreprocessLayer,
    conv_layer_params=ConvLayerHyperparams,
    fc_layer_params=DenseLayerParams)

# Target model
TargetModel = QNetwork(
    TensorflowEnv.observation_spec(),
    TensorflowEnv.action_spec(),
    preprocessing_layers=PreprocessLayer,
    conv_layer_params=ConvLayerHyperparams,
    fc_layer_params=DenseLayerParams)

    
'''
Creating the agent to play the game
'''
TrainingStep = tf.Variable(0)
UpdatePeriod = 4  # run a training step every 4 collect steps
optimizer = keras.optimizers.RMSprop(learning_rate=2.5e-4, rho=0.95, momentum=0.0,
                                     epsilon=0.00001, centered=True)
# Using a scheduler to reduce the epsilon value during training
EpsFunction = keras.optimizers.schedules.PolynomialDecay(
    initial_learning_rate=1.0,  # initial epsilon
    decay_steps=250000 // UpdatePeriod,  # <=> 1,000,000 ALE frames
    end_learning_rate=0.01)  # final epsilon

# Defining the agent parameters
PlayerAgent = DqnAgent(TensorflowEnv.time_step_spec(), # Time step specifications of the env
                       TensorflowEnv.action_spec(), # env Action specifications
                       q_network=DQNetworkModel, # DQN model defined earlier
                       optimizer=optimizer, # OPtimizer for the DQN
                       target_q_network=TargetModel, # Target model for predictions
                       target_update_period=2000,  # <=> 32,000 ALE frames
                       td_errors_loss_fn=keras.losses.Huber(reduction="none"), # loss fn for the DQN
                       gamma=0.99,  # discount factor gamma
                       train_step_counter=TrainingStep, # Current training step initialized to 0 as it indicates start of the game
                       epsilon_greedy=lambda: EpsFunction(TrainingStep)) # Epsilon value updated by the scheduler which reduces the eps as timestep increases
# Initialize the agent
PlayerAgent.initialize()

'''
Create the memory buffer to store all experiences
'''
MemoryBuffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=PlayerAgent.collect_data_spec, # Data that will be saved in the buffer
    batch_size=TensorflowEnv.batch_size, # Number of trajectories that will be stored at each step
    max_length=10000)  # reduce if OOM error

# Initialize the observer
Observer = MemoryBuffer.add_batch


class WarmingUp:
    def __init__(self, total):
        self.counter = 0
        self.total = total

    def __call__(self, trajectory):
        if not trajectory.is_boundary():
            self.counter += 1
        if self.counter % 100 == 0:
            print("\r{}/{}".format(self.counter, self.total), end="")

# Defining the metrics
TrainingMetrics = [
    tf_metrics.NumberOfEpisodes(),
    tf_metrics.EnvironmentSteps(),
    tf_metrics.AverageReturnMetric(),
    tf_metrics.AverageEpisodeLengthMetric(),
]
# saving the metrics
logging.getLogger().setLevel(logging.INFO)
log_metrics(TrainingMetrics)

'''
Defining the collect driver
'''
CollectDriver = DynamicStepDriver(
    TensorflowEnv, # Environment
    PlayerAgent.collect_policy,
    observers=[Observer] + TrainingMetrics,
    num_steps=UpdatePeriod)  # collect 4 steps for each training iteration

initial_collect_policy = RandomTFPolicy(TensorflowEnv.time_step_spec(),
                                        TensorflowEnv.action_spec())
InitializeDriver = DynamicStepDriver(
    TensorflowEnv,
    initial_collect_policy,
    observers=[MemoryBuffer.add_batch, WarmingUp(10000)],
    num_steps=10000)  # <=> 80,000 ALE frames
FinalTimeStep, FinalPolicyState = InitializeDriver.run()

tf.random.set_seed(9)  # chosen to show an example of trajectory at the end of an episode

'''
Creating the dataset
'''
trajectories, buffer_info = next(iter(MemoryBuffer.as_dataset(
    sample_batch_size=2,
    num_steps=3,
    single_deterministic_pass=False)))

timeSteps, actionSteps, nextTimeSteps = to_transition(trajectories)
dataset = MemoryBuffer.as_dataset(
    sample_batch_size=64,
    num_steps=2,
    num_parallel_calls=3).prefetch(3)

CollectDriver.run = function(CollectDriver.run)
PlayerAgent.train = function(PlayerAgent.train)

'''
Training function
'''
def AgentTraining(n_iterations):
    time_step = None
    policy_state = PlayerAgent.collect_policy.get_initial_state(TensorflowEnv.batch_size)
    iterator = iter(dataset)
    for iteration in range(n_iterations):
        time_step, policy_state = CollectDriver.run(time_step, policy_state)
        trajectories, buffer_info = next(iterator)
        train_loss = PlayerAgent.train(trajectories)
        print("\r{} loss:{:.5f}".format(iteration, train_loss.loss.numpy()), end="")
        # Display metrics every 1000 iterations
        if iteration % 1000 == 0:
            log_metrics(TrainingMetrics)


# Training the agent for 100000 iterations
AgentTraining(n_iterations=100000)


'''
Rendering after training
'''
frames = []
def StoreFrames(trajectory):
    global frames
    frames.append(TensorflowEnv.pyenv.envs[0].render(mode="rgb_array"))

WatchDriver = DynamicStepDriver(
    TensorflowEnv,
    PlayerAgent.policy,
    observers=[StoreFrames, WarmingUp(1000)],
    num_steps=1000)
FinaltimeStep, FinalPolicyState = WatchDriver.run()

displayAnimation(frames)


INFO:absl: 
		 NumberOfEpisodes = 0
		 EnvironmentSteps = 0
		 AverageReturn = 0.0
		 AverageEpisodeLength = 0.0


10000/10000

INFO:absl: 
		 NumberOfEpisodes = 0
		 EnvironmentSteps = 4
		 AverageReturn = 0.0
		 AverageEpisodeLength = 0.0


998 loss:0.00734

INFO:absl: 
		 NumberOfEpisodes = 26
		 EnvironmentSteps = 4004
		 AverageReturn = 0.8999999761581421
		 AverageEpisodeLength = 154.3000030517578


1995 loss:0.00086

INFO:absl: 
		 NumberOfEpisodes = 48
		 EnvironmentSteps = 8004
		 AverageReturn = 1.0
		 AverageEpisodeLength = 174.1999969482422


2995 loss:0.00011

INFO:absl: 
		 NumberOfEpisodes = 73
		 EnvironmentSteps = 12004
		 AverageReturn = 1.5
		 AverageEpisodeLength = 180.3000030517578


3999 loss:0.00012

INFO:absl: 
		 NumberOfEpisodes = 97
		 EnvironmentSteps = 16004
		 AverageReturn = 1.2999999523162842
		 AverageEpisodeLength = 172.3000030517578


4997 loss:0.00012

INFO:absl: 
		 NumberOfEpisodes = 120
		 EnvironmentSteps = 20004
		 AverageReturn = 1.600000023841858
		 AverageEpisodeLength = 196.8000030517578


5999 loss:0.00005

INFO:absl: 
		 NumberOfEpisodes = 146
		 EnvironmentSteps = 24004
		 AverageReturn = 0.6000000238418579
		 AverageEpisodeLength = 139.60000610351562


6998 loss:0.00047

INFO:absl: 
		 NumberOfEpisodes = 169
		 EnvironmentSteps = 28004
		 AverageReturn = 0.8999999761581421
		 AverageEpisodeLength = 176.1999969482422


7995 loss:0.00012

INFO:absl: 
		 NumberOfEpisodes = 192
		 EnvironmentSteps = 32004
		 AverageReturn = 0.6000000238418579
		 AverageEpisodeLength = 196.89999389648438


8998 loss:0.00009

INFO:absl: 
		 NumberOfEpisodes = 215
		 EnvironmentSteps = 36004
		 AverageReturn = 1.0
		 AverageEpisodeLength = 201.3000030517578


9998 loss:0.00024

INFO:absl: 
		 NumberOfEpisodes = 237
		 EnvironmentSteps = 40004
		 AverageReturn = 1.2000000476837158
		 AverageEpisodeLength = 206.6999969482422


10995 loss:0.00030

INFO:absl: 
		 NumberOfEpisodes = 257
		 EnvironmentSteps = 44004
		 AverageReturn = 0.800000011920929
		 AverageEpisodeLength = 197.1999969482422


11997 loss:0.00026

INFO:absl: 
		 NumberOfEpisodes = 280
		 EnvironmentSteps = 48004
		 AverageReturn = 0.800000011920929
		 AverageEpisodeLength = 167.0


12996 loss:0.00009

INFO:absl: 
		 NumberOfEpisodes = 303
		 EnvironmentSteps = 52004
		 AverageReturn = 0.699999988079071
		 AverageEpisodeLength = 167.3000030517578


13999 loss:0.00012

INFO:absl: 
		 NumberOfEpisodes = 325
		 EnvironmentSteps = 56004
		 AverageReturn = 0.8999999761581421
		 AverageEpisodeLength = 179.10000610351562


14997 loss:0.00005

INFO:absl: 
		 NumberOfEpisodes = 346
		 EnvironmentSteps = 60004
		 AverageReturn = 1.5
		 AverageEpisodeLength = 195.5


15997 loss:0.00014

INFO:absl: 
		 NumberOfEpisodes = 370
		 EnvironmentSteps = 64004
		 AverageReturn = 0.8999999761581421
		 AverageEpisodeLength = 186.3000030517578


16995 loss:0.00005

INFO:absl: 
		 NumberOfEpisodes = 397
		 EnvironmentSteps = 68004
		 AverageReturn = 0.8999999761581421
		 AverageEpisodeLength = 149.8000030517578


17999 loss:0.00012

INFO:absl: 
		 NumberOfEpisodes = 418
		 EnvironmentSteps = 72004
		 AverageReturn = 0.800000011920929
		 AverageEpisodeLength = 175.8000030517578


18998 loss:0.00011

INFO:absl: 
		 NumberOfEpisodes = 440
		 EnvironmentSteps = 76004
		 AverageReturn = 0.800000011920929
		 AverageEpisodeLength = 152.8000030517578


19997 loss:0.00004

INFO:absl: 
		 NumberOfEpisodes = 466
		 EnvironmentSteps = 80004
		 AverageReturn = 1.2999999523162842
		 AverageEpisodeLength = 170.0


20995 loss:0.00009

INFO:absl: 
		 NumberOfEpisodes = 490
		 EnvironmentSteps = 84004
		 AverageReturn = 1.0
		 AverageEpisodeLength = 174.8000030517578


21998 loss:0.00010

INFO:absl: 
		 NumberOfEpisodes = 514
		 EnvironmentSteps = 88004
		 AverageReturn = 1.2999999523162842
		 AverageEpisodeLength = 167.5


22997 loss:0.00012

INFO:absl: 
		 NumberOfEpisodes = 539
		 EnvironmentSteps = 92004
		 AverageReturn = 0.8999999761581421
		 AverageEpisodeLength = 150.10000610351562


23997 loss:0.00048

INFO:absl: 
		 NumberOfEpisodes = 559
		 EnvironmentSteps = 96004
		 AverageReturn = 1.100000023841858
		 AverageEpisodeLength = 178.8000030517578


24999 loss:0.00009

INFO:absl: 
		 NumberOfEpisodes = 581
		 EnvironmentSteps = 100004
		 AverageReturn = 2.0999999046325684
		 AverageEpisodeLength = 199.89999389648438


25998 loss:0.00007

INFO:absl: 
		 NumberOfEpisodes = 602
		 EnvironmentSteps = 104004
		 AverageReturn = 0.8999999761581421
		 AverageEpisodeLength = 171.8000030517578


26999 loss:0.00008

INFO:absl: 
		 NumberOfEpisodes = 627
		 EnvironmentSteps = 108004
		 AverageReturn = 1.600000023841858
		 AverageEpisodeLength = 194.8000030517578


27999 loss:0.00014

INFO:absl: 
		 NumberOfEpisodes = 651
		 EnvironmentSteps = 112004
		 AverageReturn = 0.699999988079071
		 AverageEpisodeLength = 161.5


28999 loss:0.00011

INFO:absl: 
		 NumberOfEpisodes = 673
		 EnvironmentSteps = 116004
		 AverageReturn = 1.2000000476837158
		 AverageEpisodeLength = 190.5


29995 loss:0.00010

INFO:absl: 
		 NumberOfEpisodes = 699
		 EnvironmentSteps = 120004
		 AverageReturn = 0.699999988079071
		 AverageEpisodeLength = 155.39999389648438


30998 loss:0.00122

INFO:absl: 
		 NumberOfEpisodes = 719
		 EnvironmentSteps = 124004
		 AverageReturn = 1.399999976158142
		 AverageEpisodeLength = 183.6999969482422


31997 loss:0.00037

INFO:absl: 
		 NumberOfEpisodes = 742
		 EnvironmentSteps = 128004
		 AverageReturn = 0.8999999761581421
		 AverageEpisodeLength = 169.1999969482422


32998 loss:0.00004

INFO:absl: 
		 NumberOfEpisodes = 764
		 EnvironmentSteps = 132004
		 AverageReturn = 1.7000000476837158
		 AverageEpisodeLength = 180.3000030517578


33997 loss:0.00010

INFO:absl: 
		 NumberOfEpisodes = 784
		 EnvironmentSteps = 136004
		 AverageReturn = 2.0999999046325684
		 AverageEpisodeLength = 193.6999969482422


34999 loss:0.00072

INFO:absl: 
		 NumberOfEpisodes = 799
		 EnvironmentSteps = 140004
		 AverageReturn = 3.5
		 AverageEpisodeLength = 245.89999389648438


35997 loss:0.00057

INFO:absl: 
		 NumberOfEpisodes = 818
		 EnvironmentSteps = 144004
		 AverageReturn = 2.4000000953674316
		 AverageEpisodeLength = 210.8000030517578


36996 loss:0.00042

INFO:absl: 
		 NumberOfEpisodes = 835
		 EnvironmentSteps = 148004
		 AverageReturn = 3.0999999046325684
		 AverageEpisodeLength = 210.5


37998 loss:0.00020

INFO:absl: 
		 NumberOfEpisodes = 851
		 EnvironmentSteps = 152004
		 AverageReturn = 3.0999999046325684
		 AverageEpisodeLength = 283.79998779296875


38998 loss:0.00053

INFO:absl: 
		 NumberOfEpisodes = 861
		 EnvironmentSteps = 156004
		 AverageReturn = 5.0
		 AverageEpisodeLength = 402.3999938964844


39999 loss:0.00406

INFO:absl: 
		 NumberOfEpisodes = 873
		 EnvironmentSteps = 160004
		 AverageReturn = 4.0
		 AverageEpisodeLength = 325.0


40997 loss:0.00073

INFO:absl: 
		 NumberOfEpisodes = 889
		 EnvironmentSteps = 164004
		 AverageReturn = 6.0
		 AverageEpisodeLength = 268.70001220703125


41998 loss:0.00034

INFO:absl: 
		 NumberOfEpisodes = 906
		 EnvironmentSteps = 168004
		 AverageReturn = 3.9000000953674316
		 AverageEpisodeLength = 225.0


42999 loss:0.00050

INFO:absl: 
		 NumberOfEpisodes = 918
		 EnvironmentSteps = 172004
		 AverageReturn = 6.699999809265137
		 AverageEpisodeLength = 355.3999938964844


43190 loss:0.00031

KeyboardInterrupt: ignored