<a href="https://colab.research.google.com/github/EliasNoorzad/XAI_Autonomous-Driving/blob/main/evaluation/05_overlays.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip -q install ultralytics==8.4.6

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m47.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)


Mounted at /content/drive


In [3]:
!cp /content/drive/MyDrive/XAI_Project/BDD100K_640.zip /content/

In [None]:
!cp /content/drive/MyDrive/XAI_Project/daynight_labels.csv /content/

In [None]:
!unzip -q /content/BDD100K_640.zip -d /content/BDD100K_640

In [None]:
yaml = """\
path: /content/BDD100K_640/yolo_640
train: train/images
val: val/images
test: test/images

nc: 5
names: [car, truck, bus, person, bike]
"""

with open("/content/BDD100K_640/yolo_640/dataset_640.yaml", "w") as f:
    f.write(yaml)

In [None]:
import csv
import numpy as np
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset


class BDDDetDrivableDataset(Dataset):
    """
    For preprocessed 640 dataset (yolo_640 + drivable_masks_640):
      images: <yolo_root>/<split>/images/<stem>.jpg
      labels: <yolo_root>/<split>/labels/<stem>.txt
      masks : <mask_root>/<split>/<stem>.png

    Returns:
      img   : FloatTensor [3, H, W] in [0,1]
      labels: FloatTensor [N, 5] where each row is [cls, x, y, w, h] (YOLO normalized)
      mask  : FloatTensor [1, H, W] with values 0/1
      dn    : LongTensor scalar (0=day, 1=night)  <-- always valid (unlabeled images are filtered out)
    """
    def __init__(
        self,
        yolo_root: str,
        mask_root: str,
        split: str,
        imgsz: int = 640,
        dn_csv_path: str | None = None,
    ):
        self.yolo_root = Path(yolo_root)
        self.mask_root = Path(mask_root)
        self.split = split
        self.imgsz = int(imgsz)

        self.img_dir = self.yolo_root / split / "images"
        self.lbl_dir = self.yolo_root / split / "labels"
        self.msk_dir = self.mask_root / split

        if not self.img_dir.is_dir():
            raise FileNotFoundError(f"Missing images dir: {self.img_dir}")
        if not self.lbl_dir.is_dir():
            raise FileNotFoundError(f"Missing labels dir: {self.lbl_dir}")
        if not self.msk_dir.is_dir():
            raise FileNotFoundError(f"Missing masks dir:  {self.msk_dir}")

        exts = {".jpg", ".jpeg", ".png"}
        self.img_paths = sorted([p for p in self.img_dir.iterdir() if p.suffix.lower() in exts])
        if len(self.img_paths) == 0:
            raise FileNotFoundError(f"No images found in: {self.img_dir}")

        # day/night mapping from CSV
        if dn_csv_path is None:
            raise RuntimeError("dn_csv_path is required for this dataset (day/night head training).")

        dn_csv_path = Path(dn_csv_path)
        if not dn_csv_path.exists():
            raise FileNotFoundError(f"Missing day/night CSV: {dn_csv_path}")

        dn_map = {}
        with open(dn_csv_path, "r", newline="", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            required = {"split", "image_id", "label"}
            if not required.issubset(set(reader.fieldnames or [])):
                raise ValueError(f"dn CSV must have columns {required}, got {reader.fieldnames}")

            for row in reader:
                if row["split"] != self.split:
                    continue

                image_id = row["image_id"].strip()   # stem (no extension)
                lab = row["label"].strip().lower()

                if lab == "day":
                    dn = 0
                elif lab == "night":
                    dn = 1
                else:
                    # if CSV contains anything else, it's a data error
                    raise ValueError(f"Invalid dn label in CSV for {image_id}: {row['label']}")

                dn_map[image_id] = dn

        self.dn_map = dn_map

        #  FILTER OUT UNLABELED IMAGES
        before = len(self.img_paths)
        self.img_paths = [p for p in self.img_paths if p.stem in self.dn_map]
        after = len(self.img_paths)
        if after == 0:
            raise RuntimeError(f"No labeled (day/night) images found for split='{self.split}'.")


    def __len__(self):
        return len(self.img_paths)

    @staticmethod
    def _read_yolo_labels(label_path: Path) -> torch.Tensor:
        if not label_path.exists():
            return torch.zeros((0, 5), dtype=torch.float32)

        rows = []
        with open(label_path, "r") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                parts = line.split()
                if len(parts) != 5:
                    continue
                cls, x, y, w, h = parts
                rows.append([float(cls), float(x), float(y), float(w), float(h)])

        if len(rows) == 0:
            return torch.zeros((0, 5), dtype=torch.float32)
        return torch.tensor(rows, dtype=torch.float32)

    @staticmethod
    def _pil_to_chw_float(img: Image.Image) -> torch.Tensor:
        arr = np.array(img, dtype=np.float32) / 255.0
        arr = np.transpose(arr, (2, 0, 1))
        return torch.from_numpy(arr)

    def __getitem__(self, idx: int):
        img_path = self.img_paths[idx]
        stem = img_path.stem

        label_path = self.lbl_dir / f"{stem}.txt"
        mask_path = self.msk_dir / f"{stem}.png"

        if not mask_path.exists():
            raise FileNotFoundError(f"Missing mask for {stem}: {mask_path}")

        img = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        labels = self._read_yolo_labels(label_path)

        if img.size != (self.imgsz, self.imgsz):
            img = img.resize((self.imgsz, self.imgsz), resample=Image.BILINEAR)
        if mask.size != (self.imgsz, self.imgsz):
            mask = mask.resize((self.imgsz, self.imgsz), resample=Image.NEAREST)

        img_t = self._pil_to_chw_float(img)
        mask_np = (np.array(mask, dtype=np.uint8) > 0).astype(np.float32)
        mask_t = torch.from_numpy(mask_np)[None, :, :]

        # dn is ALWAYS valid because we filtered img_paths
        dn = torch.tensor(self.dn_map[stem], dtype=torch.long)

        return img_t, labels, mask_t, dn


In [None]:
import torch
import torch.nn as nn


class ChannelAttention(nn.Module):
    def __init__(self, channels: int, reduction: int = 16):
        super().__init__()
        hidden = max(channels // reduction, 1)

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        # shared MLP (implemented with 1x1 convs)
        self.mlp = nn.Sequential(
            nn.Conv2d(channels, hidden, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden, channels, kernel_size=1, bias=False),
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_out = self.mlp(self.avg_pool(x))
        max_out = self.mlp(self.max_pool(x))
        w = self.sigmoid(avg_out + max_out)  # BxCx1x1
        return x * w


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size: int = 7):
        super().__init__()
        assert kernel_size in (3, 7)
        padding = (kernel_size - 1) // 2

        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
        self.last_sa = None  # <--- add this

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        mean_map = torch.mean(x, dim=1, keepdim=True)
        max_map, _ = torch.max(x, dim=1, keepdim=True)
        m = torch.cat([mean_map, max_map], dim=1)

        w = self.sigmoid(self.conv(m))  # Bx1xHxW in [0,1]
        self.last_sa = w.detach()
        return x * w


class CBAM(nn.Module):
    def __init__(self, channels: int, reduction: int = 16, spatial_kernel: int = 7):
        super().__init__()
        self.ca = ChannelAttention(channels, reduction=reduction)
        self.sa = SpatialAttention(kernel_size=spatial_kernel)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ca(x)
        x = self.sa(x)
        return x

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from ultralytics import YOLO


class YOLOv8DetSemSeg(nn.Module):
    """
    YOLOv8 detection model + tiny semantic seg head.
    Captures NECK features by hooking the Detect head INPUT (multi-scale features).
    """
    def __init__(self, yolo_weights: str = "yolov8n.pt", use_cbam: bool = False):
        super().__init__()
        self.yolo = YOLO(yolo_weights).model  # nn.Module
        self.use_cbam = use_cbam

        self.cbam_backbone = None  # Point 1 (after last backbone block)
        self.cbam_neck = None      # Point 2 (once in neck before heads)

        self._neck_feats = None

        self.sem_head = nn.Sequential(
            nn.LazyConv2d(64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=1)
        )

        self._register_backbone_hook_point1()
        self._register_detect_input_hook()

    def _register_backbone_hook_point1(self):
        # find first upsample -> previous layer output is your last backbone high-level feature
        idx_up = None
        for i, m in enumerate(self.yolo.model):
            if isinstance(m, nn.Upsample) or "upsample" in m.__class__.__name__.lower():
                idx_up = i
                break
        if idx_up is None or idx_up == 0:
            raise RuntimeError("Could not find neck start (Upsample) to place backbone CBAM.")

        backbone_last = self.yolo.model[idx_up - 1]

        if hasattr(self, "_bb_hook_handle") and self._bb_hook_handle is not None:
            self._bb_hook_handle.remove()

        def fwd_hook(module, inputs, output):
            if not self.use_cbam:
                return None
            if self.cbam_backbone is None:
                self.cbam_backbone = CBAM(channels=output.shape[1]).to(output.device)
            return self.cbam_backbone(output)

        self._bb_hook_handle = backbone_last.register_forward_hook(fwd_hook)

    def _register_detect_input_hook(self):
        if not hasattr(self.yolo, "model"):
            raise RuntimeError("Unexpected Ultralytics model: no .model")

        detect_module = self.yolo.model[-1]
        name = detect_module.__class__.__name__.lower()
        if "detect" not in name:
            raise RuntimeError(f"Last module is not Detect (got {detect_module.__class__.__name__}).")

        # remove previous hook if exists
        if hasattr(self, "_detect_hook_handle") and self._detect_hook_handle is not None:
            self._detect_hook_handle.remove()

        def pre_hook(module, inputs):
            feats = inputs[0]              # should be list of tensors
            feats = list(feats)            # FORCE list

            if self.use_cbam:
                if self.cbam_neck is None:
                    self.cbam_neck = CBAM(feats[-1].shape[1]).to(feats[-1].device)
                feats[-1] = self.cbam_neck(feats[-1])

            self._neck_feats = feats
            return (feats,) if self.use_cbam else None
          # return list inside the tuple wrapper


        self._detect_hook_handle = detect_module.register_forward_pre_hook(pre_hook)


    @staticmethod
    def _pick_high_res_from_detect_inputs(feats):
        # feats: list/tuple of [B,C,H,W]
        if not isinstance(feats, (list, tuple)) or len(feats) == 0:
            raise RuntimeError("Detect input features not captured.")
        return max(feats, key=lambda t: t.shape[-2] * t.shape[-1])  # highest H*W (usually P3)

    def forward(self, x):
      # TRAIN: x is a batch dict -> YOLO returns (det_loss, loss_items)
      if isinstance(x, dict):
          self._neck_feats = None
          imgs = x["img"]
          det_loss, det_items = self.yolo(x)

          feat = self._pick_high_res_from_detect_inputs(self._neck_feats)
          seg_logits = self.sem_head(feat)
          seg_logits = F.interpolate(seg_logits, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
          return det_loss, det_items, seg_logits

      # INFER: x is an image tensor -> YOLO returns preds
      self._neck_feats = None
      det_preds = self.yolo(x)

      feat = self._pick_high_res_from_detect_inputs(self._neck_feats)
      seg_logits = self.sem_head(feat)
      seg_logits = F.interpolate(seg_logits, size=x.shape[-2:], mode="bilinear", align_corners=False)
      return det_preds, seg_logits

Creating new Ultralytics Settings v0.0.6 file ✅ 
View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json'
Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics import YOLO


class YOLOv8DetSemSegDn(nn.Module):
    """
    YOLOv8 detection model + tiny semantic seg head.
    Captures NECK features by hooking the Detect head INPUT (multi-scale features).
    """
    def __init__(self, yolo_weights: str = "yolov8n.pt", use_cbam: bool = False):
        super().__init__()
        self.yolo = YOLO(yolo_weights).model  # nn.Module
        self.use_cbam = use_cbam

        self.cbam_backbone = None  # Point 1 (after last backbone block)
        self.cbam_neck = None      # Point 2 (CBAM on ALL neck feature maps)

        self._neck_feats = None

        self.sem_head = nn.Sequential(
            nn.LazyConv2d(64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=1)
        )

        self.dn_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # [B,C,1,1]
            nn.Flatten(1),            # [B,C]
            nn.LazyLinear(2)          # [B,2] logits: [day, night]
        )

        self._register_backbone_hook_point1()
        self._register_detect_input_hook()

    def _register_backbone_hook_point1(self):
        idx_up = None
        for i, m in enumerate(self.yolo.model):
            if isinstance(m, nn.Upsample) or "upsample" in m.__class__.__name__.lower():
                idx_up = i
                break
        if idx_up is None or idx_up == 0:
            raise RuntimeError("Could not find neck start (Upsample) to place backbone CBAM.")

        backbone_last = self.yolo.model[idx_up - 1]

        if hasattr(self, "_bb_hook_handle") and self._bb_hook_handle is not None:
            self._bb_hook_handle.remove()

        def fwd_hook(module, inputs, output):
            if not self.use_cbam:
                return None
            if self.cbam_backbone is None:
                self.cbam_backbone = CBAM(channels=output.shape[1]).to(output.device)
            return self.cbam_backbone(output)

        self._bb_hook_handle = backbone_last.register_forward_hook(fwd_hook)

    def _register_detect_input_hook(self):
        if not hasattr(self.yolo, "model"):
            raise RuntimeError("Unexpected Ultralytics model: no .model")

        detect_module = self.yolo.model[-1]
        name = detect_module.__class__.__name__.lower()
        if "detect" not in name:
            raise RuntimeError(f"Last module is not Detect (got {detect_module.__class__.__name__}).")

        if hasattr(self, "_detect_hook_handle") and self._detect_hook_handle is not None:
            self._detect_hook_handle.remove()

        def pre_hook(module, inputs):
            feats = list(inputs[0])  # list of multiscale neck features

            if self.use_cbam:
                # one CBAM per scale
                if self.cbam_neck is None:
                    self.cbam_neck = nn.ModuleList([CBAM(f.shape[1]).to(f.device) for f in feats])
                # apply to ALL scales
                feats = [m(f) for m, f in zip(self.cbam_neck, feats)]

            self._neck_feats = feats
            return (feats,) if self.use_cbam else None

        self._detect_hook_handle = detect_module.register_forward_pre_hook(pre_hook)

    @staticmethod
    def _pick_high_res_from_detect_inputs(feats):
        if not isinstance(feats, (list, tuple)) or len(feats) == 0:
            raise RuntimeError("Detect input features not captured.")
        return max(feats, key=lambda t: t.shape[-2] * t.shape[-1])

    @staticmethod
    def _pick_low_res_from_detect_inputs(feats):
        if not isinstance(feats, (list, tuple)) or len(feats) == 0:
            raise RuntimeError("Detect input features not captured.")
        return min(feats, key=lambda t: t.shape[-2] * t.shape[-1])

    def forward(self, x):
        # TRAIN: x is a batch dict -> YOLO returns (det_loss, loss_items)
        if isinstance(x, dict):
            self._neck_feats = None
            imgs = x["img"]
            det_loss, det_items = self.yolo(x)

            feat_seg = self._pick_high_res_from_detect_inputs(self._neck_feats)
            seg_logits = self.sem_head(feat_seg)
            seg_logits = F.interpolate(seg_logits, size=imgs.shape[-2:], mode="bilinear", align_corners=False)

            feat_dn = self._pick_low_res_from_detect_inputs(self._neck_feats)
            dn_logits = self.dn_head(feat_dn)

            return det_loss, det_items, seg_logits, dn_logits

        # INFER: x is an image tensor -> YOLO returns preds
        self._neck_feats = None
        det_preds = self.yolo(x)

        feat_seg = self._pick_high_res_from_detect_inputs(self._neck_feats)
        seg_logits = self.sem_head(feat_seg)
        seg_logits = F.interpolate(seg_logits, size=x.shape[-2:], mode="bilinear", align_corners=False)

        feat_dn = self._pick_low_res_from_detect_inputs(self._neck_feats)
        dn_logits = self.dn_head(feat_dn)

        return det_preds, seg_logits, dn_logits

In [None]:
import torch

def collate_det_seg(batch):
    # batch items: (img, labels, mask, dn)
    imgs, labels_list, masks, dns = zip(*batch)

    imgs = torch.stack(imgs, 0)         # [B,3,H,W]
    masks = torch.stack(masks, 0)       # [B,1,H,W]
    dn = torch.tensor(dns, dtype=torch.long)  # [B]

    bboxes_all, cls_all, batch_idx_all = [], [], []
    for i, lab in enumerate(labels_list):
        if lab.numel() == 0:
            continue
        cls = lab[:, 0:1].long()
        bboxes = lab[:, 1:5].float()
        bboxes_all.append(bboxes)
        cls_all.append(cls)
        batch_idx_all.append(torch.full((lab.shape[0],), i, dtype=torch.long))

    if len(bboxes_all):
        bboxes = torch.cat(bboxes_all, 0)
        cls = torch.cat(cls_all, 0)
        batch_idx = torch.cat(batch_idx_all, 0)
    else:
        bboxes = torch.zeros((0, 4), dtype=torch.float32)
        cls = torch.zeros((0, 1), dtype=torch.long)
        batch_idx = torch.zeros((0,), dtype=torch.long)

    # IMPORTANT: YOLO batch dict must contain ONLY what YOLO needs
    yolo_batch = {"img": imgs, "bboxes": bboxes, "cls": cls, "batch_idx": batch_idx}

    # return extras separately (so YOLO loss doesn't see them)
    return yolo_batch, masks, dn



In [None]:
import torch

@torch.no_grad()
def val_iou(model, loader, device, max_batches=None):
    model.eval()
    total_iou = 0.0
    total_imgs = 0

    for bi, (det_batch, mask, dn) in enumerate(loader):
        if (max_batches is not None) and (bi >= max_batches):
            break

        det_batch = {k: v.to(device, non_blocking=True) for k, v in det_batch.items()}
        gt = (mask.to(device, non_blocking=True) > 0.5).float()

        det_loss, det_items, seg_logits, dn_logits = model(det_batch)  # dict-path returns 4
        pred = (torch.sigmoid(seg_logits) > 0.5).float()

        inter = (pred * gt).sum(dim=(1, 2, 3))
        union = ((pred + gt) > 0).float().sum(dim=(1, 2, 3)).clamp_min(1.0)

        iou = inter / union
        total_iou += iou.sum().item()
        total_imgs += iou.numel()

    return total_iou / max(1, total_imgs)


In [None]:
val_ds = BDDDetDrivableDataset(
    yolo_root="/content/BDD100K_640/yolo_640",
    mask_root="/content/BDD100K_640/drivable_masks_640",
    split="val",
    imgsz=640,
    dn_csv_path="/content/daynight_labels.csv"
)

from torch.utils.data import DataLoader
import torch

val_loader = DataLoader(
    val_ds,
    batch_size=8,
    shuffle=False,
    num_workers=2,
    pin_memory=torch.cuda.is_available(),
    persistent_workers=False,
    collate_fn=collate_det_seg
)

In [None]:

# CBAM HEATMAP

import os, random, glob
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
import matplotlib.cm as cm

from ultralytics import YOLO
det_yolo = YOLO("/content/drive/MyDrive/XAI_Project/experiments/det_baseline/weights/best.pt")
det_yolo.fuse()

det_yolo.model.names = {0:"car", 1:"truck", 2:"bus", 3:"person", 4:"bike"}


# CONFIG
device = "cuda" if torch.cuda.is_available() else "cpu"

best_pt   = "/content/drive/MyDrive/XAI_Project/experiments/det_seg_cbam_640/weights/best.pt"
last_ckpt = "/content/drive/MyDrive/XAI_Project/experiments/det_seg_cbam_640/weights/last.ckpt"

out_dir = "/content/drive/MyDrive/XAI_Project/experiments/det_seg_dn_cbam_640/att_seg_dn_overlays_9"
N = 20

alpha_att = 0.55   # max heatmap opacity (per-pixel scaled by attention)
alpha_seg = 0.35

PAD_VAL = 114
PAD_TOL = 3

# global norm quantiles
GLOBAL_Q_LO = 0.05
GLOBAL_Q_HI = 0.95

IMG_SIZE = 640

os.makedirs(out_dir, exist_ok=True)


def overlay_attention_fit(im_rgb_u8: np.ndarray, att01: np.ndarray, alpha=0.55):
    # att01 in [0,1], red=high (jet)
    heat = cm.jet(att01)[..., :3]                 # [H,W,3] float [0,1]
    heat_u8 = (heat * 255).astype(np.float32)

    im = im_rgb_u8.astype(np.float32)
    a = (alpha * att01).astype(np.float32)[..., None]   # per-pixel alpha
    out = (1 - a) * im + a * heat_u8
    return np.clip(out, 0, 255).astype(np.uint8)

def overlay_seg_red(im_rgb_u8: np.ndarray, mask01: np.ndarray, alpha=0.35):
    im = im_rgb_u8.astype(np.float32)
    ov = im.copy()
    red = np.array([255, 0, 0], dtype=np.float32)
    m = (mask01 > 0)
    ov[m] = (1 - alpha) * ov[m] + alpha * red
    return np.clip(ov, 0, 255).astype(np.uint8)

def get_pad_mask_rgb_u8(img_u8: np.ndarray, pad_val=114, tol=3):
    diff = np.abs(img_u8.astype(np.int16) - pad_val)
    return (diff[..., 0] <= tol) & (diff[..., 1] <= tol) & (diff[..., 2] <= tol)

def keep_top_percent(att01: np.ndarray, nonpad_mask: np.ndarray, top_p: float = 0.15):
    vals = att01[nonpad_mask]
    if vals.size == 0:
        return att01
    thr = np.quantile(vals, 1.0 - top_p)   # keep only top top_p
    out = att01.copy()
    out[out < thr] = 0.0
    return out

def extract_xyxy_conf(det_out, conf_thres=0.25, max_det=50):
    """
    Returns a list of boxes: [(x1,y1,x2,y2,conf), ...] in 640x640 coords.
    Works for common formats:
      - Tensor/ndarray Nx6: [x1,y1,x2,y2,conf,cls]
      - Tensor Nx5: [x1,y1,x2,y2,conf]
      - Ultralytics Results-like: det_out.boxes.xyxy + det_out.boxes.conf
    """
    # unwrap list/tuple
    if isinstance(det_out, (list, tuple)) and len(det_out) > 0:
        det_out = det_out[0]

    # Ultralytics Results-like
    if hasattr(det_out, "boxes") and hasattr(det_out.boxes, "xyxy"):
        xyxy = det_out.boxes.xyxy.detach().cpu().numpy()
        conf = det_out.boxes.conf.detach().cpu().numpy()
        out = []
        for b, c in zip(xyxy, conf):
            if c >= conf_thres:
                out.append((float(b[0]), float(b[1]), float(b[2]), float(b[3]), float(c)))
        out.sort(key=lambda t: t[4], reverse=True)
        return out[:max_det]

    # tensor/ndarray
    if torch.is_tensor(det_out):
        arr = det_out.detach().cpu().numpy()
    else:
        arr = np.array(det_out)

    if arr.ndim != 2 or arr.shape[0] == 0:
        return []

    if arr.shape[1] >= 6:
        arr = arr[:, :5]  # keep xyxy + conf
    if arr.shape[1] == 5:
        out = []
        for x1, y1, x2, y2, c in arr:
            if c >= conf_thres:
                out.append((float(x1), float(y1), float(x2), float(y2), float(c)))
        out.sort(key=lambda t: t[4], reverse=True)
        return out[:max_det]

    return []

def attention_in_boxes(att01, boxes, H, W):
    """Keep attention only inside boxes, everything else 0.
       boxes can be (x1,y1,x2,y2,conf) OR (x1,y1,x2,y2,conf,cls)
    """
    m = np.zeros((H, W), dtype=np.float32)
    for b in boxes:
        x1, y1, x2, y2, conf = b[:5]   # take first 5, ignore cls if present

        x1 = int(max(0, min(W-1, round(x1))))
        y1 = int(max(0, min(H-1, round(y1))))
        x2 = int(max(0, min(W,   round(x2))))
        y2 = int(max(0, min(H,   round(y2))))

        if x2 > x1 and y2 > y1:
            m[y1:y2, x1:x2] = 1.0

    return att01 * m

def draw_boxes(img_u8: np.ndarray, boxes, color=(0, 255, 0), thickness=2):
    out = img_u8.copy()
    H, W = out.shape[:2]
    for (x1, y1, x2, y2, conf) in boxes:
        x1 = int(max(0, min(W-1, round(x1))))
        y1 = int(max(0, min(H-1, round(y1))))
        x2 = int(max(0, min(W-1, round(x2))))
        y2 = int(max(0, min(H-1, round(y2))))

        # top/bottom
        out[max(0,y1-thickness):y1+thickness, x1:x2] = color
        out[max(0,y2-thickness):y2+thickness, x1:x2] = color
        # left/right
        out[y1:y2, max(0,x1-thickness):x1+thickness] = color
        out[y1:y2, max(0,x2-thickness):x2+thickness] = color
    return out

from PIL import Image, ImageDraw, ImageFont

def draw_boxes_with_text(img_u8: np.ndarray, boxes, cls_names=None, color=(0,255,0), thickness=2):
    """
    boxes: list of tuples either (x1,y1,x2,y2,conf) or (x1,y1,x2,y2,conf,cls)
    """
    im = Image.fromarray(img_u8)
    draw = ImageDraw.Draw(im)

    # try default font
    try:
        font = ImageFont.load_default()
    except:
        font = None

    for b in boxes:
        if len(b) == 5:
            x1,y1,x2,y2,conf = b
            cls = None
        else:
            x1,y1,x2,y2,conf,cls = b

        x1,y1,x2,y2 = map(int, [x1,y1,x2,y2])

        # rectangle
        for t in range(thickness):
            draw.rectangle([x1-t, y1-t, x2+t, y2+t], outline=color)

        # text
        label = f"{conf:.2f}"
        if cls is not None and cls_names is not None:
            if isinstance(cls_names, dict):
                name = cls_names.get(int(cls), str(int(cls)))
            else:
                name = cls_names[int(cls)] if int(cls) < len(cls_names) else str(int(cls))
            label = f"{name} {conf:.2f}"

        tx, ty = x1, max(0, y1 - 12)
        draw.text((tx, ty), label, fill=color, font=font)

    return np.array(im)

import matplotlib.cm as cm

def overlay_attention_full(im_rgb_u8, att01, nonpad_mask, alpha=0.55, cmap_name="plasma"):
    cmap = cm.get_cmap(cmap_name)
    heat = cmap(att01)[..., :3]
    heat_u8 = (heat * 255).astype(np.float32)

    im = im_rgb_u8.astype(np.float32)
    out = im.copy()

    m = nonpad_mask
    out[m] = (1 - alpha) * im[m] + alpha * heat_u8[m]
    return np.clip(out, 0, 255).astype(np.uint8)


# 1) Build model + load weights
det_base = "/content/drive/MyDrive/XAI_Project/experiments/det_baseline/weights/best.pt"
model = YOLOv8DetSemSeg(yolo_weights=det_base, use_cbam=True).to(device)

# force CBAM modules to exist (if lazily created)
with torch.no_grad():
    _ = model(torch.randn(1, 3, IMG_SIZE, IMG_SIZE, device=device))

# load weights
if os.path.exists(best_pt):
    sd = torch.load(best_pt, map_location="cpu")
elif os.path.exists(last_ckpt):
    ck = torch.load(last_ckpt, map_location="cpu")
    sd = ck["model"]
else:
    raise FileNotFoundError("Neither best.pt nor last.ckpt found.")

model.load_state_dict(sd, strict=True)
model.eval()

tag_model = YOLOv8DetSemSegDn(yolo_weights=det_base, use_cbam=True).to(device)   # <- your tri-task class name
with torch.no_grad():
    _ = tag_model(torch.randn(1,3,IMG_SIZE,IMG_SIZE, device=device))
tag_sd = torch.load("/content/drive/MyDrive/XAI_Project/experiments/det_seg_dn_cbam_640/weights/best.pt", map_location="cpu")
tag_model.load_state_dict(tag_sd, strict=True)
tag_model.eval()

tag_cache = {}


# 2) Hook: store ALL spatial logits calls
cb = model.cbam_neck if getattr(model, "cbam_neck", None) is not None else getattr(model, "cbam_backbone", None)
if cb is None or (not hasattr(cb, "sa")) or (not hasattr(cb.sa, "conv")):
    raise RuntimeError("CBAM.sa.conv not found.")

_att = {"logits_list": []}

def _hook_sa_logits(mod, inp, out):
    _att["logits_list"].append(out.detach())  # [B,1,h,w] pre-sigmoid

if hasattr(cb.sa, "_sa_hook_handle") and cb.sa._sa_hook_handle is not None:
    cb.sa._sa_hook_handle.remove()
cb.sa._sa_hook_handle = cb.sa.conv.register_forward_hook(_hook_sa_logits)

#3) Pick random samples
idxs = random.sample(range(len(val_ds)), k=min(N, len(val_ds)))

# 4) PASS 1: compute fused attention + collect values for GLOBAL norm
att_cache = {}   # idx -> att_map (H,W) float32 (already pad-masked, not normalized)
seg_cache = {}   # idx -> pred_mask uint8
img_cache = {}   # idx -> img_rgb uint8
pad_cache = {}   # idx -> pad_mask bool
det_cache = {}   # idx -> list of boxes


vals_all = []    # collect attention values (non-pad) for global quantiles

with torch.no_grad():
    for idx in idxs:
        img_t = val_ds[idx][0]# [3,640,640] in [0,1]
        img_np = (img_t.permute(1,2,0).numpy() * 255).astype(np.uint8)
        H, W = img_np.shape[:2]

        # img_np is 640x640 uint8 RGB already
        res = det_yolo.predict(source=img_np, imgsz=640, conf=0.50, iou=0.6, verbose=False)[0]
        xyxy = res.boxes.xyxy.cpu().numpy()
        conf = res.boxes.conf.cpu().numpy()
        cls  = res.boxes.cls.cpu().numpy()

        boxes = [(float(b[0]), float(b[1]), float(b[2]), float(b[3]), float(c), int(k))
                for b, c, k in zip(xyxy, conf, cls)]
        det_cache[idx] = boxes
        print("idx", idx, "boxes:", len(boxes))


        pad_mask = get_pad_mask_rgb_u8(img_np, PAD_VAL, PAD_TOL)
        nonpad = ~pad_mask
        if nonpad.sum() == 0:
            continue

        x = img_t.unsqueeze(0).to(device)

        _, _, tag_logits = tag_model(x)          # tag_logits shape [B,1] or [B]
        cls = torch.softmax(tag_logits, dim=1).argmax(dim=1).item()   # 0 or 1
        tag_cache[idx] = "night" if cls == 1 else "day"               # use class order



        _att["logits_list"].clear()
        _det_preds, seg_logits = model(x)


        det_cache[idx] = boxes


        pred_mask = (torch.sigmoid(seg_logits)[0, 0] > 0.5).detach().cpu().numpy().astype(np.uint8)

        if len(_att["logits_list"]) == 0:
            continue

        # For each logits_i: sigmoid -> upsample to (H,W), then weighted average.
        fused = torch.zeros((H, W), device=device, dtype=torch.float32)
        wsum = 0.0

        for logits in _att["logits_list"]:
            # logits: [B,1,h,w]
            sa = torch.sigmoid(logits)[0, 0]  # [h,w]
            h, w = int(sa.shape[-2]), int(sa.shape[-1])

            sa_up = F.interpolate(sa[None, None], size=(H, W), mode="bilinear", align_corners=False)[0, 0]

            # weight finer maps more (proportional to resolution)
            weight = float(h * w)
            fused += weight * sa_up
            wsum += weight

        fused = fused / (wsum + 1e-6)

        # pad masking (keep padding at 0 so it doesn't affect global stats)
        fused = fused.masked_fill(torch.from_numpy(pad_mask).to(device), 0.0)

        fused_np = fused.detach().cpu().numpy().astype(np.float32)

        # collect non-pad values for global quantiles
        vals_all.append(fused_np[nonpad])

        att_cache[idx] = fused_np
        seg_cache[idx] = pred_mask
        img_cache[idx] = img_np
        pad_cache[idx] = pad_mask

# build global quantiles
if len(vals_all) == 0:
    cb.sa._sa_hook_handle.remove()
    raise RuntimeError("No attention values collected (hook not firing or dataset issue).")

vals_concat = np.concatenate(vals_all, axis=0)
g_lo = float(np.quantile(vals_concat, GLOBAL_Q_LO))
g_hi = float(np.quantile(vals_concat, GLOBAL_Q_HI))
print("Global lo/hi:", g_lo, g_hi)

#5) PASS 2: global normalize + save overlays
saved = 0
for idx in idxs:
    if idx not in att_cache:
        continue

    img_np    = img_cache[idx]
    pred_mask = seg_cache[idx]
    pad_mask  = pad_cache[idx]
    att       = att_cache[idx]  # float32, pad already 0

    H, W = img_np.shape[:2]
    boxes = det_cache.get(idx, [])

    # 1) GLOBAL NORMALIZATION
    att01 = np.clip((att - g_lo) / (g_hi - g_lo + 1e-6), 0.0, 1.0).astype(np.float32)
    att01[pad_mask] = 0.0

    att_full = att01  # raw fused attention (normalized), no boxing, no thresholding
    nonpad_mask = ~pad_mask

    # ONE heatmap overlay (only on non-pad so bars stay gray)
    over = overlay_attention_full(img_np, att_full, nonpad_mask, alpha=alpha_att, cmap_name="plasma")

    # add seg + boxes for context (doesn't change attention)
    over = overlay_seg_red(over, pred_mask, alpha=alpha_seg)
    over = draw_boxes_with_text(over, boxes, cls_names=det_yolo.model.names, color=(0,255,0), thickness=2)


    out_path = os.path.join(out_dir, f"val_{idx:05d}_attFUSED_globalNORM_seg.png")
    Image.fromarray(img_np).save(os.path.join(out_dir, f"val_{idx:05d}_BASE.png"))


    label = tag_cache.get(idx, "unknown")
    im = Image.fromarray(over)
    draw = ImageDraw.Draw(im)
    y0 = int(np.argmax(nonpad_mask.any(axis=1)))
    draw.text((10, y0 + 10), f"pred: {label}", fill=(255,255,255))


    im.save(out_path)

    saved += 1

# cleanup hook
cb.sa._sa_hook_handle.remove()


print("Saved:", saved, "images to:", out_dir)
print("Files now:", len(glob.glob(out_dir + '/*.png')))
print("Note: fused all CBAM.sa.conv calls + global norm quantiles.")

Model summary (fused): 73 layers, 3,006,623 parameters, 0 gradients, 8.1 GFLOPs
idx 3446 boxes: 5
idx 8132 boxes: 3
idx 7335 boxes: 2
idx 7767 boxes: 7
idx 7838 boxes: 6
idx 5347 boxes: 5
idx 1196 boxes: 3
idx 3172 boxes: 6
idx 6982 boxes: 3
idx 3685 boxes: 11
idx 5263 boxes: 8
idx 1557 boxes: 4
idx 2435 boxes: 4
idx 3294 boxes: 5
idx 4920 boxes: 6
idx 7470 boxes: 10
idx 7562 boxes: 7
idx 5178 boxes: 4
idx 4310 boxes: 2
idx 6574 boxes: 13
Global lo/hi: 0.5769175291061401 0.7334437966346741


  cmap = cm.get_cmap(cmap_name)


Saved: 20 images to: /content/drive/MyDrive/XAI_Project/experiments/det_seg_dn_cbam_640/att_seg_dn_overlays_9
Files now: 40
Note: fused all CBAM.sa.conv calls + global norm quantiles.
