<a href="https://colab.research.google.com/github/Toneejake/studybuddy-prototype/blob/main/black_hole_sun.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Cell 1**

In [None]:
# This cell installs all necessary libraries.
# It uses a robust method to check for existing packages before installing.
print("⏳ Installing dependencies...")
import subprocess, sys, importlib

def pip_install_many(args):
    subprocess.check_call([sys.executable, "-m", "pip", "install", *args, "-q"])

try:
    import torch, torchvision, torchaudio
except Exception:
    pip_install_many(["torch", "torchvision", "torchaudio", "--index-url", "https://download.pytorch.org/whl/cu118"])

for pkg_name, import_name in [
    ("datasets", "datasets"), ("Pillow", "PIL"), ("matplotlib", "matplotlib"),
    ("scikit-learn", "sklearn"), ("tqdm", "tqdm"), ("gymnasium", "gymnasium"),
    ("stable-baselines3", "stable_baselines3"), ("sb3-contrib", "sb3_contrib"),
    ("opencv-python-headless", "cv2")
]:
    try:
        importlib.import_module(import_name)
    except Exception:
        pip_install_many([pkg_name])

print("✅ All dependencies are ready.")

⏳ Installing dependencies...
✅ All dependencies are ready.


Gym has been unmaintained since 2022 and does not support NumPy 2.0 amongst other critical functionality.
Please upgrade to Gymnasium, the maintained drop-in replacement of Gym, or contact the authors of your software and request that they upgrade.
See the migration guide at https://gymnasium.farama.org/introduction/migration_guide/ for additional information.
  return datetime.utcnow().replace(tzinfo=utc)


**Cell 2**

In [None]:
import numpy as np
import torch
from torch import nn
from torchvision import transforms
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
from datasets import load_dataset
import gymnasium as gym
from gymnasium import spaces
from sb3_contrib import MaskablePPO
from sb3_contrib.common.maskable.utils import get_action_masks
from sb3_contrib.common.maskable.evaluation import evaluate_policy
from google.colab import files
import os

# Core configuration dictionary
CFG = {
    "image_size": 256,
    "obs_fire_ds": 64,
    "max_agents": 10,
    "max_steps": 500,
    "seed": 42,
    "device": "cuda" if torch.cuda.is_available() else "cpu"
}

# Set seeds for reproducibility
np.random.seed(CFG["seed"])
torch.manual_seed(CFG["seed"])

print(f"Configuration loaded. Using device: {CFG['device']}")

Configuration loaded. Using device: cuda


**Cell 3**

In [None]:
# The U-Net architecture is unchanged from V1.5.
# We need to define it so we can load the pre-trained weights.
# THIS IS THE CORRECTED VERSION WITH THE ORIGINAL LAYER NAMES.

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        def double_conv(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True)
            )

        self.dconv_down1 = double_conv(3, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)

        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.dconv_up3 = double_conv(256 + 512, 256)
        self.dconv_up2 = double_conv(128 + 256, 128)
        self.dconv_up1 = double_conv(128 + 64, 64)

        self.conv_last = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)
        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)
        x = self.dconv_down4(x)
        x = nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        x = torch.cat([x, conv3], dim=1)
        x = self.dconv_up3(x)
        x = nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        x = torch.cat([x, conv2], dim=1)
        x = self.dconv_up2(x)
        x = nn.functional.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        x = torch.cat([x, conv1], dim=1)
        x = self.dconv_up1(x)
        out = self.conv_last(x)
        return out

print("✅ U-Net architecture defined with original layer names.")

✅ U-Net architecture defined with original layer names.


**Cell 4**

In [None]:
# This cell loads your pre-trained U-Net model.
# It will prompt you to upload the file.

perception_model = UNet().to(CFG['device'])
MODEL_PATH = "unet_floorplan_model.pth"

print(f"Please upload your trained U-Net model file: '{MODEL_PATH}'")
uploaded = files.upload()

if MODEL_PATH in uploaded:
    perception_model.load_state_dict(torch.load(MODEL_PATH, map_location=CFG['device']))
    perception_model.eval()
    print(f"\n✅ Perception AI model '{MODEL_PATH}' loaded successfully!")
else:
    raise FileNotFoundError(f"'{MODEL_PATH}' not found. Please upload the correct file.")

# Define the inference function that uses the loaded model
def create_grid_from_image(model, pil_image, image_size, device):
    model.eval()
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)), transforms.ToTensor(),
    ])
    input_tensor = transform(pil_image.convert("RGB")).unsqueeze(0).to(device)
    with torch.no_grad():
        output = model(input_tensor)
    grid_tensor = (torch.sigmoid(output) > 0.5).float().squeeze().cpu()
    return grid_tensor.numpy()

print("Inference function is ready.")

Please upload your trained U-Net model file: 'unet_floorplan_model.pth'


Saving unet_floorplan_model.pth to unet_floorplan_model.pth

✅ Perception AI model 'unet_floorplan_model.pth' loaded successfully!
Inference function is ready.


**Cell 5**

In [None]:
import heapq

# These classes are the fundamental building blocks of the simulation.
# They are unchanged from our previous discussions.

def a_star_search(grid, start, goal, fire_map=None):
    def heuristic(a, b): return abs(a[0] - b[0]) + abs(a[1] - b[1])
    neighbors = [(0,1),(0,-1),(1,0),(-1,0)]; close_set, came_from, gscore = set(), {}, {start: 0}
    fscore = {start: heuristic(start, goal)}; oheap = []
    heapq.heappush(oheap, (fscore[start], start))
    while oheap:
        current = heapq.heappop(oheap)[1]
        if current == goal:
            path = [];
            while current in came_from: path.append(current); current = came_from[current]
            return path[::-1]
        close_set.add(current)
        for i, j in neighbors:
            neighbor = current[0] + i, current[1] + j; tentative_g_score = gscore[current] + 1
            if not (0 <= neighbor[0] < grid.shape[1] and 0 <= neighbor[1] < grid.shape[0]): continue
            if grid[neighbor[1]][neighbor[0]] == 1: continue
            if fire_map is not None and fire_map[neighbor[1]][neighbor[0]] == 1: continue
            if neighbor in close_set and tentative_g_score >= gscore.get(neighbor, 0): continue
            if tentative_g_score < gscore.get(neighbor, 0) or neighbor not in [i[1] for i in oheap]:
                came_from[neighbor], gscore[neighbor] = current, tentative_g_score
                fscore[neighbor] = tentative_g_score + heuristic(neighbor, goal); heapq.heappush(oheap, (fscore[neighbor], neighbor))
    return []

class FireSimulator:
    def __init__(self, grid, p=0.25): self.grid, self.p, self.map = grid, p, np.zeros_like(grid, dtype=float); self.dirs = [(0,1),(0,-1),(1,0),(-1,0)]
    def start(self, points):
        for y, x in points: self.map[y, x] = 1
    def step(self):
        new_map = self.map.copy()
        for r, c in np.argwhere(self.map == 1):
            for dr, dc in self.dirs:
                nr, nc = r + dr, c + dc
                if 0 <= nr < self.map.shape[0] and 0 <= nc < self.map.shape[1] and self.map[nr, nc] == 0 and self.grid[nr, nc] == 0 and np.random.rand() < self.p: new_map[nr, nc] = 1
        self.map = new_map
    def reset(self, points=None): self.map.fill(0); self.start(points or [])

class Person:
    def __init__(self, pos): self.ipos, self.pos, self.path, self.status, self.state = tuple(pos), list(pos), [], 'evacuating', 'CALM'; self.speed, self.trip, self.trip_t = 1.0, 0.0, 0
    def update_state(self, fire):
        if self.trip_t > 0: return
        fires = np.argwhere(fire == 1); min_d = np.min(np.linalg.norm(fires - np.array([self.pos[1], self.pos[0]]), axis=1)) if len(fires) > 0 else float('inf')
        if min_d < 25: self.state, self.speed, self.trip = 'PANICKED', 1.5, 0.1
        elif min_d < 50: self.state, self.speed, self.trip = 'ALERT', 1.2, 0.0
        else: self.state, self.speed, self.trip = 'CALM', 1.0, 0.0
    def move(self):
        if self.trip_t > 0: self.trip_t -= 1; return
        if self.state == 'PANICKED' and np.random.rand() < self.trip: self.trip_t = 5; return
        for _ in range(int(round(self.speed))):
            if self.path: self.pos = self.path.pop(0)
    def check_status(self, fire, exits):
        if self.status != 'evacuating': return
        if fire[int(self.pos[1]), int(self.pos[0])] == 1: self.status = 'burned'
        elif any(np.linalg.norm(np.array(self.pos) - np.array(ex)) < 5 for ex in exits): self.status = 'escaped'
    def compute_path(self, grid, goal, fire): self.path = a_star_search(grid, (int(self.pos[0]), int(self.pos[1])), goal, fire)
    def reset(self): self.pos, self.path, self.status, self.state = list(self.ipos), [], 'evacuating', 'CALM'; self.speed, self.trip, self.trip_t = 1.0, 0.0, 0

print("✅ Core simulation logic is ready.")

✅ Core simulation logic is ready.


**Cell 6**

In [None]:
# Cell 6: The V2.1 Hybrid Gymnasium Environment (Fully Spawn-Safe)
# This version is designed to work robustly with the 'spawn' multiprocessing method by
# handling all asset loading internally within each worker process.

class EvacuationEnv_v2_Hybrid(gym.Env):
    def __init__(self, hf_dataset_name, max_agents, max_steps, use_perception_net=False, perception_model_path=None, device=None):
        super().__init__()

        # --- SPAWN-SAFE INITIALIZATION ---
        # Each worker process will load its own copy of the dataset and model.
        self.hf_dataset = load_dataset(hf_dataset_name, split="train")

        self.use_perception_net = use_perception_net
        if self.use_perception_net:
            if perception_model_path is None or device is None:
                raise ValueError("perception_model_path and device must be provided if use_perception_net is True.")
            # Each worker loads the U-Net model from the provided path.
            self.perception_model = UNet().to(device)
            self.perception_model.load_state_dict(torch.load(perception_model_path, map_location=device))
            self.perception_model.eval()
            self.device = device
        else:
            self.perception_model = None
            self.device = None
        # ------------------------------------

        self.image_size, self.max_agents, self.max_steps = CFG["image_size"], max_agents, max_steps
        self.MAX_EXITS = (self.image_size - 2) * 4
        self.action_space = spaces.Discrete(self.MAX_EXITS)
        self._exit_slot_map = self._create_exit_slot_map()
        obs_shape = (CFG["obs_fire_ds"]**2) + (self.max_agents * 3) + 1
        self.observation_space = spaces.Dict({
            "observation": spaces.Box(low=0, high=1, shape=(obs_shape,), dtype=np.float32),
            "action_mask": spaces.Box(low=0, high=1, shape=(self.MAX_EXITS,), dtype=np.int8)
        })
        self.gt_transform = transforms.Compose([transforms.Resize((self.image_size, self.image_size)), transforms.ToTensor()])

    def _create_exit_slot_map(self):
        slots, idx, w, h = {}, 0, self.image_size, self.image_size
        for x in range(1, w - 1): slots[idx] = (x, 1); idx += 1
        for x in range(1, w - 1): slots[idx] = (x, h - 2); idx += 1
        for y in range(1, h - 1): slots[idx] = (1, y); idx += 1
        for y in range(1, h - 1): slots[idx] = (w - 2, y); idx += 1
        return slots

    def _get_observation(self):
        fire_obs = cv2.resize(self.fire_sim.map.astype(np.float32), (CFG["obs_fire_ds"], CFG["obs_fire_ds"])).flatten()
        agent_obs = np.zeros(self.max_agents * 3, dtype=np.float32)
        state_map = {'CALM': 0.0, 'ALERT': 0.5, 'PANICKED': 1.0}
        for i, agent in enumerate(self.agents):
            agent_obs[i*3] = agent.pos[0] / self.image_size; agent_obs[i*3+1] = agent.pos[1] / self.image_size
            agent_obs[i*3+2] = state_map.get(agent.state, 0.0)
        time_obs = np.array([self.current_step / self.max_steps])
        obs = np.concatenate([fire_obs, agent_obs, time_obs]).astype(np.float32)
        return {"observation": obs, "action_mask": self.action_mask}

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        random_idx = np.random.randint(0, len(self.hf_dataset))
        item = self.hf_dataset[random_idx]
        if self.use_perception_net:
            self.base_grid = create_grid_from_image(self.perception_model, item['plans'], self.image_size, self.device)
        else:
            gt_tensor = self.gt_transform(item['walls'].convert("L"))
            self.base_grid = (gt_tensor > 0.5).squeeze().numpy().astype(float)

        self.exits, self.action_mask = [], np.zeros(self.MAX_EXITS, dtype=np.int8)
        for idx, (x, y) in self._exit_slot_map.items():
            if self.base_grid[y, x] == 0: self.exits.append((x, y)); self.action_mask[idx] = 1

        if not self.exits: # Fallback for maps with no exits
            valid_points = np.argwhere(self.base_grid == 0)
            if len(valid_points) == 0:
                self.action_mask[0] = 1 # Enable one action to prevent a crash
                self.exits.append(self._exit_slot_map[0])
            else:
                y, x = valid_points[0]; self.exits.append((x, y))
                closest_idx = min(self._exit_slot_map, key=lambda k: np.hypot(self._exit_slot_map[k][0]-x, self._exit_slot_map[k][1]-y))
                self.action_mask[closest_idx] = 1

        self.fire_sim = FireSimulator(self.base_grid); self.current_step = 0
        valid_spawns = np.argwhere(self.base_grid == 0)
        if len(valid_spawns) == 0:
            self.agents = [] # No place to spawn agents
        else:
            fire_y, fire_x = valid_spawns[np.random.choice(len(valid_spawns))]
            self.fire_sim.reset(points=[(fire_y, fire_x)])
            num_agents = np.random.randint(1, self.max_agents + 1)
            spawn_indices = np.random.choice(len(valid_spawns), min(num_agents, len(valid_spawns)), replace=False)
            self.agents = [Person(pos=(x, y)) for y, x in valid_spawns[spawn_indices] if self.fire_sim.map[y,x]==0]
        return self._get_observation(), {}

    def step(self, action):
        self.current_step += 1
        if not self.action_mask[action]: # Handle rare case where model predicts a masked action
            return self._get_observation(), -1.0, False, False, {}
        reward = -0.01; target_exit = self._exit_slot_map[action]
        self.fire_sim.step()
        for agent in self.agents:
            if agent.status == 'evacuating':
                agent.update_state(self.fire_sim.map)
                if not agent.path or self.current_step % 10 == 0: agent.compute_path(self.base_grid, target_exit, self.fire_sim.map)
                agent.move(); agent.check_status(self.fire_sim.map, self.exits)
                if agent.status == 'escaped': reward += 10
                elif agent.status == 'burned': reward -= 10
        terminated = all(a.status != 'evacuating' for a in self.agents)
        truncated = self.current_step >= self.max_steps
        return self._get_observation(), reward, terminated, truncated, {}

print("✅ V2.1 Hybrid Curriculum Environment 'EvacuationEnv_v2_Hybrid' (Fully Spawn-Safe) is defined.")

✅ V2.1 Hybrid Curriculum Environment 'EvacuationEnv_v2_Hybrid' (Fully Spawn-Safe) is defined.


**Cell 7**

In [None]:
# Cell 7: Environment Factory and Dataset Loader (Final Version)
import multiprocessing as mp
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import SubprocVecEnv

# This is a critical guard for using 'spawn' or 'forkserver' in notebooks.
if __name__ == '__main__':
    try:
        mp.set_start_method('spawn', force=True)
        print("✅ Multiprocessing start method set to 'spawn'.")
    except RuntimeError:
        pass # Can only be set once

# The factory now passes strings (file paths, names) instead of large objects.
def make_env_v2_1(use_perception_net: bool = False, n_envs=1):
    env_kwargs = {
        'hf_dataset_name': "zimhe/pseudo-floor-plan-12k",
        'max_agents': CFG["max_agents"],
        'max_steps': CFG["max_steps"],
        'use_perception_net': use_perception_net,
        'perception_model_path': MODEL_PATH if use_perception_net else None,
        'device': CFG['device']
    }

    # Use vectorized environments for faster training
    return make_vec_env(
        EvacuationEnv_v2_Hybrid,
        n_envs=n_envs,
        vec_env_cls=SubprocVecEnv,
        env_kwargs=env_kwargs,
        # CORRECTED: Pass start_method inside the vec_env_kwargs dictionary
        vec_env_kwargs=dict(start_method='spawn')
    )

print("✅ V2.1 Hybrid environment factory (Final Version) is ready.")

✅ Multiprocessing start method set to 'spawn'.
✅ V2.1 Hybrid environment factory (Final Version) is ready.


**Cell 8**

In [None]:
print("--- Running Smoke Test ---")
# Create a single environment using the FAST path (GT masks)
test_env = make_env_v2_1(use_perception_net=False, n_envs=1)
obs = test_env.reset()
print("Initial observation shapes are OK.")

# Check if masking works
mask = get_action_masks(test_env)
print(f"Action mask shape: {mask.shape}, Number of valid actions: {mask.sum()}")
assert mask.shape[1] == test_env.action_space[0].n, "Mask shape mismatch!"

# Take a few random (but valid) steps
for i in range(5):
    valid_actions = np.where(get_action_masks(test_env)[0])[0]
    action = [np.random.choice(valid_actions)]
    obs, _, _, _ = test_env.step(action)
print("Stepping through the environment works.")
test_env.close()
print("✅ Smoke test passed!")

--- Running Smoke Test ---
Initial observation shapes are OK.


EOFError: 

**Cell 9**

In [None]:
# Phase 1: Train a baseline model quickly using the perfect ground-truth data.
print("--- Phase 1: Starting Baseline Model Training ---")

# Training parameters
N_ENVS = 4 # A T4 can handle 4-8 parallel environments
TRAIN_STEPS_BASELINE = 50000 # Increase to 200k+ for a production model
MODEL_BASELINE_PATH = "ppo_commander_v2.0_baseline.zip"

# Create the vectorized environments for fast training
train_env_baseline = make_env_v2_1(use_perception_net=False, n_envs=N_ENVS)

# We use MaskablePPO, designed for this kind of problem
model_baseline = MaskablePPO(
    "MultiInputPolicy",
    train_env_baseline,
    learning_rate=3e-4,
    n_steps=2048,
    batch_size=64,
    n_epochs=10,
    gamma=0.99,
    gae_lambda=0.95,
    verbose=1,
    device=CFG['device'],
    tensorboard_log="./tensorboard_logs/baseline/"
)

print(f"🚀 Training baseline model for {TRAIN_STEPS_BASELINE} steps...")
model_baseline.learn(total_timesteps=TRAIN_STEPS_BASELINE, progress_bar=True)

model_baseline.save(MODEL_BASELINE_PATH)
train_env_baseline.close()
print(f"\n✅ Baseline model training complete. Model saved to '{MODEL_BASELINE_PATH}'")

**Cell 10**

In [None]:
import pandas as pd
# Phase 2: Measure how well the baseline model performs on imperfect, U-Net generated maps.
print("\n--- Phase 2: Evaluating Domain Gap ---")

if not os.path.exists(MODEL_BASELINE_PATH):
    raise FileNotFoundError("Baseline model not found. Please run Phase 1 first.")

N_EVAL_EPISODES = 20 # Use more episodes (e.g., 100) for a more stable estimate

# Create two evaluation environments: one with GT, one with U-Net
eval_env_gt = make_env_v2_1(use_perception_net=False, n_envs=1)
eval_env_unet = make_env_v2_1(use_perception_net=True, n_envs=1)

# Load the trained baseline model
model_to_eval = MaskablePPO.load(MODEL_BASELINE_PATH)

print(f"Running {N_EVAL_EPISODES} evaluation episodes on Ground-Truth maps...")
mean_reward_gt, std_reward_gt = evaluate_policy(model_to_eval, eval_env_gt, n_eval_episodes=N_EVAL_EPISODES, warn=False)

print(f"Running {N_EVAL_EPISODES} evaluation episodes on U-Net maps...")
mean_reward_unet, std_reward_unet = evaluate_policy(model_to_eval, eval_env_unet, n_eval_episodes=N_EVAL_EPISODES, warn=False)

# Display results
results = {
    "Environment": ["Ground-Truth (Ideal)", "U-Net Predicted (Real-World)"],
    "Mean Reward": [f"{mean_reward_gt:.2f}", f"{mean_reward_unet:.2f}"],
    "Std Reward": [f"{std_reward_gt:.2f}", f"{std_reward_unet:.2f}"]
}
df = pd.DataFrame(results)
print("\n--- Domain Gap Analysis ---")
print(df.to_string(index=False))

reward_diff = (mean_reward_gt - mean_reward_unet)
print(f"\nPerformance drop (Domain Gap): {reward_diff:.2f} reward points.")
print("A significant drop indicates that fine-tuning is necessary.")

eval_env_gt.close()
eval_env_unet.close()

**Cell 11**

In [None]:
# Phase 3: Close the domain gap by fine-tuning the model on U-Net data.
print("\n--- Phase 3: Starting Fine-Tuning ---")

TRAIN_STEPS_FINETUNE = 25000 # Usually fewer steps are needed for fine-tuning
LEARNING_RATE_FINETUNE = 3e-5 # Use a lower learning rate!
MODEL_FINETUNED_PATH = "ppo_commander_v2.0_finetuned.zip"

# Create a new environment that uses the SLOW path (U-Net predictions)
train_env_finetune = make_env_v2_1(use_perception_net=True, n_envs=N_ENVS)

# Load the baseline model to continue its training
model_finetune = MaskablePPO.load(
    MODEL_BASELINE_PATH,
    env=train_env_finetune,
    learning_rate=LEARNING_RATE_FINETUNE, # Set the new, lower learning rate
    device=CFG['device']
)

print(f"🚀 Fine-tuning model for {TRAIN_STEPS_FINETUNE} steps on U-Net data...")
model_finetune.learn(total_timesteps=TRAIN_STEPS_FINETUNE, progress_bar=True, reset_num_timesteps=False)

model_finetune.save(MODEL_FINETUNED_PATH)
train_env_finetune.close()
print(f"\n✅ Fine-tuning complete. Final model saved to '{MODEL_FINETUNED_PATH}'")

**Cell 12**

In [None]:
from IPython.display import display, HTML
import matplotlib.patches as patches
from matplotlib.animation import FuncAnimation

print("\n--- Grand Finale: Testing the Final, Fine-Tuned Model ---")

if not os.path.exists(MODEL_FINETUNED_PATH):
    raise FileNotFoundError("Fine-tuned model not found. Please run all phases first.")

print("Please upload a NEW, UNSEEN floor plan for the final test.")
uploaded = files.upload()

if uploaded:
    test_image_path = list(uploaded.keys())[0]
    test_image_pil = Image.open(test_image_path)

    # 1. Analyze with Perception AI
    print("Step 1: Analyzing floor plan...")
    simulation_grid = create_grid_from_image(perception_model, test_image_pil, CFG['image_size'], CFG['device'])

    # 2. Setup a single test environment with the U-Net's grid
    # We must create a temporary, non-randomized environment for this single test
    class TestEnv(EvacuationEnv_v2_Hybrid):
        def __init__(self, grid):
            # Hacky way to bypass the dataset loading for a single grid test
            dummy_dataset = [{'plans': test_image_pil, 'walls': test_image_pil}]
            super().__init__(dummy_dataset, CFG['max_agents'], CFG['max_steps'])
            self.fixed_grid = grid
        def reset(self, seed=None, options=None):
            self.base_grid = self.fixed_grid # Override with our grid
            # Now call the original reset logic, but it will use our fixed grid
            return super().reset(seed=seed)

    final_test_env = TestEnv(simulation_grid)
    model_final = MaskablePPO.load(MODEL_FINETUNED_PATH)
    print("Step 2: Environment and AI are ready.")

    # 3. Run simulation
    print("Step 3: Running AI-controlled evacuation...")
    obs, _ = final_test_env.reset()
    history = []
    while True:
        action, _ = model_final.predict(obs, action_masks=obs["action_mask"], deterministic=True)
        obs, _, terminated, truncated, _ = final_test_env.step(action[0])
        history.append({
            'fire': final_test_env.fire_sim.map.copy(),
            'agents': [{'p': a.pos, 's': a.status, 'st': a.state, 't': a.trip_t > 0} for a in final_test_env.agents],
            'step': final_test_env.current_step
        })
        if terminated or truncated: break
    print("✅ Simulation complete.")

    # 4. Animate and display report
    fig, ax = plt.subplots(figsize=(8, 8))
    bg_img = test_image_pil.resize((CFG['image_size'], CFG['image_size']))
    def animate(i):
        ax.clear(); frame = history[i]
        ax.imshow(bg_img, extent=[0, CFG['image_size'], 0, CFG['image_size']])
        ax.imshow(frame['fire'], cmap='Reds', alpha=0.5, extent=[0, CFG['image_size'], CFG['image_size'], 0])
        for a in frame['agents']:
            c={'CALM':'blue','ALERT':'yellow','PANICKED':'orange','escaped':'lime','burned':'black'}.get(a['s'], a['st'])
            shape = patches.Ellipse(a['p'],6,3,color=c,alpha=.9) if a['t'] else patches.Circle(a['p'],3,color=c,alpha=.9)
            ax.add_patch(shape)
        ax.set_title(f"SafeScape V2.1 Evacuation | Step: {frame['step']}"); ax.axis('off')
    anim = FuncAnimation(fig, animate, frames=len(history), interval=100); plt.close()
    display(HTML(anim.to_jshtml()))

    # 5. Risk Assessment Dashboard
    final_agents = history[-1]['agents']
    escaped = sum(1 for a in final_agents if a['s'] == 'escaped')
    burned = sum(1 for a in final_agents if a['s'] == 'burned')
    total = len(final_agents)
    print("\n--- SafeScape V2.1 Risk Assessment ---")
    print(f"Outcome for '{test_image_path}':")
    print(f"  - Agents Escaped: {escaped}/{total} ({escaped/total:.1%})")
    print(f"  - Agents Burned:  {burned}/{total} ({burned/total:.1%})")