<a href="https://colab.research.google.com/github/ayannaavalos/166/blob/main/DQN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q "gymnasium[atari]" autorom stable-baselines3 torch tensorboard imageio imageio-ffmpeg
!AutoROM --accept-license

AutoROM will download the Atari 2600 ROMs.
They will be installed to:
	/usr/local/lib/python3.12/dist-packages/AutoROM/roms

Existing ROMs will be overwritten.


In [None]:
import os, time, glob, collections, typing as tt
from dataclasses import dataclass
from datetime import datetime

import numpy as np
import gymnasium as gym
import ale_py

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.tensorboard.writer import SummaryWriter
from torch.utils.data import Dataset, DataLoader

from gymnasium import spaces
from stable_baselines3.common import atari_wrappers
from IPython.display import Video, display

In [None]:
# Base Configuration
DEFAULT_ENV_NAME = "ALE/SpaceInvaders-v5"
MEAN_REWARD_BOUND = 10_000

GAMMA = 0.99
BATCH_SIZE = 32
REPLAY_SIZE = 50_000
LEARNING_RATE = 1e-4
SYNC_TARGET_FRAMES = 1_000
REPLAY_START_SIZE = 2_000

SAVE_EPSILON = 0.5  # Only save if at least this much better
EPSILON_DECAY_LAST_FRAME = 50_000
EPSILON_START = 1.0
EPSILON_FINAL = 0.05


CLIP_REWARD = True

EARLY_EPISODES_TO_RECORD = 1

LATE_THRESHOLD = -19.0
LATE_AFTER_FRAMES = 12_000
LATE_EPISODES_TO_RECORD = 1

In [None]:
class ImageToPyTorch(gym.ObservationWrapper):
    """(H, W, C) -> (C, H, W) for PyTorch convs."""
    def __init__(self, env):
        super().__init__(env)
        obs = self.observation_space
        assert isinstance(obs, gym.spaces.Box) and len(obs.shape) == 3
        new_shape = (obs.shape[-1], obs.shape[0], obs.shape[1])
        self.observation_space = gym.spaces.Box(
            low=obs.low.min(), high=obs.high.max(),
            shape=new_shape, dtype=obs.dtype
        )
    def observation(self, observation):
        return np.moveaxis(observation, 2, 0)


class BufferWrapper(gym.ObservationWrapper):
    """Stack the last n_steps frames along channel axis."""
    def __init__(self, env, n_steps):
        super().__init__(env)
        obs = env.observation_space
        assert isinstance(obs, spaces.Box)
        self.observation_space = gym.spaces.Box(
            obs.low.repeat(n_steps, axis=0),
            obs.high.repeat(n_steps, axis=0),
            dtype=obs.dtype
        )
        self.buffer = collections.deque(maxlen=n_steps)

    def reset(self, *, seed: tt.Optional[int] = None, options: tt.Optional[dict] = None):
        for _ in range(self.buffer.maxlen):
            self.buffer.append(np.zeros_like(self.env.observation_space.low))
        obs, info = self.env.reset(seed=seed, options=options)
        return self.observation(obs), info

    def observation(self, observation: np.ndarray) -> np.ndarray:
        self.buffer.append(observation)
        return np.concatenate(self.buffer)


def make_env(env_name: str, n_steps=4, render_mode=None, **kwargs):
    print(f"Creating environment {env_name}")
    env = gym.make(env_name, render_mode=render_mode, **kwargs)
    env = atari_wrappers.AtariWrapper(
        env,
        clip_reward=bool(globals().get("CLIP_REWARD", False)),
        noop_max=30
    )
    env = ImageToPyTorch(env)
    env = BufferWrapper(env, n_steps=n_steps)
    return env

In [None]:
class DQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
        )
        size = self.conv(torch.zeros(1, *input_shape)).size(-1)
        self.fc = nn.Sequential(
            nn.Linear(size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions),
        )

    def forward(self, x: torch.ByteTensor):
        x = x.float() / 255.0
        return self.fc(self.conv(x))


In [None]:
State = np.ndarray
Action = int

@dataclass
class Experience:
    state: State
    action: Action
    reward: float
    done_trunc: bool
    new_state: State


class ExperienceBuffer:
    def __init__(self, capacity: int):
        self.buffer = collections.deque(maxlen=capacity)
    def __len__(self): return len(self.buffer)
    def append(self, experience: Experience): self.buffer.append(experience)
    def sample(self, batch_size: int) -> tt.List[Experience]:
        idxs = np.random.choice(len(self), batch_size, replace=False)
        return [self.buffer[i] for i in idxs]


class Agent:
    def __init__(self, env: gym.Env, exp_buffer: ExperienceBuffer):
        self.env = env
        self.exp_buffer = exp_buffer
        self.state, _ = self.env.reset()
        self.total_reward = 0.0

    @torch.no_grad()
    def play_step(self, net: DQN, device: torch.device, epsilon: float = 0.0) -> tt.Optional[float]:
        if np.random.random() < epsilon:
            action = self.env.action_space.sample()
        else:
            state_v = torch.as_tensor(self.state, device=device).unsqueeze(0)
            q_vals_v = net(state_v)
            _, act_v = torch.max(q_vals_v, dim=1)
            action = int(act_v.item())

        new_state, reward, is_done, is_tr, _ = self.env.step(action)
        self.total_reward += reward

        self.exp_buffer.append(Experience(
            state=self.state, action=action, reward=float(reward),
            done_trunc=(is_done or is_tr), new_state=new_state
        ))
        self.state = new_state

        if is_done or is_tr:
            done_reward = self.total_reward
            self.state, _ = self.env.reset()
            self.total_reward = 0.0
            return done_reward
        return None

In [None]:
def batch_to_tensors(batch: tt.List[Experience], device: torch.device):
    states, actions, rewards, dones, new_states = [], [], [], [], []
    for e in batch:
        states.append(e.state)
        actions.append(e.action)
        rewards.append(e.reward)
        dones.append(e.done_trunc)
        new_states.append(e.new_state)

    states_t = torch.as_tensor(np.asarray(states), device=device)
    actions_t = torch.as_tensor(actions, dtype=torch.long, device=device)
    rewards_t = torch.as_tensor(rewards, dtype=torch.float32, device=device)
    dones_t = torch.as_tensor(dones, dtype=torch.bool, device=device)
    new_states_t = torch.as_tensor(np.asarray(new_states), device=device)
    return states_t, actions_t, rewards_t, dones_t, new_states_t


def calc_loss(batch: tt.List[Experience], net: DQN, tgt_net: DQN, device: torch.device) -> torch.Tensor:
    states_t, actions_t, rewards_t, dones_t, new_states_t = batch_to_tensors(batch, device)
    state_action_values = net(states_t).gather(1, actions_t.unsqueeze(-1)).squeeze(-1)
    with torch.no_grad():
        next_state_values = tgt_net(new_states_t).max(1)[0]
        next_state_values[dones_t] = 0.0
    expected_state_action_values = rewards_t + GAMMA * next_state_values
    return nn.MSELoss()(state_action_values, expected_state_action_values)

In [None]:
USE_GOOGLE_DRIVE = True

save_dir_drive = "/content/drive/MyDrive/PUBLIC/Models"
save_dir_local = "saved_models"
env_name = DEFAULT_ENV_NAME
safe_env_name = env_name.replace("/", "_")

if USE_GOOGLE_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')
    os.makedirs(save_dir_drive, exist_ok=True)

os.makedirs(save_dir_local, exist_ok=True)

print("Saving to:")
if USE_GOOGLE_DRIVE:
    print(" - Google Drive:", save_dir_drive)
print(" - Local       :", save_dir_local)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Saving to:
 - Google Drive: /content/drive/MyDrive/PUBLIC/Models
 - Local       : saved_models


In [None]:
def make_recording_env(env_name: str, video_folder: str, name_prefix: str, n_steps: int = 4):
    os.makedirs(video_folder, exist_ok=True)
    base = gym.make(env_name, render_mode="rgb_array")
    wrapped = atari_wrappers.AtariWrapper(
        base,
        clip_reward=bool(globals().get("CLIP_REWARD", False)),
        noop_max=30,
        terminal_on_life_loss=True
    )
    wrapped = ImageToPyTorch(wrapped)
    wrapped = BufferWrapper(wrapped, n_steps=n_steps)
    rec = gym.wrappers.RecordVideo(
        wrapped,
        video_folder=video_folder,
        name_prefix=name_prefix,
        episode_trigger=lambda ep_idx: True,
    )
    return rec

@torch.no_grad()
def record_episodes_with(
    net: DQN,
    env_name: str,
    video_folder: str,
    name_prefix: str,
    episodes: int = 1,
    epsilon: float = 0.05,
    max_steps: int | None = None,   # NEW: hard cap per episode (e.g., 1200 frames)
):
    env = make_recording_env(env_name, video_folder, name_prefix)
    device = next(net.parameters()).device
    vids = []
    for ep in range(episodes):
        obs, _ = env.reset()
        done = trunc = False
        ep_ret = 0.0
        steps = 0
        while not (done or trunc):
            # ε-greedy policy (ε=1.0 -> random)
            if np.random.random() < epsilon:
                action = env.action_space.sample()
            else:
                q = net(torch.as_tensor(obs, device=device).unsqueeze(0))
                action = int(torch.argmax(q, dim=1).item())

            obs, r, done, trunc, _ = env.step(action)
            ep_ret += float(r)
            steps += 1

            if (max_steps is not None) and (steps >= max_steps):   # NEW: short clips
                break

        print(f"[{name_prefix}] episode={ep} return={ep_ret:.1f} steps={steps}")

    env.close()
    vids = sorted(glob.glob(os.path.join(video_folder, f"{name_prefix}*.mp4")))
    return vids

def record_late_clip_quick(net):
    return record_episodes_with(
        net=net,
        env_name=DEFAULT_ENV_NAME,
        video_folder="videos_late_quick",
        name_prefix="late_quick",
        episodes=LATE_EPISODES_TO_RECORD,   # usually 3
        epsilon=0.05                         # almost greedy
    )

def show_videos(video_paths):
    for vp in video_paths:
        print("📹", vp)
        display(Video(vp, embed=True, html_attributes="controls"))


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_comment = f"test_epsdec{EPSILON_DECAY_LAST_FRAME}_rs{REPLAY_START_SIZE}_sync{SYNC_TARGET_FRAMES}"

env = make_env(env_name)
net = DQN(env.observation_space.shape, env.action_space.n).to(device)
tgt_net = DQN(env.observation_space.shape, env.action_space.n).to(device)
writer = SummaryWriter(comment=f"-{env_name}-{model_comment}")
print(net)

buffer = ExperienceBuffer(REPLAY_SIZE)
agent = Agent(env, buffer)
epsilon = EPSILON_START

optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE)
total_rewards = []
frame_idx = 0
ts_frame = 0
ts = time.time()
best_m_reward = None
start_time = time.time()

# Flags to auto-capture videos
captured_early = False
captured_late  = False
LATE_THRESHOLD = 0.0  # set to -5.0  to record sooner

while True:
    frame_idx += 1
    epsilon = max(EPSILON_FINAL, EPSILON_START - frame_idx / EPSILON_DECAY_LAST_FRAME)

    reward = agent.play_step(net, device, epsilon)
    if reward is not None:
        total_rewards.append(reward)
        speed = (frame_idx - ts_frame) / (time.time() - ts + 1e-8)
        elapsed = time.time() - start_time  # seconds
        ts_frame = frame_idx
        ts = time.time()
        m_reward = float(np.mean(total_rewards[-100:]))

        print(f"{frame_idx}: done {len(total_rewards)} games, reward {m_reward:.3f}, eps {epsilon:.2f}, speed {speed:.2f} f/s, time {elapsed/60:.1f} min")

        writer.add_scalar("epsilon", epsilon, frame_idx)
        writer.add_scalar("speed", speed, frame_idx)
        writer.add_scalar("reward_100", m_reward, frame_idx)
        writer.add_scalar("reward", reward, frame_idx)

        if best_m_reward is None or m_reward > best_m_reward + SAVE_EPSILON:
            timestamp = datetime.now().strftime("%Y%m%d-%H%M")
            model_filename = f"{safe_env_name}-best_{int(m_reward)}-{timestamp}-{model_comment}.dat"
            model_path_local = os.path.join(save_dir_local, model_filename)
            torch.save(net.state_dict(), model_path_local)

            if USE_GOOGLE_DRIVE:
                model_path_drive = os.path.join(save_dir_drive, model_filename)
                torch.save(net.state_dict(), model_path_drive)
                print("💾 Model saved to:")
                print(" - Google Drive:", model_path_drive)
                print(" - Local:        ", model_path_local)
            else:
                print("💾 Model saved to:")
                print(" - Local:        ", model_path_local)

            if best_m_reward is not None:
                print(f"Best reward updated {best_m_reward:.3f} -> {m_reward:.3f}")
            best_m_reward = m_reward

            if (not captured_early) and (frame_idx >= 5):
              early_vids = record_episodes_with(
                  net=net,
                  env_name=DEFAULT_ENV_NAME,
                  video_folder="videos_early",
                  name_prefix=f"early_f{frame_idx:06d}",  # unique + sortable
                  episodes=EARLY_EPISODES_TO_RECORD,
                  epsilon=1.0,                            # very bad / random-ish
                  max_steps=400               # small file
              )
              if early_vids:
                  print(f"EARLY videos @ frame {frame_idx}:", early_vids)
              captured_early = True

        # Auto-record LATE once when performance reaches threshold
        if (not captured_late) and ( (m_reward >= LATE_THRESHOLD) or (frame_idx >= LATE_AFTER_FRAMES) ):
          late_vids = record_late_clip_quick(net)
          if late_vids:
              print("LATE video:", late_vids[-1])
          captured_late = True

        if m_reward > MEAN_REWARD_BOUND:
            print("Solved in %d frames!" % frame_idx)
            break

    if len(buffer) < REPLAY_START_SIZE:
        continue
    if frame_idx % SYNC_TARGET_FRAMES == 0:
        tgt_net.load_state_dict(net.state_dict())

    optimizer.zero_grad()
    batch = buffer.sample(BATCH_SIZE)
    loss_t = calc_loss(batch, net, tgt_net, device)
    loss_t.backward()
    optimizer.step()

env.close()
writer.close()

Creating environment ALE/SpaceInvaders-v5
DQN(
  (conv): Sequential(
    (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): ReLU()
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (5): ReLU()
    (6): Flatten(start_dim=1, end_dim=-1)
  )
  (fc): Sequential(
    (0): Linear(in_features=3136, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=6, bias=True)
  )
)
28: done 1 games, reward 1.000, eps 1.00, speed 79.97 f/s, time 0.0 min
💾 Model saved to:
 - Google Drive: /content/drive/MyDrive/PUBLIC/Models/ALE_SpaceInvaders-v5-best_1-20250929-0021-test_epsdec50000_rs2000_sync1000.dat
 - Local:         saved_models/ALE_SpaceInvaders-v5-best_1-20250929-0021-test_epsdec50000_rs2000_sync1000.dat
[early_f000028] episode=0 return=0.0 steps=29
EARLY videos @ frame 28: ['videos_early/early_f000028-episode-0.mp4']
[late_quick] episode=0 return=2.0 steps=41
LATE

KeyboardInterrupt: 

In [None]:
import os, glob
from IPython.display import Video, display

# Collect candidate folders
candidates = set()
for base in [".", "/content"]:
    candidates.update([p for p in glob.glob(os.path.join(base, "videos*")) if os.path.isdir(p)])

# Add common folders explicitly (covers both relative and absolute paths)
candidates.update([
    "videos", "videos_early", "videos_late", "videos_late_quick",
    "videos_ckpts", "videos_bc", "videos_compare"
])

# Normalize / dedupe / keep only existing
folders = sorted(set(os.path.abspath(p) for p in candidates if os.path.isdir(p)))

if not folders:
    print("No video folders found.")
else:
    for folder in folders:
        mp4s = sorted(glob.glob(os.path.join(folder, "*.mp4")))
        print(f"\n=== {folder} ({len(mp4s)} file(s)) ===")
        if not mp4s:
            print("  (no mp4 files)")
            continue
        for vp in mp4s:
            print("📹", vp)
            display(Video(vp, embed=True, html_attributes="controls"))


=== /content/videos_early (1 file(s)) ===
📹 /content/videos_early/early_f000028-episode-0.mp4



=== /content/videos_late_quick (1 file(s)) ===
📹 /content/videos_late_quick/late_quick-episode-0.mp4
