In [None]:
!pip install keras-rl2
!pip install gym[atari]
!pip install atari-py
!pip install tensorflow==2.11.0

In [None]:
import os
import tensorflow as tf

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!python -m atari_py.import_roms /content/drive/MyDrive/dqn/roms/

In [None]:
import gym
from gym.envs.registration import register
from rl.callbacks import FileLogger, ModelIntervalCheckpoint

In [None]:
register(
    id='Breakout-v4',
    entry_point='gym.envs.atari:AtariEnv',
    kwargs={'game': 'breakout', 'obs_type': 'image', 'frameskip': 1},
    max_episode_steps=10000,
    nondeterministic=False,
)

In [None]:
#!/usr/bin/env python3
"""
Imports
"""
import tensorflow as tf
import gym
import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Flatten
from keras.optimizers import Adam
from rl.agents import DQNAgent
from rl.memory import SequentialMemory
from rl.policy import EpsGreedyQPolicy
from keras.callbacks import ModelCheckpoint

# Breakout Environment
env = gym.make('Breakout-v4')
np.random.seed(123)
nb_actions = env.action_space.n

# NN Model
model = Sequential()
model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
model.add(Dense(64, activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(nb_actions, activation='linear'))

# Print Model Summary For Visuals
print(model.summary())

# Compile
memory = SequentialMemory(limit=1000000, window_length=1)
policy = EpsGreedyQPolicy()
dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=1000,
               target_model_update=1e-2, policy=policy)
dqn.compile(Adam(lr=1e-3), metrics=['mae'])

In [None]:
#!/usr/bin/env python3
"""
Imports
"""
from keras.callbacks import Callback


weights = None

class CustomModelCheckpoint(Callback):
    """
    Custom Model Checkpoint: callbacks

    Created to handle error happening with this line:

    dqn.fit(env, callbacks=callbacks_list, nb_steps=50000, visualize=False, verbose=2)

    AttributeError: 'DQNAgent' object has no attribute 'distribute_strategy'

    Saves the models weights without having to use distribute_strategy
    """
    def __init__(self, filepath, monitor='val_loss', save_best_only=False):
        super(CustomModelCheckpoint, self).__init__()
        self.filepath = filepath
        self.monitor = monitor
        self.save_best_only = save_best_only
        self.best_value = None

    def on_epoch_end(self, epoch, logs=None):
        current_value = logs.get(self.monitor)
        if current_value is None:
            return
        if self.best_value is None or current_value > self.best_value:
            self.best_value = current_value
            self.model.save_weights(self.filepath, overwrite=True)
            print(f"Saved model weights at {self.filepath} - {self.monitor}: {current_value}")

# Usage:
custom_checkpoint = CustomModelCheckpoint(filepath='/content/drive/MyDrive/policy.h5', monitor='episode_reward', save_best_only=True)
callbacks_list = [custom_checkpoint]

dqn.fit(env, callbacks=callbacks_list, nb_steps=50000, visualize=False, verbose=2)

# Save the final policy network
dqn.save_weights('policy_final.h5', overwrite=True)

In [None]:
weights_filename = f'/content/drive/MyDrive/policy.h5'
if weights:
    weights_filename = weights
dqn.load_weights(weights_filename)
dqn.test(env, nb_episodes=10, visualize=True)