# Video Pinball

This project aims to teach a reinforcement learning agent to play the game [Atari Video Pinball](https://gymnasium.farama.org/environments/atari/video_pinball/).  
For this purpose, the Deep Q-Learning approach is followed using a neural network.

## Information

General information:

|                   |                                   |
| ----------------- | --------------------------------- |
| Action Space      | Discrete(18)                      |
| Observation Space | (210, 160, 3)                     |
| Observation High  | 255                               |
| Observation Low   | 0                                 |
| Import            | `gym.make("ALE/VideoPinball-v5")` |


### Actions

The `Video Pinball` Game form the Atari (2600) environment has the following actions which are described in the [manual of the game](https://atariage.com/manual_html_page.php?SoftwareLabelID=588):
**This is the reduced action space, which is available when choosing `v0`, `v4` or specifying `full_action_space=false` during initialization. Otherwise more actions will be available.**

| Num | Action    | Description                                                                                  |
| --- | --------- | -------------------------------------------------------------------------------------------- |
| 0   | NOOP      | No Operation                                                                                 |
| 1   | FIRE      | Press the red controller button to release the spring and shoot the ball into the playfield. |
| 2   | UP        | Move the Joystick up to move both flippers at the same time.                                 |
| 3   | RIGHT     | Move the Joystick to the right to move the right flipper up.                                 |
| 4   | LEFT      | Move the Joystick to the left to move the left flipper up.                                   |
| 5   | DOWN      | Pull the Joystick down (towards you) to bring the plunger back.                              |
| 6   | UPFIRE    | "Nudge" the ball into upwards direction.                                                     |
| 7   | RIGHTFIRE | "Nudge" the ball to the right.                                                               |
| 8   | LEFTFIRE  | "Nudge" the ball to the left.                                                                |

Furthermore it might be interesting to try different modes/difficulties of the game.


### Difficulties

There are two available difficulties:

- `a` (aka. pinbal wizards) is for expert players and has two additional drain holes at the bottom
- `b` is for the beginning/novice players

### Observations

By default, the environment returns the RGB image which is displayed to human players as an observation.  
However it is possible to observe
- The 128 Bytes of RAM of the console (`Box([0 ... 0], [255 ... 255], (128,), uint8)`)
- A grayscale image (`Box([[0 ... 0] ... [0  ... 0]], [[255 ... 255] ... [255  ... 255]], (250, 160), uint8)`)

instead. 

## Preparation

### Installs

In [1]:
!apt-get update
!apt-get install -y xvfb python-opengl

0% [Working]            Hit:1 http://ppa.launchpad.net/c2d4u.team/c2d4u4.0+/ubuntu bionic InRelease
0% [Connecting to archive.ubuntu.com (91.189.91.39)] [Connecting to security.ub                                                                               Hit:2 http://ppa.launchpad.net/cran/libgit2/ubuntu bionic InRelease
0% [Connecting to archive.ubuntu.com (91.189.91.39)] [Connecting to security.ub0% [1 InRelease gpgv 15.9 kB] [Connecting to archive.ubuntu.com (91.189.91.39)]                                                                               Hit:3 http://ppa.launchpad.net/deadsnakes/ppa/ubuntu bionic InRelease
0% [1 InRelease gpgv 15.9 kB] [Connecting to archive.ubuntu.com (91.189.91.39)]                                                                               Get:4 https://cloud.r-project.org/bin/linux/ubuntu bionic-cran40/ InRelease [3,626 B]
0% [1 InRelease gpgv 15.9 kB] [Connecting to archive.ubuntu.com (91.189.91.39)]0% [1 InRelease gpgv 15.9 kB]

In [2]:
%pip install -q --upgrade pip
%pip install -q gym==0.21.0
%pip install -q 'gym[atari]==0.12.5'
%pip install -q matplotlib
%pip install -q pyvirtualdisplay

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m50.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m47.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for gym (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m20.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m968.6/968.6 kB[0m [31m26.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m760.8/760.8 kB[0m [31m36.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for atari_py (setup.py) ... [?25l[?25hdone
  Building wheel for gym (setup.py) ... [?25l[?25hdone
[0m

### Download required files from Drive

In [3]:
import sys

if 'google.colab' in sys.modules:
    %pip install -q --upgrade gdown
    from gdown import download_folder

    download_folder("https://drive.google.com/drive/folders/1SW56nbccfHJtC6oGBIcp7XCeJDkKehGK")

[0m

Retrieving folder list


Retrieving folder 1sLRDYhYwlE9uNs5beDdNyrO80oBs3nzI __pycache__
Processing file 14ST9RPnaFI-rgaHgmwQrG7LDEy3Z-OPs abstract_agent.py
Processing file 1uOowZtnqB-Df3n5nhoSCoXNPBod40ZKk atari_helpers.py
Processing file 1xi6hZIOR_R_eQO9het6sy2m7h2GRheu2 check_test.py
Processing file 1Lw5_R_Y0Gk1nNSeG264I9M1w1N1WCyMF loggers.py
Processing file 1HQHDLpU7p3YI1Av_jFunScokmy10jgTX plot_utils.py
Building directory structure completed


Retrieving folder list completed
Building directory structure
Downloading...
From: https://drive.google.com/uc?id=14ST9RPnaFI-rgaHgmwQrG7LDEy3Z-OPs
To: /content/external/abstract_agent.py
100%|██████████| 835/835 [00:00<00:00, 1.94MB/s]
Downloading...
From: https://drive.google.com/uc?id=1uOowZtnqB-Df3n5nhoSCoXNPBod40ZKk
To: /content/external/atari_helpers.py
100%|██████████| 9.11k/9.11k [00:00<00:00, 18.5MB/s]
Downloading...
From: https://drive.google.com/uc?id=1xi6hZIOR_R_eQO9het6sy2m7h2GRheu2
To: /content/external/check_test.py
100%|██████████| 1.56k/1.56k [00:00<00:00, 3.23MB/s]
Downloading...
From: https://drive.google.com/uc?id=1Lw5_R_Y0Gk1nNSeG264I9M1w1N1WCyMF
To: /content/external/loggers.py
100%|██████████| 982/982 [00:00<00:00, 2.25MB/s]
Downloading...
From: https://drive.google.com/uc?id=1HQHDLpU7p3YI1Av_jFunScokmy10jgTX
To: /content/external/plot_utils.py
100%|██████████| 2.28k/2.28k [00:00<00:00, 5.28MB/s]
Download completed


### Imports

In [4]:
import json
import os
import time
from abc import ABC, abstractmethod
from collections import deque
from contextlib import suppress
from datetime import datetime
from random import sample
from typing import Any, Tuple

import gym
import numpy as np
import tensorflow as tf
from keras import Model
from keras.layers import Conv2D, Dense, Flatten, Input, Lambda, multiply
from keras.losses import huber_loss
from keras.optimizers import Adam
from keras.utils import to_categorical
from tensorflow import keras
from tensorflow.compat.v1.keras.backend import set_session
from pyvirtualdisplay import Display
import matplotlib.pyplot as plt
%matplotlib inline

# local files
from external.abstract_agent import AbstractAgent
from external.atari_helpers import LazyFrames, make_atari, wrap_deepmind
from external.loggers import TensorBoardLogger, tf_summary_image
from external.plot_utils import plot_statistics

Instructions for updating:
non-resource variables are not supported in the long term


In [5]:
display = Display(visible=0, size=(1400, 900))
display.start()

is_ipython = 'inline' in plt.get_backend()
if is_ipython:
    from IPython import display
    from IPython.display import SVG

plt.ion()

### Extended DQN-Agent

In [6]:
class AbstractDQNAgent(AbstractAgent):
    __slots__ = [
        "action_size",
        "state_size",
        "gamma",
        "epsilon",
        "epsilon_decay",
        "epsilon_min",
        "alpha",
        "batch_size",
        "memory_size",
        "start_replay_step",
        "target_model_update_interval",
        "train_freq",
    ]

    def __init__(self,
                 action_size: int,
                 state_size: int,
                 gamma: float,
                 epsilon: float,
                 epsilon_decay: float,
                 epsilon_min: float,
                 alpha: float,
                 batch_size: int,
                 memory_size: int,
                 start_replay_step: int,
                 target_model_update_interval: int,
                 train_freq: int,
                 ):
        self.action_size = action_size
        self.state_size = state_size

        self.replay_has_started = False

        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min
        self.alpha = alpha

        self.memory_size = memory_size
        self.memory = deque(maxlen=self.memory_size)
        self.batch_size = batch_size

        self.step = 0
        self.start_replay_step = start_replay_step

        self.target_model_update_interval = target_model_update_interval

        self.train_freq = train_freq

        assert self.start_replay_step >= self.batch_size, "The number of steps to start replay must be at least as large as the batch size"

        self.action_mask = np.ones((1, self.action_size))
        self.action_mask_batch = np.ones((self.batch_size, self.action_size))

        self.tf_config_intra_threads = 8
        self.tf_config_inter_threads = 4
        self.tf_config_soft_placement = True
        self.tf_config_allow_growth = True

        config = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=self.tf_config_intra_threads,
                                inter_op_parallelism_threads=self.tf_config_inter_threads,
                                allow_soft_placement=self.tf_config_soft_placement
                                )

        config.gpu_options.allow_growth = self.tf_config_allow_growth
        session = tf.compat.v1.Session(config=config)
        set_session(session)  # set this TensorFlow session as the default session for Keras

        self.model = self._build_model()
        self.target_model = self._build_model()

    def save(self, target_path: str) -> None:
      """
        Saves the current state of the DQNAgent to some output files.
        Together with `load` this serves as a very rudimentary checkpointing.
      """
      agent_dict = {
            "agent_init": {},
            "agent_params": {},
            "tf_config": {}
        }

      if not os.path.exists(target_path):
        os.makedirs(target_path)

      for slot in self.__slots__:
          agent_dict["agent_init"].update({slot: getattr(self, slot)})

      agent_dict["agent_init"].update({"memory_size": self.memory.maxlen})

      for attr in ["action_mask", "action_mask_batch"]:
          agent_dict["agent_params"].update({attr: getattr(self, attr).tolist()})

      agent_dict["agent_params"].update({"memory": list(self.memory)})

      for tf_config in [
          "tf_config_intra_threads",
          "tf_config_inter_threads",
          "tf_config_soft_placement",
          "tf_config_allow_growth",
      ]:
          agent_dict["tf_config"].update({tf_config: getattr(self, tf_config)})

      with open(os.path.join(target_path, "agent.json"), "w") as f:
          json.dump(agent_dict, f)

      self.model.save_weights(os.path.join(target_path, "model.h5"))
      self.target_model.save_weights(os.path.join(target_path, "target_model.h5"))

    @classmethod
    def load(cls, path: str) -> "AbstractDQNAgent":
      """
        Loads the serialized state of a DQNAgent and returns an instance of it.
      """

      with open(os.path.join(path, "agent.json"), "r") as f:
          agent_dict = json.load(f)

      agent = cls(**agent_dict["agent_init"])

      agent.action_mask = np.array(agent_dict["agent_params"]["action_mask"])
      agent.action_mask_batch = np.array(agent_dict["agent_params"]["action_mask_batch"])

      config = tf.compat.v1.ConfigProto(
          intra_op_parallelism_threads=agent_dict["tf_config"]["tf_config_intra_threads"],
          inter_op_parallelism_threads=agent_dict["tf_config"]["tf_config_inter_threads"],
          allow_soft_placement=agent_dict["tf_config"]["tf_config_soft_placement"])

      config.gpu_options.allow_growth = agent_dict["tf_config"]["tf_config_allow_growth"]
      session = tf.compat.v1.Session(config=config)
      set_session(session)

      agent.model.load_weights('model.h5')
      agent.target_model.load_weights("target_model.h5")

      return agent

    @abstractmethod
    def train(self, experience):
      raise NotImplementedError

    @abstractmethod
    def act(self, state):
      raise NotImplementedError

    @abstractmethod
    def _build_model(self) -> Model:
      raise NotImplementedError

## Deep Q-Learning Network (DQN)

In [7]:
env = make_atari("VideoPinball-v4")
env = wrap_deepmind(env, frame_stack=True) # maps frames to 84x84x4

NoopResetEnv (max 30) wrapper is used.
MaxAndSkipEnv (skip 4) wrapper is used.
EpisodicLifeEnv wrapper is used.
FireResetEnv wrapper is used.
ClipRewardEnv wrapper is used.
FrameStack (4) wrapper is used.


### Create the DQN Agent

Take the given `AbstractDQNAgent` (previously called `DQNAgent`) and add missing methods.

In [8]:
class DQNAgent(AbstractDQNAgent):
    def _build_model(self) -> Model:
        """Deep Q-network as defined in the DeepMind article on Nature
        
        Returns:
            Model: Tensorflow Model which will be used as internal deep neural network
        """

        atari_shape = (84, 84, 4)

        # Frames from the observation
        frames_input = Input(atari_shape, name="frames")

        # Actions as input
        action_mask = Input((self.action_size,), name="action_mask")

        # Normalize the frames from [0, 255] to [0, 1]
        normalized = Lambda(lambda x: x / 255.0, name="normalization")(frames_input)

        # "The first hidden layer convolves 16 8×8 filters with stride 4 with the 
        # input image and applies a rectifier nonlinearity."
        # Results in an output shape of (20, 20, 16)
        conv1 = Conv2D(
            filters=32,
            kernel_size=(8, 8),
            strides=(4, 4),
            activation="relu"
        )(normalized)

        # "The second hidden layer convolves 32 4×4 filters with stride 2, again followed 
        # by a rectifier nonlinearity." 
        # Results in an output shape of (9, 9, 32)
        conv2 = Conv2D(
            filters=64,
            kernel_size=(4,4),
            strides=(2,2),
            activation="relu"
        )(conv1)

        conv3 = Conv2D(
            filters=64,
            kernel_size=(4,4),
            strides=(1,1),
            activation="relu"
        )(conv2)

        # Flattening the last convolutional layer.
        conv_flattened = Flatten()(conv3)

        # "The final hidden layer is fully-connected and consists of 256 rectifier units."
        hidden = Dense(units=512, activation='relu')(conv_flattened)

        # "The output layer is a fully-connected linear layer with a single output 
        # for each valid action."
        output = Dense(self.action_size)(hidden)

        # Multiply the output with the action mask to get only one action output
        filtered_output = multiply([output, action_mask])

        model = Model(inputs=[frames_input, action_mask], outputs=filtered_output)
        model.compile(loss=huber_loss, optimizer=Adam(learning_rate=self.alpha), metrics=None)

        return model


    def act(self, state: LazyFrames) -> int:
        """Selects the action to be executed based on the given state.

        Implements epsilon greedy exploration strategy, i.e. with a probability of
        epsilon, a random action is selected.

        Args:
            state [LazyFrames]: LazyFrames object representing the state based on 4 stacked observations (images)

        Returns:
            action [int]
        """

        if np.random.rand() <= self.epsilon:
            action = env.action_space.sample()
        else:
            # ! TODO self.model.predict oder self.target_model.predict?
            q_values = self.model.predict([[np.array(state)], self.action_mask])
            action = np.argmax(q_values)
        return action

        
    def train(self, experience: Tuple[LazyFrames, int, LazyFrames, float, bool]):
        """Stores the experience in memory. If memory is full trains network by replay.

        Args:
            experience [tuple]: Tuple of state, action, next state, reward, done.

        Returns:
            None
        """
        
        self.memory.append(experience)
        
        #  - Update epsilon as long as it is not minimal
        #  - Update weights of the target model (syn of the two models)
        #  - Execute replay
        if self.step >= self.start_replay_step:
            if self.epsilon > self.epsilon_min:
                self.epsilon -= self.epsilon_decay
            if self.step % self.target_model_update_interval == 0:
                self.target_model.set_weights(self.model.get_weights())
            if self.step % self.train_freq == 0:
                self._replay()

        self.step += 1


    def _replay(self) -> None:
        """Gets random experiences from memory for batch update of Q-function.

        Returns:
            None
        """

        states, actions, next_states, rewards, dones = [np.array(memory) for memory in zip(*sample(self.memory, self.batch_size))]

        # ! Can be left out if useless
        assert all(isinstance(x, np.ndarray) for x in (states, actions, rewards, next_states, dones)), \
            "All experience batches should be of type np.ndarray."
        assert states.shape == (self.batch_size, 84, 84, 4), \
            f"States shape should be: {(self.batch_size, 84, 84, 4)}"
        assert actions.shape == (self.batch_size,), f"Actions shape should be: {(self.batch_size,)}"
        assert rewards.shape == (self.batch_size,), f"Rewards shape should be: {(self.batch_size,)}"
        assert next_states.shape == (self.batch_size, 84, 84, 4), \
            f"Next states shape should be: {(self.batch_size, 84, 84, 4)}"
        assert dones.shape == (self.batch_size,), f"Dones shape should be: {(self.batch_size,)}"

        # Predict the Q values of the next states. Passing ones as the action mask.
        next_q_values = self.target_model.predict([next_states, self.action_mask_batch], verbose=0)

        # Calculate the Q values.
        # - Terminal states get the reward
        # - Non-terminal states get reward + gamma * max next_state q_value
        q_values = [reward + (1 - done) * self.gamma * np.max(next_q_value) for done, reward, next_q_value in zip(dones, rewards, next_q_values)]

        # Create a one hot encoding of the actions (the selected action is 1 all others 0)
        one_hot_actions = to_categorical(actions, num_classes=self.action_size)

        # Create the target Q values based on the one hot encoding of the actions and the calculated Q values
        # This can be seen as matrix multiplication
        # q_values = [0.5, 0.7, 0.9]
        # actions [[1. 0. 0. 0.]
        #          [0. 0. 1. 0.]
        #          [0. 0. 0. 1.]]
        # output  [[0.5 0.  0.   0. ]
        #          [0.  0.  0.7  0. ]
        #          [0.  0.  0.9  0. ]]
        target_q_values = np.array(q_values)[np.newaxis].T * one_hot_actions

        # Fit the model with the given states and the selected actions as one hot vector and the target_q_values as y
        self.model.fit(
           x=[states, one_hot_actions],  # states and mask
           y=target_q_values,  # target Q values
           batch_size=self.batch_size,
           verbose=0
        )
        


In [9]:
def interact_with_environment(env, agent, n_episodes=600, max_steps=1000000, train=True, verbose=True):      
    statistics = []
    tb_logger = TensorBoardLogger(f'./logs/run-{datetime.now().strftime("%Y-%m-%d_%H:%M:%S")}')
    
    with suppress(KeyboardInterrupt):
        total_step = 0
        for episode in range(n_episodes):
            done = False
            episode_reward = 0
            state = env.reset()
            episode_start_time = time.time()
            episode_step = 0

            while not done:
                action = agent.act(state)
                next_state, reward, done, _ = env.step(action)

                if train:
                    agent.train((state, action, next_state, reward, done))

                if episode == 0:
                    # for debug purpose log every state of first episode
                    for obs in state:
                        tb_logger.log_image(f'state_t{episode_step}:', tf_summary_image(np.array(obs, copy=False)),
                                            global_step=total_step)
                state = next_state
                episode_reward += reward
                episode_step += 1
            
            total_step += episode_step

            if episode % 10 == 0:
                speed = episode_step / (time.time() - episode_start_time)
                tb_logger.log_scalar('score', episode_reward, global_step=total_step)
                tb_logger.log_scalar('epsilon', agent.epsilon, global_step=total_step)
                tb_logger.log_scalar('speed', speed, global_step=total_step)
                if verbose:
                    print(f'episode: {episode}/{n_episodes}, score: {episode_reward}, steps: {episode_step}, '
                          f'total steps: {total_step}, e: {agent.epsilon:.3f}, speed: {speed:.2f} steps/s')

            statistics.append({
                'episode': episode,
                'score': episode_reward,
                'steps': episode_step
            })
                                  
            if total_step >= max_steps:
                break
        
    return statistics

In [None]:
action_size = env.action_space.n
state_size = env.observation_space.shape[0]

# Hyperparams (should be sufficient)
episodes = 20000
annealing_steps = 100000  # not episodes!
gamma = 0.99
epsilon = 1.0
epsilon_min = 0.01
epsilon_decay = 0.000004
alpha = 0.0001
batch_size = 64
memory_size = 100000
start_replay_step = 100000
target_model_update_interval = 1000
train_freq = 4

agent = DQNAgent(action_size=action_size, state_size=state_size, gamma=gamma, 
                 epsilon=epsilon, epsilon_decay=epsilon_decay, epsilon_min=epsilon_min, 
                 alpha=alpha, batch_size=batch_size, memory_size=memory_size,
                 start_replay_step=start_replay_step, 
                 target_model_update_interval=target_model_update_interval, train_freq=train_freq)

statistics = interact_with_environment(env, agent, n_episodes=episodes, verbose=True)
env.close()
plot_statistics(statistics)

episode: 0/20000, score: 10.0, steps: 91, total steps: 91, e: 1.000, speed: 163.80 steps/s
episode: 10/20000, score: 52.0, steps: 387, total steps: 2117, e: 1.000, speed: 124.81 steps/s
episode: 20/20000, score: 23.0, steps: 200, total steps: 3330, e: 1.000, speed: 288.12 steps/s
episode: 30/20000, score: 22.0, steps: 99, total steps: 5477, e: 1.000, speed: 310.10 steps/s
episode: 40/20000, score: 11.0, steps: 163, total steps: 7900, e: 1.000, speed: 305.14 steps/s
episode: 50/20000, score: 1.0, steps: 88, total steps: 9544, e: 1.000, speed: 306.25 steps/s
episode: 60/20000, score: 1.0, steps: 45, total steps: 11113, e: 1.000, speed: 306.62 steps/s
episode: 70/20000, score: 7.0, steps: 135, total steps: 13366, e: 1.000, speed: 306.62 steps/s
episode: 80/20000, score: 6.0, steps: 69, total steps: 14750, e: 1.000, speed: 290.98 steps/s
episode: 90/20000, score: 45.0, steps: 199, total steps: 16199, e: 1.000, speed: 299.54 steps/s
episode: 100/20000, score: 10.0, steps: 130, total steps: 

  updates=self.state_updates,


episode: 420/20000, score: 80.0, steps: 513, total steps: 100139, e: 0.999, speed: 45.10 steps/s
episode: 430/20000, score: 0.0, steps: 132, total steps: 101162, e: 0.995, speed: 111.38 steps/s
episode: 440/20000, score: 1.0, steps: 54, total steps: 102829, e: 0.989, speed: 76.52 steps/s
episode: 450/20000, score: 16.0, steps: 64, total steps: 104875, e: 0.980, speed: 121.41 steps/s
episode: 460/20000, score: 13.0, steps: 94, total steps: 106765, e: 0.973, speed: 115.79 steps/s
episode: 470/20000, score: 117.0, steps: 628, total steps: 108922, e: 0.964, speed: 123.64 steps/s
episode: 480/20000, score: 31.0, steps: 226, total steps: 110530, e: 0.958, speed: 121.27 steps/s
episode: 490/20000, score: 21.0, steps: 125, total steps: 112445, e: 0.950, speed: 124.67 steps/s
episode: 500/20000, score: 6.0, steps: 54, total steps: 113772, e: 0.945, speed: 127.08 steps/s
episode: 510/20000, score: 19.0, steps: 149, total steps: 115575, e: 0.938, speed: 121.31 steps/s
episode: 520/20000, score: 1

In [None]:
for i in range(3):
    state = env.reset()
    img = plt.imshow(env.render(mode='rgb_array'))
    for j in range(200):
        action = agent.act(state)
        img.set_data(env.render(mode='rgb_array')) 
        plt.axis('off')
        display.display(plt.gcf())
        display.clear_output(wait=True)
        state, reward, done, _ = env.step(action)
        if done:
            break 
            
env.close()

In [None]:
tf.keras.utils.plot_model(agent.model, to_file='keras_plot_model_2.png', show_shapes=True)
display.Image('keras_plot_model_2.png')

In [None]:
import pickle
from os import path

save_dir = "./saved_model"

In [None]:
agent.model.save(path.join(save_dir, "model.tf"))
agent.target_model.save(path.join(save_dir, "target_model.tf"))

In [None]:
agent.model = None
agent.target_model = None

In [None]:
with open(path.join(save_dir, "agent.pkl"), "wb") as f:
    pickle.dump(agent, f)
    f.close()

Load the model

In [None]:
with open(path.join(save_dir, "agent.pkl"), "rb") as f:
    agent = pickle.load(f)
    f.close()

agent.model = tf.keras.models.load_model(path.join(save_dir, "model.tf"))
agent.target_model = tf.keras.models.load_model(path.join(save_dir, "target_model.tf"))

## To-Dos

- Create on place for hyperparameters for inside the model and pass it on
  - e. g. optimizer, different metrics, checkpointing for the internal model, etc.