<a href="https://colab.research.google.com/github/EliasNoorzad/XAI_Autonomous-Driving/blob/main/evaluation/04_perturbation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)


Mounted at /content/drive


In [None]:
!pip -q install ultralytics==8.4.6

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m39.7 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
!cp /content/drive/MyDrive/XAI_Project/BDD100K_640.zip /content/

In [None]:
!cp /content/drive/MyDrive/XAI_Project/daynight_labels.csv /content/

In [None]:
!unzip -q /content/BDD100K_640.zip -d /content/BDD100K_640

In [None]:
yaml = """\
path: /content/BDD100K_640/yolo_640
train: train/images
val: val/images
test: test/images

nc: 5
names: [car, truck, bus, person, bike]
"""

with open("/content/BDD100K_640/yolo_640/dataset_640.yaml", "w") as f:
    f.write(yaml)

In [None]:
import csv
import numpy as np
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset


class BDDDetDrivableDataset(Dataset):
    """
    For preprocessed 640 dataset (yolo_640 + drivable_masks_640):
      images: <yolo_root>/<split>/images/<stem>.jpg
      labels: <yolo_root>/<split>/labels/<stem>.txt
      masks : <mask_root>/<split>/<stem>.png

    Returns:
      img   : FloatTensor [3, H, W] in [0,1]
      labels: FloatTensor [N, 5] where each row is [cls, x, y, w, h] (YOLO normalized)
      mask  : FloatTensor [1, H, W] with values 0/1
      dn    : LongTensor scalar (0=day, 1=night)  <-- always valid (unlabeled images are filtered out)
    """
    def __init__(
        self,
        yolo_root: str,
        mask_root: str,
        split: str,
        imgsz: int = 640,
        dn_csv_path: str | None = None,
    ):
        self.yolo_root = Path(yolo_root)
        self.mask_root = Path(mask_root)
        self.split = split
        self.imgsz = int(imgsz)

        self.img_dir = self.yolo_root / split / "images"
        self.lbl_dir = self.yolo_root / split / "labels"
        self.msk_dir = self.mask_root / split

        if not self.img_dir.is_dir():
            raise FileNotFoundError(f"Missing images dir: {self.img_dir}")
        if not self.lbl_dir.is_dir():
            raise FileNotFoundError(f"Missing labels dir: {self.lbl_dir}")
        if not self.msk_dir.is_dir():
            raise FileNotFoundError(f"Missing masks dir:  {self.msk_dir}")

        exts = {".jpg", ".jpeg", ".png"}
        self.img_paths = sorted([p for p in self.img_dir.iterdir() if p.suffix.lower() in exts])
        if len(self.img_paths) == 0:
            raise FileNotFoundError(f"No images found in: {self.img_dir}")

        #  day/night mapping from CSV
        if dn_csv_path is None:
            raise RuntimeError("dn_csv_path is required for this dataset (day/night head training).")

        dn_csv_path = Path(dn_csv_path)
        if not dn_csv_path.exists():
            raise FileNotFoundError(f"Missing day/night CSV: {dn_csv_path}")

        dn_map = {}
        with open(dn_csv_path, "r", newline="", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            required = {"split", "image_id", "label"}
            if not required.issubset(set(reader.fieldnames or [])):
                raise ValueError(f"dn CSV must have columns {required}, got {reader.fieldnames}")

            for row in reader:
                if row["split"] != self.split:
                    continue

                image_id = row["image_id"].strip()   # stem (no extension)
                lab = row["label"].strip().lower()

                if lab == "day":
                    dn = 0
                elif lab == "night":
                    dn = 1
                else:
                    # if CSV contains anything else, it's a data error
                    raise ValueError(f"Invalid dn label in CSV for {image_id}: {row['label']}")

                dn_map[image_id] = dn

        self.dn_map = dn_map

        # FILTER OUT UNLABELED IMAGES
        before = len(self.img_paths)
        self.img_paths = [p for p in self.img_paths if p.stem in self.dn_map]
        after = len(self.img_paths)
        if after == 0:
            raise RuntimeError(f"No labeled (day/night) images found for split='{self.split}'.")


    def __len__(self):
        return len(self.img_paths)

    @staticmethod
    def _read_yolo_labels(label_path: Path) -> torch.Tensor:
        if not label_path.exists():
            return torch.zeros((0, 5), dtype=torch.float32)

        rows = []
        with open(label_path, "r") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                parts = line.split()
                if len(parts) != 5:
                    continue
                cls, x, y, w, h = parts
                rows.append([float(cls), float(x), float(y), float(w), float(h)])

        if len(rows) == 0:
            return torch.zeros((0, 5), dtype=torch.float32)
        return torch.tensor(rows, dtype=torch.float32)

    @staticmethod
    def _pil_to_chw_float(img: Image.Image) -> torch.Tensor:
        arr = np.array(img, dtype=np.float32) / 255.0
        arr = np.transpose(arr, (2, 0, 1))
        return torch.from_numpy(arr)

    def __getitem__(self, idx: int):
        img_path = self.img_paths[idx]
        stem = img_path.stem

        label_path = self.lbl_dir / f"{stem}.txt"
        mask_path = self.msk_dir / f"{stem}.png"

        if not mask_path.exists():
            raise FileNotFoundError(f"Missing mask for {stem}: {mask_path}")

        img = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        labels = self._read_yolo_labels(label_path)

        if img.size != (self.imgsz, self.imgsz):
            img = img.resize((self.imgsz, self.imgsz), resample=Image.BILINEAR)
        if mask.size != (self.imgsz, self.imgsz):
            mask = mask.resize((self.imgsz, self.imgsz), resample=Image.NEAREST)

        img_t = self._pil_to_chw_float(img)
        mask_np = (np.array(mask, dtype=np.uint8) > 0).astype(np.float32)
        mask_t = torch.from_numpy(mask_np)[None, :, :]

        # dn is ALWAYS valid because we filtered img_paths
        dn = torch.tensor(self.dn_map[stem], dtype=torch.long)

        return img_t, labels, mask_t, dn

In [None]:
import torch
import torch.nn as nn


class ChannelAttention(nn.Module):
    def __init__(self, channels: int, reduction: int = 16):
        super().__init__()
        hidden = max(channels // reduction, 1)

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        # shared MLP (implemented with 1x1 convs)
        self.mlp = nn.Sequential(
            nn.Conv2d(channels, hidden, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden, channels, kernel_size=1, bias=False),
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_out = self.mlp(self.avg_pool(x))
        max_out = self.mlp(self.max_pool(x))
        w = self.sigmoid(avg_out + max_out)  # BxCx1x1
        return x * w


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size: int = 7):
        super().__init__()
        assert kernel_size in (3, 7)
        padding = (kernel_size - 1) // 2

        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
        self.last_sa = None  # <--- add this

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        mean_map = torch.mean(x, dim=1, keepdim=True)
        max_map, _ = torch.max(x, dim=1, keepdim=True)
        m = torch.cat([mean_map, max_map], dim=1)

        w = self.sigmoid(self.conv(m))  # Bx1xHxW in [0,1]
        self.last_sa = w.detach()
        return x * w


class CBAM(nn.Module):
    def __init__(self, channels: int, reduction: int = 16, spatial_kernel: int = 7):
        super().__init__()
        self.ca = ChannelAttention(channels, reduction=reduction)
        self.sa = SpatialAttention(kernel_size=spatial_kernel)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ca(x)
        x = self.sa(x)
        return x

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics import YOLO


class YOLOv8DetSemSeg(nn.Module):
    """
    YOLOv8 detection model + tiny semantic seg head.
    Captures NECK features by hooking the Detect head INPUT (multi-scale features).
    """
    def __init__(self, yolo_weights: str = "yolov8n.pt", use_cbam: bool = False):
        super().__init__()
        self.yolo = YOLO(yolo_weights).model  # nn.Module
        self.use_cbam = use_cbam

        self.cbam_backbone = None  # Point 1 (after last backbone block)
        self.cbam_neck = None      # Point 2 (once in neck before heads)

        self._neck_feats = None

        self.sem_head = nn.Sequential(
            nn.LazyConv2d(64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=1)
        )

        self._register_backbone_hook_point1()
        self._register_detect_input_hook()

    def _register_backbone_hook_point1(self):
        # find first upsample -> previous layer output is your last backbone high-level feature
        idx_up = None
        for i, m in enumerate(self.yolo.model):
            if isinstance(m, nn.Upsample) or "upsample" in m.__class__.__name__.lower():
                idx_up = i
                break
        if idx_up is None or idx_up == 0:
            raise RuntimeError("Could not find neck start (Upsample) to place backbone CBAM.")

        backbone_last = self.yolo.model[idx_up - 1]

        if hasattr(self, "_bb_hook_handle") and self._bb_hook_handle is not None:
            self._bb_hook_handle.remove()

        def fwd_hook(module, inputs, output):
            if not self.use_cbam:
                return None
            if self.cbam_backbone is None:
                self.cbam_backbone = CBAM(channels=output.shape[1]).to(output.device)
            return self.cbam_backbone(output)

        self._bb_hook_handle = backbone_last.register_forward_hook(fwd_hook)


    def _register_detect_input_hook(self):
        if not hasattr(self.yolo, "model"):
            raise RuntimeError("Unexpected Ultralytics model: no .model")

        detect_module = self.yolo.model[-1]
        name = detect_module.__class__.__name__.lower()
        if "detect" not in name:
            raise RuntimeError(f"Last module is not Detect (got {detect_module.__class__.__name__}).")

        # remove previous hook if exists
        if hasattr(self, "_detect_hook_handle") and self._detect_hook_handle is not None:
            self._detect_hook_handle.remove()

        def pre_hook(module, inputs):
            feats = inputs[0]              # should be list of tensors
            feats = list(feats)            # FORCE list

            if self.use_cbam:
                if self.cbam_neck is None:
                    self.cbam_neck = CBAM(feats[-1].shape[1]).to(feats[-1].device)
                feats[-1] = self.cbam_neck(feats[-1])

            self._neck_feats = feats
            return (feats,) if self.use_cbam else None
          # return list inside the tuple wrapper


        self._detect_hook_handle = detect_module.register_forward_pre_hook(pre_hook)


    @staticmethod
    def _pick_high_res_from_detect_inputs(feats):
        # feats: list/tuple of [B,C,H,W]
        if not isinstance(feats, (list, tuple)) or len(feats) == 0:
            raise RuntimeError("Detect input features not captured.")
        return max(feats, key=lambda t: t.shape[-2] * t.shape[-1])  # highest H*W (usually P3)

    def forward(self, x):
      # TRAIN: x is a batch dict -> YOLO returns (det_loss, loss_items)
      if isinstance(x, dict):
          self._neck_feats = None
          imgs = x["img"]
          det_loss, det_items = self.yolo(x)

          feat = self._pick_high_res_from_detect_inputs(self._neck_feats)
          seg_logits = self.sem_head(feat)
          seg_logits = F.interpolate(seg_logits, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
          return det_loss, det_items, seg_logits

      # INFER: x is an image tensor -> YOLO returns preds
      self._neck_feats = None
      det_preds = self.yolo(x)

      feat = self._pick_high_res_from_detect_inputs(self._neck_feats)
      seg_logits = self.sem_head(feat)
      seg_logits = F.interpolate(seg_logits, size=x.shape[-2:], mode="bilinear", align_corners=False)
      return det_preds, seg_logits

Creating new Ultralytics Settings v0.0.6 file ✅ 
View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json'
Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics import YOLO


class YOLOv8DetSemSegDn(nn.Module):
    """
    YOLOv8 detection model + tiny semantic seg head.
    Captures NECK features by hooking the Detect head INPUT (multi-scale features).
    """
    def __init__(self, yolo_weights: str = "yolov8n.pt", use_cbam: bool = False):
        super().__init__()
        self.yolo = YOLO(yolo_weights).model  # nn.Module
        self.use_cbam = use_cbam

        self.cbam_backbone = None  # Point 1 (after last backbone block)
        self.cbam_neck = None      # Point 2 (CBAM on ALL neck feature maps)

        self._neck_feats = None

        self.sem_head = nn.Sequential(
            nn.LazyConv2d(64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=1)
        )

        self.dn_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # [B,C,1,1]
            nn.Flatten(1),            # [B,C]
            nn.LazyLinear(2)          # [B,2] logits: [day, night]
        )

        self._register_backbone_hook_point1()
        self._register_detect_input_hook()

    def _register_backbone_hook_point1(self):
        idx_up = None
        for i, m in enumerate(self.yolo.model):
            if isinstance(m, nn.Upsample) or "upsample" in m.__class__.__name__.lower():
                idx_up = i
                break
        if idx_up is None or idx_up == 0:
            raise RuntimeError("Could not find neck start (Upsample) to place backbone CBAM.")

        backbone_last = self.yolo.model[idx_up - 1]

        if hasattr(self, "_bb_hook_handle") and self._bb_hook_handle is not None:
            self._bb_hook_handle.remove()

        def fwd_hook(module, inputs, output):
            if not self.use_cbam:
                return None
            if self.cbam_backbone is None:
                self.cbam_backbone = CBAM(channels=output.shape[1]).to(output.device)
            return self.cbam_backbone(output)

        self._bb_hook_handle = backbone_last.register_forward_hook(fwd_hook)

    def _register_detect_input_hook(self):
        if not hasattr(self.yolo, "model"):
            raise RuntimeError("Unexpected Ultralytics model: no .model")

        detect_module = self.yolo.model[-1]
        name = detect_module.__class__.__name__.lower()
        if "detect" not in name:
            raise RuntimeError(f"Last module is not Detect (got {detect_module.__class__.__name__}).")

        if hasattr(self, "_detect_hook_handle") and self._detect_hook_handle is not None:
            self._detect_hook_handle.remove()

        def pre_hook(module, inputs):
            feats = list(inputs[0])  # list of multiscale neck features

            if self.use_cbam:
                # one CBAM per scale
                if self.cbam_neck is None:
                    self.cbam_neck = nn.ModuleList([CBAM(f.shape[1]).to(f.device) for f in feats])
                # apply to ALL scales
                feats = [m(f) for m, f in zip(self.cbam_neck, feats)]

            self._neck_feats = feats
            return (feats,) if self.use_cbam else None

        self._detect_hook_handle = detect_module.register_forward_pre_hook(pre_hook)

    @staticmethod
    def _pick_high_res_from_detect_inputs(feats):
        if not isinstance(feats, (list, tuple)) or len(feats) == 0:
            raise RuntimeError("Detect input features not captured.")
        return max(feats, key=lambda t: t.shape[-2] * t.shape[-1])

    @staticmethod
    def _pick_low_res_from_detect_inputs(feats):
        if not isinstance(feats, (list, tuple)) or len(feats) == 0:
            raise RuntimeError("Detect input features not captured.")
        return min(feats, key=lambda t: t.shape[-2] * t.shape[-1])

    def forward(self, x):
        # TRAIN: x is a batch dict -> YOLO returns (det_loss, loss_items)
        if isinstance(x, dict):
            self._neck_feats = None
            imgs = x["img"]
            det_loss, det_items = self.yolo(x)

            feat_seg = self._pick_high_res_from_detect_inputs(self._neck_feats)
            seg_logits = self.sem_head(feat_seg)
            seg_logits = F.interpolate(seg_logits, size=imgs.shape[-2:], mode="bilinear", align_corners=False)

            feat_dn = self._pick_low_res_from_detect_inputs(self._neck_feats)
            dn_logits = self.dn_head(feat_dn)

            return det_loss, det_items, seg_logits, dn_logits

        # INFER: x is an image tensor -> YOLO returns preds
        self._neck_feats = None
        det_preds = self.yolo(x)

        feat_seg = self._pick_high_res_from_detect_inputs(self._neck_feats)
        seg_logits = self.sem_head(feat_seg)
        seg_logits = F.interpolate(seg_logits, size=x.shape[-2:], mode="bilinear", align_corners=False)

        feat_dn = self._pick_low_res_from_detect_inputs(self._neck_feats)
        dn_logits = self.dn_head(feat_dn)

        return det_preds, seg_logits, dn_logits

In [None]:
import torch

def collate_det_seg(batch):
    # batch items: (img, labels, mask, dn)
    imgs, labels_list, masks, dns = zip(*batch)

    imgs = torch.stack(imgs, 0)         # [B,3,H,W]
    masks = torch.stack(masks, 0)       # [B,1,H,W]
    dn = torch.tensor(dns, dtype=torch.long)  # [B]

    bboxes_all, cls_all, batch_idx_all = [], [], []
    for i, lab in enumerate(labels_list):
        if lab.numel() == 0:
            continue
        cls = lab[:, 0:1].long()
        bboxes = lab[:, 1:5].float()
        bboxes_all.append(bboxes)
        cls_all.append(cls)
        batch_idx_all.append(torch.full((lab.shape[0],), i, dtype=torch.long))

    if len(bboxes_all):
        bboxes = torch.cat(bboxes_all, 0)
        cls = torch.cat(cls_all, 0)
        batch_idx = torch.cat(batch_idx_all, 0)
    else:
        bboxes = torch.zeros((0, 4), dtype=torch.float32)
        cls = torch.zeros((0, 1), dtype=torch.long)
        batch_idx = torch.zeros((0,), dtype=torch.long)


    yolo_batch = {"img": imgs, "bboxes": bboxes, "cls": cls, "batch_idx": batch_idx}


    return yolo_batch, masks, dn



In [None]:
import torch

@torch.no_grad()
def val_iou(model, loader, device, max_batches=None):
    model.eval()
    total_iou = 0.0
    total_imgs = 0

    for bi, (det_batch, mask, dn) in enumerate(loader):
        if (max_batches is not None) and (bi >= max_batches):
            break

        det_batch = {k: v.to(device, non_blocking=True) for k, v in det_batch.items()}
        gt = (mask.to(device, non_blocking=True) > 0.5).float()

        det_loss, det_items, seg_logits, dn_logits = model(det_batch)  # dict-path returns 4
        pred = (torch.sigmoid(seg_logits) > 0.5).float()

        inter = (pred * gt).sum(dim=(1, 2, 3))
        union = ((pred + gt) > 0).float().sum(dim=(1, 2, 3)).clamp_min(1.0)

        iou = inter / union
        total_iou += iou.sum().item()
        total_imgs += iou.numel()

    return total_iou / max(1, total_imgs)


In [None]:
val_ds = BDDDetDrivableDataset(
    yolo_root="/content/BDD100K_640/yolo_640",
    mask_root="/content/BDD100K_640/drivable_masks_640",
    split="val",
    imgsz=640,
    dn_csv_path="/content/daynight_labels.csv"
)

from torch.utils.data import DataLoader
import torch

val_loader = DataLoader(
    val_ds,
    batch_size=8,
    shuffle=False,
    num_workers=2,
    pin_memory=torch.cuda.is_available(),
    persistent_workers=False,
    collate_fn=collate_det_seg
)

In [None]:

# PERTURBATION TEST (DET+SEG+DN + CBAM)
# mask top-attended vs random (same area) and compare IoU drop


import os, random
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"

# paths
cbam_best_pt = "/content/drive/MyDrive/XAI_Project/experiments/det_seg_dn_cbam_640/weights/best.pt"
det_base_pt  = "/content/drive/MyDrive/XAI_Project/experiments/det_baseline/weights/best.pt"

# settings

IMG_SIZE = 640
TOP_P    = 0.30
FILL01   = 114/255.0
N        = 1000  # how many images to test
PATCH_H = 128
PATCH_W = 128
SEED_MASK = 1234
rng = np.random.default_rng(SEED_MASK)


# helpers
def make_top_mask(att01, nonpad_mask, top_p=0.15):
    vals = att01[nonpad_mask]
    thr = np.quantile(vals, 1.0 - top_p)
    return (att01 >= thr) & nonpad_mask

def make_rand_mask(nonpad_mask, k):
    ys, xs = np.where(nonpad_mask)
    sel = rng.choice(len(ys), size=k, replace=False)
    m = np.zeros_like(nonpad_mask, dtype=bool)
    m[ys[sel], xs[sel]] = True
    return m

def apply_mask_to_x_from_img(x, img_np_u8, mask_hw):
    # x: [1,3,H,W] float in [0,1]
    # img_np_u8: HxWx3 uint8 (same image)
    # mask_hw: HxW bool

    img2 = img_np_u8.copy()
    img2[mask_hw] = 114  # set masked pixels to true padding gray in uint8
    x2 = torch.from_numpy(img2).permute(2,0,1).float() / 255.0
    return x2.unsqueeze(0).to(x.device)


def iou_bin(pred01, gt01):
    p = pred01.astype(bool); g = gt01.astype(bool)
    inter = (p & g).sum()
    union = (p | g).sum()
    return float(inter) / float(union + 1e-6)

def bbox_from_mask(m):
    ys, xs = np.where(m)
    if len(ys) == 0:
        return None
    y1, y2 = ys.min(), ys.max() + 1
    x1, x2 = xs.min(), xs.max() + 1
    return y1, x1, y2, x2

def random_bbox(valid_mask, h, w, tries=200):
    H, W = valid_mask.shape
    for _ in range(tries):
        y1 = np.random.randint(0, H - h + 1)
        x1 = np.random.randint(0, W - w + 1)
        patch = valid_mask[y1:y1+h, x1:x1+w]
        # require most of the patch to be valid (not padding)
        if patch.mean() > 0.9:
            return y1, x1, y1+h, x1+w
    return None

def clamp(v, lo, hi):
    return max(lo, min(hi, v))

def patch_mask_from_center(nonpad_mask, cy, cx, ph, pw):
    H, W = nonpad_mask.shape
    y1 = clamp(cy - ph//2, 0, H - ph)
    x1 = clamp(cx - pw//2, 0, W - pw)
    y2, x2 = y1 + ph, x1 + pw
    m = np.zeros_like(nonpad_mask, dtype=bool)
    m[y1:y2, x1:x2] = True
    return m & nonpad_mask

def random_patch_mask(nonpad_mask, ph, pw, tries=200):
    H, W = nonpad_mask.shape
    for _ in range(tries):
        y1 = int(rng.integers(0, H - ph + 1))
        x1 = int(rng.integers(0, W - pw + 1))
        patch = nonpad_mask[y1:y1+ph, x1:x1+pw]
        if patch.mean() > 0.95:
            m = np.zeros_like(nonpad_mask, dtype=bool)
            m[y1:y1+ph, x1:x1+pw] = True
            return m
    return None


def get_pad_mask_rgb_u8(img_u8, pad_val=114, tol=3):
    diff = np.abs(img_u8.astype(np.int16) - pad_val)
    return (diff[...,0] <= tol) & (diff[...,1] <= tol) & (diff[...,2] <= tol)

# building CBAM det+seg+dn model
model = YOLOv8DetSemSegDn(yolo_weights=det_base_pt, use_cbam=True).to(device)

with torch.no_grad():
    _ = model(torch.zeros(1,3,IMG_SIZE,IMG_SIZE, device=device))


ckpt = torch.load(cbam_best_pt, map_location="cpu")
sd = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt
model.load_state_dict(sd, strict=True)


model.eval()

# hook all 3 cbam_neck spatial conv outputs
_att = {"logits_list": []}
def _hook_sa_logits(mod, inp, out):
    _att["logits_list"].append(out.detach())

handles = []
for i in range(3):
    handles.append(model.cbam_neck[i].sa.conv.register_forward_hook(_hook_sa_logits))

# choose indices
SEED_POOL = 42
rng_pool = np.random.default_rng(SEED_POOL)
idxs = rng_pool.choice(len(val_ds), size=min(1000, len(val_ds)), replace=False).tolist()


drops_top, drops_rand = [], []

with torch.no_grad():
    for idx in idxs:
        sample = val_ds[idx]
        img_t   = sample[0]          # [3,640,640]
        gt_mask = sample[2]          # <-- if your gt mask is not at [2], change THIS ONE LINE

        if torch.is_tensor(gt_mask):
            gt_mask = gt_mask[0].cpu().numpy().astype(np.uint8)

        img_np = (img_t.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
        H, W = img_np.shape[:2]

        pad_mask = get_pad_mask_rgb_u8(img_np, pad_val=114, tol=3)
        nonpad_mask = ~pad_mask

        x = img_t.unsqueeze(0).to(device)

        # clean forward (collect attention)
        _att["logits_list"].clear()
        det_out, seg_logits, tag_logits = model(x)


        # fuse attention (sigmoid each scale, upsample, weighted avg)
        fused = torch.zeros((H, W), device=device, dtype=torch.float32)
        wsum = 0.0
        for logits in _att["logits_list"]:
            sa = torch.sigmoid(logits)[0,0]  # [h,w]
            h, w = int(sa.shape[-2]), int(sa.shape[-1])
            sa_up = F.interpolate(sa[None,None], size=(H,W), mode="bilinear", align_corners=False)[0,0]
            weight = float(h*w)
            fused += weight * sa_up
            wsum  += weight
        fused = fused / (wsum + 1e-6)
        fused = fused.masked_fill(torch.from_numpy(pad_mask).to(device), 0.0)
        att01 = fused.detach().cpu().numpy().astype(np.float32)

        vals = att01[nonpad_mask]
        lo = np.quantile(vals, 0.05)
        hi = np.quantile(vals, 0.95)
        att01 = np.clip((att01 - lo) / (hi - lo + 1e-6), 0.0, 1.0).astype(np.float32)
        att01[pad_mask] = 0.0


        # segmentation clean
        pred_clean = (torch.sigmoid(seg_logits)[0,0] > 0.5).detach().cpu().numpy().astype(np.uint8)
        iou_clean = iou_bin(pred_clean, gt_mask)


      # masks (fixed-size patches)
        top_pix = make_top_mask(att01, nonpad_mask, TOP_P)   # TOP_P is fine (0.30)

        ys, xs = np.where(top_pix)
        if len(ys) == 0:
            continue
        cy = int(np.mean(ys))
        cx = int(np.mean(xs))

        top_mask = patch_mask_from_center(nonpad_mask, cy, cx, PATCH_H, PATCH_W)


        rand_mask = random_patch_mask(nonpad_mask, PATCH_H, PATCH_W)
        if rand_mask is None:
            continue


        # perturbed inputs
        x_top  = apply_mask_to_x_from_img(x, img_np, top_mask)
        x_rand = apply_mask_to_x_from_img(x, img_np, rand_mask)


        # forwards perturbed
        _, seg_top,  _ = model(x_top)
        _, seg_rand, _ = model(x_rand)


        pred_top  = (torch.sigmoid(seg_top)[0,0] > 0.5).detach().cpu().numpy().astype(np.uint8)
        pred_rand = (torch.sigmoid(seg_rand)[0,0] > 0.5).detach().cpu().numpy().astype(np.uint8)

        iou_top  = iou_bin(pred_top, gt_mask)
        iou_rand = iou_bin(pred_rand, gt_mask)

        drops_top.append(iou_clean - iou_top)
        drops_rand.append(iou_clean - iou_rand)

# cleanup hooks
for h in handles:
    h.remove()

print("Images used:", len(drops_top))
print("Mean IoU drop (top-att):", float(np.mean(drops_top)) if drops_top else None)
print("Mean IoU drop (random):",  float(np.mean(drops_rand)) if drops_rand else None)


Images used: 1000
Mean IoU drop (top-att): 0.16929898551437883
Mean IoU drop (random): 0.09067801369320716


In [None]:

# PERTURBATION TEST (DET+SEG+DN + CBAM)
# mask top-attended vs random (same area) and compare IoU drop

from ultralytics import YOLO
import os, random
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"

# paths
cbam_best_pt = "/content/drive/MyDrive/XAI_Project/experiments/det_seg_dn_cbam_640/weights/best.pt"
det_base_pt  = "/content/drive/MyDrive/XAI_Project/experiments/det_baseline/weights/best.pt"

# settings

IMG_SIZE = 640
TOP_P    = 0.30
FILL01   = 114/255.0
N        = 1000  # how many images to test
PATCH_H = 128
PATCH_W = 128
SEED_MASK = 1234
rng = np.random.default_rng(SEED_MASK)


# helpers
def make_top_mask(att01, nonpad_mask, top_p=0.15):
    vals = att01[nonpad_mask]
    thr = np.quantile(vals, 1.0 - top_p)
    return (att01 >= thr) & nonpad_mask

def make_rand_mask(nonpad_mask, k):
    ys, xs = np.where(nonpad_mask)
    sel = rng.choice(len(ys), size=k, replace=False)
    m = np.zeros_like(nonpad_mask, dtype=bool)
    m[ys[sel], xs[sel]] = True
    return m

def apply_mask_to_x_from_img(x, img_np_u8, mask_hw):
    img2 = img_np_u8.copy()
    img2[mask_hw] = 114
    x2 = torch.from_numpy(img2).permute(2,0,1).float() / 255.0
    return x2.unsqueeze(0).to(x.device), img2   # <- return BOTH



def iou_bin(pred01, gt01):
    p = pred01.astype(bool); g = gt01.astype(bool)
    inter = (p & g).sum()
    union = (p | g).sum()
    return float(inter) / float(union + 1e-6)

def bbox_from_mask(m):
    ys, xs = np.where(m)
    if len(ys) == 0:
        return None
    y1, y2 = ys.min(), ys.max() + 1
    x1, x2 = xs.min(), xs.max() + 1
    return y1, x1, y2, x2

def random_bbox(valid_mask, h, w, tries=200):
    H, W = valid_mask.shape
    for _ in range(tries):
        y1 = np.random.randint(0, H - h + 1)
        x1 = np.random.randint(0, W - w + 1)
        patch = valid_mask[y1:y1+h, x1:x1+w]
        # require most of the patch to be valid (not padding)
        if patch.mean() > 0.9:
            return y1, x1, y1+h, x1+w
    return None

def clamp(v, lo, hi):
    return max(lo, min(hi, v))

def patch_mask_from_center(nonpad_mask, cy, cx, ph, pw):
    H, W = nonpad_mask.shape
    y1 = clamp(cy - ph//2, 0, H - ph)
    x1 = clamp(cx - pw//2, 0, W - pw)
    y2, x2 = y1 + ph, x1 + pw
    m = np.zeros_like(nonpad_mask, dtype=bool)
    m[y1:y2, x1:x2] = True
    return m & nonpad_mask

def random_patch_mask(nonpad_mask, ph, pw, tries=200):
    H, W = nonpad_mask.shape
    for _ in range(tries):
        y1 = int(rng.integers(0, H - ph + 1))
        x1 = int(rng.integers(0, W - pw + 1))
        patch = nonpad_mask[y1:y1+ph, x1:x1+pw]
        if patch.mean() > 0.95:
            m = np.zeros_like(nonpad_mask, dtype=bool)
            m[y1:y1+ph, x1:x1+pw] = True
            return m
    return None

def yolo_xywhn_to_xyxy_px(labels_xywhn, H, W):
    """
    labels_xywhn: array Nx5 => [cls, x, y, w, h] normalized (YOLO format)
    returns: list of (cls, x1,y1,x2,y2)
    """
    out = []
    if labels_xywhn is None:
        return out
    arr = np.array(labels_xywhn)
    if arr.size == 0:
        return out
    arr = arr.reshape(-1, arr.shape[-1])

    # if tensor
    if torch.is_tensor(labels_xywhn):
        arr = labels_xywhn.detach().cpu().numpy()

    for row in arr:
        if len(row) < 5:
            continue
        cls, x, y, w, h = row[:5]
        x1 = (x - w/2) * W
        y1 = (y - h/2) * H
        x2 = (x + w/2) * W
        y2 = (y + h/2) * H
        out.append((int(cls), float(x1), float(y1), float(x2), float(y2)))
    return out


def get_gt_boxes_from_sample(sample, H, W):
    """
    Tries to extract GT detection boxes from val_ds sample.
    Common cases:
      - sample[1] is Nx5 tensor/array: [cls, x, y, w, h] normalized (YOLO labels)
    """
    if len(sample) < 2:
        return []
    gt = sample[1]

    # case: tensor/ndarray Nx5
    if torch.is_tensor(gt) or isinstance(gt, (np.ndarray, list, tuple)):
        try:
            return yolo_xywhn_to_xyxy_px(gt, H, W)
        except:
            return []

    return []


def iou_xyxy(a, b):
    # a,b: (x1,y1,x2,y2)
    ax1, ay1, ax2, ay2 = a
    bx1, by1, bx2, by2 = b
    ix1, iy1 = max(ax1, bx1), max(ay1, by1)
    ix2, iy2 = min(ax2, bx2), min(ay2, by2)
    iw, ih = max(0.0, ix2 - ix1), max(0.0, iy2 - iy1)
    inter = iw * ih
    area_a = max(0.0, ax2-ax1) * max(0.0, ay2-ay1)
    area_b = max(0.0, bx2-bx1) * max(0.0, by2-by1)
    return inter / (area_a + area_b - inter + 1e-6)


def ap50_single_image(preds, gts, iou_thr=0.50):
    """
    preds: list of (cls, conf, x1,y1,x2,y2)
    gts:   list of (cls, x1,y1,x2,y2)

    Returns AP@0.50 computed on THIS image, across all classes (micro).
    This is "mAP-like" (not dataset mAP), good for perturbation comparison.
    """
    if len(gts) == 0:
        return 0.0

    # sort predictions by confidence desc
    preds = sorted(preds, key=lambda t: t[1], reverse=True)

    gt_used = [False] * len(gts)
    tp = np.zeros(len(preds), dtype=np.float32)
    fp = np.zeros(len(preds), dtype=np.float32)

    for i, (pc, conf, px1, py1, px2, py2) in enumerate(preds):
        best_j = -1
        best_iou = 0.0
        for j, (gc, gx1, gy1, gx2, gy2) in enumerate(gts):
            if gt_used[j]:
                continue
            if pc != gc:
                continue
            iou = iou_xyxy((px1,py1,px2,py2), (gx1,gy1,gx2,gy2))
            if iou > best_iou:
                best_iou = iou
                best_j = j

        if best_j >= 0 and best_iou >= iou_thr:
            tp[i] = 1.0
            gt_used[best_j] = True
        else:
            fp[i] = 1.0

    cum_tp = np.cumsum(tp)
    cum_fp = np.cumsum(fp)
    recall = cum_tp / (len(gts) + 1e-6)
    precision = cum_tp / (cum_tp + cum_fp + 1e-6)

    # AP by 11-point interpolation (simple + stable)
    ap = 0.0
    for r in np.linspace(0, 1, 11):
        p = precision[recall >= r].max() if np.any(recall >= r) else 0.0
        ap += p / 11.0
    return float(ap)


def det_predict_boxes(det_eval, img_np_u8, conf=0.25, iou=0.6):
    """
    Returns list of (cls, conf, x1,y1,x2,y2) in pixel coords.
    """
    res = det_eval.predict(source=img_np_u8, imgsz=img_np_u8.shape[0], conf=conf, iou=iou, verbose=False)[0]
    if res.boxes is None or len(res.boxes) == 0:
        return []
    xyxy = res.boxes.xyxy.detach().cpu().numpy()
    cls  = res.boxes.cls.detach().cpu().numpy().astype(int)
    cf   = res.boxes.conf.detach().cpu().numpy()
    out = []
    for (x1,y1,x2,y2), c, p in zip(xyxy, cls, cf):
        out.append((int(c), float(p), float(x1), float(y1), float(x2), float(y2)))
    return out



def get_pad_mask_rgb_u8(img_u8, pad_val=114, tol=3):
    diff = np.abs(img_u8.astype(np.int16) - pad_val)
    return (diff[...,0] <= tol) & (diff[...,1] <= tol) & (diff[...,2] <= tol)

# building CBAM det+seg+dn model
model = YOLOv8DetSemSegDn(yolo_weights=det_base_pt, use_cbam=True).to(device)

with torch.no_grad():
    _ = model(torch.zeros(1,3,IMG_SIZE,IMG_SIZE, device=device))


ckpt = torch.load(cbam_best_pt, map_location="cpu")
sd = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt
model.load_state_dict(sd, strict=True)


model.eval()


# a YOLO "shell" just to host the loaded weights (any yolov8n.pt / your det baseline pt works)
det_shell = det_base_pt
det_eval = YOLO(det_shell)

# class names (your 5 classes)
det_eval.model.names = {0:"car", 1:"truck", 2:"bus", 3:"person", 4:"bike"}

# load YOLO submodule weights from your tri-task checkpoint (sd)
yolo_sd = {k.replace("yolo.", ""): v for k, v in sd.items() if k.startswith("yolo.")}
det_eval.model.load_state_dict(yolo_sd, strict=False)
det_eval.fuse()


# hook all 3 cbam_neck spatial conv outputs
_att = {"logits_list": []}
def _hook_sa_logits(mod, inp, out):
    _att["logits_list"].append(out.detach())

handles = []
for i in range(3):
    handles.append(model.cbam_neck[i].sa.conv.register_forward_hook(_hook_sa_logits))

# choose indices
SEED_POOL = 42
rng_pool = np.random.default_rng(SEED_POOL)
idxs = rng_pool.choice(len(val_ds), size=min(1000, len(val_ds)), replace=False).tolist()


drops_top, drops_rand = [], []
ap_drops_top  = []
ap_drops_rand = []


with torch.no_grad():
    for idx in idxs:
        sample = val_ds[idx]
        img_t   = sample[0]          # [3,640,640]
        gt_mask = sample[2]          # <-- if your gt mask is not at [2], change THIS ONE LINE

        if torch.is_tensor(gt_mask):
            gt_mask = gt_mask[0].cpu().numpy().astype(np.uint8)

        img_np = (img_t.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
        H, W = img_np.shape[:2]

        gt_boxes = get_gt_boxes_from_sample(sample, H, W)  # list of (cls,x1,y1,x2,y2)


        pad_mask = get_pad_mask_rgb_u8(img_np, pad_val=114, tol=3)
        nonpad_mask = ~pad_mask

        x = img_t.unsqueeze(0).to(device)

        # clean forward (collect attention)
        _att["logits_list"].clear()
        det_out, seg_logits, tag_logits = model(x)

        pred_det_clean = det_predict_boxes(det_eval, img_np, conf=0.25, iou=0.6)
        ap_clean = ap50_single_image(pred_det_clean, gt_boxes, iou_thr=0.50)



        # fuse attention (sigmoid each scale, upsample, weighted avg)
        fused = torch.zeros((H, W), device=device, dtype=torch.float32)
        wsum = 0.0
        for logits in _att["logits_list"]:
            sa = torch.sigmoid(logits)[0,0]  # [h,w]
            h, w = int(sa.shape[-2]), int(sa.shape[-1])
            sa_up = F.interpolate(sa[None,None], size=(H,W), mode="bilinear", align_corners=False)[0,0]
            weight = float(h*w)
            fused += weight * sa_up
            wsum  += weight
        fused = fused / (wsum + 1e-6)
        fused = fused.masked_fill(torch.from_numpy(pad_mask).to(device), 0.0)
        att01 = fused.detach().cpu().numpy().astype(np.float32)

        vals = att01[nonpad_mask]
        lo = np.quantile(vals, 0.05)
        hi = np.quantile(vals, 0.95)
        att01 = np.clip((att01 - lo) / (hi - lo + 1e-6), 0.0, 1.0).astype(np.float32)
        att01[pad_mask] = 0.0


        # segmentation clean
        pred_clean = (torch.sigmoid(seg_logits)[0,0] > 0.5).detach().cpu().numpy().astype(np.uint8)
        iou_clean = iou_bin(pred_clean, gt_mask)


      # masks (fixed-size patches)
        top_pix = make_top_mask(att01, nonpad_mask, TOP_P)   # TOP_P is fine (0.30)

        ys, xs = np.where(top_pix)
        if len(ys) == 0:
            continue
        cy = int(np.mean(ys))
        cx = int(np.mean(xs))

        top_mask = patch_mask_from_center(nonpad_mask, cy, cx, PATCH_H, PATCH_W)


        rand_mask = random_patch_mask(nonpad_mask, PATCH_H, PATCH_W)
        if rand_mask is None:
            continue


        # perturbed inputs
        x_top,  img_top_np  = apply_mask_to_x_from_img(x, img_np, top_mask)
        x_rand, img_rand_np = apply_mask_to_x_from_img(x, img_np, rand_mask)



        # forwards perturbed
        _, seg_top,  _ = model(x_top)
        _, seg_rand, _ = model(x_rand)

        pred_det_top  = det_predict_boxes(det_eval, img_top_np,  conf=0.25, iou=0.6)
        pred_det_rand = det_predict_boxes(det_eval, img_rand_np, conf=0.25, iou=0.6)

        ap_top  = ap50_single_image(pred_det_top,  gt_boxes, iou_thr=0.50)
        ap_rand = ap50_single_image(pred_det_rand, gt_boxes, iou_thr=0.50)

        ap_drops_top.append(ap_clean - ap_top)
        ap_drops_rand.append(ap_clean - ap_rand)



        pred_top  = (torch.sigmoid(seg_top)[0,0] > 0.5).detach().cpu().numpy().astype(np.uint8)
        pred_rand = (torch.sigmoid(seg_rand)[0,0] > 0.5).detach().cpu().numpy().astype(np.uint8)

        iou_top  = iou_bin(pred_top, gt_mask)
        iou_rand = iou_bin(pred_rand, gt_mask)

        drops_top.append(iou_clean - iou_top)
        drops_rand.append(iou_clean - iou_rand)

# cleanup hooks
for h in handles:
    h.remove()

print("Images used:", len(drops_top))
print("Mean IoU drop (top-att):", float(np.mean(drops_top)) if drops_top else None)
print("Mean IoU drop (random):",  float(np.mean(drops_rand)) if drops_rand else None)
print("Mean AP50 drop (top-att):",  float(np.mean(ap_drops_top)) if ap_drops_top else None)
print("Mean AP50 drop (random):",   float(np.mean(ap_drops_rand)) if ap_drops_rand else None)


Model summary (fused): 73 layers, 3,006,623 parameters, 0 gradients, 8.1 GFLOPs
Images used: 1000
Mean IoU drop (top-att): 0.16929898551437883
Mean IoU drop (random): 0.09067801369320716
Mean AP50 drop (top-att): 0.019582364522153513
Mean AP50 drop (random): 0.009320265210873913


In [None]:
# PERTURBATION TEST (DET+SEG+DN + CBAM)
# mask top-attended vs random (same area) and compare IoU drop

from ultralytics import YOLO
import os, random
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"

# paths
cbam_best_pt = "/content/drive/MyDrive/XAI_Project/experiments/det_seg_dn_cbam_640/weights/best.pt"
det_base_pt  = "/content/drive/MyDrive/XAI_Project/experiments/det_baseline/weights/best.pt"

# settings

IMG_SIZE = 640
TOP_P    = 0.30
FILL01   = 114/255.0
N        = 1000  # how many images to test
PATCH_H = 128
PATCH_W = 128
SEED_MASK = 1234
DN_INDEX = 3
rng = np.random.default_rng(SEED_MASK)


# helpers
def make_top_mask(att01, nonpad_mask, top_p=0.15):
    vals = att01[nonpad_mask]
    thr = np.quantile(vals, 1.0 - top_p)
    return (att01 >= thr) & nonpad_mask

def make_rand_mask(nonpad_mask, k):
    ys, xs = np.where(nonpad_mask)
    sel = rng.choice(len(ys), size=k, replace=False)
    m = np.zeros_like(nonpad_mask, dtype=bool)
    m[ys[sel], xs[sel]] = True
    return m

def apply_mask_to_x_from_img(x, img_np_u8, mask_hw):
    img2 = img_np_u8.copy()
    img2[mask_hw] = 114
    x2 = torch.from_numpy(img2).permute(2,0,1).float() / 255.0
    return x2.unsqueeze(0).to(x.device), img2   # <- return BOTH



def iou_bin(pred01, gt01):
    p = pred01.astype(bool); g = gt01.astype(bool)
    inter = (p & g).sum()
    union = (p | g).sum()
    return float(inter) / float(union + 1e-6)

def bbox_from_mask(m):
    ys, xs = np.where(m)
    if len(ys) == 0:
        return None
    y1, y2 = ys.min(), ys.max() + 1
    x1, x2 = xs.min(), xs.max() + 1
    return y1, x1, y2, x2

def random_bbox(valid_mask, h, w, tries=200):
    H, W = valid_mask.shape
    for _ in range(tries):
        y1 = np.random.randint(0, H - h + 1)
        x1 = np.random.randint(0, W - w + 1)
        patch = valid_mask[y1:y1+h, x1:x1+w]
        # require most of the patch to be valid (not padding)
        if patch.mean() > 0.9:
            return y1, x1, y1+h, x1+w
    return None

def clamp(v, lo, hi):
    return max(lo, min(hi, v))

def patch_mask_from_center(nonpad_mask, cy, cx, ph, pw):
    H, W = nonpad_mask.shape
    y1 = clamp(cy - ph//2, 0, H - ph)
    x1 = clamp(cx - pw//2, 0, W - pw)
    y2, x2 = y1 + ph, x1 + pw
    m = np.zeros_like(nonpad_mask, dtype=bool)
    m[y1:y2, x1:x2] = True
    return m & nonpad_mask

def random_patch_mask(nonpad_mask, ph, pw, tries=200):
    H, W = nonpad_mask.shape
    for _ in range(tries):
        y1 = int(rng.integers(0, H - ph + 1))
        x1 = int(rng.integers(0, W - pw + 1))
        patch = nonpad_mask[y1:y1+ph, x1:x1+pw]
        if patch.mean() > 0.95:
            m = np.zeros_like(nonpad_mask, dtype=bool)
            m[y1:y1+ph, x1:x1+pw] = True
            return m
    return None

def yolo_xywhn_to_xyxy_px(labels_xywhn, H, W):
    """
    labels_xywhn: array Nx5 => [cls, x, y, w, h] normalized (YOLO format)
    returns: list of (cls, x1,y1,x2,y2)
    """
    out = []
    if labels_xywhn is None:
        return out
    arr = np.array(labels_xywhn)
    if arr.size == 0:
        return out
    arr = arr.reshape(-1, arr.shape[-1])

    # if tensor
    if torch.is_tensor(labels_xywhn):
        arr = labels_xywhn.detach().cpu().numpy()

    for row in arr:
        if len(row) < 5:
            continue
        cls, x, y, w, h = row[:5]
        x1 = (x - w/2) * W
        y1 = (y - h/2) * H
        x2 = (x + w/2) * W
        y2 = (y + h/2) * H
        out.append((int(cls), float(x1), float(y1), float(x2), float(y2)))
    return out


def get_gt_boxes_from_sample(sample, H, W):
    """
    Tries to extract GT detection boxes from val_ds sample.
    Common cases:
      - sample[1] is Nx5 tensor/array: [cls, x, y, w, h] normalized (YOLO labels)
    """
    if len(sample) < 2:
        return []
    gt = sample[1]

    # case: tensor/ndarray Nx5
    if torch.is_tensor(gt) or isinstance(gt, (np.ndarray, list, tuple)):
        try:
            return yolo_xywhn_to_xyxy_px(gt, H, W)
        except:
            return []

    return []


def iou_xyxy(a, b):
    # a,b: (x1,y1,x2,y2)
    ax1, ay1, ax2, ay2 = a
    bx1, by1, bx2, by2 = b
    ix1, iy1 = max(ax1, bx1), max(ay1, by1)
    ix2, iy2 = min(ax2, bx2), min(ay2, by2)
    iw, ih = max(0.0, ix2 - ix1), max(0.0, iy2 - iy1)
    inter = iw * ih
    area_a = max(0.0, ax2-ax1) * max(0.0, ay2-ay1)
    area_b = max(0.0, bx2-bx1) * max(0.0, by2-by1)
    return inter / (area_a + area_b - inter + 1e-6)


def ap50_single_image(preds, gts, iou_thr=0.50):
    """
    preds: list of (cls, conf, x1,y1,x2,y2)
    gts:   list of (cls, x1,y1,x2,y2)

    Returns AP@0.50 computed on THIS image, across all classes (micro).
    This is "mAP-like" (not dataset mAP), good for perturbation comparison.
    """
    if len(gts) == 0:
        return 0.0

    # sort predictions by confidence desc
    preds = sorted(preds, key=lambda t: t[1], reverse=True)

    gt_used = [False] * len(gts)
    tp = np.zeros(len(preds), dtype=np.float32)
    fp = np.zeros(len(preds), dtype=np.float32)

    for i, (pc, conf, px1, py1, px2, py2) in enumerate(preds):
        best_j = -1
        best_iou = 0.0
        for j, (gc, gx1, gy1, gx2, gy2) in enumerate(gts):
            if gt_used[j]:
                continue
            if pc != gc:
                continue
            iou = iou_xyxy((px1,py1,px2,py2), (gx1,gy1,gx2,gy2))
            if iou > best_iou:
                best_iou = iou
                best_j = j

        if best_j >= 0 and best_iou >= iou_thr:
            tp[i] = 1.0
            gt_used[best_j] = True
        else:
            fp[i] = 1.0

    cum_tp = np.cumsum(tp)
    cum_fp = np.cumsum(fp)
    recall = cum_tp / (len(gts) + 1e-6)
    precision = cum_tp / (cum_tp + cum_fp + 1e-6)

    # AP by 11-point interpolation (simple + stable)
    ap = 0.0
    for r in np.linspace(0, 1, 11):
        p = precision[recall >= r].max() if np.any(recall >= r) else 0.0
        ap += p / 11.0
    return float(ap)


def det_predict_boxes(det_eval, img_np_u8, conf=0.25, iou=0.6):
    """
    Returns list of (cls, conf, x1,y1,x2,y2) in pixel coords.
    """
    res = det_eval.predict(source=img_np_u8, imgsz=img_np_u8.shape[0], conf=conf, iou=iou, verbose=False)[0]
    if res.boxes is None or len(res.boxes) == 0:
        return []
    xyxy = res.boxes.xyxy.detach().cpu().numpy()
    cls  = res.boxes.cls.detach().cpu().numpy().astype(int)
    cf   = res.boxes.conf.detach().cpu().numpy()
    out = []
    for (x1,y1,x2,y2), c, p in zip(xyxy, cls, cf):
        out.append((int(c), float(p), float(x1), float(y1), float(x2), float(y2)))
    return out


def get_dn_gt(sample, dn_index=DN_INDEX):
    dn = sample[dn_index]
    # dn can be tensor, numpy scalar, int
    if torch.is_tensor(dn):
        dn = dn.detach().cpu()
        dn = dn.view(-1)[0].item()  # scalar
    else:
        dn = float(dn)
    # make it 0/1
    return int(dn >= 0.5)

def tag_prob_night_and_pred(tag_logits):
    """
    Returns:
      p_night: float in [0,1]
      pred: 0/1  (0=day, 1=night)
    Supports:
      - sigmoid head: tag_logits shape [B,1] or [B]
      - 2-class head: tag_logits shape [B,2]
    """
    t = tag_logits
    if torch.is_tensor(t):
        t = t.detach()
    t = t.view(t.shape[0], -1)  # [B, K]

    if t.shape[1] == 2:
        # 2-class logits
        probs = torch.softmax(t, dim=1)      # [B,2]
        p_night = probs[0, 1].item()
        pred = int(torch.argmax(probs[0]).item())  # 0/1
        return p_night, pred
    else:
        # sigmoid logit
        p_night = torch.sigmoid(t[0, 0]).item()
        pred = int(p_night >= 0.5)
        return p_night, pred

def prob_of_correct_class(p_night, gt_dn):
    # gt_dn: 0=day, 1=night
    return p_night if gt_dn == 1 else (1.0 - p_night)




def get_pad_mask_rgb_u8(img_u8, pad_val=114, tol=3):
    diff = np.abs(img_u8.astype(np.int16) - pad_val)
    return (diff[...,0] <= tol) & (diff[...,1] <= tol) & (diff[...,2] <= tol)

# building CBAM det+seg+dn model
model = YOLOv8DetSemSegDn(yolo_weights=det_base_pt, use_cbam=True).to(device)

with torch.no_grad():
    _ = model(torch.zeros(1,3,IMG_SIZE,IMG_SIZE, device=device))


ckpt = torch.load(cbam_best_pt, map_location="cpu")
sd = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt
model.load_state_dict(sd, strict=True)


model.eval()


# a YOLO "shell" just to host the loaded weights (any yolov8n.pt / your det baseline pt works)
det_shell = det_base_pt
det_eval = YOLO(det_shell)

# class names (your 5 classes)
det_eval.model.names = {0:"car", 1:"truck", 2:"bus", 3:"person", 4:"bike"}

# load YOLO submodule weights from your tri-task checkpoint (sd)
yolo_sd = {k.replace("yolo.", ""): v for k, v in sd.items() if k.startswith("yolo.")}
det_eval.model.load_state_dict(yolo_sd, strict=False)
det_eval.fuse()


# hook all 3 cbam_neck spatial conv outputs
_att = {"logits_list": []}
def _hook_sa_logits(mod, inp, out):
    _att["logits_list"].append(out.detach())

handles = []
for i in range(3):
    handles.append(model.cbam_neck[i].sa.conv.register_forward_hook(_hook_sa_logits))

# choose indices
SEED_POOL = 42
rng_pool = np.random.default_rng(SEED_POOL)
idxs = rng_pool.choice(len(val_ds), size=min(1000, len(val_ds)), replace=False).tolist()


drops_top, drops_rand = [], []
ap_drops_top  = []
ap_drops_rand = []

# tagging accumulators
tag_ok_clean = []
tag_ok_top   = []
tag_ok_rand  = []

tag_pok_clean = []
tag_pok_top   = []
tag_pok_rand  = []



with torch.no_grad():
    for idx in idxs:
        sample = val_ds[idx]
        img_t   = sample[0]          # [3,640,640]
        gt_mask = sample[2]          # <-- if your gt mask is not at [2], change THIS ONE LINE
        gt_dn = get_dn_gt(sample)   # 0=day, 1=night

        if torch.is_tensor(gt_mask):
            gt_mask = gt_mask[0].cpu().numpy().astype(np.uint8)

        img_np = (img_t.permute(1,2,0).cpu().numpy() * 255).astype(np.uint8)
        H, W = img_np.shape[:2]

        gt_boxes = get_gt_boxes_from_sample(sample, H, W)  # list of (cls,x1,y1,x2,y2)


        pad_mask = get_pad_mask_rgb_u8(img_np, pad_val=114, tol=3)
        nonpad_mask = ~pad_mask

        x = img_t.unsqueeze(0).to(device)

        # clean forward (collect attention)
        _att["logits_list"].clear()
        det_out, seg_logits, tag_logits = model(x)

        pred_det_clean = det_predict_boxes(det_eval, img_np, conf=0.25, iou=0.6)
        ap_clean = ap50_single_image(pred_det_clean, gt_boxes, iou_thr=0.50)

        p_clean, pred_clean_dn = tag_prob_night_and_pred(tag_logits)
        tag_ok_clean.append(int(pred_clean_dn == gt_dn))
        tag_pok_clean.append(prob_of_correct_class(p_clean, gt_dn))




        # fuse attention (sigmoid each scale, upsample, weighted avg)
        fused = torch.zeros((H, W), device=device, dtype=torch.float32)
        wsum = 0.0
        for logits in _att["logits_list"]:
            sa = torch.sigmoid(logits)[0,0]  # [h,w]
            h, w = int(sa.shape[-2]), int(sa.shape[-1])
            sa_up = F.interpolate(sa[None,None], size=(H,W), mode="bilinear", align_corners=False)[0,0]
            weight = float(h*w)
            fused += weight * sa_up
            wsum  += weight
        fused = fused / (wsum + 1e-6)
        fused = fused.masked_fill(torch.from_numpy(pad_mask).to(device), 0.0)
        att01 = fused.detach().cpu().numpy().astype(np.float32)

        vals = att01[nonpad_mask]
        lo = np.quantile(vals, 0.05)
        hi = np.quantile(vals, 0.95)
        att01 = np.clip((att01 - lo) / (hi - lo + 1e-6), 0.0, 1.0).astype(np.float32)
        att01[pad_mask] = 0.0


        # segmentation clean
        pred_clean = (torch.sigmoid(seg_logits)[0,0] > 0.5).detach().cpu().numpy().astype(np.uint8)
        iou_clean = iou_bin(pred_clean, gt_mask)


      # masks (fixed-size patches)
        top_pix = make_top_mask(att01, nonpad_mask, TOP_P)   # TOP_P is fine (0.30)

        ys, xs = np.where(top_pix)
        if len(ys) == 0:
            continue
        cy = int(np.mean(ys))
        cx = int(np.mean(xs))

        top_mask = patch_mask_from_center(nonpad_mask, cy, cx, PATCH_H, PATCH_W)


        rand_mask = random_patch_mask(nonpad_mask, PATCH_H, PATCH_W)
        if rand_mask is None:
            continue


        # perturbed inputs
        x_top,  img_top_np  = apply_mask_to_x_from_img(x, img_np, top_mask)
        x_rand, img_rand_np = apply_mask_to_x_from_img(x, img_np, rand_mask)



        # forwards perturbed
        _, seg_top,  tag_top  = model(x_top)
        _, seg_rand, tag_rand = model(x_rand)

        p_top,  pred_top_dn  = tag_prob_night_and_pred(tag_top)
        p_rand, pred_rand_dn = tag_prob_night_and_pred(tag_rand)

        tag_ok_top.append(int(pred_top_dn == gt_dn))
        tag_ok_rand.append(int(pred_rand_dn == gt_dn))

        tag_pok_top.append(prob_of_correct_class(p_top, gt_dn))
        tag_pok_rand.append(prob_of_correct_class(p_rand, gt_dn))



        pred_det_top  = det_predict_boxes(det_eval, img_top_np,  conf=0.25, iou=0.6)
        pred_det_rand = det_predict_boxes(det_eval, img_rand_np, conf=0.25, iou=0.6)

        ap_top  = ap50_single_image(pred_det_top,  gt_boxes, iou_thr=0.50)
        ap_rand = ap50_single_image(pred_det_rand, gt_boxes, iou_thr=0.50)

        ap_drops_top.append(ap_clean - ap_top)
        ap_drops_rand.append(ap_clean - ap_rand)



        pred_top  = (torch.sigmoid(seg_top)[0,0] > 0.5).detach().cpu().numpy().astype(np.uint8)
        pred_rand = (torch.sigmoid(seg_rand)[0,0] > 0.5).detach().cpu().numpy().astype(np.uint8)

        iou_top  = iou_bin(pred_top, gt_mask)
        iou_rand = iou_bin(pred_rand, gt_mask)

        drops_top.append(iou_clean - iou_top)
        drops_rand.append(iou_clean - iou_rand)

# cleanup hooks
for h in handles:
    h.remove()

print("Images used:", len(drops_top))
print("Mean IoU drop (top-att):", float(np.mean(drops_top)) if drops_top else None)
print("Mean IoU drop (random):",  float(np.mean(drops_rand)) if drops_rand else None)
print("Mean AP50 drop (top-att):",  float(np.mean(ap_drops_top)) if ap_drops_top else None)
print("Mean AP50 drop (random):",   float(np.mean(ap_drops_rand)) if ap_drops_rand else None)

def mean(x):
    return float(np.mean(x)) if len(x) else None

acc_clean = mean(tag_ok_clean)
acc_top   = mean(tag_ok_top)
acc_rand  = mean(tag_ok_rand)

print("Tag Acc clean:", acc_clean)
print("Tag Acc top:  ", acc_top,  "  drop:", (acc_clean - acc_top) if acc_clean is not None else None)
print("Tag Acc rand: ", acc_rand, "  drop:", (acc_clean - acc_rand) if acc_clean is not None else None)

pok_clean = mean(tag_pok_clean)
pok_top   = mean(tag_pok_top)
pok_rand  = mean(tag_pok_rand)

print("Tag P(correct) clean:", pok_clean)
print("Tag P(correct) top:  ", pok_top,  "  drop:", (pok_clean - pok_top) if pok_clean is not None else None)
print("Tag P(correct) rand: ", pok_rand, "  drop:", (pok_clean - pok_rand) if pok_clean is not None else None)



Model summary (fused): 73 layers, 3,006,623 parameters, 0 gradients, 8.1 GFLOPs
Images used: 1000
Mean IoU drop (top-att): 0.16929898551437883
Mean IoU drop (random): 0.09067801369320716
Mean AP50 drop (top-att): 0.019582364522153513
Mean AP50 drop (random): 0.009320265210873913
Tag Acc clean: 0.908
Tag Acc top:   0.87   drop: 0.038000000000000034
Tag Acc rand:  0.876   drop: 0.03200000000000003
Tag P(correct) clean: 0.7318826522640884
Tag P(correct) top:   0.7159924722909927   drop: 0.01589017997309572
Tag P(correct) rand:  0.7138034129329026   drop: 0.018079239331185826
