# üß™ PyTorch Lab 10: Solving Maze with Deep RL

## 0. Setup
We use:
- `gymnasium` for the environment wrapper
- `torch` for the CNN policy/value network
- `imageio` to convert recorded mp4 videos to gifs

If you run locally, you may need ffmpeg installed for mp4 decoding; on Colab it usually works.


In [None]:
!pip -q install gymnasium
!pip -q install imageio imageio-ffmpeg


^C


In [None]:
import os, glob
import numpy as np
import gymnasium as gym
from gymnasium import spaces
import imageio

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm


### Device (CPU/GPU)
We pick CUDA if available. This avoids runtime errors when GPU is not present.

> If you **want to force GPU only** like in your original script, replace this with `device = "cuda"`.


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device


## 1. The environment: SimpleMaze10x10_4Actions + FrameStack4

- The agent starts at `(1,1)` and must reach `(8,8)` (for size 10).
- Reward structure:
  - `-0.01` per step (time penalty)
  - extra `-0.1` if bumping into a wall
  - `+1.0` when reaching the goal
- Observation: 84√ó84 grayscale image with:
  - walls = dark gray
  - goal = medium gray
  - agent = white
- We stack the last 4 frames ‚Üí observation becomes `(4,84,84)`.

This mimics Atari-style inputs while staying tiny.


In [None]:
class SimpleMaze10x10_4Actions(gym.Env):
    '''
    10x10 maze with Atari-like actions:
      0=UP, 1=DOWN, 2=RIGHT, 3=LEFT

    Observation:
      single grayscale frame (84,84) uint8 (frame-stacked outside)
    '''
    metadata = {"render_modes": ["rgb_array"], "render_fps": 30}

    def __init__(self, size=10, max_steps=200, seed=0, render_mode=None):
        super().__init__()
        self.size = size
        self.max_steps = max_steps
        self.rng = np.random.default_rng(seed)
        self.render_mode = render_mode

        self.action_space = spaces.Discrete(4)
        self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84), dtype=np.uint8)

        self._build_fixed_maze()
        self.reset(seed=seed)

    def get_action_meanings(self):
        return ["UP", "DOWN", "RIGHT", "LEFT"]

    def _action_to_delta(self, a: int):
        if a == 0:   # UP
            return (-1, 0)
        if a == 1:   # DOWN
            return (1, 0)
        if a == 2:   # RIGHT
            return (0, 1)
        if a == 3:   # LEFT
            return (0, -1)
        raise ValueError(a)

    def _build_fixed_maze(self):
        s = self.size
        grid = np.zeros((s, s), dtype=np.uint8)

        # border walls
        grid[0, :] = 1
        grid[-1, :] = 1
        grid[:, 0] = 1
        grid[:, -1] = 1

        # internal walls (fixed layout)
        walls = [
            (2, 3), (2, 4),(2, 5),(2, 6),(2, 7),(2, 8),
            (3, 3),
            (5, 4), (5, 5), (5, 6),
            (6, 4),(6, 6),
            (7,4), (7, 5),(7, 6),
        ]
        for r, c in walls:
            if 0 <= r < s and 0 <= c < s:
                grid[r, c] = 1

        self.grid = grid
        self.start = (1, 1)
        self.goal = (s - 2, s - 2)
        self.grid[self.start] = 0
        self.grid[self.goal] = 0

    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        self.pos = list(self.start)
        self.steps = 0
        obs = self._get_obs()
        info = {"pos": tuple(self.pos)}
        return obs, info

    def step(self, action):
        self.steps += 1

        dr, dc = self._action_to_delta(int(action))
        nr, nc = self.pos[0] + dr, self.pos[1] + dc

        reward = -0.01
        bumped = False

        # wall collision
        if self.grid[nr, nc] == 1:
            bumped = True
            reward -= 0.1
        else:
            self.pos = [nr, nc]

        terminated = (tuple(self.pos) == self.goal)
        if terminated:
            reward = 1.0

        truncated = (self.steps >= self.max_steps)
        obs = self._get_obs()
        info = {"pos": tuple(self.pos), "bumped": bumped}
        return obs, reward, terminated, truncated, info

    def _get_obs(self):
        H, W = 84, 84
        img = np.zeros((H, W), dtype=np.uint8)

        cell_h = H // self.size
        cell_w = W // self.size

        # draw walls
        for r in range(self.size):
            for c in range(self.size):
                if self.grid[r, c] == 1:
                    y0, y1 = r * cell_h, (r + 1) * cell_h
                    x0, x1 = c * cell_w, (c + 1) * cell_w
                    img[y0:y1, x0:x1] = 60

        # goal
        gr, gc = self.goal
        y0, y1 = gr * cell_h, (gr + 1) * cell_h
        x0, x1 = gc * cell_w, (gc + 1) * cell_w
        img[y0:y1, x0:x1] = 160

        # agent
        ar, ac = self.pos
        y0, y1 = ar * cell_h, (ar + 1) * cell_h
        x0, x1 = ac * cell_w, (ac + 1) * cell_w
        img[y0:y1, x0:x1] = 255

        return img

    def render(self):
        if self.render_mode == "rgb_array":
            g = self._get_obs()
            return np.stack([g, g, g], axis=-1)
        return None

    def close(self):
        pass


class FrameStack4(gym.Wrapper):
    def __init__(self, env, k=4):
        super().__init__(env)
        self.k = k
        self.frames = None
        self.observation_space = spaces.Box(low=0, high=255, shape=(k, 84, 84), dtype=np.uint8)

    def reset(self, **kwargs):
        obs, info = self.env.reset(**kwargs)
        self.frames = [obs.copy() for _ in range(self.k)]
        return self._get(), info

    def step(self, action):
        obs, reward, terminated, truncated, info = self.env.step(action)
        self.frames.pop(0)
        self.frames.append(obs.copy())
        return self._get(), reward, terminated, truncated, info

    def _get(self):
        return np.stack(self.frames, axis=0)  # (4,84,84)


### Helpers: make_env + video helpers


In [None]:
def make_env(record=False, seed=0, max_steps=200):
    render_mode = "rgb_array" if record else None
    env = SimpleMaze10x10_4Actions(size=10, max_steps=max_steps, seed=seed, render_mode=render_mode)
    env = FrameStack4(env, k=4)

    if record:
        env = gym.wrappers.RecordVideo(env, video_folder="videos", episode_trigger=lambda i: True)

    return env


def newest_mp4(folder):
    mp4s = sorted(glob.glob(os.path.join(folder, "*.mp4")), key=os.path.getmtime)
    if not mp4s:
        raise FileNotFoundError(f"No mp4 found in {folder}")
    return mp4s[-1]


def mp4_to_gif(mp4_path, gif_path, fps=30):
    reader = imageio.get_reader(mp4_path)
    frames = [frame for frame in reader]
    reader.close()
    imageio.mimsave(gif_path, frames, fps=fps)


## 2. Visual sanity check
Reset the environment and display the most recent frame in the stack.


In [None]:
import matplotlib.pyplot as plt

env = make_env(record=False, seed=0)
obs, info = env.reset()

plt.figure(figsize=(4,4))
plt.imshow(obs[-1], cmap="gray", vmin=0, vmax=255)
plt.title(f"Start pos: {info['pos']}")
plt.axis("off")
plt.show()

env.close()


## 3. The network (CNN84)
Atari-style CNN trunk + two heads:
- policy logits over 4 actions
- value estimate V(s)


In [None]:
class CNN84(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=4, stride=2)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, stride=1)

        self.fc1 = nn.Linear(64 * 7 * 7, 512)
        self.p = nn.Linear(512, 4)
        self.v = nn.Linear(512, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))

        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))

        logits = self.p(x)
        value = self.v(x).squeeze(1)
        return logits, value


## 4. Monte‚ÄëCarlo returns
Compute discounted returns backwards.


In [None]:
@torch.no_grad()
def discounted_returns(rewards: torch.Tensor, gamma: float) -> torch.Tensor:
    T = rewards.shape[0]
    returns = torch.empty_like(rewards)
    G = torch.zeros((), device=rewards.device, dtype=rewards.dtype)
    for t in range(T - 1, -1, -1):
        G = rewards[t] + gamma * G
        returns[t] = G
    return returns


## 5. Collecting one episode
Roll out using the current stochastic policy and store (obs, action, reward).


In [None]:
def run_episode(env, model):
    obs, info = env.reset()
    done = False
    ep_return = 0.0
    episode_buffer = []

    model.eval()
    with torch.no_grad():
        while not done:
            x = (torch.tensor(obs).unsqueeze(0).float() / 255.0).to(device)
            logits, _ = model(x)
            dist = torch.distributions.Categorical(logits=logits)
            action = dist.sample().item()

            obs, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated

            episode_buffer.append((obs, action, reward))
            ep_return += reward

    return episode_buffer, ep_return


## 6. Part A ‚Äî Simple REINFORCE (baseline)

REINFORCE gradient:
$$
\nabla_\theta J(\theta) \approx \sum_t \nabla_\theta \log \pi_\theta(a_t|s_t)\, G_t
$$

We minimize:
$$
\mathcal{L}_\text{actor} = -\mathbb{E}[\log \pi(a_t|s_t)\, G_t]
$$


In [None]:
def update_batch_reinforce(model, episodes, optimizer, gamma=0.99, normalize_returns=True):
    obs_list, act_list, G_list = [], [], []

    for episode_buffer in episodes:
        obs, action, reward = map(np.array, zip(*episode_buffer))
        G = discounted_returns(torch.tensor(reward, dtype=torch.float32, device=device), gamma=gamma)

        obs_list.append(obs)
        act_list.append(action)
        G_list.append(G)

    obs = np.concatenate(obs_list, axis=0)
    action = np.concatenate(act_list, axis=0)
    G = torch.cat(G_list, dim=0)

    obs = torch.tensor(obs, dtype=torch.float32, device=device) / 255.0
    action = torch.tensor(action, dtype=torch.long, device=device)

    logits, _ = model(obs)
    dist = torch.distributions.Categorical(logits=logits)
    logp = dist.log_prob(action)

    if normalize_returns:
        G = (G - G.mean()) / (G.std() + 1e-8)

    loss_actor = -(logp * G).mean()

    optimizer.zero_grad(set_to_none=True)
    loss_actor.backward()
    optimizer.step()

    return loss_actor.item(), dist.entropy().mean().item()


### Train REINFORCE
We update every `BATCH_EPISODES` episodes.


In [None]:
VIDEO_DIR = "videos"
os.makedirs(VIDEO_DIR, exist_ok=True)

# Hyperparameters
N_EPISODES = 300
BATCH_EPISODES = 5
GAMMA = 0.99
LR = 2e-4

model = CNN84().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)

env = make_env(record=True, seed=0)

rewards, losses, entropies = [], [], []
batch_buffer = []

for i in tqdm(range(N_EPISODES)):
    episode_buffer, r = run_episode(env, model)

    # ---- CORRECTION kept as comment (remove after you verify) ----
    # batch_buffer.extend(episode_buffer)  # ‚ùå wrong type: would flatten transitions; update expects list of episodes
    batch_buffer.append(episode_buffer)    # ‚úÖ each element is one episode (list of transitions)

    rewards.append(r)

    if (i + 1) % BATCH_EPISODES == 0:
        l, e = update_batch_reinforce(model, batch_buffer, optimizer, gamma=GAMMA)
        batch_buffer = []
        losses.append(l)
        entropies.append(e)

    if (i + 1) % 50 == 0:
        print(
            f"Episode {i+1} | "
            f"Avg Reward (50): {np.mean(rewards[-50:]):.2f} | "
            f"Last Entropy: {entropies[-1] if entropies else float('nan'):.2f} | "
            f"Last Loss: {losses[-1] if losses else float('nan'):.4f}"
        )
        try:
            mp4_path = newest_mp4(VIDEO_DIR)
            gif_path = f"reinforce_ep{i+1}.gif"
            mp4_to_gif(mp4_path, gif_path, fps=30)
            print("Saved gif:", gif_path)
        except Exception as ex:
            print("GIF export skipped:", ex)

env.close()


### Plot learning curves (REINFORCE)


In [None]:
plt.figure()
plt.plot(rewards)
plt.title("Episode return (REINFORCE)")
plt.xlabel("Episode")
plt.ylabel("Return")
plt.show()

plt.figure()
plt.plot(losses)
plt.title("Actor loss per update (REINFORCE)")
plt.xlabel("Update step")
plt.ylabel("Loss")
plt.show()


## 7. Part B ‚Äî Variance reduction: baseline with a learned V(s)

Your method:
- Advantage: `A = G - V(s)` (with `V` detached in actor term)
- Normalize `A` in the batch
- Add a critic MSE loss to fit `V(s) ‚âà G`
- Total: `loss = loss_actor + 0.5 * loss_critic`


In [None]:
def update_batch_actor_critic_baseline(model, episodes, optimizer, gamma=0.99):
    obs_list, act_list, G_list = [], [], []

    for episode_buffer in episodes:
        obs, action, reward = map(np.array, zip(*episode_buffer))
        G = discounted_returns(torch.tensor(reward, dtype=torch.float32, device=device), gamma=gamma)

        obs_list.append(obs)
        act_list.append(action)
        G_list.append(G)

    obs = np.concatenate(obs_list, axis=0)
    action = np.concatenate(act_list, axis=0)
    G = torch.cat(G_list, dim=0)

    obs = torch.tensor(obs, dtype=torch.float32, device=device) / 255.0
    action = torch.tensor(action, dtype=torch.long, device=device)

    logits, V = model(obs)
    dist = torch.distributions.Categorical(logits=logits)
    logp = dist.log_prob(action)

    A = G - V.detach()
    A = (A - A.mean()) / (A.std() + 1e-8)

    loss_actor = -(logp * A).mean()
    loss_critic = F.mse_loss(V, G)
    loss = loss_actor + 0.5 * loss_critic

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    return loss.item(), loss_actor.item(), loss_critic.item(), dist.entropy().mean().item()


### Train with baseline (re-init model for comparison)


In [None]:
model2 = CNN84().to(device)
optimizer2 = torch.optim.Adam(model2.parameters(), lr=LR)

N_EPISODES_2 = 300
BATCH_EPISODES_2 = 5

env = make_env(record=True, seed=0)

rewards2, losses2, actor_losses2, critic_losses2, entropies2 = [], [], [], [], []
batch_buffer = []

for i in tqdm(range(N_EPISODES_2)):
    episode_buffer, r = run_episode(env, model2)

    # ---- CORRECTION kept as comment (remove after you verify) ----
    # batch_buffer.extend(episode_buffer)  # ‚ùå wrong type: would flatten transitions; update expects list of episodes
    batch_buffer.append(episode_buffer)    # ‚úÖ each element is one episode (list of transitions)

    rewards2.append(r)

    if (i + 1) % BATCH_EPISODES_2 == 0:
        l, la, lc, e = update_batch_actor_critic_baseline(model2, batch_buffer, optimizer2, gamma=GAMMA)
        batch_buffer = []
        losses2.append(l)
        actor_losses2.append(la)
        critic_losses2.append(lc)
        entropies2.append(e)

    if (i + 1) % 50 == 0:
        print(
            f"Episode {i+1} | "
            f"Avg Reward (50): {np.mean(rewards2[-50:]):.2f} | "
            f"Entropy: {entropies2[-1] if entropies2 else float('nan'):.2f} | "
            f"Loss: {losses2[-1] if losses2 else float('nan'):.4f} | "
            f"Actor: {actor_losses2[-1] if actor_losses2 else float('nan'):.4f} | "
            f"Critic: {critic_losses2[-1] if critic_losses2 else float('nan'):.4f}"
        )
        try:
            mp4_path = newest_mp4(VIDEO_DIR)
            gif_path = f"baseline_ep{i+1}.gif"
            mp4_to_gif(mp4_path, gif_path, fps=30)
            print("Saved gif:", gif_path)
        except Exception as ex:
            print("GIF export skipped:", ex)

env.close()


### Plot learning curves (baseline)


In [None]:
plt.figure()
plt.plot(rewards2)
plt.title("Episode return (Baseline / Actor-Critic style)")
plt.xlabel("Episode")
plt.ylabel("Return")
plt.show()

plt.figure()
plt.plot(losses2, label="total")
plt.plot(actor_losses2, label="actor")
plt.plot(critic_losses2, label="critic")
plt.title("Losses per update (Baseline / Actor-Critic style)")
plt.xlabel("Update step")
plt.ylabel("Loss")
plt.legend()
plt.show()


## 8. Compare REINFORCE vs Baseline


In [None]:
plt.figure()
plt.plot(rewards, label="REINFORCE")
plt.plot(rewards2, label="Baseline (A=G-V)")
plt.title("Episode return comparison")
plt.xlabel("Episode")
plt.ylabel("Return")
plt.legend()
plt.show()


## 9. Record a final episode (GIF)


In [None]:
record_env = make_env(record=True, seed=42)
episode_buffer, r = run_episode(record_env, model2)
record_env.close()
print(f"Recorded episode return: {r}")

mp4_path = newest_mp4(VIDEO_DIR)
mp4_to_gif(mp4_path, "last_episode.gif", fps=30)
print("Saved gif: last_episode.gif")


## Exercises (optional)
1. Reward shaping: change wall penalty or step penalty.
2. Add an entropy bonus term.
3. Replace Monte‚ÄëCarlo returns with TD(0) targets.
4. Change frame stack size (k=1 vs k=4).
