In [1]:
import tensorflow as tf
from tensorflow import keras
import gymnasium
from gymnasium.wrappers import AtariPreprocessing
from gymnasium.wrappers import FrameStackObservation, TimeLimit
from collections import deque
import ale_py
import matplotlib.pyplot as plt
import numpy as np


2024-11-12 16:10:08.568680: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-11-12 16:10:10.164476: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/nick/miniconda3/envs/tf_env/lib/python3.9/site-packages/nvidia/cudnn/lib:/home/nick/miniconda3/envs/tf_env/lib/
2024-11-12 16:10:10.164573: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: :/home/nick

In [2]:
env = gymnasium.make("BreakoutNoFrameskip-v4", render_mode="rgb_array")

A.L.E: Arcade Learning Environment (version 0.10.1+unknown)
[Powered by Stella]


In [3]:
class AtariPreprocessingFire(AtariPreprocessing):
    def reset(self, **kwargs):
        obs, reset_info = super().reset(**kwargs)
        super().step(1)
        return obs, reset_info
    def step(self, action):
        self.lives_before_action = self.ale.lives()
        obs, rewards, terminated, truncated, info = super().step(action)
        done = terminated or truncated
        if not done and self.ale.lives() < self.lives_before_action:
            super().step(1)
        return obs, rewards, terminated, truncated, info

In [4]:
env = AtariPreprocessingFire(env)
env = FrameStackObservation(env, stack_size=4)



In [5]:
def plot_observation(obs):
    obs = obs.astype(np.float32)
    img = obs[:3]
    current_frame_delta = np.maximum(obs[3] - obs[:3].mean(axis=0), 0.)
    img[0] += current_frame_delta
    img[2] += current_frame_delta
    img = np.clip(img / 150, 0, 1)
    img = np.transpose(img, (1,2,0))
    plt.imshow(img)

In [6]:
update_period = 4
optimizer = keras.optimizers.RMSprop(learning_rate=2.5e-4, rho=0.95, momentum=0.0, epsilon=1e-5, centered=True)
epsilon_fn = keras.optimizers.schedules.PolynomialDecay(initial_learning_rate=1.0,
                                          decay_steps=250000 // update_period,
                                          end_learning_rate=0.01)
replay_buffer = deque(maxlen=100000)

In [7]:
q_net = keras.models.Sequential([
    keras.layers.Input(shape=env.observation_space.shape),
    keras.layers.Lambda(lambda obs: tf.cast(obs, np.float32) / 255.),
    keras.layers.Conv2D(32, (8,8), strides=4, activation="relu", data_format="channels_first"),
    keras.layers.Conv2D(64, (4,4), strides=2, activation="relu", data_format="channels_first"),
    keras.layers.Conv2D(64, (3,3), strides=1, activation="relu", data_format="channels_first"),
    keras.layers.Flatten(),
    keras.layers.Dense(512, activation="relu"),
    keras.layers.Dense(4)
])

target_net = keras.models.clone_model(q_net)
target_net.set_weights(q_net.get_weights())

In [8]:
def epsilon_greedy_policy(obs, action_space, epsilon):
    if np.random.rand() < epsilon:
        return action_space.sample()
    else:
        q_values = q_net.predict(obs)
        return np.argmax(q_values)    

In [9]:
replay_buffer = deque(maxlen=100000)

class DqnAgent:
    def __init__(self, n_iterations, q_network, **kwargs):
        self.n_train_step = 0
        self.n_iterations = n_iterations
        self.q_network = q_network

        # metrics
        self.episodes = 0
        self.environment_steps = 0
        
    
    def initialization(self, num_steps, replay_buffer):
        """
        Collect the initial experiences, before training
        """
        state, _ = env.reset()
        for _ in range(num_steps):
            action = env.action_space.sample()
            next_state, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            replay_buffer.append((state, action, reward, next_state, done))
            state = next_state
            if done:
                state, _ = env.reset()

    def train_step(self):
        pass
            
    def collect_step(self, replay_buffer, update_period=4):
        state, _ = env.reset()
        epsilon = epsilon_fn(self.train_step)
        
        for step in range(update_period):
            
            action = epsilon_greedy_policy(state, env.action_space, epsilon)
            next_state, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated

            replay_buffer.append((state, action, reward, next_state, done))
            state = next_state
            
            if done:
                state, _ = env.reset()
                self.episodes += 1

        self.train_step()
        self.n_train_step += 1
        epsilon = epsilon_fn(self.train_step)
            

In [10]:
agent = DqnAgent(n_iterations=50000,
                max_episode_steps=27000,
                q_network=q_net,
                )
agent.initialization(20000, replay_buffer)

In [11]:
for iteration in range(n_iterations):
    agent.collect_step(replay_buffer, update_period)

    if iteration % 1000 == 0:
        print(f"NumberOfEpisodes = {agent.episodes} 
        \nEnvironmentSteps = {agent.environment_steps} 
        \nAverageReturn = {} 
        \nAverageEpisodeLength = {}")
    