# Active Object Localization with Deep Reinforcement Learning  
### Full Reimplementation of Caicedo & Lazebnik (ICCV 2015)

This notebook reproduces the neural agent from:

**Caicedo & Lazebnik — Active Object Localization, ICCV 2015**

It includes:

- Class-specific DQN agent  
- CNN feature extraction (AlexNet or ResNet)  
- Action history vector  
- 9 bounding box transformations  
- Paper reward function  
- Single-object training  
- Episode trajectory visualization  
- GIF and MP4 animations  
- Heatmaps (episode + dataset-level)  
- Quantitative evaluation across VOC 2007  

This notebook **does NOT use CLIP**, transformers, or zero-shot models,  
to remain faithful to the original paper.


In [1]:
!pip install shimmy stable-baselines3[extra] gymnasium torch torchvision tensorflow_datasets opencv-python imageio matplotlib


Collecting shimmy
  Downloading Shimmy-2.0.0-py3-none-any.whl.metadata (3.5 kB)
Collecting stable-baselines3[extra]
  Downloading stable_baselines3-2.7.1-py3-none-any.whl.metadata (4.8 kB)
Downloading Shimmy-2.0.0-py3-none-any.whl (30 kB)
Downloading stable_baselines3-2.7.1-py3-none-any.whl (188 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m188.0/188.0 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: shimmy, stable-baselines3
Successfully installed shimmy-2.0.0 stable-baselines3-2.7.1


## Imports & Environment Setup


In [None]:
import os
import cv2
import gymnasium as gym
import numpy as np
import random
import torch
import torch.nn as nn
import torchvision.transforms as T
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.monitor import Monitor

from gymnasium.vector import AsyncVectorEnv

from IPython.display import Image as IPyImage
from IPython.display import display

# For GIF/MP4
import imageio


## GPU Check


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


## Set random seeds for reproducibility


In [None]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if device == "cuda":
    torch.cuda.manual_seed_all(seed)


## Create output directories (gifs, mp4s, results)


In [None]:
os.makedirs("animations", exist_ok=True)
os.makedirs("results", exist_ok=True)


# **Section 2 — Dataset: PASCAL VOC 2007 (TFDS)**

We load VOC 2007 using TensorFlow Datasets (TFDS).  
This gives us:

- `voc/2007:train`: Training split  
- `voc/2007:validation`: Validation split  
- `voc/2007:test`: Test split  

The ICCV 2015 paper uses:
- **train + validation** for training  
- **test** for evaluation  

We follow that design here.

We also define helper functions to convert TFDS bounding boxes  
(from normalized format) to pixel coordinates.


In [None]:
# Load VOC 2007 data from TFDS
ds_train = tfds.load("voc/2007", split="train", shuffle_files=True)
ds_val   = tfds.load("voc/2007", split="validation", shuffle_files=True)
ds_test  = tfds.load("voc/2007", split="test", shuffle_files=True)

# Convert to list for easier random access
ds_train = list(ds_train)
ds_val = list(ds_val)
ds_test = list(ds_test)

print("Train size:", len(ds_train))
print("Validation size:", len(ds_val))
print("Test size:", len(ds_test))


## VOC 20-Class Label List (in official order)

We hardcode the 20-class list used in the VOC 2007 dataset:

1. person  
2. bird  
3. cat  
4. cow  
5. dog  
6. horse  
7. sheep  
8. aeroplane  
9. bicycle  
10. boat  
11. bus  
12. car  
13. motorbike  
14. train  
15. bottle  
16. chair  
17. dining table  
18. potted plant  
19. sofa  
20. tv/monitor  


In [None]:
VOC_CLASSES = [
    "aeroplane", "bicycle", "bird", "boat", "bottle",
    "bus", "car", "cat", "chair", "cow",
    "diningtable", "dog", "horse", "motorbike", "person",
    "pottedplant", "sheep", "sofa", "train", "tvmonitor"
]


## Bounding Box conversion utility

TFDS provides VOC bounding boxes in **normalized format**:

- ymin, xmin, ymax, xmax in range [0, 1]

We convert this to pixel coordinates using the image dimensions.


In [None]:
def tfds_box_to_pixel(bbox, img_shape):
    """
    Convert normalized TFDS VOC bbox to absolute pixel coordinates.

    bbox = [ymin, xmin, ymax, xmax]
    img_shape = (H, W, 3)
    """
    H, W = img_shape[:2]
    ymin, xmin, ymax, xmax = bbox
    x1 = int(xmin * W)
    y1 = int(ymin * H)
    x2 = int(xmax * W)
    y2 = int(ymax * H)
    return x1, y1, x2, y2


In [None]:
def tfds_to_numpy(sample):
    return {
        "image": np.array(sample["image"]),
        "objects": {
            "label": np.array(sample["objects"]["label"]),
            "bbox": np.array(sample["objects"]["bbox"]),
        }
    }


## Extract all GT boxes for one image (for a specific class)

This helps with:
- selecting training samples for a given class
- evaluating multi-object detection
- visualizing ground-truth targets


In [None]:
def get_all_gt_boxes_for_class(sample, class_name, img_shape):
    H, W = img_shape
    gt_boxes = []

    for label, bbox in zip(sample["objects"]["label"], sample["objects"]["bbox"]):
        cls = VOC_CLASSES[int(label)]
        if cls == class_name:
            x1, y1, x2, y2 = tfds_box_to_pixel(bbox, (H, W, 3))
            gt_boxes.append(Box(x1, y1, x2, y2))

    return gt_boxes


## Quick visualization of a raw VOC sample (for debugging)

Useful to verify that bounding boxes and class labels load properly.


In [None]:
def visualize_voc_sample(sample):
    img = sample["image"]
    H, W = img.shape[:2]

    fig, ax = plt.subplots(figsize=(6,6))
    ax.imshow(img)
    ax.set_title("VOC Sample")

    for label, bbox in zip(sample["objects"]["label"], sample["objects"]["bbox"]):
        x1, y1, x2, y2 = tfds_box_to_pixel(bbox, img.shape)
        class_name = VOC_CLASSES[int(label)]
        rect = plt.Rectangle((x1, y1), x2-x1, y2-y1,
                             fill=False, edgecolor='r', linewidth=2)
        ax.add_patch(rect)
        ax.text(x1, y1, class_name, color='yellow')

    plt.axis("off")
    plt.show()


In [None]:
visualize_voc_sample(random.choice(ds_test))


# **Section 3 — CNN Feature Extractor (AlexNet or ResNet)**

The original ICCV 2015 paper uses **AlexNet conv5** as the feature extractor.  
To maintain faithfulness to the paper, AlexNet is supported.

We also allow using **ResNet-50** as an optional modern backbone:

- Faster convergence  
- Higher quality features  
- Better stability in DQN  

The RL agent does not backprop through the CNN —  
features are used *only* for the state representation.

The user may select:

```python
cnn_type = "alexnet"
cnn_type = "resnet50"


In [None]:
cnn_type = "resnet50"   # you may change this anytime

## Preprocessing transform for CNN input
- Resize crop to 224×224
- Convert to tensor
- Normalize with ImageNet mean/std


In [None]:
preprocess = T.Compose([
    T.ToPILImage(),
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
])


## FeatureExtractor

This class:

- Loads AlexNet or ResNet-50  
- Removes classifier layers  
- Extracts features from last convolutional layer  
- Flattens the feature map into a 1D vector  

Output feature dimension:

- AlexNet conv5 → **256 × 6 × 6 = 9216**  
- ResNet-50 layer4 → **2048 × 7 × 7 = 100,352**  
- ResNet-50 global average pool → **2048** (we will use this one)

For RL stability, the preferred choice is:

### **ResNet-50 global pooled features (2048-D)**  


In [None]:
class FeatureExtractor(nn.Module):
    def __init__(self, cnn_type="resnet50", device="cpu"):
        super().__init__()
        self.cnn_type = cnn_type
        self.device = device

        if cnn_type == "alexnet":
            from torchvision.models import alexnet, AlexNet_Weights
            model = alexnet(weights=AlexNet_Weights.IMAGENET1K_V1)
            self.features = model.features  # conv layers only
            self.output_dim = 256 * 6 * 6

        elif cnn_type == "resnet50":
            from torchvision.models import resnet50, ResNet50_Weights
            model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
            # all layers except the final FC
            self.features = nn.Sequential(*list(model.children())[:-1])
            self.output_dim = 2048  # after global average pool

        else:
            raise ValueError("cnn_type must be 'alexnet' or 'resnet50'")

        self.to(device)
        self.eval()  # we never train CNN

    def forward(self, img_crop):
        """
        img_crop: numpy image (H, W, 3)
        returns: 1D feature vector (numpy)
        """
       # --- Normalize dtype/shape for ToPILImage ---
        if isinstance(img_crop, torch.Tensor):
            # [C, H, W] -> [H, W, C], assume in [0,1]
            if img_crop.dim() == 3 and img_crop.shape[0] in (1, 3):
                img_crop = img_crop.permute(1, 2, 0).cpu().numpy()
            else:
                img_crop = img_crop.cpu().numpy()

        if isinstance(img_crop, np.ndarray):
            # If float, convert to uint8 [0,255]
            if img_crop.dtype != np.uint8:
                arr = img_crop
                # heuristic: if max <= 1, assume [0,1] and scale
                if arr.size > 0 and arr.max() <= 1.0 + 1e-6:
                    arr = arr * 255.0
                img_crop = np.clip(arr, 0, 255).astype(np.uint8)

        # Handle empty crops robustly
        if img_crop is None or (isinstance(img_crop, np.ndarray) and img_crop.size == 0):
            x = torch.zeros((1, 3, 224, 224), device=self.device)
        else:
            try:
                x = preprocess(img_crop).unsqueeze(0).to(self.device)
            except Exception:
                # Fallback for any weird shape/dtype
                x = torch.zeros((1, 3, 224, 224), device=self.device)

        with torch.no_grad():
            feats = self.features(x)

        feats = feats.view(feats.size(0), -1)
        return feats.cpu().numpy().squeeze()


## Quick test of feature extractor
This verifies that CNN → feature vector works.


In [None]:
test_img = random.choice(ds_train)["image"]
fe = FeatureExtractor(cnn_type=cnn_type, device=device)

feat = fe(test_img)
print("Feature dimension:", feat.shape)


# **Section 4 — Core RL Components**

This section implements the foundational components of the Active Localization agent:

### 1. Bounding Box Representation (`Box` class)
Handles:
- clipping to image boundaries  
- integer conversion  
- width/height calculations  
- movement and resizing  

### 2. IoU Computation
Intersection over Union for evaluating progress and triggers.

### 3. Action Transformations
The ICCV 2015 paper uses **9 discrete actions**:

1. move left  
2. move right  
3. move up  
4. move down  
5. scale bigger  
6. scale smaller  
7. make fatter (increase width)  
8. make taller (increase height)  
9. trigger (finalize box)  

### 4. Reward Function (From the paper)
- +1 if IoU increases  
- −1 if IoU decreases  
- Trigger:
  - +3 if IoU ≥ 0.6  
  - −3 otherwise  


In [None]:
class Box:
    """
    Bounding box for the RL agent.
    Coordinates stored as absolute pixel coords: (x1, y1, x2, y2)
    """

    def __init__(self, x1, y1, x2, y2):
        self.x1 = int(x1)
        self.y1 = int(y1)
        self.x2 = int(x2)
        self.y2 = int(y2)

    def as_int(self):
        return int(self.x1), int(self.y1), int(self.x2), int(self.y2)

    def width(self):
        return max(1, self.x2 - self.x1)

    def height(self):
        return max(1, self.y2 - self.y1)

    def copy(self):
        return Box(self.x1, self.y1, self.x2, self.y2)

    def clip(self, W, H):
        """Ensure box stays within image boundaries."""
        self.x1 = np.clip(self.x1, 0, W-1)
        self.y1 = np.clip(self.y1, 0, H-1)
        self.x2 = np.clip(self.x2, 1, W)
        self.y2 = np.clip(self.y2, 1, H)


In [None]:
def iou(boxA, boxB):
    """
    Compute IoU between two Box objects.
    """

    xA1, yA1, xA2, yA2 = boxA.as_int()
    xB1, yB1, xB2, yB2 = boxB.as_int()

    inter_x1 = max(xA1, xB1)
    inter_y1 = max(yA1, yB1)
    inter_x2 = min(xA2, xB2)
    inter_y2 = min(yA2, yB2)

    inter_w = max(0, inter_x2 - inter_x1)
    inter_h = max(0, inter_y2 - inter_y1)
    inter_area = inter_w * inter_h

    areaA = boxA.width() * boxA.height()
    areaB = boxB.width() * boxB.height()

    union = areaA + areaB - inter_area + 1e-6

    return inter_area / union


## 9 Actions from Caicedo & Lazebnik (ICCV 2015)

We implement:

0. move left  
1. move right  
2. move up  
3. move down  
4. scale bigger  
5. scale smaller  
6. increase width (fatter)  
7. increase height (taller)  
8. trigger (terminate episode)  


In [None]:
def apply_action(box, action, W, H):
    """
    Apply one of the 8 transformation actions to the box.
    The 9th action (trigger) is handled in the environment.
    """
    new_box = box.copy()

    dx = int(0.2 * box.width())
    dy = int(0.2 * box.height())

    # 0: left
    if action == 0:
        new_box.x1 -= dx
        new_box.x2 -= dx

    # 1: right
    elif action == 1:
        new_box.x1 += dx
        new_box.x2 += dx

    # 2: up
    elif action == 2:
        new_box.y1 -= dy
        new_box.y2 -= dy

    # 3: down
    elif action == 3:
        new_box.y1 += dy
        new_box.y2 += dy

    # 4: scale bigger
    elif action == 4:
        new_box.x1 -= dx
        new_box.y1 -= dy
        new_box.x2 += dx
        new_box.y2 += dy

    # 5: scale smaller
    elif action == 5:
        new_box.x1 += dx
        new_box.y1 += dy
        new_box.x2 -= dx
        new_box.y2 -= dy

    # 6: fatter
    elif action == 6:
        new_box.x1 -= dx
        new_box.x2 += dx

    # 7: taller
    elif action == 7:
        new_box.y1 -= dy
        new_box.y2 += dy

    # clip to image bounds
    new_box.clip(W, H)
    return new_box


## Reward Function (Paper-Accurate)

Between non-trigger actions:

- +1  if IoU(new) > IoU(old)
- −1  otherwise

For trigger action:

- +3  if IoU(new) ≥ 0.6
- −3  otherwise

This matches the ICCV 2015 methodology exactly.


# **Section 5 — RL Environment (TFDSVOCEnv)**

This is a faithful reproduction of the environment from:

**Caicedo & Lazebnik — Active Object Localization (ICCV 2015)**

Key features:

- The agent starts with the full image as the initial box.
- The agent performs **9 actions**:
  - 8 transformations
  - 1 trigger
- The agent receives:
  - CNN features of the cropped region
  - A 10-step action history (one-hot)
- Training episodes:
  - Terminate on trigger or max steps
- IoR (Inhibition of Return) is supported for inference.


In [None]:
class TFDSVOCEnv(gym.Env):
    """
    Paper-accurate RL environment for VOC object localization.
    - State = CNN features + action history
    - Reward = IoU-based (from ICCV 2015 paper)
    - Actions = 9 discrete actions
    """

    metadata = {"render.modes": ["human"]}

    def __init__(self, ds, class_name, feature_extractor, max_steps=40, device="cpu"):
        super().__init__()

        self.ds = ds                       # dataset list
        self.class_name = class_name       # target VOC class
        self.fe = feature_extractor        # CNN feature extractor
        self.device = device
        self.max_steps = max_steps

        # action history parameters
        self.hist_len = 10                 # last 10 actions
        self.num_actions = 9               # 8 transforms + trigger
        self.hist_dim = self.hist_len * self.num_actions

        # Final state dimension: CNN features + history
        self.state_dim = self.fe.output_dim + self.hist_dim

        # Gym spaces
        self.action_space = gym.spaces.Discrete(self.num_actions)
        self.observation_space = gym.spaces.Box(
            low=-np.inf, high=np.inf, shape=(self.state_dim,), dtype=np.float32
        )

    # -----------------------------------------------------
    def reset(self, seed=None, options=None):
      super().reset(seed=seed)

      while True:
          sample = random.choice(self.ds)

          # Store the original TFDS sample for potential visualization
          self._current_sample = sample

          # ---- CRITICAL FIX: convert all TF tensors to numpy ----
          img = np.array(sample["image"])
          labels = np.array(sample["objects"]["label"])
          bboxes = np.array(sample["objects"]["bbox"])

          H, W = img.shape[:2]

          boxes = []
          for label, bbox in zip(labels, bboxes):
              class_name = VOC_CLASSES[int(label)]
              if class_name == self.class_name:
                  ymin, xmin, ymax, xmax = [float(x) for x in bbox]
                  x1 = int(xmin * W)
                  y1 = int(ymin * H)
                  x2 = int(xmax * W)
                  y2 = int(ymax * H)
                  boxes.append(Box(x1, y1, x2, y2))

          if len(boxes) > 0:
              break

      # Assign fixed numpy image
      self.image = img
      self.gt_boxes = boxes
      self.gt_box = random.choice(boxes)
      self.H, self.W = H, W

      self.ior_mask = np.zeros((self.H, self.W), dtype=np.uint8)

      self.box = Box(0, 0, self.W, self.H)
      self.history = []
      self.steps = 0

      return self._get_state(), {}

    # -----------------------------------------------------
    def step(self, action):
        """
        Apply transformation or trigger.
        Paper reward:
        - +1 if IoU improves
        - -1 otherwise
        - Trigger: +3 if IoU >= 0.6, else -3
        """
        old_iou = iou(self.box, self.gt_box)

        # Apply transform actions (0-7)
        if action != 8:
            self.box = apply_action(self.box, action, self.W, self.H)

        # Compute new IoU
        new_iou = iou(self.box, self.gt_box)

        # Trigger (action 8)
        if action == 8:
            reward = 3 if new_iou >= 0.6 else -3
            terminated = True

            # IoR update for multi-detection inference
            if new_iou >= 0.6:
                x1, y1, x2, y2 = self.box.as_int()
                self.ior_mask[y1:y2, x1:x2] = 1

        else:
            # Non-trigger reward
            reward = 1 if new_iou > old_iou else -1
            terminated = False

        self.steps += 1
        truncated = self.steps >= self.max_steps

        # Update action history
        self.history.append(action)
        if len(self.history) > self.hist_len:
            self.history.pop(0)

        return self._get_state(), reward, terminated, truncated, {}

    # -----------------------------------------------------
    def _get_state(self):
        """
        State representation = CNN features + one-hot history.
        """
        x1, y1, x2, y2 = self.box.as_int()

        # Crop region corresponding to current box
        crop = self.image[y1:y2, x1:x2]
        if crop.size == 0:
            crop = np.zeros((224, 224, 3), dtype=np.uint8)

        # Apply IoR mask (in inference only)
        crop_mask = self.ior_mask[y1:y2, x1:x2]
        if crop_mask.size > 0:
            crop = crop.copy()
            crop[crop_mask == 1] = crop[crop_mask == 1] * 0.2

        # CNN feature extraction
        cnn_feat = self.fe(crop)

        # Action history one-hot encoding
        hist = np.zeros(self.hist_dim)
        for i, act in enumerate(self.history):
            hist[i * self.num_actions + act] = 1

        # Concatenate final state
        state = np.concatenate([cnn_feat, hist])
        return state.astype(np.float32)

    # -----------------------------------------------------
    def reset_search_only(self):
        """
        Reset the agent's search box without changing the image or GT.
        Used for multi-object detection (IoR).
        """
        self.box = Box(0, 0, self.W, self.H)
        self.history = []
        self.steps = 0
        return self._get_state(), {}

# **Section 6 — Training the DQN Agent**

We train a **class-specific DQN** using Stable-Baselines3.

Paper methodology:
- One DQN per object class
- CNN feature extractor is fixed (not trained)
- RL agent learns bounding box transformations
- Reward is based on improvement in IoU

For this demo:
- We train for **30,000 steps**
- Later, you can increase to 200k–1M steps for full training


In [None]:
def make_env(class_name, split="train"):
    """
    Factory that creates a single CUBActiveEnv instance for a given class.
    Compatible with Stable-Baselines3 VecEnvs.
    """
    def _init():
        # choose dataset by split
        ds = ds_train if split == "train" else ds_test

        # new FeatureExtractor (already defined earlier)
        fe = FeatureExtractor(cnn_type=cnn_type, device=device)

        # CUBActiveEnv (already defined earlier)
        env = TFDSVOCEnv(
            ds=ds,
            class_name=class_name,
            feature_extractor=fe,
            max_steps=40,
            device=device,
        )
        return env

    return _init

In [None]:
def create_training_env(class_name):
    env_fn = make_env(class_name, split="train")
    vec_env = DummyVecEnv([env_fn])
    return vec_env


We use DQN hyperparameters approximating the original setup:

- learning_rate = 1e-4  
- gamma = 0.9  
- exploration_fraction = 0.2  
- exploration_final_eps = 0.1  
- buffer_size = 50k  
- batch_size = 32  

These values reproduce the behavior of the ICCV 2015 agent.


In [None]:
CLASS_TO_TRAIN = "person"   # choose VOC class

env_train = make_env(CLASS_TO_TRAIN)()  # Raw env ONLY
model = DQN(
    "MlpPolicy",
    env_train,
    learning_rate=1e-4,
    buffer_size=50000,
    batch_size=32,
    gamma=0.9,
    exploration_fraction=0.2,
    exploration_final_eps=0.1,
    verbose=1,
    tensorboard_log="./tb_logs/",
)


In [None]:
TRAIN_STEPS = 30000  # increase later if needed
model.learn(total_timesteps=TRAIN_STEPS)


In [None]:
model_path = f"results/dqn_{CLASS_TO_TRAIN}_{cnn_type}_30k.zip"
model.save(model_path)
print("Model saved:", model_path)


In [None]:
model = DQN.load(model_path)
print("Model loaded!")


In [None]:
test_env = make_env(CLASS_TO_TRAIN, split="test")()
obs, _ = test_env.reset()
done = False
total_reward = 0

while not done:
    action, _ = model.predict(obs, deterministic=True)
    obs, reward, terminated, truncated, info = test_env.step(action)
    done = terminated or truncated
    total_reward += reward

print("Test episode reward:", total_reward)


# **Section 7 — Episode Visualization Tools**

These tools allow us to inspect:

- How the bounding box moves during an episode  
- Whether IoU consistently improves  
- How many steps the agent takes  
- The quality of the final localization  

We provide:

### 1. `run_episode_collect_data()`
Collects all predicted boxes and IoUs.

### 2. `draw_box()`
Draws a labeled bounding box onto an image.

### 3. `visualize_episode()`
Plots:
- Agent trajectory  
- IoU curve  
- Final detection overlay  

These visualizations help confirm that training is behaving correctly.


In [None]:
def draw_box(image, box, color=(0,255,0), label=None, thickness=2):
    img = image.copy()
    x1, y1, x2, y2 = box.as_int()
    cv2.rectangle(img, (x1,y1), (x2,y2), color, thickness)

    if label is not None:
        cv2.putText(img, label, (x1, max(0, y1-5)),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
    return img


In [None]:
def run_episode_collect_data(env, model):
    obs, _ = env.reset()
    done = False

    boxes = []
    ious = []
    actions = []

    # initial box
    boxes.append(env.box.copy())
    ious.append(iou(env.box, env.gt_box))

    while not done:
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated

        boxes.append(env.box.copy())
        ious.append(iou(env.box, env.gt_box))
        actions.append(action)

    return boxes, ious, actions


In [None]:
def visualize_episode(sample, boxes, ious, class_name):
    """
    Visualize:
    - GT boxes only for class_name
    - Final prediction (if IoU >= 0.6)
    """
    img = sample["image"].copy()
    H, W = img.shape[:2]

    # ---- Safety: empty ious or boxes ----
    if ious is None or len(ious) == 0 or boxes is None or len(boxes) == 0:
        final_iou = 0.0
        final_box = None
    else:
        final_box = boxes[-1]
        final_iou = ious[-1]

    fig, ax = plt.subplots(1, 1, figsize=(8, 6))

    # ----- Draw GT only for target class -----
    for label, bbox in zip(sample["objects"]["label"], sample["objects"]["bbox"]):
        cls = VOC_CLASSES[int(label)]
        if cls == class_name:
            x1, y1, x2, y2 = tfds_box_to_pixel(bbox, img.shape)
            img = draw_box(img, Box(x1, y1, x2, y2),
                           color=(255, 0, 0),
                           label=f"GT: {cls}")

    # ----- Threshold logic for final prediction -----
    if final_iou >= 0.6 and final_box is not None:
        img = draw_box(
            img,
            final_box,
            color=(0,255,0),
            label=f"Pred: {class_name}, IoU={final_iou:.2f}"
        )
        title = f"Final Prediction: {class_name}"
    else:
        title = f"No {class_name} found (IoU={final_iou:.2f} < 0.6)"

    # ----- Show figure -----
    ax.imshow(img)
    ax.set_title(title)
    ax.axis("off")
    plt.show()


In [None]:
env = make_env("person")()
obs, _ = env.reset()

print(type(env.image))
print(type(obs))


In [None]:
# Create a fresh test environment
env_test = TFDSVOCEnv(
    ds=ds_test,
    class_name=CLASS_TO_TRAIN,
    feature_extractor=FeatureExtractor(cnn_type=cnn_type, device=device),
    max_steps=40,
    device=device
)

# Run one episode. The environment will be reset internally by run_episode_collect_data,
# and env_test will retain the state of the episode that just ran.
boxes, ious, actions = run_episode_collect_data(env_test, model)

# Visualize
# Convert the sample stored in env_test to a numpy-safe format for visualization
sample_np = tfds_to_numpy(env_test._current_sample)

# Visualize
visualize_episode(sample_np, boxes, ious, CLASS_TO_TRAIN)

In [None]:
def sample_contains_class(sample, class_name):
    labels = sample["objects"]["label"]
    for l in labels:
        if VOC_CLASSES[int(l)] == class_name:
            return True
    return False


In [None]:
import random

def get_mixed_test_samples(env, class_name, total=10):
    positives = []
    negatives = []

    # First collect positive and negative pools
    for i in range(len(env.ds)):
        sample = env.ds[i]

        if sample_contains_class(sample, class_name):
            positives.append(sample)
        else:
            negatives.append(sample)

    # Ensure random mixing
    pos_samples = random.sample(positives, k=min(5, len(positives)))
    neg_samples = random.sample(negatives, k=min(5, len(negatives)))

    combined = pos_samples + neg_samples
    random.shuffle(combined)

    return combined[:total]


In [None]:
def evaluate_10_images(env, model, class_name):
    samples = get_mixed_test_samples(env, class_name)

    for idx, sample in enumerate(samples):
        print(f"\n============== IMAGE {idx+1} / 10 ==============")
        sample_np = tfds_to_numpy(sample)
        # GT status
        has_class = sample_contains_class(sample, class_name)
        print(f"GT contains {class_name}: {has_class}")

        # Run RL episode on this sample
        boxes, ious, actions = run_episode_collect_data(env, model)

        # Visualize
        visualize_episode(sample_np, boxes, ious, class_name)


In [None]:
evaluate_10_images(env_test, model, CLASS_TO_TRAIN)


# **Section 8 — GIF + MP4 Animations (Episode Visualization)**

This section generates:

### • GIF Animation  
### • MP4 Video

Each frame displays:
- Current predicted bounding box (green)
- Ground truth box (red)
- Step number overlay
- Optional IoU value

GIFs are ideal for documentation; MP4 is better for presentations and playback.

We implement:

1. `generate_episode_frames()`
2. `save_gif(frames, filename)`
3. `save_mp4(frames, filename)`
4. `animate_episode(env, model)`

All animations are stored under:
./animations/

In [None]:
def generate_episode_frames(env, model, class_name):
    """
    Run one episode and collect rendered frames.
    Returns:
      - frames: list of RGB numpy arrays
    """
    obs, _ = env.reset()
    done = False

    frames = []

    step = 0

    # save initial frame
    frame = draw_box(env.image, env.box, color=(0,255,0),
                     label=f"step {step}")
    # draw GT in red
    x1, y1, x2, y2 = env.gt_box.as_int()
    frame = draw_box(frame, env.gt_box, color=(255,0,0), label="GT")
    frames.append(frame)

    while not done:
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        step += 1

        frame = draw_box(env.image, env.box, color=(0,255,0),
                         label=f"step {step}")
        frame = draw_box(frame, env.gt_box, color=(255,0,0), label="GT")
        frames.append(frame)

    return frames


In [None]:
def save_gif(frames, filename, fps=4):
    """
    Save episode frames as animated GIF.
    """
    path = os.path.join("animations", filename)
    os.makedirs("animations", exist_ok=True)
    imageio.mimsave(path, frames, fps=fps)
    print("Saved GIF:", path)


In [None]:
def save_mp4(frames, filename, fps=8):
    """
    Save episode frames as MP4 video.
    """
    path = os.path.join("animations", filename)
    os.makedirs("animations", exist_ok=True)

    height, width, _ = frames[0].shape
    writer = cv2.VideoWriter(
        path,
        cv2.VideoWriter_fourcc(*"mp4v"),
        fps,
        (width, height)
    )

    for frame in frames:
        writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))

    writer.release()
    print("Saved MP4:", path)


In [None]:
def animate_episode(env, model, class_name):
    frames = generate_episode_frames(env, model, class_name)

    gif_name = f"episode_{class_name}.gif"
    mp4_name = f"episode_{class_name}.mp4"

    save_gif(frames, gif_name)
    save_mp4(frames, mp4_name)

    return frames


In [None]:
env_test = TFDSVOCEnv(
    ds=ds_test,
    class_name=CLASS_TO_TRAIN,
    feature_extractor=FeatureExtractor(cnn_type=cnn_type, device=device),
    max_steps=40,
    device=device
)

frames = animate_episode(env_test, model, CLASS_TO_TRAIN)

# Display GIF inline
IPyImage(filename=f"animations/episode_{CLASS_TO_TRAIN}.gif")


# **Section 9 — Attention Heatmaps (Episode + Dataset-Level)**

We visualize the agent's attention using heatmaps:

### Episode Heatmap
- Tracks which image regions the agent cropped throughout an episode.
- Each step updates the heatmap by marking the bounding box region.

### Dataset Heatmap
- Aggregates heatmaps across multiple test images.
- Reveals general search patterns for a class.

These maps help diagnose agent behavior and verify improved policies.


In [None]:
def update_heatmap(heatmap, box):
    x1, y1, x2, y2 = box.as_int()
    heatmap[y1:y2, x1:x2] += 1


In [None]:
def run_episode_with_heatmap(env, model):
    obs, _ = env.reset()
    done = False

    H, W = env.H, env.W
    heatmap = np.zeros((H, W), dtype=np.float32)

    boxes = []
    ious = []

    # Record the initial box
    boxes.append(env.box.copy())
    ious.append(iou(env.box, env.gt_box))
    update_heatmap(heatmap, env.box)

    while not done:
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated

        boxes.append(env.box.copy())
        ious.append(iou(env.box, env.gt_box))
        update_heatmap(heatmap, env.box)

    return boxes, ious, heatmap


In [None]:
import numpy as np
import cv2

def compute_heatmap_for_sample(sample, boxes, img_shape):
    H, W = img_shape[:2]
    heat = np.zeros((H, W), dtype=np.float32)

    for box in boxes:
        x1, y1, x2, y2 = map(int, [box.x1, box.y1, box.x2, box.y2])
        heat[y1:y2, x1:x2] += 1  # visitation count

    # Normalize 0–1
    heat -= heat.min()
    if heat.max() > 0:
        heat /= heat.max()

    return heat


In [None]:
def visualize_heatmap(image, heatmap, title="Attention Heatmap"):
    # Normalize to 0–255
    h = heatmap.copy()
    if h.max() > 0:
        h = h / h.max()

    h = (h * 255).astype(np.uint8)

    # Apply JET colormap
    heat = cv2.applyColorMap(h, cv2.COLORMAP_JET)
    heat = cv2.cvtColor(heat, cv2.COLOR_BGR2RGB)

    # Overlay heatmap on original
    overlay = cv2.addWeighted(image, 0.6, heat, 0.4, 0)

    plt.figure(figsize=(7,7))
    plt.imshow(overlay)
    plt.title(title)
    plt.axis("off")
    plt.show()


In [None]:
# Create test environment
env_test = TFDSVOCEnv(
    ds=ds_test,
    class_name=CLASS_TO_TRAIN,
    feature_extractor=FeatureExtractor(cnn_type=cnn_type, device=device),
    max_steps=40,
    device=device
)

boxes, ious, heatmap = run_episode_with_heatmap(env_test, model)

visualize_heatmap(env_test.image, heatmap, title=f"Episode Heatmap ({CLASS_TO_TRAIN})")


# **Section 10 — Quantitative Evaluation**

We evaluate trained DQN agents on VOC 2007 test set.

For each object class:

1. Sample test images containing that class  
2. Run the RL agent until trigger or timeout  
3. Record IoU between final predicted box and closest GT box  
4. Compute:
   - Mean IoU
   - Median IoU
   - Accuracy@0.5
   - Accuracy@0.6
   - Accuracy@0.7

Finally:
- Aggregate results across all 20 VOC classes  
- Produce a results pandas DataFrame  
- Optionally export to CSV  


In [None]:
def evaluate_single_episode(env, model):
    """
    Runs a single episode and returns:
    - predicted box
    - best-matching GT box
    - IoU value
    """
    obs, _ = env.reset()
    done = False

    while not done:
        action, _ = model.predict(obs, deterministic=True)
        obs, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated

    pred_box = env.box
    gt_box = env.gt_box
    iou_val = iou(pred_box, gt_box)
    return pred_box, gt_box, iou_val


In [None]:
def evaluate_class(model, class_name, num_samples=50):
    ious = []

    for _ in range(num_samples):

        # find image containing class
        while True:
            sample = random.choice(ds_test)
            img = sample["image"]
            gt_boxes = get_all_gt_boxes_for_class(sample, class_name, img.shape[:2])
            if len(gt_boxes) > 0:
                break

        # build environment for this image
        env = TFDSVOCEnv(
            ds=[sample],
            class_name=class_name,
            feature_extractor=FeatureExtractor(cnn_type=cnn_type, device=device),
            max_steps=40,
            device=device
        )

        _, _, iou_val = evaluate_single_episode(env, model)
        ious.append(iou_val)

    ious = np.array(ious)

    results = {
        "class": class_name,
        "mean_iou": float(np.mean(ious)),
        "median_iou": float(np.median(ious)),
        "acc_50": float(np.mean(ious >= 0.5)),
        "acc_60": float(np.mean(ious >= 0.6)),
        "acc_70": float(np.mean(ious >= 0.7)),
        "all_ious": ious,
    }

    return results


In [None]:
def evaluate_all_classes(model, num_samples=50):
    results = []

    for cls in VOC_CLASSES:
        print(f"Evaluating: {cls}")
        res = evaluate_class(model, cls, num_samples=num_samples)
        results.append(res)

    return results


In [None]:
import pandas as pd

def results_to_dataframe(results):
    df = pd.DataFrame([{
        "class": r["class"],
        "mean_iou": r["mean_iou"],
        "median_iou": r["median_iou"],
        "acc_50": r["acc_50"],
        "acc_60": r["acc_60"],
        "acc_70": r["acc_70"]
    } for r in results])

    return df


In [None]:
def save_results_csv(df, filename="evaluation_results.csv"):
    path = os.path.join("results", filename)
    df.to_csv(path, index=False)
    print("Saved results to:", path)


In [None]:
# Evaluate all classes (reduce num_samples if this is too slow)
results = evaluate_all_classes(model, num_samples=30)
df_results = results_to_dataframe(results)

df_results


In [None]:
save_results_csv(df_results, "voc2007_dqn_eval.csv")
