In [None]:
%%writefile mask_rcnn_vgg13_wbc_colab.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os, sys, gc, time, math, csv, random
from pathlib import Path
from typing import Dict, List, Any

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch import amp

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

from PIL import Image
from tqdm.auto import tqdm
import albumentations as A
import cv2

from pycocotools.coco import COCO

import torchvision
from torchvision.transforms import functional as F
from torchvision.models.detection import MaskRCNN
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
from torchvision.models import ResNet101_Weights
from torch.cuda.amp import autocast, GradScaler

from collections import OrderedDict
from torchvision.models._utils import IntermediateLayerGetter
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool
from torchvision.models.detection.backbone_utils import BackboneWithFPN
from torchvision.models import (
    VGG11_Weights, VGG13_Weights, VGG16_Weights, VGG19_Weights
)


torch.backends.cudnn.benchmark = True

DBG_OVERFIT_TINY   = False
DBG_USE_GT_PROPOSALS = False
DBG_SKIP_NMS       = False
DBG_SCORE_THR      = 0.0


def set_torch_threads(n=4):
    os.environ.setdefault("OMP_NUM_THREADS", str(n))
    os.environ.setdefault("MKL_NUM_THREADS", str(n))
    try:
        torch.set_float32_matmul_precision("medium")
    except Exception:
        pass


def collate_fn(batch):
    return tuple(zip(*batch))


def coco_xywh_to_xyxy(box):
    x, y, w, h = box
    return [x, y, x + w, y + h]


def clamp_box_xyxy(box, W, H):
    x1, y1, x2, y2 = box
    x1 = max(0, min(int(x1), W - 1))
    y1 = max(0, min(int(y1), H - 1))
    x2 = max(0, min(int(x2), W - 1))
    y2 = max(0, min(int(y2), H - 1))
    if x2 <= x1:
        x2 = min(W - 1, x1 + 1)
    if y2 <= y1:
        y2 = min(H - 1, y1 + 1)
    return [x1, y1, x2, y2]


def box_iou_xyxy(boxA, boxB):
    xA = max(boxA[0], boxB[0]); yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2]); yB = min(boxA[3], boxB[3])
    inter = max(0, xB - xA) * max(0, yB - yA)
    if inter <= 0:
        return 0.0
    areaA = max(0, boxA[2]-boxA[0]) * max(0, boxA[3]-boxA[1])
    areaB = max(0, boxB[2]-boxB[0]) * max(0, boxB[3]-boxB[1])
    return inter / (areaA + areaB - inter + 1e-6)


def greedy_match_accuracy(pred, gt, iou_thr=0.5):
    p_boxes = pred["boxes"].cpu().numpy()
    p_labels = pred["labels"].cpu().numpy()
    p_scores = pred["scores"].cpu().numpy()
    order = p_scores.argsort()[::-1]
    p_boxes = p_boxes[order]; p_labels = p_labels[order]

    g_boxes = gt["boxes"].cpu().numpy()
    g_labels = gt["labels"].cpu().numpy()
    matched = set()
    hits = 0
    for pb, pl in zip(p_boxes, p_labels):
        best_iou, best_j = 0.0, -1
        for j,(gb,gl) in enumerate(zip(g_boxes, g_labels)):
            if j in matched:
                continue
            iou = box_iou_xyxy(pb, gb)
            if iou > best_iou:
                best_iou, best_j = iou, j
        if best_j >= 0 and best_iou >= iou_thr and pl == g_labels[best_j]:
            matched.add(best_j)
            hits += 1
    return hits / max(1, len(g_boxes))

def get_gpu_temp(gpu_index: int = 0):
    try:
        import subprocess
        out = subprocess.check_output(
            ['nvidia-smi', f'--id={gpu_index}',
             '--query-gpu=temperature.gpu', '--format=csv,noheader,nounits'],
            stderr=subprocess.DEVNULL
        ).decode().strip()
        return int(out)
    except Exception:
        return None

def cool_if_hot(threshold: int = 85, resume: int = 80, sleep_s: int = 10, gpu_index: int = 0):
    temp = get_gpu_temp(gpu_index)
    if temp is None:
        return False
    tripped = False
    while temp is not None and temp >= threshold:
        tripped = True
        print(f"GPU {gpu_index} = {temp}°C ≥ {threshold}°C", flush=True)
        try:
            if torch.cuda.is_available():
                torch.cuda.synchronize()
                torch.cuda.empty_cache()
        except Exception:
            pass
        time.sleep(sleep_s)
        temp = get_gpu_temp(gpu_index)
    if tripped:
        print(f"Continue Training", flush=True)
    return tripped


# ====== Dataset ======
class CocoInstanceDataset(Dataset):
    def __init__(self, root_split: str, transforms=None, category_id_mapping: Dict[int, int] = None):
        self.root = Path(root_split)
        self.img_dir = self.root / "images"
        self.ann_file = self.root / "annotations.json"
        assert self.img_dir.is_dir(), f"images folder not found: {self.img_dir}"
        assert self.ann_file.is_file(), f"annotations.json not found: {self.ann_file}"

        self.coco = COCO(str(self.ann_file))
        self.ids = list(sorted(self.coco.imgs.keys()))
        self.transforms = transforms

        cats = self.coco.loadCats(self.coco.getCatIds())
        self.catid_to_name = {c["id"]: c["name"] for c in cats}

        if category_id_mapping is None:
            sorted_cat_ids = sorted(self.catid_to_name.keys())
            self.catid_to_contig = {cid: i+1 for i, cid in enumerate(sorted_cat_ids)}
        else:
            self.catid_to_contig = dict(category_id_mapping)

        self.contig_to_name = {v: self.catid_to_name[k] for k, v in self.catid_to_contig.items()}

    @property
    def classes(self) -> List[str]:
        num = len(self.catid_to_contig)
        arr = ['__background__'] + [self.contig_to_name[i] for i in range(1, num+1)]
        return arr

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index: int):
        img_id = self.ids[index]
        img_info = self.coco.loadImgs(img_id)[0]
        img_path = self.img_dir / img_info["file_name"]

        image = Image.open(img_path).convert("RGB")
        W, H = image.size

        ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None)
        anns = self.coco.loadAnns(ann_ids)

        boxes, labels, masks, area, iscrowd = [], [], [], [], []

        for ann in anns:
            if "bbox" not in ann:
                continue
            xyxy = clamp_box_xyxy(coco_xywh_to_xyxy(ann["bbox"]), W, H)
            boxes.append(xyxy)
            cat_id = ann["category_id"]
            labels.append(self.catid_to_contig[cat_id])

            m = self.coco.annToMask(ann)
            masks.append(m)

            area.append(ann.get("area", float((xyxy[2]-xyxy[0]) * (xyxy[3]-xyxy[1]))))
            iscrowd.append(ann.get("iscrowd", 0))

        if len(boxes) == 0:
            boxes = torch.zeros((0, 4), dtype=torch.float32)
            labels = torch.zeros((0,), dtype=torch.int64)
            masks = torch.zeros((0, H, W), dtype=torch.uint8)
            area = torch.zeros((0,), dtype=torch.float32)
            iscrowd = torch.zeros((0,), dtype=torch.uint8)
        else:
            boxes = torch.as_tensor(boxes, dtype=torch.float32)
            labels = torch.as_tensor(labels, dtype=torch.int64)
            masks = torch.as_tensor(np.stack(masks, axis=0), dtype=torch.uint8)
            area = torch.as_tensor(area, dtype=torch.float32)
            iscrowd = torch.as_tensor(iscrowd, dtype=torch.uint8)

        target = {
            "boxes": boxes,
            "labels": labels,
            "masks": masks,
            "image_id": torch.tensor([img_id], dtype=torch.int64),
            "area": area,
            "iscrowd": iscrowd,
        }

        if self.transforms is not None:
            image, target = self.transforms(image, target)

        return image, target


# ====== Transforms ======
class ToTensor:
    def __call__(self, image, target):
        return F.to_tensor(image), target

class NormalizeDet:
    def __init__(self, mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)):
        self.mean = mean; self.std = std
    def __call__(self, image, target):
        image = (image - torch.tensor(self.mean)[:,None,None]) / torch.tensor(self.std)[:,None,None]
        return image, target

class AlbumentationsDet:
    def __init__(self, p=0.5):
        self.p = p
        self.geom = A.Compose([
            A.RandomRotate90(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.02, scale_limit=0.1, rotate_limit=10,
                               border_mode=cv2.BORDER_REFLECT_101, p=0.5),
        ], bbox_params=A.BboxParams(
            format="pascal_voc",
            label_fields=["labels"],
            min_visibility=0.2
        ))
        self.img_only = A.Compose([
            A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15, p=0.5),
            A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=10, val_shift_limit=10, p=0.3),
            A.GaussNoise(var_limit=(5.0, 20.0), p=0.2),
            A.GaussianBlur(blur_limit=3, p=0.2),
        ])

    def __call__(self, image, target):
        if random.random() > self.p:
            return image, target

        img = np.array(image)
        boxes = target["boxes"].numpy() if target["boxes"].numel() > 0 else np.zeros((0,4), dtype=np.float32)
        labels = target["labels"].numpy() if target["labels"].numel() > 0 else np.zeros((0,), dtype=np.int64)
        masks  = target["masks"].numpy()  if target["masks"].numel()  > 0 else np.zeros((0, img.shape[0], img.shape[1]), dtype=np.uint8)

        img = self.img_only(image=img)["image"]

        if boxes.shape[0] == 0:
            from PIL import Image
            return Image.fromarray(img), target

        bboxes_list = boxes.tolist()
        masks_list  = [m for m in masks]

        out = self.geom(image=img, bboxes=bboxes_list, labels=labels.tolist(), masks=masks_list)
        img_aug = out["image"]
        boxes_aug = np.array(out["bboxes"], dtype=np.float32) if len(out["bboxes"]) > 0 else np.zeros((0, 4), dtype=np.float32)
        labels_aug = np.array(out["labels"], dtype=np.int64)   if len(out["labels"]) > 0 else np.zeros((0,), dtype=np.int64)

        if len(out["masks"]) > 0:
            if isinstance(out["masks"], np.ndarray) and out["masks"].ndim == 2:
                masks_aug = out["masks"][None, ...].astype(np.uint8)
            else:
                masks_aug = np.stack(out["masks"], axis=0).astype(np.uint8)
        else:
            masks_aug = np.zeros((0, img_aug.shape[0], img_aug.shape[1]), dtype=np.uint8)

        if boxes_aug.shape[0] > 0:
            valid = (boxes_aug[:, 2] > boxes_aug[:, 0]) & (boxes_aug[:, 3] > boxes_aug[:, 1])
            boxes_aug  = boxes_aug[valid]
            labels_aug = labels_aug[valid]
            if masks_aug.shape[0] != boxes_aug.shape[0]:
                n = min(masks_aug.shape[0], boxes_aug.shape[0])
                boxes_aug  = boxes_aug[:n]
                labels_aug = labels_aug[:n]
                masks_aug  = masks_aug[:n] if n > 0 else np.zeros((0, img_aug.shape[0], img_aug.shape[1]), dtype=np.uint8)
            else:
                masks_aug = masks_aug[valid]
            H, W = img_aug.shape[:2]
            boxes_aug = np.clip(boxes_aug, 0, max(H, W))

        import torch
        target["boxes"]  = torch.as_tensor(boxes_aug, dtype=torch.float32)
        target["labels"] = torch.as_tensor(labels_aug, dtype=torch.int64)
        target["masks"]  = torch.as_tensor(masks_aug, dtype=torch.uint8)

        if target["boxes"].numel() > 0:
            xyxy = target["boxes"]
            area = (xyxy[:,2]-xyxy[:,0]).clamp(min=0) * (xyxy[:,3]-xyxy[:,1]).clamp(min=0)
            target["area"] = area.to(torch.float32)
        else:
            target["area"] = torch.zeros((0,), dtype=torch.float32)

        from PIL import Image
        image = Image.fromarray(img_aug)
        return image, target


class ComposeDet:
    def __init__(self, transforms):
        self.transforms = transforms
    def __call__(self, image, target):
        for t in self.transforms:
            image, target = t(image, target)
        return image, target


def get_transform(train: bool, aug_prob=0.0):
    if train:
        tr = [
            AlbumentationsDet(p=aug_prob),
            ToTensor(),
            NormalizeDet(),
        ]
    else:
        tr = [ToTensor(), NormalizeDet()]
    return ComposeDet(tr)


# ====== model ======
_VGG_WEIGHTS = {
    "vgg11": VGG11_Weights.IMAGENET1K_V1,
    "vgg13": VGG13_Weights.IMAGENET1K_V1,
    "vgg16": VGG16_Weights.IMAGENET1K_V1,
    "vgg19": VGG19_Weights.IMAGENET1K_V1,
}

def _build_vgg_with_fpn(vgg_name: str,
                        pretrained: bool = True,
                        trainable_layers: int = 3,
                        out_channels: int = 256) -> nn.Module:
    vgg_name = vgg_name.lower()
    if vgg_name not in _VGG_WEIGHTS:
        raise ValueError(f"Unsupported VGG backbone: {vgg_name}")

    weights = _VGG_WEIGHTS[vgg_name] if pretrained else None
    vgg = getattr(torchvision.models, vgg_name)(weights=weights)

    pool_idx = [i for i, m in enumerate(vgg.features) if isinstance(m, nn.MaxPool2d)]

    assert len(pool_idx) == 5, f"Unexpected number of MaxPool in {vgg_name}: {pool_idx}"

    c2, c3, c4, c5 = pool_idx[1], pool_idx[2], pool_idx[3], pool_idx[4]

    in_channels_list = [128, 256, 512, 512]

    blocks = []
    start = 0
    for p in pool_idx:
        blocks.append((start, p))
        start = p + 1

    num_blocks = 5
    tl = max(0, min(trainable_layers, num_blocks))

    train_blocks = set(range(num_blocks - tl, num_blocks))

    for bi, (s, e) in enumerate(blocks):
        req = bi in train_blocks
        for k in range(s, e + 1):
            for p in vgg.features[k].parameters():
                p.requires_grad = req

    return_layers = {
        str(c2): "0",
        str(c3): "1",
        str(c4): "2",
        str(c5): "3",
    }


    backbone = BackboneWithFPN(
        backbone=vgg.features,
        return_layers=return_layers,
        in_channels_list=in_channels_list,
        out_channels=out_channels,
        extra_blocks=LastLevelMaxPool()
    )

    backbone.out_channels = out_channels
    return backbone

def get_mask_rcnn_vgg(vgg_name: str,
                      num_classes: int,
                      trainable_layers: int = 3,
                      pretrained_backbone: bool = True,
                      fpn_out_channels: int = 256) -> MaskRCNN:
    backbone = _build_vgg_with_fpn(
        vgg_name=vgg_name,
        pretrained=pretrained_backbone,
        trainable_layers=trainable_layers,
        out_channels=fpn_out_channels
    )
    model = MaskRCNN(backbone, num_classes=num_classes)
    return model


def forward_with_switches(model, images, targets=None):
    images = list(img for img in images)
    original_image_sizes = [im.shape[-2:] for im in images]
    images, targets = model.transform(images, targets)
    features = model.backbone(images.tensors)

    if DBG_USE_GT_PROPOSALS and targets is not None:
        proposals = [t["boxes"] for t in targets]
        proposal_losses = {}
    else:
        proposals, proposal_losses = model.rpn(images, features, targets)

    detections, detector_losses = model.roi_heads(
        features, proposals, images.image_sizes, targets
    )
    detections = model.transform.postprocess(detections, images.image_sizes, original_image_sizes)

    if DBG_SCORE_THR > 0.0:
        for d in detections:
            keep = d["scores"] >= DBG_SCORE_THR
            for k in ("boxes", "scores", "labels"):
                d[k] = d[k][keep]

    losses = {}
    losses.update(proposal_losses)
    losses.update(detector_losses)
    return detections, losses


def enhance_classification_head(model, num_classes):
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    class StrongerPredictor(nn.Module):
        def __init__(self, in_channels, num_classes):
            super().__init__()
            hidden = 1024
            self.fc1 = nn.Linear(in_channels, hidden)
            self.relu1 = nn.ReLU(inplace=True)
            self.fc2 = nn.Linear(hidden, hidden)
            self.relu2 = nn.ReLU(inplace=True)
            self.cls_score = nn.Linear(hidden, num_classes)
            self.bbox_pred = nn.Linear(hidden, num_classes * 4)

        def forward(self, x):
            x = self.fc1(x)
            x = self.relu1(x)
            x = self.fc2(x)
            x = self.relu2(x)
            scores = self.cls_score(x)
            bbox_deltas = self.bbox_pred(x)
            return scores, bbox_deltas

    model.roi_heads.box_predictor = StrongerPredictor(in_features, num_classes)
    return model


def optimize_model_parameters(model):
    model.rpn.pre_nms_top_n_train = 2000
    model.rpn.post_nms_top_n_train = 1000
    model.rpn.pre_nms_top_n_test = 1000
    model.rpn.post_nms_top_n_test = 500
    model.rpn.nms_thresh = 0.7

    model.roi_heads.batch_size_per_image = 256
    model.roi_heads.positive_fraction = 0.5
    model.roi_heads.score_thresh = 0.05
    model.roi_heads.nms_thresh = 0.5
    return model


def get_optimizer_and_scheduler(model, lr=0.002):
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(params, lr=lr, momentum=0.9, weight_decay=1e-4)
    scheduler = lr_scheduler.MultiStepLR(
        optimizer,
        milestones=[15, 30, 45],
        gamma=0.5
    )
    return optimizer, scheduler


def train_one_epoch(model, optimizer, data_loader, device, scaler, accum_steps: int = 2):
    model.train()
    running = {"loss": 0.0, "cls": 0.0, "box_reg": 0.0, "mask": 0.0, "obj": 0.0, "rpn_box": 0.0}
    n = 0
    optimizer.zero_grad(set_to_none=True)

    pbar = tqdm(data_loader, desc="Train", dynamic_ncols=True, leave=False, position=0)
    use_amp = (device.type == "cuda")
    for step, (images, targets) in enumerate(pbar, 1):
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        cool_if_hot(threshold=83, resume=78, sleep_s=10, gpu_index=0)

        with torch.cuda.amp.autocast(enabled=(device.type == "cuda")):
            _, loss_dict = forward_with_switches(model, images, targets)

            losses = (
                loss_dict.get("loss_classifier",    0.0) +
                loss_dict.get("loss_box_reg",       0.0) +
                loss_dict.get("loss_mask",          0.0) +
                loss_dict.get("loss_objectness",    0.0) +
                loss_dict.get("loss_rpn_box_reg",   0.0)
            )

            loss_value = sum(
                (float(v.item()) if torch.is_tensor(v) else float(v))
                for v in loss_dict.values()
            )
            running["loss"] += loss_value

        loss_scaled = losses / accum_steps
        scaler.scale(loss_scaled).backward()
        if step % accum_steps == 0:
            scaler.step(optimizer); scaler.update()
            optimizer.zero_grad(set_to_none=True)
            cool_if_hot(threshold=83, resume=78, sleep_s=10, gpu_index=0)

        n += 1
        running["cls"]     += loss_dict.get("loss_classifier",    torch.tensor(0.0, device=device)).item()
        running["box_reg"] += loss_dict.get("loss_box_reg",       torch.tensor(0.0, device=device)).item()
        running["mask"]    += loss_dict.get("loss_mask",          torch.tensor(0.0, device=device)).item()
        running["obj"]     += loss_dict.get("loss_objectness",    torch.tensor(0.0, device=device)).item()
        running["rpn_box"] += loss_dict.get("loss_rpn_box_reg",   torch.tensor(0.0, device=device)).item()

        pbar.set_postfix(loss=f"{running['loss']/n:.4f}")
    for k in running: running[k] /= max(1, n)
    return running


@torch.inference_mode()
def evaluate_losses(model, loader, device):
    was_training = model.training
    model.train(True)

    running = 0.0
    n = 0
    last_keys = None

    for images, targets in loader:
        images  = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        _, loss_dict = forward_with_switches(model, images, targets)
        loss_val = 0.0
        for k in ("loss_classifier", "loss_box_reg", "loss_mask", "loss_objectness", "loss_rpn_box_reg"):
            v = loss_dict.get(k, 0.0)
            loss_val += float(v) if isinstance(v, (float, int)) else float(v.item())
        running += loss_val
        n += 1
        last_keys = list(loss_dict.keys())

    model.train(was_training)
    return {"loss": running / max(n, 1)}


@torch.inference_mode()
def evaluate_fg_top1(model, data_loader, device,
                     iou_thr=0.5, score_thr=0.3, max_batches=None, desc="FG Top-1"):
    from torchvision.ops import box_iou

    was_training = model.training
    model.eval()

    matched_total = 0
    matched_correct = 0
    seen = 0
    TOPK = 300

    for images, targets in tqdm(data_loader, desc=desc, dynamic_ncols=True):
        seen += 1
        if (max_batches is not None) and (seen > max_batches):
            break

        images = [img.to(device) for img in images]
        cool_if_hot(threshold=83, resume=78, sleep_s=10, gpu_index=0)
        outs = model(images)
        torch.cuda.synchronize()
        torch.cuda.empty_cache()

        for out, gt in zip(outs, targets):
            if gt["boxes"].numel() == 0 or out["boxes"].numel() == 0:
                continue

            scores = out["scores"].detach().cpu()
            keep = scores >= score_thr
            if keep.sum().item() == 0:
                continue

            boxes_p  = out["boxes"].detach().cpu()[keep]
            labels_p = out["labels"].detach().cpu()[keep]
            scores_p = scores[keep]

            if scores_p.numel() > TOPK:
                vals, idx = torch.topk(scores_p, TOPK)
                boxes_p, labels_p, scores_p = boxes_p[idx], labels_p[idx], vals

            boxes_g  = gt["boxes"].detach().cpu()
            labels_g = gt["labels"].detach().cpu()
            ious = box_iou(boxes_p, boxes_g)

            order = torch.argsort(scores_p, descending=True)
            used_g = set()
            for pi in order.tolist():
                if ious.size(1) == 0:
                    break
                iou_row = ious[pi]
                cand_iou = iou_row.clone()
                for gidx in used_g:
                    cand_iou[gidx] = -1.0
                gi = int(torch.argmax(cand_iou).item())
                best_iou = float(iou_row[gi].item())
                if best_iou < iou_thr:
                    continue
                used_g.add(gi)
                matched_total += 1
                if int(labels_p[pi]) == int(labels_g[gi]):
                    matched_correct += 1

    acc = (matched_correct / matched_total) if matched_total > 0 else 0.0

    if was_training:
        model.train()
    return acc


@torch.inference_mode()
def evaluate_box_head_top1_acc(model, loader, device, desc="BoxHead Top-1"):
    model.eval()
    total = 0
    correct = 0
    bg_argmax = 0

    for images, targets in tqdm(loader, desc=desc, leave=False):
        images = [img.to(device) for img in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        img_list, targets_t = model.transform(images, targets)
        features = model.backbone(img_list.tensors)
        proposals = [t["boxes"] for t in targets_t]
        if all(p.numel() == 0 for p in proposals):
            continue
        img_sizes = img_list.image_sizes

        box_feats = model.roi_heads.box_roi_pool(features, proposals, img_sizes)
        if box_feats.numel() == 0:
            continue
        box_feats = model.roi_heads.box_head(box_feats)
        class_logits, _ = model.roi_heads.box_predictor(box_feats)

        C = class_logits.shape[1]
        gt_labels = torch.cat([t["labels"] for t in targets_t if t["boxes"].numel() > 0], dim=0)

        bg_argmax += (class_logits.argmax(dim=1) == 0).sum().item()

        fg_logits = class_logits[:, 1:]
        preds = fg_logits.argmax(dim=1) + 1

        correct += (preds == gt_labels).sum().item()
        total   += gt_labels.numel()

    if total == 0:
        return 0.0
    return correct / total


def diagnose_classification_issue(model, val_loader, device, max_batches=10):
    acc = evaluate_box_head_top1_acc(model, val_loader, device, desc="BoxHead Top-1 (diagnose)")
    print(f"Box-head top1: {acc:.3f}")
    return acc


def fit(model, optimizer, lr_sch, loaders, device,
        epochs=10, accum_steps=2, save_path="mask_rcnn_wbc_best.pth",
        logs_dir: Path | None = None):

    if logs_dir is None:
        logs_dir = Path("logs")
    logs_dir.mkdir(parents=True, exist_ok=True)

    scaler = GradScaler(enabled=(device.type == "cuda"))
    best_val = float("inf")

    hist = {"train_loss": [], "val_loss": [], "train_fg_top1": [], "val_fg_top1": [], "train_top1": [], "val_top1": []}

    csv_path = logs_dir / "history.csv"
    with open(csv_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["epoch","train_loss","val_loss","train_boxhead_top1","val_boxhead_top1","train_fg_top1","val_fg_top1"])

    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}\n" + "-"*12)

        if epoch < 10:
            aug_p = 0.0
        elif epoch < 20:
            aug_p = 0.3
        else:
            aug_p = 0.5

        train_loader = loaders["train"]
        if hasattr(train_loader.dataset.transforms.transforms[0], "p"):
            train_loader.dataset.transforms.transforms[0].p = aug_p
        print(f"[Epoch {epoch+1}] Data Augmentation Probability = {aug_p}")

        train_logs = train_one_epoch(model, optimizer, loaders["train"], device, scaler, accum_steps)
        val_logs   = evaluate_losses(model, loaders["val"], device)

        train_fg_top1 = evaluate_fg_top1(model, loaders["train"], device, iou_thr=0.5, score_thr=0.8, desc="Train(FG Top1)")
        val_fg_top1   = evaluate_fg_top1(model, loaders["val"],   device, iou_thr=0.5, score_thr=0.8, desc="Val(FG Top1)")

        train_top1 = evaluate_box_head_top1_acc(model, loaders["train"], device, desc="Train(BoxHead Top1)")
        val_top1   = evaluate_box_head_top1_acc(model, loaders["val"],   device, desc="Val(BoxHead Top1)")

        lr_sch.step()

        hist["train_loss"].append(train_logs["loss"])
        hist["val_loss"].append(val_logs["loss"])
        hist["train_fg_top1"].append(train_fg_top1)
        hist["val_fg_top1"].append(val_fg_top1)
        hist["train_top1"].append(train_top1)
        hist["val_top1"].append(val_top1)

        print(f"[Epoch {epoch+1}] loss: {train_logs['loss']:.4f} (train)  {val_logs['loss']:.4f} (val)")
        print(f"[Epoch {epoch+1}] box-head top1: {train_top1:.3f} (train)  {val_top1:.3f} (val)")
        print(f"[Epoch {epoch+1}] FG top-1:      {train_fg_top1:.3f} (train)  {val_fg_top1:.3f} (val)")

        with open(csv_path, "a", newline="") as f:
            writer = csv.writer(f)
            writer.writerow([epoch+1, train_logs["loss"], val_logs["loss"], train_top1, val_top1, train_fg_top1, val_fg_top1])

        if val_logs["loss"] < best_val:
            best_val = val_logs["loss"]
            torch.save(model.state_dict(), save_path)
            print(f"Saved best model to {save_path}")

        torch.cuda.empty_cache(); gc.collect(); print()

    plt.figure()
    plt.plot(hist["train_loss"], label="train")
    plt.plot(hist["val_loss"], label="val")
    plt.title("Total Loss"); plt.legend()
    plt.savefig(logs_dir / "loss_curve.png", dpi=200, bbox_inches="tight")

    plt.figure()
    plt.plot(hist["train_top1"], label="train top-1")
    plt.plot(hist["val_top1"],   label="val top-1")
    plt.title("Box-Head Top-1"); plt.legend()
    plt.savefig(logs_dir / "box_head_top1_curve.png", dpi=200, bbox_inches="tight")

    plt.figure()
    plt.plot(hist["train_fg_top1"], label="train fg top-1")
    plt.plot(hist["val_fg_top1"],   label="val fg top-1")
    plt.title("FG Top-1"); plt.legend()
    plt.savefig(logs_dir / "fg_top1_curve.png", dpi=200, bbox_inches="tight")

    return model


def main():
    set_torch_threads(4)

    data_dir = "/content/drive/MyDrive/Thesis/MaskRCNN/TestData_new"
    train_split = os.path.join(data_dir, "train")
    val_split   = os.path.join(data_dir, "val")

    current_aug_p = 0.0
    train_ds = CocoInstanceDataset(train_split, transforms=get_transform(train=True, aug_prob=current_aug_p))
    val_ds   = CocoInstanceDataset(val_split,   transforms=get_transform(train=False))

    num_classes = len(train_ds.classes)
    print("Classes:", train_ds.classes, " -> num_classes =", num_classes)

    loaders = {
        "train": DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2,
                            collate_fn=collate_fn, pin_memory=torch.cuda.is_available()),
        "val":   DataLoader(val_ds, batch_size=2, shuffle=False, num_workers=2,
                            collate_fn=collate_fn, pin_memory=torch.cuda.is_available())
    }

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", device)

    model = get_mask_rcnn_vgg("vgg13", num_classes=num_classes, trainable_layers=5)
    model = enhance_classification_head(model, num_classes)
    model = optimize_model_parameters(model)
    model.to(device)

    optimizer, lr_sch = get_optimizer_and_scheduler(model, lr=0.003)

    from datetime import datetime
    timestamp = datetime.now().strftime("%Y%m%dT%H%M%S")
    logs_dir = Path("/content/drive/MyDrive/maskrcnn_logs") / f"{timestamp}_maskrcnn_VGG13"
    logs_dir.mkdir(parents=True, exist_ok=True)
    save_path = logs_dir / "mask_rcnn_VGG13_best.pth"
    print(f"[Logger] logs dir: {logs_dir}")

    diagnose_classification_issue(model, loaders["val"], device)

    model = fit(model, optimizer, lr_sch, loaders, device,
                epochs=50, accum_steps=4,
                save_path=save_path,
                logs_dir=logs_dir)

    print("Training done.")


if __name__ == "__main__":
    main()


Writing mask_rcnn_vgg13_wbc_colab.py
