In [3]:
# !pip install "moviepy" "gymnasium[atari]" "ale-py" "torch" "numpy" "tqdm" "matplotlib" --quiet

import os
import base64
import random
import time
from collections import deque
from dataclasses import dataclass
from typing import Tuple
import copy

import gymnasium as gym
from gymnasium.wrappers import RecordVideo, ResizeObservation, GrayscaleObservation, FrameStackObservation
import ale_py
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from IPython.display import HTML
import matplotlib.pyplot as plt


In [None]:

class CNNDQN(nn.Module):
    def __init__(self, input_shape: Tuple[int, int, int], num_actions: int):
        super().__init__()
        c, h, w = input_shape
        self.conv = nn.Sequential(
            nn.Conv2d(c, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
        )
        with torch.no_grad():
            dummy = torch.zeros(1, c, h, w)
            n_flatten = self.conv(dummy).view(1, -1).size(1)
        self.fc = nn.Sequential(
            nn.Linear(n_flatten, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x / 255.0
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


class ReplayBuffer:
    def __init__(self, capacity: int, state_shape: Tuple[int, ...]):
        self.capacity = int(capacity)
        self.state_shape = state_shape
        self.states = np.zeros((capacity, *state_shape), dtype=np.uint8)
        self.next_states = np.zeros((capacity, *state_shape), dtype=np.uint8)
        self.actions = np.zeros((capacity,), dtype=np.int64)
        self.rewards = np.zeros((capacity,), dtype=np.float32)
        self.dones = np.zeros((capacity,), dtype=np.bool_)
        self._pos = 0
        self._full = False

    def __len__(self) -> int:
        return self.capacity if self._full else self._pos

    def push(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool) -> None:
        idx = self._pos
        self.states[idx] = state
        self.next_states[idx] = next_state
        self.actions[idx] = action
        self.rewards[idx] = reward
        self.dones[idx] = done
        self._pos = (self._pos + 1) % self.capacity
        if self._pos == 0:
            self._full = True

    def sample(self, batch_size: int):
        if len(self) < batch_size:
            raise ValueError(f"Not enough elements in the buffer to sample: {len(self)} < {batch_size}")
        max_idx = len(self)
        indices = np.random.randint(0, max_idx, size=batch_size)
        batch_states = self.states[indices]
        batch_actions = self.actions[indices]
        batch_rewards = self.rewards[indices]
        batch_next_states = self.next_states[indices]
        batch_dones = self.dones[indices]
        return batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones


@dataclass
class EpsilonConfig:
    start: float
    end: float
    decay_frames: int


class DQNAgent:
    def __init__(
        self,
        model: nn.Module,
        num_actions: int,
        gamma: float,
        learning_rate: float,
        epsilon_config: EpsilonConfig,
        device: torch.device,
        target_update_freq: int = 5000,
    ):
        self.model = model
        self.target_model = copy.deepcopy(model)
        self.num_actions = num_actions
        self.gamma = gamma
        self.device = device
        self.epsilon_config = epsilon_config
        self.optimizer = optim.Adam(
            self.model.parameters(),
            lr=learning_rate,
        )
        self.loss_fn = nn.SmoothL1Loss()
        self.target_update_freq = target_update_freq
        self.update_steps = 0

    def select_action(self, state: np.ndarray, frame_idx: int) -> int:
        epsilon = self.epsilon_by_frame(frame_idx)
        if np.random.rand() < epsilon:
            return np.random.randint(0, self.num_actions)
        state_tensor = torch.from_numpy(state).unsqueeze(0).to(self.device).float()
        with torch.no_grad():
            q_values = self.model(state_tensor)
        action = int(q_values.argmax(dim=1).item())
        return action

    def update(self, replay_buffer: ReplayBuffer, batch_size: int) -> float | None:
        if len(replay_buffer) < batch_size:
            return None

        batch_states, batch_actions, batch_rewards, batch_next_states, batch_dones = replay_buffer.sample(
            batch_size
        )

        states = torch.from_numpy(batch_states).to(self.device).float()
        actions = torch.from_numpy(batch_actions).long().to(self.device)
        rewards = torch.from_numpy(batch_rewards).to(self.device)
        next_states = torch.from_numpy(batch_next_states).to(self.device).float()
        dones = torch.from_numpy(batch_dones.astype(np.float32)).to(self.device)

        q_values = self.model(states)
        actions = actions.unsqueeze(1)
        q_values = q_values.gather(1, actions).squeeze(1)

        with torch.no_grad():
            next_q_online = self.model(next_states)
            greedy_actions = next_q_online.argmax(dim=1)
            next_q_target = self.target_model(next_states)
            max_next_q_values = next_q_target.gather(1, greedy_actions.unsqueeze(1)).squeeze(1)
            targets = rewards + self.gamma * max_next_q_values * (1.0 - dones)

        loss = self.loss_fn(q_values, targets)

        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=10.0)
        self.optimizer.step()

        self.update_steps += 1
        if self.update_steps % self.target_update_freq == 0:
            self.target_model.load_state_dict(self.model.state_dict())

        return float(loss.item())

In [None]:

def make_atari_env(env_id: str, num_stack: int = 4, render_mode: str | None = None) -> gym.Env:
    env = gym.make(
        env_id,
        render_mode=render_mode,
        repeat_action_probability=0.0,
    )
    env = ResizeObservation(env, (84, 84))
    env = GrayscaleObservation(env)
    env = FrameStackObservation(env, num_stack)
    return env


def preprocess_obs(obs: np.ndarray, num_stack: int) -> np.ndarray:
    arr = np.array(obs, copy=False)
    if arr.ndim == 3 and arr.shape[-1] == num_stack:
        arr = np.transpose(arr, (2, 0, 1))
    return arr.astype(np.uint8)


def train_dqn(
    env_name: str = "ALE/Breakout-v5",
    num_stack: int = 4,
    total_frames: int = 5_000_000,
    batch_size: int = 32,
    gamma: float = 0.99,
    replay_buffer_size: int = 250_000,
    learning_rate: float = 5e-5,
    epsilon_start: float = 1.0,
    epsilon_end: float = 0.01,
    epsilon_decay_frames: int = 500_000,
    learning_starts: int = 50_000,
    train_freq: int = 4,
    print_freq: int = 50_000,
):
    env = make_atari_env(
        env_id=env_name,
        num_stack=num_stack,
        render_mode=None,
    )

    obs, info = env.reset()
    obs_proc = preprocess_obs(obs, num_stack)
    state_shape = obs_proc.shape

    replay_buffer = ReplayBuffer(
        capacity=replay_buffer_size,
        state_shape=state_shape,
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_actions = env.action_space.n
    model = CNNDQN(input_shape=state_shape, num_actions=num_actions).to(device)
    eps_cfg = EpsilonConfig(
        start=epsilon_start,
        end=epsilon_end,
        decay_frames=epsilon_decay_frames,
    )
    agent = DQNAgent(
        model=model,
        num_actions=num_actions,
        gamma=gamma,
        learning_rate=learning_rate,
        epsilon_config=eps_cfg,
        device=device,
    )

    frame_idx = 0
    episode_idx = 0
    episode_rewards = []
    episode_end_frames = []
    recent_rewards = deque(maxlen=10)
    recent_losses = deque(maxlen=100)
    start_time = time.time()

    while frame_idx < total_frames:
        terminated = False
        episode_reward = 0.0
        obs, info = env.reset()
        obs_proc = preprocess_obs(obs, num_stack)

        for _ in range(random.randint(1, 30)):
            obs, _, terminated_warm, truncated_warm, info = env.step(1)
            if terminated_warm or truncated_warm:
                obs, info = env.reset()
            obs_proc = preprocess_obs(obs, num_stack)

        while not terminated and frame_idx < total_frames:
            current_lives = info.get("lives", 0)

            action = agent.select_action(obs_proc, frame_idx)
            next_obs, reward, terminated_step, truncated_step, info = env.step(action)
            terminated = terminated_step or truncated_step
            next_obs_proc = preprocess_obs(next_obs, num_stack)

            next_lives = info.get("lives", current_lives)
            life_lost = next_lives < current_lives

            clipped_reward = float(np.sign(reward))

            replay_buffer.push(
                state=obs_proc,
                action=action,
                reward=clipped_reward,
                next_state=next_obs_proc,
                done=life_lost,
            )

            obs_proc = next_obs_proc
            episode_reward += reward
            frame_idx += 1

            if frame_idx >= learning_starts and frame_idx % train_freq == 0:
                loss_value = agent.update(replay_buffer, batch_size=batch_size)
                if loss_value is not None:
                    recent_losses.append(loss_value)

            if terminated:
                episode_idx += 1
                episode_rewards.append(episode_reward)
                episode_end_frames.append(frame_idx)
                recent_rewards.append(episode_reward)
                avg_reward = sum(recent_rewards) / len(recent_rewards) if recent_rewards else 0.0
                elapsed = time.time() - start_time
                print(
                    f"Frame: {frame_idx}/{total_frames} | "
                    f"Episode: {episode_idx} | "
                    f"Reward: {episode_reward:.2f} | "
                    f"AvgReward(10): {avg_reward:.2f} | "
                    f"Elapsed: {elapsed/60:.1f} min"
                )

        if frame_idx % print_freq == 0 and frame_idx > 0:
            avg_reward = sum(recent_rewards) / len(recent_rewards) if recent_rewards else 0.0
            elapsed = time.time() - start_time
            print(
                f"[Progress] Frame: {frame_idx}/{total_frames} | "
                f"AvgReward(10): {avg_reward:.2f} | "
                f"Elapsed: {elapsed/60:.1f} min"
            )

    env.close()
    torch.save(model.state_dict(), "dqn_breakout_colab.pt")
    print("Training finished. Model saved to dqn_breakout_colab.pt")

    return model, episode_rewards, episode_end_frames

In [None]:
model, rewards, frames = train_dqn()

plt.figure(figsize=(8, 4))
plt.plot(frames, rewards)
plt.xlabel("Frames")
plt.ylabel("Reward par épisode")
plt.title("Évolution de la reward par épisode en fonction des frames")
plt.grid(True)
plt.show()


In [21]:
def record_dqn_video(
    checkpoint_path: str = "dqn_best.pt",
    env_name: str = "ALE/Breakout-v5",
    num_stack: int = 4,
    video_folder: str = "videos",
    name_prefix: str = "dqn_breakout",
    num_episodes: int = 1,
    max_steps_per_episode: int = 5000,
):
    video_folder = os.path.abspath(video_folder)
    os.makedirs(video_folder, exist_ok=True)

    env = make_atari_env(
        env_id=env_name,
        num_stack=num_stack,
        render_mode="rgb_array",
    )

    env = RecordVideo(
        env,
        video_folder=video_folder,
        name_prefix=name_prefix,
        episode_trigger=lambda ep_id: True,
    )

    obs, info = env.reset()
    obs_proc = preprocess_obs(obs, num_stack)
    state_shape = obs_proc.shape
    num_actions = env.action_space.n

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CNNDQN(input_shape=state_shape, num_actions=num_actions).to(device)
    state_dict = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(state_dict)
    model.eval()

    fire_action = 1
    epsilon_eval = 0.05

    for ep in range(num_episodes):
        obs, info = env.reset()
        for _ in range(random.randint(1, 30)):
            obs, _, terminated_warm, truncated_warm, info = env.step(fire_action)
            if terminated_warm or truncated_warm:
                obs, info = env.reset()
        obs_proc = preprocess_obs(obs, num_stack)

        done = False
        ep_reward = 0.0
        steps = 0

        while not done and steps < max_steps_per_episode:
            if np.random.rand() < epsilon_eval:
                action = np.random.randint(0, num_actions)
            else:
                state_tensor = torch.from_numpy(obs_proc).unsqueeze(0).to(device).float()
                with torch.no_grad():
                    q_values = model(state_tensor)
                action = int(q_values.argmax(dim=1).item())

            obs, reward, terminated, truncated, info = env.step(action)
            obs_proc = preprocess_obs(obs, num_stack)
            done = terminated or truncated
            ep_reward += reward
            steps += 1

        print(f"Episode {ep + 1}/{num_episodes} Reward: {ep_reward}, Steps: {steps}")

    env.close()

    mp4_files = [f for f in os.listdir(video_folder) if f.endswith(".mp4")]
    if not mp4_files:
        print("Aucune vidéo trouvée.")
        video_html = None
    else:
        mp4_files.sort()
        video_path = os.path.join(video_folder, mp4_files[-1])
        print("Vidéo enregistrée dans :", video_path)
        with open(video_path, "rb") as f:
            mp4 = f.read()
        data_url = "data:video/mp4;base64," + base64.b64encode(mp4).decode()
        video_html = HTML(f'<video width="480" height="360" controls><source src="{data_url}" type="video/mp4"></video>')

    return video_html

# Je renomme à la main le modèle avant de lancer cette fonction
# mon meilleur modèle est celui avec 2M de frames
video_html = record_dqn_video(
    checkpoint_path="dqn_2M.pt",
    num_episodes=50,
    max_steps_per_episode=3000,
)

video_html

Episode 1/50 Reward: 289.0, Steps: 1872
Episode 2/50 Reward: 181.0, Steps: 1372
Episode 3/50 Reward: 72.0, Steps: 1247
Episode 4/50 Reward: 245.0, Steps: 1823
Episode 5/50 Reward: 68.0, Steps: 1198
Episode 6/50 Reward: 66.0, Steps: 947
Episode 7/50 Reward: 295.0, Steps: 1572
Episode 8/50 Reward: 92.0, Steps: 1297
Episode 9/50 Reward: 112.0, Steps: 1322
Episode 10/50 Reward: 117.0, Steps: 1580
Episode 11/50 Reward: 67.0, Steps: 1050
Episode 12/50 Reward: 89.0, Steps: 1285
Episode 13/50 Reward: 209.0, Steps: 1661
Episode 14/50 Reward: 101.0, Steps: 1398
Episode 15/50 Reward: 87.0, Steps: 1392
Episode 16/50 Reward: 29.0, Steps: 956
Episode 17/50 Reward: 100.0, Steps: 1184
Episode 18/50 Reward: 117.0, Steps: 1331
Episode 19/50 Reward: 135.0, Steps: 1589
Episode 20/50 Reward: 73.0, Steps: 1247
Episode 21/50 Reward: 91.0, Steps: 1481
Episode 22/50 Reward: 82.0, Steps: 1200
Episode 23/50 Reward: 79.0, Steps: 1464
Episode 24/50 Reward: 105.0, Steps: 1383
Episode 25/50 Reward: 100.0, Steps: 138