In [None]:
!pip install gymnasium[atari]
!pip install gymnasium[accept-rom-license]
!pip install ale-py
!pip install shimmy
!pip install tensorflow

Collecting shimmy
  Downloading Shimmy-2.0.0-py3-none-any.whl.metadata (3.5 kB)
Downloading Shimmy-2.0.0-py3-none-any.whl (30 kB)
Installing collected packages: shimmy
Successfully installed shimmy-2.0.0


In [None]:
# 1_DQN_Baseline.py
# Standard Deep Q-Network (Nature 2015) optimized for Demo Run

import os
import argparse
import time
import random
import numpy as np
import tensorflow as tf
import gymnasium as gym
from collections import deque
import cv2
import ale_py

gym.register_envs(ale_py)

# --- DEMO SETTINGS ---
env_id = 'PongNoFrameskip-v4'
seed = 42
lr = 0.0001
buffer_size = 50000
batch_size = 32
warm_start = 500
train_freq = 4
target_q_update_freq = 200
reward_gamma = 0.99
number_timesteps = 2000   # Short demo run
clipnorm = 10.0
epsilon_start = 1.0
epsilon_end = 0.01
epsilon_decay_steps = 2000

class FireResetEnv(gym.Wrapper):
    def __init__(self, env):
        super().__init__(env)
        assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
        assert len(env.unwrapped.get_action_meanings()) >= 3

    def reset(self, **kwargs):
        self.env.reset(**kwargs)
        obs, _, terminated, truncated, _ = self.env.step(1)
        if terminated or truncated: self.env.reset(**kwargs)
        obs, _, terminated, truncated, _ = self.env.step(2)
        if terminated or truncated: self.env.reset(**kwargs)
        return obs, {}

class ProcessFrame84(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
    def observation(self, obs):
        return ProcessFrame84.process(obs)
    @staticmethod
    def process(frame):
        if frame.size == 210 * 160 * 3: img = np.reshape(frame, [210, 160, 3]).astype(np.float32)
        elif frame.size == 250 * 160 * 3: img = np.reshape(frame, [250, 160, 3]).astype(np.float32)
        else: return frame
        img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
        resized_screen = cv2.resize(img, (84, 84), interpolation=cv2.INTER_AREA)
        x_t = np.reshape(resized_screen, [84, 84, 1])
        return x_t.astype(np.uint8)

class FrameStack(gym.Wrapper):
    def __init__(self, env, k):
        super().__init__(env)
        self.k = k
        self.frames = deque([], maxlen=k)
        shp = env.observation_space.shape
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(shp[0], shp[1], shp[2] * k), dtype=env.observation_space.dtype)
    def reset(self, **kwargs):
        ob, info = self.env.reset(**kwargs)
        for _ in range(self.k): self.frames.append(ob)
        return self._get_ob(), info
    def step(self, action):
        ob, reward, terminated, truncated, info = self.env.step(action)
        self.frames.append(ob)
        return self._get_ob(), reward, terminated, truncated, info
    def _get_ob(self):
        assert len(self.frames) == self.k
        return np.concatenate(self.frames, axis=2)

def build_env(env_id, seed=0):
    env = gym.make(env_id, render_mode='rgb_array')
    env = gym.wrappers.RecordEpisodeStatistics(env)
    env = FireResetEnv(env)
    env = ProcessFrame84(env)
    env = FrameStack(env, 4)
    env.action_space.seed(seed)
    return env

class ReplayBuffer:
    def __init__(self, size):
        self.buffer = deque(maxlen=size)
    def add(self, obs, act, rew, next_obs, done):
        obs = np.array(obs, dtype=np.uint8)
        next_obs = np.array(next_obs, dtype=np.uint8)
        self.buffer.append((obs, act, rew, next_obs, done))
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        obs, act, rew, next_obs, done = zip(*batch)
        return (np.array(obs), np.array(act), np.array(rew, dtype=np.float32), np.array(next_obs), np.array(done, dtype=np.float32))

def sync(model, target_model):
    target_model.set_weights(model.get_weights())
def huber_loss(x):
    return tf.keras.losses.Huber()(tf.zeros_like(x), x)
def epsilon(step):
    if step > epsilon_decay_steps: return epsilon_end
    else: return epsilon_start - (epsilon_start - epsilon_end) * (step / epsilon_decay_steps)

random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)

# Network Architecture
class QFunc(tf.keras.Model):
    def __init__(self, name, action_dim):
        super(QFunc, self).__init__(name=name)
        self.conv1 = tf.keras.layers.Conv2D(32, (8, 8), strides=(4, 4), padding='valid', activation='relu')
        self.conv2 = tf.keras.layers.Conv2D(64, (4, 4), strides=(2, 2), padding='valid', activation='relu')
        self.conv3 = tf.keras.layers.Conv2D(64, (3, 3), strides=(1, 1), padding='valid', activation='relu')
        self.flat = tf.keras.layers.Flatten()
        self.fc1 = tf.keras.layers.Dense(512, activation='relu')
        self.fc2 = tf.keras.layers.Dense(action_dim, activation='linear')

    @tf.function
    def call(self, pixels, **kwargs):
        pixels = tf.divide(tf.cast(pixels, tf.float32), tf.constant(255.0))
        if len(pixels.shape) == 4 and pixels.shape[1] == 4: pixels = tf.transpose(pixels, perm=[0, 2, 3, 1])
        feature = self.flat(self.conv3(self.conv2(self.conv1(pixels))))
        return self.fc2(self.fc1(feature))

class DQN(object):
    def __init__(self, action_dim):
        self.action_dim = action_dim
        self.qnet = QFunc('q', action_dim)
        self.targetqnet = QFunc('targetq', action_dim)
        dummy_obs = tf.zeros((1, 84, 84, 4))
        self.qnet(dummy_obs); self.targetqnet(dummy_obs)
        sync(self.qnet, self.targetqnet)
        self.niter = 0
        self.optimizer = tf.optimizers.Adam(learning_rate=lr, epsilon=1e-5, clipnorm=clipnorm)

    def get_action(self, obv):
        if random.random() < epsilon(self.niter): return int(random.random() * self.action_dim)
        else:
            obv = np.expand_dims(obv, 0).astype('float32')
            return self._qvalues_func(obv).numpy().argmax(1)[0]

    @tf.function
    def _qvalues_func(self, obv): return self.qnet(obv)

    def train(self, b_o, b_a, b_r, b_o_, b_d):
        self._train_func(b_o, b_a, b_r, b_o_, b_d)
        self.niter += 1
        if self.niter % target_q_update_freq == 0: sync(self.qnet, self.targetqnet)

    @tf.function
    def _train_func(self, b_o, b_a, b_r, b_o_, b_d):
        with tf.GradientTape() as tape:
            b_a_ = tf.one_hot(tf.argmax(self.qnet(b_o_), 1), self.action_dim)
            # Vanilla DQN difference: Max over target network directly
            b_q_next = tf.reduce_max(self.targetqnet(b_o_), axis=1)
            target_q = b_r + (1 - b_d) * reward_gamma * b_q_next

            b_q = tf.reduce_sum(self.qnet(b_o) * tf.one_hot(b_a, self.action_dim), 1)
            loss = tf.reduce_mean(huber_loss(target_q - b_q))

        grad = tape.gradient(loss, self.qnet.trainable_weights)
        self.optimizer.apply_gradients(zip(grad, self.qnet.trainable_weights))
        return loss

if __name__ == '__main__':
    print(f"Creating environment {env_id}...")
    try: env = build_env(env_id, seed=seed)
    except: env = build_env('Pong-v4', seed=seed)

    dqn = DQN(env.action_space.n)
    buffer = ReplayBuffer(buffer_size)
    o, _ = env.reset()
    start_time = time.time()

    print("Starting DQN Baseline Training...")
    for i in range(1, number_timesteps + 1):
        a = dqn.get_action(o)
        o_, r, terminated, truncated, info = env.step(a)
        done = terminated or truncated
        buffer.add(o, a, r, o_, done)

        if i % 100 == 0:
             print(f"Step: {i} / {number_timesteps} - Epsilon: {epsilon(i):.3f}")

        if i >= warm_start and i % train_freq == 0:
            transitions = buffer.sample(batch_size)
            dqn.train(*transitions)

        if done:
            o, _ = env.reset()
        else:
            o = o_

    print("DQN Baseline Finished.")

Creating environment PongNoFrameskip-v4...
Starting DQN Baseline Training...
Step: 100 / 2000 - Epsilon: 0.951
Step: 200 / 2000 - Epsilon: 0.901
Step: 300 / 2000 - Epsilon: 0.852
Step: 400 / 2000 - Epsilon: 0.802
Step: 500 / 2000 - Epsilon: 0.752
Step: 600 / 2000 - Epsilon: 0.703
Step: 700 / 2000 - Epsilon: 0.653
Step: 800 / 2000 - Epsilon: 0.604
Step: 900 / 2000 - Epsilon: 0.554
Step: 1000 / 2000 - Epsilon: 0.505
Step: 1100 / 2000 - Epsilon: 0.456
Step: 1200 / 2000 - Epsilon: 0.406
Step: 1300 / 2000 - Epsilon: 0.357
Step: 1400 / 2000 - Epsilon: 0.307
Step: 1500 / 2000 - Epsilon: 0.258
Step: 1600 / 2000 - Epsilon: 0.208
Step: 1700 / 2000 - Epsilon: 0.158
Step: 1800 / 2000 - Epsilon: 0.109
Step: 1900 / 2000 - Epsilon: 0.059
Step: 2000 / 2000 - Epsilon: 0.010
DQN Baseline Finished.
