# Baseline CAM Inference and Evaluation

In [None]:
%cd /content/ai-competition-baselines
!python -m src.stage3_localization.weak_cam.infer \
  --config configs/stage3_weak_cam.yaml \
  --data_root "$DATA_ROOT"

In [None]:
%cd /content/ai-competition-baselines
!python -m src.stage3_localization.weak_cam.eval_pointing_game \
  --config configs/stage3_weak_cam.yaml \
  --data_root "$DATA_ROOT" \
  --radius 15

# Custom Stage 3 CAM Implementation

In [None]:
# CELL 4: Prepare Pointing Game Annotations CSV

import os
import pandas as pd
import cv2
import yaml
import shutil


ROBOFLOW_ROOT = os.path.join(
    "/content/drive/MyDrive/stage3_cam",
    dataset.location
)

OUTPUT_ROOT = "/content/drive/MyDrive/stage3_cam/data_stage3"

IMG_DST = os.path.join(OUTPUT_ROOT, "images")
ANN_DST = os.path.join(OUTPUT_ROOT, "annotations")

os.makedirs(IMG_DST, exist_ok=True)
os.makedirs(ANN_DST, exist_ok=True)

data_yaml_path = os.path.join(ROBOFLOW_ROOT, "data.yaml")
with open(data_yaml_path, "r") as f:
    yolo_data = yaml.safe_load(f)

class_names = yolo_data["names"]
print(f"Classes: {class_names}")

val_img_src = os.path.join(ROBOFLOW_ROOT, "valid", "images")
val_label_src = os.path.join(ROBOFLOW_ROOT, "valid", "labels")

rows = []

img_files = [
    f for f in os.listdir(val_img_src)
    if f.lower().endswith((".jpg", ".jpeg", ".png"))
]

for img_file in img_files:
    src_img_path = os.path.join(val_img_src, img_file)
    dst_img_path = os.path.join(IMG_DST, img_file)

    shutil.copy2(src_img_path, dst_img_path)

    label_file = os.path.splitext(img_file)[0] + ".txt"
    label_path = os.path.join(val_label_src, label_file)

    if not os.path.exists(label_path):
        continue

    img = cv2.imread(src_img_path)
    if img is None:
        continue

    img_h, img_w = img.shape[:2]

    with open(label_path, "r") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) < 5:
                continue

            class_id = int(parts[0])
            x_center_norm = float(parts[1])
            y_center_norm = float(parts[2])

            x = x_center_norm * img_w
            y = y_center_norm * img_h

            rows.append({
                "image_id": img_file,
                "class_id": class_id,
                "class_name": class_names[class_id],
                "x": x,
                "y": y
            })

points_df = pd.DataFrame(rows)
csv_path = os.path.join(ANN_DST, "pointing_game.csv")
points_df.to_csv(csv_path, index=False)

print(f"\nPointing game CSV created")
print(f"Images processed : {len(img_files)}")
print(f"Total points     : {len(points_df)}")
print(f"Saved to         : {csv_path}")

points_df.head()

In [None]:
# CELL 11: Model Definitions

%%writefile /content/drive/MyDrive/stage3_cam/project_stage3/src/stage1_binary/model.py
import timm
import torch.nn as nn

def build_model(backbone: str, num_classes: int = 2, pretrained: bool = True, dropout: float = 0.2) -> nn.Module:
    return timm.create_model(backbone, pretrained=pretrained, num_classes=num_classes, drop_rate=dropout)

In [None]:
%%writefile /content/drive/MyDrive/stage3_cam/project_stage3/src/stage2_multilabel/model.py
import timm
import torch.nn as nn

def build_model(backbone: str, num_labels: int = 3, pretrained: bool = True, dropout: float = 0.2) -> nn.Module:
    return timm.create_model(backbone, pretrained=pretrained, num_classes=num_labels, drop_rate=dropout)

In [None]:
# CELL 12: CAM Utilities (Revised)
%%writefile /content/drive/MyDrive/stage3_cam/project_stage3/src/stage3_localization/weak_cam/cam.py
from dataclasses import dataclass
import torch
import torch.nn as nn

@dataclass
class CamTargetConfig:
    task: str
    label_index: int = 0
    class_index: int = 1

class MultiLabelOutputTarget:
    def __init__(self, label_index: int):
        self.label_index = int(label_index)

    def __call__(self, model_output: torch.Tensor) -> torch.Tensor:
        """
        Handles single-image (1D) or batch (2D) outputs for multi-label tasks.
        """
        if model_output.dim() == 1:  # single image
            return model_output[self.label_index]
        elif model_output.dim() == 2:  # batch
            return model_output[:, self.label_index]
        else:
            raise ValueError(f"Unexpected model output shape: {model_output.shape}")

def find_last_conv_layer(model: nn.Module) -> nn.Module:
    """
    Finds the last Conv2d layer in a model.
    """
    last_conv = None
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            last_conv = m
    if last_conv is None:
        raise RuntimeError("No Conv2d layer found")
    return last_conv

def build_cam_target(cfg: CamTargetConfig):
    """
    Builds the appropriate GradCAM target object depending on the task.
    """
    if cfg.task == "stage1_binary":
        from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
        return ClassifierOutputTarget(int(cfg.class_index))
    if cfg.task == "stage2_multilabel":
        return MultiLabelOutputTarget(int(cfg.label_index))
    raise ValueError(f"Unknown task: {cfg.task}")

In [None]:
# CELL: Full Stage 3 CAM Inference (Stage 1 + Stage 2)
%%writefile /content/drive/MyDrive/stage3_cam/project_stage3/src/stage3_localization/weak_cam/infer.py
import os
import argparse
import yaml
import cv2
import numpy as np
import torch
from tqdm import tqdm

from src.common.io import ensure_dir, list_images_from_dir
from src.common.utils import get_device
from src.common.transforms import get_eval_transforms

from src.stage1_binary.model import build_model as build_stage1_model
from src.stage2_multilabel.model import build_model as build_stage2_model
from src.stage3_localization.weak_cam.cam import CamTargetConfig, build_cam_target, find_last_conv_layer
from pytorch_grad_cam import GradCAM

# ---------------------------
# Helper Functions
# ---------------------------
def load_yaml(path: str):
    with open(path, "r") as f:
        return yaml.safe_load(f)

def load_model(task: str, stage_cfg_path: str, ckpt_path: str, device):
    cfg = load_yaml(stage_cfg_path)
    if task == "stage1_binary":
        model = build_stage1_model(
            backbone=cfg["model"]["backbone"],
            num_classes=int(cfg["model"]["num_classes"]),
            pretrained=False,
            dropout=float(cfg["model"]["dropout"])
        )
    elif task == "stage2_multilabel":
        model = build_stage2_model(
            backbone=cfg["model"]["backbone"],
            num_labels=int(cfg["model"]["num_labels"]),
            pretrained=False,
            dropout=float(cfg["model"]["dropout"])
        )
    else:
        raise ValueError(f"Unknown task: {task}")

    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt["model"])
    model.to(device)
    model.eval()
    return model

def overlay_cam(img_rgb, cam_mask, alpha=0.45):
    heatmap = (cam_mask * 255.0).astype(np.uint8)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    heatmap_rgb = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    return (alpha * heatmap_rgb + (1 - alpha) * img_rgb).astype(np.uint8)

# ---------------------------
# Main Inference
# ---------------------------
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", required=True, help="Path to stage3_weak_cam.yaml")
    args = parser.parse_args()

    cfg = load_yaml(args.config)
    device = get_device()

    # Direct paths from YAML
    img_dir = cfg["data"]["img_dir"]
    out_dir = cfg["output"]["out_dir"]

    heat_dir_stage1 = os.path.join(out_dir, "heatmaps_stage1")
    heat_dir_stage2 = os.path.join(out_dir, "heatmaps_stage2")
    ov_dir_stage1 = os.path.join(out_dir, "overlays_stage1")
    ov_dir_stage2 = os.path.join(out_dir, "overlays_stage2")

    for d in [heat_dir_stage1, heat_dir_stage2, ov_dir_stage1, ov_dir_stage2]:
        ensure_dir(d)

    # ---------------------------
    # Load Stage 1 & Stage 2 models
    # ---------------------------
    stage1_cfg = cfg["model"]["stage1"]
    stage2_cfg = cfg["model"]["stage2"]

    model_stage1 = load_model(stage1_cfg["task"], stage1_cfg["config_path"], stage1_cfg["ckpt_path"], device)
    model_stage2 = load_model(stage2_cfg["task"], stage2_cfg["config_path"], stage2_cfg["ckpt_path"], device)

    cam_stage1 = GradCAM(model=model_stage1, target_layers=[find_last_conv_layer(model_stage1)])
    cam_stage2 = GradCAM(model=model_stage2, target_layers=[find_last_conv_layer(model_stage2)])

    # Transforms
    tf = get_eval_transforms(int(cfg["cam"]["input_size"]))

    # Targets
    target_stage1 = build_cam_target(CamTargetConfig(
        task=stage1_cfg["task"],
        class_index=int(cfg["target"]["class_index"])
    ))

    target_stage2 = build_cam_target(CamTargetConfig(
        task=stage2_cfg["task"],
        class_index=int(cfg["target"]["class_index"])
    ))

    # Images
    images = list_images_from_dir(img_dir, ".jpg")
    images = images[:cfg["output"].get("max_images", 9999)]

    for img_id in tqdm(images, desc="Stage3 CAM"):
        img_path = os.path.join(img_dir, img_id)
        img_bgr = cv2.imread(img_path)
        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        x = tf(image=img_rgb)["image"].unsqueeze(0).to(device)

        # ---------------------------
        # Stage 1 CAM
        # ---------------------------
        cam_mask1 = cam_stage1(input_tensor=x, targets=[target_stage1])[0]

        if cfg["output"]["save_heatmaps"]:
            heat1 = (cam_mask1 * 255.0).astype(np.uint8)
            cv2.imwrite(os.path.join(heat_dir_stage1, img_id.replace(".jpg","_heat.png")), heat1)

        if cfg["output"]["save_overlays"]:
            cam_resized1 = cv2.resize(cam_mask1, (img_rgb.shape[1], img_rgb.shape[0]))
            ov1 = overlay_cam(img_rgb, cam_resized1, float(cfg["output"].get("overlay_alpha",0.45)))
            cv2.imwrite(os.path.join(ov_dir_stage1, img_id.replace(".jpg","_ov.jpg")),
                        cv2.cvtColor(ov1, cv2.COLOR_RGB2BGR))

        # ---------------------------
        # Stage 2 CAM
        # ---------------------------
        cam_mask2 = cam_stage2(input_tensor=x, targets=[target_stage2])[0]

        if cfg["output"]["save_heatmaps"]:
            heat2 = (cam_mask2 * 255.0).astype(np.uint8)
            cv2.imwrite(os.path.join(heat_dir_stage2, img_id.replace(".jpg","_heat.png")), heat2)

        if cfg["output"]["save_overlays"]:
            cam_resized2 = cv2.resize(cam_mask2, (img_rgb.shape[1], img_rgb.shape[0]))
            ov2 = overlay_cam(img_rgb, cam_resized2, float(cfg["output"].get("overlay_alpha",0.45)))
            cv2.imwrite(os.path.join(ov_dir_stage2, img_id.replace(".jpg","_ov.jpg")),
                        cv2.cvtColor(ov2, cv2.COLOR_RGB2BGR))

    print(f"âœ… Stage 3 CAM outputs saved to {out_dir}")

if __name__ == "__main__":
    main()

In [None]:
# CELL 18: Run CAM Inference

%cd /content/drive/MyDrive/stage3_cam/project_stage3

!PYTHONPATH=$PWD python -m src.stage3_localization.weak_cam.infer \
    --config configs/stage3_weak_cam.yaml

In [None]:
# CELL 16: Copy Stage 1 Config to Drive

%%writefile /content/drive/MyDrive/stage1_results/stage1_binary.yaml
data:
  train_csv: train.csv
  val_csv: val.csv
  img_dir_train: /content/data_stage1/images/train
  img_dir_val: /content/data_stage1/images/val
  image_id_col: image_id
  label_col: label

model:
  backbone: resnet50
  num_classes: 2
  img_size: 224
  pretrained: true
  dropout: 0.2

train:
  epochs: 10
  batch_size: 32
  lr: 0.0001
  weight_decay: 0.0001
  num_workers: 2
  seed: 42
  mixed_precision: true

augment:
  enabled: true
  preset: stage1_safe

eval:
  primary_metric: macro_f1

output:
  best_ckpt_name: best_stage1.pt

In [None]:
%%writefile /content/drive/MyDrive/stage3_cam/project_stage3/configs/stage3_weak_cam_eval_stage1.yaml
cam:
  input_size: 224
  target_layer: null

data:
  img_dir: /content/drive/MyDrive/stage3_cam/data_stage3/images
  points_csv: /content/drive/MyDrive/stage3_cam/data_stage3/annotations/pointing_game.csv
  image_id_col: image_id
  x_col: x
  y_col: y

model:
  task: stage1_binary
  config_path: /content/drive/MyDrive/stage1_results/stage1_binary.yaml
  ckpt_path: /content/drive/MyDrive/project/runs/stage1_binary/best_stage1.pt

target:
  class_index: 1
  label_index: 0

output:
  out_dir: /content/drive/MyDrive/stage3_cam/project_stage3/outputs/cam_results
  save_heatmaps: true
  save_overlays: true
  overlay_alpha: 0.45
  max_images: 50

In [None]:
%%writefile /content/drive/MyDrive/stage3_cam/project_stage3/configs/stage3_weak_cam_eval_stage2.yaml
cam:
  input_size: 224
  target_layer: null

data:
  img_dir: /content/drive/MyDrive/stage3_cam/data_stage3/images
  points_csv: /content/drive/MyDrive/stage3_cam/data_stage3/annotations/pointing_game.csv
  image_id_col: image_id
  x_col: x
  y_col: y

model:
  task: stage2_multilabel
  config_path: /content/drive/MyDrive/project_stage2/configs/stage2_multilabel.yaml
  ckpt_path: /content/drive/MyDrive/project_stage2/runs/stage2_multilabel/best_stage2.pt

target:
  class_index: 1
  label_index: 0

output:
  out_dir: /content/drive/MyDrive/stage3_cam/project_stage3/outputs/cam_results
  save_heatmaps: true
  save_overlays: true
  overlay_alpha: 0.45
  max_images: 50

In [None]:
# Stage 1
!PYTHONPATH=$PWD python -m src.stage3_localization.weak_cam.eval_pointing_game \
    --config configs/stage3_weak_cam_eval_stage1.yaml \
    --data_root /content/drive/MyDrive/stage3_cam/data_stage3 \
    --radius 15

# Stage 2
!PYTHONPATH=$PWD python -m src.stage3_localization.weak_cam.eval_pointing_game \
    --config configs/stage3_weak_cam_eval_stage2.yaml \
    --data_root /content/drive/MyDrive/stage3_cam/data_stage3 \
    --radius 15

In [None]:
# CELL 20: Visualize Stage-specific CAM Results with Pointing Game Points

import matplotlib.pyplot as plt
from PIL import Image
import os
import pandas as pd


stage = "stage1"

output_dir = "/content/drive/MyDrive/stage3_cam/project_stage3/outputs/cam_results"
overlay_dir = os.path.join(output_dir, f"overlays_{stage}")
points_csv = "/content/drive/MyDrive/stage3_cam/data_stage3/annotations/pointing_game.csv"


points_df = pd.read_csv(points_csv)

if not os.path.exists(overlay_dir):
    raise FileNotFoundError(f"No overlays folder found for {stage}: {overlay_dir}")

overlay_files = [f for f in os.listdir(overlay_dir) if f.endswith('.jpg')]
if not overlay_files:
    print("No overlay images found in:", overlay_dir)
else:
    overlay_files = overlay_files[:6]

    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    axes = axes.ravel()

    for idx, img_file in enumerate(overlay_files):
        img_path = os.path.join(overlay_dir, img_file)
        img = Image.open(img_path)
        axes[idx].imshow(img)

        points = points_df[points_df['image_id'] == img_file]
        for _, row in points.iterrows():
            axes[idx].scatter(row['x'], row['y'], c='lime', s=50, marker='x', linewidths=2)

        axes[idx].set_title(f"{stage.upper()} CAM + Points: {img_file}", fontsize=10)
        axes[idx].axis('off')

    plt.tight_layout()

    save_path = os.path.join(output_dir, f"cam_visualization_{stage}.png")
    plt.savefig(save_path, dpi=150, bbox_inches='tight')
    print(f"{stage.upper()} CAM visualization saved to: {save_path}")

    plt.show()