<a href="https://colab.research.google.com/github/EliasNoorzad/XAI_Autonomous-Driving/blob/main/train/Det%2BSeg%2Btag__(with_attention).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Install and use Ultralytics YOLO version 8.4.6 (used throughout the project)

In [None]:
!pip -q install ultralytics==8.4.6


[?25l   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m0.0/1.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m[91m‚ï∏[0m [32m1.2/1.2 MB[0m [31m42.0 MB/s[0m eta [36m0:00:01[0m[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m1.2/1.2 MB[0m [31m27.6 MB/s[0m eta [36m0:00:00[0m
[?25h

Mount Google Drive in Colab

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)


Mounted at /content/drive


Copy the dataset zip from Drive to the Colab workspace

In [None]:
!cp /content/drive/MyDrive/XAI_Project/BDD100K_640.zip /content/

Copy the day/night label file from Drive to the Colab workspace.

In [None]:
!cp /content/drive/MyDrive/XAI_Project/daynight_labels.csv /content/

Unzip the dataset into /content/BDD100K_640

In [None]:
!unzip -q /content/BDD100K_640.zip -d /content/BDD100K_640

Create and save the YOLO dataset config file (dataset.yaml).

In [None]:
yaml = """\
path: /content/BDD100K_640/yolo_640
train: train/images
val: val/images
test: test/images

nc: 5
names: [car, truck, bus, person, bike]
"""

with open("/content/BDD100K_640/yolo_640/dataset_640.yaml", "w") as f:
    f.write(yaml)

Extend the dataset to support the tri-task setting by loading day/night labels from a CSV, filtering out unlabeled images, and returning each sample as (image, YOLO labels, drivable mask, day/night class).

In [None]:
import csv
import numpy as np
from pathlib import Path
from PIL import Image
import torch
from torch.utils.data import Dataset


class BDDDetDrivableDataset(Dataset):
    """
    For preprocessed 640 dataset (yolo_640 + drivable_masks_640):
      images: <yolo_root>/<split>/images/<stem>.jpg
      labels: <yolo_root>/<split>/labels/<stem>.txt
      masks : <mask_root>/<split>/<stem>.png

    Returns:
      img   : FloatTensor [3, H, W] in [0,1]
      labels: FloatTensor [N, 5] where each row is [cls, x, y, w, h] (YOLO normalized)
      mask  : FloatTensor [1, H, W] with values 0/1
      dn    : LongTensor scalar (0=day, 1=night) - unlabeled images are filtered out
    """
    def __init__(
        self,
        yolo_root: str,
        mask_root: str,
        split: str,
        imgsz: int = 640,
        dn_csv_path: str | None = None,
    ):
        self.yolo_root = Path(yolo_root)
        self.mask_root = Path(mask_root)
        self.split = split
        self.imgsz = int(imgsz)

        self.img_dir = self.yolo_root / split / "images"
        self.lbl_dir = self.yolo_root / split / "labels"
        self.msk_dir = self.mask_root / split

        if not self.img_dir.is_dir():
            raise FileNotFoundError(f"Missing images dir: {self.img_dir}")
        if not self.lbl_dir.is_dir():
            raise FileNotFoundError(f"Missing labels dir: {self.lbl_dir}")
        if not self.msk_dir.is_dir():
            raise FileNotFoundError(f"Missing masks dir:  {self.msk_dir}")

        exts = {".jpg", ".jpeg", ".png"}
        self.img_paths = sorted([p for p in self.img_dir.iterdir() if p.suffix.lower() in exts])
        if len(self.img_paths) == 0:
            raise FileNotFoundError(f"No images found in: {self.img_dir}")

        # day/night mapping from CSV
        if dn_csv_path is None:
            raise RuntimeError("dn_csv_path is required for this dataset (day/night head training).")

        dn_csv_path = Path(dn_csv_path)
        if not dn_csv_path.exists():
            raise FileNotFoundError(f"Missing day/night CSV: {dn_csv_path}")

        dn_map = {}
        with open(dn_csv_path, "r", newline="", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            required = {"split", "image_id", "label"}
            if not required.issubset(set(reader.fieldnames or [])):
                raise ValueError(f"dn CSV must have columns {required}, got {reader.fieldnames}")

            for row in reader:
                if row["split"] != self.split:
                    continue

                image_id = row["image_id"].strip()   # stem (no extension)
                lab = row["label"].strip().lower()

                if lab == "day":
                    dn = 0
                elif lab == "night":
                    dn = 1
                else:
                    # if CSV contains anything else, it's a data error
                    raise ValueError(f"Invalid dn label in CSV for {image_id}: {row['label']}")

                dn_map[image_id] = dn

        self.dn_map = dn_map

        # filtering out unlabeled images
        before = len(self.img_paths)
        self.img_paths = [p for p in self.img_paths if p.stem in self.dn_map]
        after = len(self.img_paths)
        if after == 0:
            raise RuntimeError(f"No labeled (day/night) images found for split='{self.split}'.")


    def __len__(self):
        return len(self.img_paths)

    @staticmethod
    def _read_yolo_labels(label_path: Path) -> torch.Tensor:
        if not label_path.exists():
            return torch.zeros((0, 5), dtype=torch.float32)

        rows = []
        with open(label_path, "r") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                parts = line.split()
                if len(parts) != 5:
                    continue
                cls, x, y, w, h = parts
                rows.append([float(cls), float(x), float(y), float(w), float(h)])

        if len(rows) == 0:
            return torch.zeros((0, 5), dtype=torch.float32)
        return torch.tensor(rows, dtype=torch.float32)

    @staticmethod
    def _pil_to_chw_float(img: Image.Image) -> torch.Tensor:
        arr = np.array(img, dtype=np.float32) / 255.0
        arr = np.transpose(arr, (2, 0, 1))
        return torch.from_numpy(arr)

    def __getitem__(self, idx: int):
        img_path = self.img_paths[idx]
        stem = img_path.stem

        label_path = self.lbl_dir / f"{stem}.txt"
        mask_path = self.msk_dir / f"{stem}.png"

        if not mask_path.exists():
            raise FileNotFoundError(f"Missing mask for {stem}: {mask_path}")

        img = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")

        labels = self._read_yolo_labels(label_path)

        if img.size != (self.imgsz, self.imgsz):
            img = img.resize((self.imgsz, self.imgsz), resample=Image.BILINEAR)
        if mask.size != (self.imgsz, self.imgsz):
            mask = mask.resize((self.imgsz, self.imgsz), resample=Image.NEAREST)

        img_t = self._pil_to_chw_float(img)
        mask_np = (np.array(mask, dtype=np.uint8) > 0).astype(np.float32)
        mask_t = torch.from_numpy(mask_np)[None, :, :]

        # dn is always valid because we filtered img_paths
        dn = torch.tensor(self.dn_map[stem], dtype=torch.long)

        return img_t, labels, mask_t, dn



Sanity-check the tri-task dataset by loading one training sample and printing tensor shapes, value ranges, and its day/night label.

In [None]:
ds = BDDDetDrivableDataset(
    yolo_root="/content/BDD100K_640/yolo_640",
    mask_root="/content/BDD100K_640/drivable_masks_640",
    split="train",
    imgsz=640,
    dn_csv_path="/content/daynight_labels.csv"
)

img, labels, mask, dn = ds[0]
print(img.shape, labels.shape, mask.shape, dn)
print(img.min().item(), img.max().item(), mask.unique())


torch.Size([3, 640, 640]) torch.Size([2, 5]) torch.Size([1, 640, 640]) tensor(0)
0.0 1.0 tensor([0., 1.])


Sanity-check the day/night labels: confirm the training split size (64,828 samples) and verify that labels in the first 200 samples are valid ({0,1} only).

In [None]:
print("len(ds) =", len(ds))

# check first 200 samples dn values (fast)
vals = set()
for i in range(min(200, len(ds))):
    _, _, _, dn_i = ds[i]
    vals.add(int(dn_i))

print("dn values in first batch =", vals)
assert vals.issubset({0, 1}), f"Found invalid dn values: {vals}"


len(ds) = 64828
dn values in first batch = {0, 1}


the CBAM attention block (channel attention + spatial attention) and store the spatial attention map (last_sa) for later attention overlays.

In [None]:
import torch
import torch.nn as nn


class ChannelAttention(nn.Module):
    def __init__(self, channels: int, reduction: int = 16):
        super().__init__()
        hidden = max(channels // reduction, 1)

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        # shared MLP (implemented with 1x1 convs)
        self.mlp = nn.Sequential(
            nn.Conv2d(channels, hidden, kernel_size=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(hidden, channels, kernel_size=1, bias=False),
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        avg_out = self.mlp(self.avg_pool(x))
        max_out = self.mlp(self.max_pool(x))
        w = self.sigmoid(avg_out + max_out)  # BxCx1x1
        return x * w


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size: int = 7):
        super().__init__()
        assert kernel_size in (3, 7)
        padding = (kernel_size - 1) // 2

        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()
        self.last_sa = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        mean_map = torch.mean(x, dim=1, keepdim=True)
        max_map, _ = torch.max(x, dim=1, keepdim=True)
        m = torch.cat([mean_map, max_map], dim=1)

        w = self.sigmoid(self.conv(m))  # Bx1xHxW in [0,1]
        self.last_sa = w.detach()
        return x * w


class CBAM(nn.Module):
    def __init__(self, channels: int, reduction: int = 16, spatial_kernel: int = 7):
        super().__init__()
        self.ca = ChannelAttention(channels, reduction=reduction)
        self.sa = SpatialAttention(kernel_size=spatial_kernel)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.ca(x)
        x = self.sa(x)
        return x


Sanity-check CBAM on a random tensor to confirm the output shape matches the input and contains no NaNs ([2, 128, 80, 80]).

In [None]:
import torch


x = torch.randn(2, 128, 80, 80)
m = CBAM(channels=128, reduction=16, spatial_kernel=7)

y = m(x)

assert y.shape == x.shape, (x.shape, y.shape)
assert not torch.isnan(y).any(), "NaNs found in CBAM output"
print("CBAM sanity OK:", y.shape)


CBAM sanity OK: torch.Size([2, 128, 80, 80])


Define the tri-task YOLOv8 model (detection + drivable-area segmentation + day/night classification) with optional CBAM: apply CBAM at the backbone‚Äìneck boundary and on all neck feature scales, use the highest-resolution neck feature for segmentation logits, and the lowest-resolution feature for the day/night head.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from ultralytics import YOLO


class YOLOv8DetSemSeg(nn.Module):
    """
    YOLOv8 detection model + tiny semantic seg head.
    Captures NECK features by hooking the Detect head INPUT (multi-scale features).
    """
    def __init__(self, yolo_weights: str = "yolov8n.pt", use_cbam: bool = False):
        super().__init__()
        self.yolo = YOLO(yolo_weights).model  # nn.Module
        self.use_cbam = use_cbam

        self.cbam_backbone = None  # Point 1 (after last backbone block)
        self.cbam_neck = None      # Point 2 (CBAM on ALL neck feature maps)

        self._neck_feats = None

        self.sem_head = nn.Sequential(
            nn.LazyConv2d(64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 1, kernel_size=1)
        )

        self.dn_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # [B,C,1,1]
            nn.Flatten(1),            # [B,C]
            nn.LazyLinear(2)          # [B,2] logits: [day, night]
        )

        self._register_backbone_hook_point1()
        self._register_detect_input_hook()

    def _register_backbone_hook_point1(self):
        idx_up = None
        for i, m in enumerate(self.yolo.model):
            if isinstance(m, nn.Upsample) or "upsample" in m.__class__.__name__.lower():
                idx_up = i
                break
        if idx_up is None or idx_up == 0:
            raise RuntimeError("Could not find neck start (Upsample) to place backbone CBAM.")

        backbone_last = self.yolo.model[idx_up - 1]

        if hasattr(self, "_bb_hook_handle") and self._bb_hook_handle is not None:
            self._bb_hook_handle.remove()

        def fwd_hook(module, inputs, output):
            if not self.use_cbam:
                return None
            if self.cbam_backbone is None:
                self.cbam_backbone = CBAM(channels=output.shape[1]).to(output.device)
            return self.cbam_backbone(output)

        self._bb_hook_handle = backbone_last.register_forward_hook(fwd_hook)

    def _register_detect_input_hook(self):
        if not hasattr(self.yolo, "model"):
            raise RuntimeError("Unexpected Ultralytics model: no .model")

        detect_module = self.yolo.model[-1]
        name = detect_module.__class__.__name__.lower()
        if "detect" not in name:
            raise RuntimeError(f"Last module is not Detect (got {detect_module.__class__.__name__}).")

        if hasattr(self, "_detect_hook_handle") and self._detect_hook_handle is not None:
            self._detect_hook_handle.remove()

        def pre_hook(module, inputs):
            feats = list(inputs[0])  # list of multiscale neck features

            if self.use_cbam:
                # one CBAM per scale
                if self.cbam_neck is None:
                    self.cbam_neck = nn.ModuleList([CBAM(f.shape[1]).to(f.device) for f in feats])
                # apply to ALL scales
                feats = [m(f) for m, f in zip(self.cbam_neck, feats)]

            self._neck_feats = feats
            return (feats,) if self.use_cbam else None

        self._detect_hook_handle = detect_module.register_forward_pre_hook(pre_hook)

    @staticmethod
    def _pick_high_res_from_detect_inputs(feats):
        if not isinstance(feats, (list, tuple)) or len(feats) == 0:
            raise RuntimeError("Detect input features not captured.")
        return max(feats, key=lambda t: t.shape[-2] * t.shape[-1])

    @staticmethod
    def _pick_low_res_from_detect_inputs(feats):
        if not isinstance(feats, (list, tuple)) or len(feats) == 0:
            raise RuntimeError("Detect input features not captured.")
        return min(feats, key=lambda t: t.shape[-2] * t.shape[-1])

    def forward(self, x):
        # TRAIN: x is a batch dict -> YOLO returns (det_loss, loss_items)
        if isinstance(x, dict):
            self._neck_feats = None
            imgs = x["img"]
            det_loss, det_items = self.yolo(x)

            feat_seg = self._pick_high_res_from_detect_inputs(self._neck_feats)
            seg_logits = self.sem_head(feat_seg)
            seg_logits = F.interpolate(seg_logits, size=imgs.shape[-2:], mode="bilinear", align_corners=False)

            feat_dn = self._pick_low_res_from_detect_inputs(self._neck_feats)
            dn_logits = self.dn_head(feat_dn)

            return det_loss, det_items, seg_logits, dn_logits

        # x is an image tensor -> YOLO returns preds
        self._neck_feats = None
        det_preds = self.yolo(x)

        feat_seg = self._pick_high_res_from_detect_inputs(self._neck_feats)
        seg_logits = self.sem_head(feat_seg)
        seg_logits = F.interpolate(seg_logits, size=x.shape[-2:], mode="bilinear", align_corners=False)

        feat_dn = self._pick_low_res_from_detect_inputs(self._neck_feats)
        dn_logits = self.dn_head(feat_dn)

        return det_preds, seg_logits, dn_logits



Creating new Ultralytics Settings v0.0.6 file ‚úÖ 
View Ultralytics Settings with 'yolo settings' or at '/root/.config/Ultralytics/settings.json'
Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.


Sanity check for the tri-task CBAM model: load the best det+seg(+CBAM) checkpoint, run one labeled sample, and confirm the segmentation and day/night heads produce the expected shapes ([1,1,640,640] and [1,2]) and a valid day/night prediction for that example.

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# using det+seg(+CBAM) best, not det_baseline
cbam_best = "/content/drive/MyDrive/XAI_Project/experiments/det_seg_cbam_640/weights/best.pt"

model = YOLOv8DetSemSeg(
    yolo_weights=cbam_best,
    use_cbam=True
).to(device).eval()

# build Lazy + CBAM modules once
with torch.no_grad():
    _ = model(torch.zeros(1, 3, 640, 640, device=device))

# loading weights correctly (ema/model might be Module)
ckpt = torch.load(cbam_best, map_location=device)
src = (ckpt.get("ema") or ckpt.get("model") or ckpt.get("state_dict") or ckpt) if isinstance(ckpt, dict) else ckpt
sd = src.state_dict() if isinstance(src, torch.nn.Module) else src
model.load_state_dict(sd, strict=True)
model.eval()

# testing one sample
img, labels, mask, dn = ds[0]
x = img.unsqueeze(0).to(device)

with torch.no_grad():
    det_out, seg_logits, dn_logits = model(x)

print(seg_logits.shape)            # [1, 1, 640, 640]
print(dn_logits.shape)             # [1, 2]
print("gt:", dn, "pred:", torch.argmax(dn_logits, dim=1).item())



torch.Size([1, 1, 640, 640])
torch.Size([1, 2])
gt: 0 pred: 0


Collate function for tri-task training: batch images and masks, pack variable-length YOLO boxes into a YOLO-compatible batch dict, and return day/night labels separately so they don‚Äôt interfere with YOLO‚Äôs loss computation.

In [None]:
import torch

def collate_det_seg(batch):
    # batch items: (img, labels, mask, dn)
    imgs, labels_list, masks, dns = zip(*batch)

    imgs = torch.stack(imgs, 0)         # [B,3,H,W]
    masks = torch.stack(masks, 0)       # [B,1,H,W]
    dn = torch.tensor(dns, dtype=torch.long)  # [B]

    bboxes_all, cls_all, batch_idx_all = [], [], []
    for i, lab in enumerate(labels_list):
        if lab.numel() == 0:
            continue
        cls = lab[:, 0:1].long()
        bboxes = lab[:, 1:5].float()
        bboxes_all.append(bboxes)
        cls_all.append(cls)
        batch_idx_all.append(torch.full((lab.shape[0],), i, dtype=torch.long))

    if len(bboxes_all):
        bboxes = torch.cat(bboxes_all, 0)
        cls = torch.cat(cls_all, 0)
        batch_idx = torch.cat(batch_idx_all, 0)
    else:
        bboxes = torch.zeros((0, 4), dtype=torch.float32)
        cls = torch.zeros((0, 1), dtype=torch.long)
        batch_idx = torch.zeros((0,), dtype=torch.long)


    yolo_batch = {"img": imgs, "bboxes": bboxes, "cls": cls, "batch_idx": batch_idx}

    # return extras separately (so YOLO loss doesn't see them)
    return yolo_batch, masks, dn



In [None]:
import torch
import torch.nn.functional as F

def train_one_epoch(model, loader, optimizer, device, scaler, lambda_seg=1, lambda_dn=0.1):
    model.train()
    tot_loss = tot_det = tot_seg = tot_dn = 0.0

    for det_batch, mask, dn in loader:
        det_batch = {k: v.to(device, non_blocking=True) for k, v in det_batch.items()}
        mask = mask.to(device, non_blocking=True)
        dn = dn.to(device, non_blocking=True).long()

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast(device_type="cuda", enabled=torch.cuda.is_available()):
            det_loss, _, seg_logits, dn_logits = model(det_batch)
            det_loss = det_loss.mean()

            seg_loss = F.binary_cross_entropy_with_logits(seg_logits, mask)


            valid = (dn >= 0)
            dn_loss = F.cross_entropy(dn_logits[valid], dn[valid]) if valid.any() else torch.zeros((), device=device)

            loss = det_loss + lambda_seg * seg_loss + lambda_dn * dn_loss

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        tot_loss += loss.item()
        tot_det  += det_loss.item()
        tot_seg  += seg_loss.item()
        tot_dn   += dn_loss.item()

    n = max(1, len(loader))
    return (tot_loss / n), (tot_det / n), (tot_seg / n), (tot_dn / n)



Compute mean validation IoU for drivable-area segmentation in the tri-task setup by running the model on YOLO-style batches and comparing thresholded segmentation predictions against ground-truth masks (optionally on a limited number of batches).

In [None]:
import torch

@torch.no_grad()
def val_iou(model, loader, device, max_batches=None):
    model.eval()
    total_iou = 0.0
    total_imgs = 0

    for bi, (det_batch, mask, dn) in enumerate(loader):
        if (max_batches is not None) and (bi >= max_batches):
            break

        det_batch = {k: v.to(device, non_blocking=True) for k, v in det_batch.items()}
        gt = (mask.to(device, non_blocking=True) > 0.5).float()

        det_loss, det_items, seg_logits, dn_logits = model(det_batch)  # dict-path returns 4
        pred = (torch.sigmoid(seg_logits) > 0.5).float()

        inter = (pred * gt).sum(dim=(1, 2, 3))
        union = ((pred + gt) > 0).float().sum(dim=(1, 2, 3)).clamp_min(1.0)

        iou = inter / union
        total_iou += iou.sum().item()
        total_imgs += iou.numel()

    return total_iou / max(1, total_imgs)


Initialize the tri-task training dataset (images, YOLO labels, drivable masks, and day/night labels from CSV).

In [None]:
train_ds = BDDDetDrivableDataset(
    yolo_root="/content/BDD100K_640/yolo_640",
    mask_root="/content/BDD100K_640/drivable_masks_640",
    split="train",
    imgsz=640,
    dn_csv_path="/content/daynight_labels.csv"
)


Create the tri-task training DataLoader using the custom collate function that returns a YOLO batch dict plus segmentation masks and day/night labels.

In [None]:
from torch.utils.data import DataLoader
import torch

train_loader = DataLoader(
    train_ds,
    batch_size=8,
    shuffle=True,
    num_workers=2,
    pin_memory=torch.cuda.is_available(),
    persistent_workers=False,
    collate_fn=collate_det_seg
)


Initialize the tri-task CBAM wrapper with the Ultralytics trainer model, re-register hooks, then load the pretrained det+seg checkpoint while skipping incompatible keys (new day/night head and the redesigned multi-scale cbam_neck) so training can continue with the new heads/modules.

In [None]:

import torch, copy
from ultralytics.models.yolo.detect.train import DetectionTrainer

device = "cuda" if torch.cuda.is_available() else "cpu"

stageA_pt   = "/content/drive/MyDrive/XAI_Project/experiments/det_seg_cbam_640/weights/best.pt"
det_base_pt = "/content/drive/MyDrive/XAI_Project/experiments/det_baseline/weights/best.pt"

data_yaml = "/content/BDD100K_640/yolo_640/dataset_640.yaml"
imgsz = 640

# wrapper
model = YOLOv8DetSemSeg(yolo_weights=det_base_pt, use_cbam=True).to(device)

# trainer
trainer = DetectionTrainer(overrides={
    "model": det_base_pt,
    "data":  data_yaml,
    "imgsz": imgsz,
    "device": 0,
})
trainer.setup_model()
trainer.model.args = trainer.args
trainer.model.init_criterion()

# swap
model.yolo = trainer.model.to(device)

# re-hook
model._neck_feats = None
model.cbam_backbone = None
model.cbam_neck = None
model._register_backbone_hook_point1()
model._register_detect_input_hook()

# materialize Lazy + CBAM modules once
with torch.no_grad():
    _ = model(torch.zeros(1, 3, imgsz, imgsz, device=device))

# load  weights but drop incompatible keys
ckpt = torch.load(stageA_pt, map_location=device)
sd = ckpt["state_dict"] if isinstance(ckpt, dict) and "state_dict" in ckpt else ckpt
sd2 = copy.deepcopy(sd)

# drop dn head
sd2.pop("dn_head.2.weight", None)
sd2.pop("dn_head.2.bias", None)

# drop ALL old cbam_neck.* weights (they don't match new per-scale CBAM channels)
for k in list(sd2.keys()):
    if k.startswith("cbam_neck."):
        sd2.pop(k)

# load remaining weights
missing, unexpected = model.load_state_dict(sd2, strict=False)
print("missing:", missing)
print("unexpected:", unexpected)

print("OK: loaded Stage-A (det+seg) weights. cbam_neck will be randomly init for new 3-scale design.")



Ultralytics 8.4.6 üöÄ Python-3.12.12 torch-2.9.0+cu126 CUDA:0 (Tesla T4, 15095MiB)
[34m[1mengine/trainer: [0magnostic_nms=False, amp=True, angle=1.0, augment=False, auto_augment=randaugment, batch=16, bgr=0.0, box=7.5, cache=False, cfg=None, classes=None, close_mosaic=10, cls=0.5, compile=False, conf=None, copy_paste=0.0, copy_paste_mode=flip, cos_lr=False, cutmix=0.0, data=/content/BDD100K_640/yolo_640/dataset_640.yaml, degrees=0.0, deterministic=True, device=0, dfl=1.5, dnn=False, dropout=0.0, dynamic=False, embed=None, epochs=100, erasing=0.4, exist_ok=False, fliplr=0.5, flipud=0.0, format=torchscript, fraction=1.0, freeze=None, half=False, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, imgsz=640, int8=False, iou=0.7, keras=False, kobj=1.0, line_width=None, lr0=0.01, lrf=0.01, mask_ratio=4, max_det=300, mixup=0.0, mode=train, model=/content/drive/MyDrive/XAI_Project/experiments/det_baseline/weights/best.pt, momentum=0.937, mosaic=1.0, multi_scale=0.0, name=train, nbs=64, nms=False, opset=No

Sanity-check the tri-task model inference interface: confirm it returns three outputs (detection predictions, segmentation logits [1,1,640,640], and day/night logits [1,2]).

In [None]:
# test: inference outputs count
x = torch.randn(1, 3, 640, 640, device=device)
out = model(x)
print(len(out), [o.shape if torch.is_tensor(o) else type(o) for o in out])
# 3 outputs, and dn_logits shape [1,2]


3 [<class 'dict'>, torch.Size([1, 1, 640, 640]), torch.Size([1, 2])]


Create an optimizer that updates only the day/night classification head (dn_head).

In [None]:
optimizer = torch.optim.AdamW(
    [p for p in model.dn_head.parameters() if p.requires_grad],
    lr=1e-4,
    weight_decay=5e-4
)


Build the validation dataset and DataLoader for the tri-task setup (including day/night labels) using the same collate function as training.

In [None]:
val_ds = BDDDetDrivableDataset(
    yolo_root="/content/BDD100K_640/yolo_640",
    mask_root="/content/BDD100K_640/drivable_masks_640",
    split="val",
    imgsz=640,
    dn_csv_path="/content/daynight_labels.csv"
)

from torch.utils.data import DataLoader
import torch

val_loader = DataLoader(
    val_ds,
    batch_size=8,
    shuffle=False,
    num_workers=2,
    pin_memory=torch.cuda.is_available(),
    persistent_workers=False,
    collate_fn=collate_det_seg
)

Sanity-check the tri-task wrapper with CBAM disabled vs enabled, confirming both configurations produce segmentation logits [1,1,640,640] and day/night logits [1,2] in inference.

In [None]:
best_pt = "/content/drive/MyDrive/XAI_Project/experiments/det_baseline/weights/best.pt"
device = "cuda" if torch.cuda.is_available() else "cpu"

img, labels, mask, dn = ds[0]
x = img.unsqueeze(0).to(device)

# CBAM OFF
m0 = YOLOv8DetSemSeg(best_pt, use_cbam=False).to(device).eval()
with torch.no_grad():
    det0, seg0, dn0 = m0(x)
print("OFF seg:", seg0.shape, "dn:", dn0.shape)

# CBAM ON
m1 = YOLOv8DetSemSeg(best_pt, use_cbam=True).to(device).eval()
with torch.no_grad():
    det1, seg1, dn1 = m1(x)
print("ON  seg:", seg1.shape, "dn:", dn1.shape)


OFF seg: torch.Size([1, 1, 640, 640]) dn: torch.Size([1, 2])
ON  seg: torch.Size([1, 1, 640, 640]) dn: torch.Size([1, 2])


One-batch sanity run for day/night training only: freeze all parameters except dn_head, run a single training batch with AMP, and print the cross-entropy loss and accuracy (loss=0.706, acc=0.375, logits shape (8,2)) to confirm the head-only update works.

In [None]:
# 1-epoch sanity test for dn only

import torch
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# freeze all
for p in model.parameters():
    p.requires_grad = False
# unfreeze dn head only
for p in model.dn_head.parameters():
    p.requires_grad = True

optimizer = torch.optim.AdamW(
    [p for p in model.dn_head.parameters() if p.requires_grad],
    lr=1e-4, weight_decay=5e-4
)
scaler = torch.amp.GradScaler(enabled=torch.cuda.is_available())

model.train()
tot_loss, tot_acc, nb = 0.0, 0.0, 0

for det_batch, mask, dn in train_loader:  # collate_det_seg -> (det_batch, mask, dn)
    img = det_batch["img"].to(device, non_blocking=True)
    dn  = dn.to(device, non_blocking=True).long()

    valid = (dn >= 0)
    if not valid.any():
        continue

    optimizer.zero_grad(set_to_none=True)

    with torch.amp.autocast(device_type="cuda", enabled=torch.cuda.is_available()):
        out = model(img)             # returns (det, seg, dn) in the model
        dn_logits = out[-1]          # last output
        loss = F.cross_entropy(dn_logits[valid], dn[valid])

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    with torch.no_grad():
        pred = torch.argmax(dn_logits[valid], dim=1)
        acc = (pred == dn[valid]).float().mean().item()

    tot_loss += loss.item()
    tot_acc += acc
    nb += 1

    print(f"ONE BATCH OK | loss={loss.item():.4f} acc={acc:.4f} | dn_logits={tuple(dn_logits.shape)}")
    break  # stop after 1 batch

print(f"Sanity 1-epoch (1 batch) done | avg_loss={tot_loss/max(1,nb):.4f} avg_acc={tot_acc/max(1,nb):.4f}")


ONE BATCH OK | loss=0.7060 acc=0.3750 | dn_logits=(8, 2)
Sanity 1-epoch (1 batch) done | avg_loss=0.7060 avg_acc=0.3750


Train the tri-task model‚Äôs day/night tagging head only for 30 epochs: freeze all other parameters, optimize cross-entropy on day/night labels with AMP, evaluate validation accuracy every 2 epochs, and save snapshots plus best/last checkpoints for resume.




In [None]:
import os
import torch
import torch.nn.functional as F


# train day/night only

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

save_root = "/content/drive/MyDrive/XAI_Project/experiments/det_seg_dn_cbam_640/weights"
os.makedirs(save_root, exist_ok=True)
best_path = os.path.join(save_root, "best.pt")
ckpt_path = os.path.join(save_root, "last.ckpt")

# freezing everything
for p in model.parameters():
    p.requires_grad = False

# Unfreezing only tagging head
for p in model.dn_head.parameters():
    p.requires_grad = True

# Optimizer (only dn_head)
optimizer = torch.optim.AdamW(
    [p for p in model.dn_head.parameters() if p.requires_grad],
    lr=1e-4,
    weight_decay=5e-4
)

# AMP scalar
scaler = torch.amp.GradScaler(enabled=torch.cuda.is_available())

# Resume
start_epoch = 1
best_dn_acc = -1.0

if os.path.exists(ckpt_path):
    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt["model"], strict=False)
    optimizer.load_state_dict(ckpt["optimizer"])
    best_dn_acc = ckpt.get("best_dn_acc", -1.0)
    if "scaler" in ckpt:
        scaler.load_state_dict(ckpt["scaler"])
    start_epoch = int(ckpt["epoch"]) + 1
    print(f"Resuming from epoch {start_epoch} | best_dn_acc={best_dn_acc:.4f}", flush=True)

# train Oone epoch (DN only, no det/seg loss)
def train_one_epoch_dn_only(model, loader, optimizer, device, scaler):
    model.train()
    tot_loss = 0.0
    tot_acc = 0.0
    n_batches = 0

    for det_batch, mask, dn in loader:   # collate_det_seg returns (det_batch, mask, dn)
        img = det_batch["img"].to(device, non_blocking=True)
        dn  = dn.to(device, non_blocking=True).long()

        valid = (dn >= 0)
        if not valid.any():
            continue

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast(device_type="cuda", enabled=torch.cuda.is_available()):
            out = model(img)  # inference-path: returns (det_preds, seg_logits, dn_logits) or (det_preds, seg_logits)
            dn_logits = out[-1]  # last output (model returns dn_logits last)
            dn_loss = F.cross_entropy(dn_logits[valid], dn[valid])

        scaler.scale(dn_loss).backward()
        scaler.step(optimizer)
        scaler.update()

        with torch.no_grad():
            pred = torch.argmax(dn_logits[valid], dim=1)
            acc = (pred == dn[valid]).float().mean().item()

        tot_loss += dn_loss.item()
        tot_acc += acc
        n_batches += 1

    n = max(1, n_batches)
    return tot_loss / n, tot_acc / n

# val DN accuracy
@torch.no_grad()
def val_dn_acc(model, loader, device, max_batches=None):
    model.eval()
    correct = 0
    total = 0

    for bi, (det_batch, mask, dn) in enumerate(loader):
        if (max_batches is not None) and (bi >= max_batches):
            break

        img = det_batch["img"].to(device, non_blocking=True)
        dn  = dn.to(device, non_blocking=True).long()

        valid = (dn >= 0)
        if not valid.any():
            continue

        out = model(img)
        dn_logits = out[-1]
        pred = torch.argmax(dn_logits[valid], dim=1)

        correct += (pred == dn[valid]).sum().item()
        total += valid.sum().item()

    return correct / max(1, total)

# Train loop
epochs = 30

for epoch in range(start_epoch, epochs + 1):
    train_dn_loss, train_dn_acc = train_one_epoch_dn_only(
        model=model,
        loader=train_loader,
        optimizer=optimizer,
        device=device,
        scaler=scaler
    )

    # validate every 2 epochs
    if epoch % 2 == 0:
        dn_acc = val_dn_acc(model, val_loader, device, max_batches=125)

        snap_path = os.path.join(save_root, f"epoch{epoch:02d}.pt")
        torch.save(model.state_dict(), snap_path)
        print(f"  saved SNAPSHOT: {snap_path}", flush=True)

        print(
            f"epoch {epoch:02d}/{epochs} | "
            f"train_dn_loss={train_dn_loss:.4f} train_dn_acc={train_dn_acc:.4f} | "
            f"val_dn_acc={dn_acc:.4f}",
            flush=True
        )

        # save BEST by dn accuracy
        if dn_acc > best_dn_acc:
            best_dn_acc = dn_acc
            torch.save(model.state_dict(), best_path)
            print(f"  saved BEST: {best_path} (best_dn_acc={best_dn_acc:.4f})", flush=True)
    else:
        print(
            f"epoch {epoch:02d}/{epochs} | "
            f"train_dn_loss={train_dn_loss:.4f} train_dn_acc={train_dn_acc:.4f} | "
            f"val_dn_acc=skip",
            flush=True
        )

    # save LAST (resume checkpoint)
    torch.save({
        "epoch": epoch,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "best_dn_acc": best_dn_acc,
        "scaler": scaler.state_dict(),
    }, ckpt_path)


epoch 01/30 | train_dn_loss=0.6710 train_dn_acc=0.5677 | val_dn_acc=skip
  saved SNAPSHOT: /content/drive/MyDrive/XAI_Project/experiments/det_seg_dn_cbam_640/weights/epoch02.pt
epoch 02/30 | train_dn_loss=0.6441 train_dn_acc=0.5919 | val_dn_acc=0.6470
  saved BEST: /content/drive/MyDrive/XAI_Project/experiments/det_seg_dn_cbam_640/weights/best.pt (best_dn_acc=0.6470)
epoch 03/30 | train_dn_loss=0.6200 train_dn_acc=0.6796 | val_dn_acc=skip
  saved SNAPSHOT: /content/drive/MyDrive/XAI_Project/experiments/det_seg_dn_cbam_640/weights/epoch04.pt
epoch 04/30 | train_dn_loss=0.5990 train_dn_acc=0.7419 | val_dn_acc=0.7430
  saved BEST: /content/drive/MyDrive/XAI_Project/experiments/det_seg_dn_cbam_640/weights/best.pt (best_dn_acc=0.7430)
epoch 05/30 | train_dn_loss=0.5796 train_dn_acc=0.7752 | val_dn_acc=skip
  saved SNAPSHOT: /content/drive/MyDrive/XAI_Project/experiments/det_seg_dn_cbam_640/weights/epoch06.pt
epoch 06/30 | train_dn_loss=0.5613 train_dn_acc=0.8005 | val_dn_acc=0.8390
  saved 

Final tri-task evaluation script: load the best CBAM tri-task checkpoint into the YOLOv8 wrapper (with trainer swap and hooks), evaluate detection with Ultralytics val() on validation and test splits, and compute custom segmentation mean IoU plus day/night accuracy, balanced accuracy, F1, and confusion counts on validation and test loaders.

In [None]:
# FINAL TRI-TASK EVALUATION
# Detection: Ultralytics .val() on val + test
# Segmentation: mean IoU on val + test (binary)
# Day/Night: acc + balanced acc + F1 on val + test


import os
import torch
import torch.nn.functional as F
from ultralytics import YOLO
from ultralytics.models.yolo.detect.train import DetectionTrainer

device = "cuda" if torch.cuda.is_available() else "cpu"
dev_override = 0 if device == "cuda" else "cpu"

# paths
data_yaml = "/content/BDD100K_640/yolo_640/dataset_640.yaml"

best_wrapper_pt = "/content/drive/MyDrive/XAI_Project/experiments/det_seg_dn_cbam_640/weights/best.pt"

det_base_pt = "/content/drive/MyDrive/XAI_Project/experiments/det_baseline/weights/best.pt"

imgsz = 640
batch = 8


# 1) build model (wrapper + trainer swap) + load best_wrapper_pt
# =========================
model = YOLOv8DetSemSeg(yolo_weights=det_base_pt, use_cbam=True).to(device)

trainer = DetectionTrainer(overrides={
    "model": det_base_pt,
    "data":  data_yaml,
    "imgsz": imgsz,
    "device": dev_override,
    "batch": batch,
})
trainer.setup_model()
trainer.model.args = trainer.args
trainer.model.init_criterion()

# swap YOLO model into wrapper
model.yolo = trainer.model.to(device)

# re-hook
model._neck_feats = None
model.cbam_backbone = None
model.cbam_neck = None
model._register_backbone_hook_point1()
model._register_detect_input_hook()

# materialize Lazy + CBAM
with torch.no_grad():
    _ = model(torch.zeros(1, 3, imgsz, imgsz, device=device))

# load wrapper state_dict (non-strict because cbam_neck.0/1/2 + dn_head may differ)
sd = torch.load(best_wrapper_pt, map_location=device)
if isinstance(sd, dict) and "state_dict" in sd:
    sd = sd["state_dict"]
model.load_state_dict(sd, strict=False)
model.eval()

print("OK: loaded wrapper best:", best_wrapper_pt)


# 2) Detection eval (Ultralytics)

# attach trained detection model into a YOLO() wrapper and run .val()
yolo = YOLO(det_base_pt)
yolo.model = model.yolo  # use the swapped, trained model

print("\n=== DETECTION: VAL ===")
det_val = yolo.val(data=data_yaml, imgsz=imgsz, device=dev_override, split="val", batch=batch)

print("\n=== DETECTION: TEST ===")
try:
    det_test = yolo.val(data=data_yaml, imgsz=imgsz, device=dev_override, split="test", batch=batch)
except Exception as e:
    det_test = None
    print("No test split or error running test val():", repr(e))

# 3) deg + day/night eval

@torch.no_grad()
def eval_seg_dn(model, loader, device, max_batches=None):
    model.eval()

    # seg IoU
    inter_sum = 0.0
    union_sum = 0.0

    # dn metrics (ignore dn=-1)
    tp = tn = fp = fn = 0

    n_batches = 0
    for bi, (det_batch, mask, dn) in enumerate(loader):
        if (max_batches is not None) and (bi >= max_batches):
            break

        img = det_batch["img"].to(device, non_blocking=True)
        gt  = (mask.to(device, non_blocking=True) > 0.5).float()
        dn  = dn.to(device, non_blocking=True).long()

        out = model(img)            # (det, seg, dn)
        seg_logits = out[1]
        dn_logits  = out[-1]

        pred = (torch.sigmoid(seg_logits) > 0.5).float()

        inter = (pred * gt).sum(dim=(1,2,3))
        union = ((pred + gt) > 0).float().sum(dim=(1,2,3))
        inter_sum += inter.sum().item()
        union_sum += union.sum().item()

        valid = (dn >= 0)
        if valid.any():
            pred_dn = torch.argmax(dn_logits[valid], dim=1)
            true_dn = dn[valid]

            # positive class = 1 (night), negative = 0 (day)
            tp += int(((pred_dn == 1) & (true_dn == 1)).sum().item())
            tn += int(((pred_dn == 0) & (true_dn == 0)).sum().item())
            fp += int(((pred_dn == 1) & (true_dn == 0)).sum().item())
            fn += int(((pred_dn == 0) & (true_dn == 1)).sum().item())

        n_batches += 1

    miou = (inter_sum / max(1.0, union_sum))

    # dn acc/balanced acc/f1
    total = tp + tn + fp + fn
    acc = (tp + tn) / max(1, total)

    tpr = tp / max(1, (tp + fn))  # recall night
    tnr = tn / max(1, (tn + fp))  # recall day
    bal_acc = 0.5 * (tpr + tnr)

    prec = tp / max(1, (tp + fp))
    rec  = tpr
    f1 = (2 * prec * rec) / max(1e-12, (prec + rec))

    cm = {"tn": tn, "fp": fp, "fn": fn, "tp": tp}
    return miou, acc, bal_acc, f1, cm


print("\n=== SEG + DAY/NIGHT: VAL (custom) ===")
val_miou, val_acc, val_bal, val_f1, val_cm = eval_seg_dn(model, val_loader, device)
print(f"SEG mean IoU: {val_miou:.4f}")
print(f"DN acc: {val_acc:.4f} | balanced acc: {val_bal:.4f} | F1(night=1): {val_f1:.4f}")
print("DN confusion (tn,fp,fn,tp):", val_cm)

print("\n=== SEG + DAY/NIGHT: TEST (custom) ===")
try:
    test_miou, test_acc, test_bal, test_f1, test_cm = eval_seg_dn(model, test_loader, device)
    print(f"SEG mean IoU: {test_miou:.4f}")
    print(f"DN acc: {test_acc:.4f} | balanced acc: {test_bal:.4f} | F1(night=1): {test_f1:.4f}")
    print("DN confusion (tn,fp,fn,tp):", test_cm)
except NameError:
    print("No test_loader defined. (If you have a test split, build test_ds/test_loader like val_loader.)")



Ultralytics 8.4.6 üöÄ Python-3.12.12 torch-2.9.0+cu126 CUDA:0 (Tesla T4, 15095MiB)
[34m[1mengine/trainer: [0magnostic_nms=False, amp=True, angle=1.0, augment=False, auto_augment=randaugment, batch=8, bgr=0.0, box=7.5, cache=False, cfg=None, classes=None, close_mosaic=10, cls=0.5, compile=False, conf=None, copy_paste=0.0, copy_paste_mode=flip, cos_lr=False, cutmix=0.0, data=/content/BDD100K_640/yolo_640/dataset_640.yaml, degrees=0.0, deterministic=True, device=0, dfl=1.5, dnn=False, dropout=0.0, dynamic=False, embed=None, epochs=100, erasing=0.4, exist_ok=False, fliplr=0.5, flipud=0.0, format=torchscript, fraction=1.0, freeze=None, half=False, hsv_h=0.015, hsv_s=0.7, hsv_v=0.4, imgsz=640, int8=False, iou=0.7, keras=False, kobj=1.0, line_width=None, lr0=0.01, lrf=0.01, mask_ratio=4, max_det=300, mixup=0.0, mode=train, model=/content/drive/MyDrive/XAI_Project/experiments/det_baseline/weights/best.pt, momentum=0.937, mosaic=1.0, multi_scale=0.0, name=train2, nbs=64, nms=False, opset=No