In [None]:
# === Cell 1: Installs ===
!pip install torch torchvision albumentations scikit-learn pandas tqdm thop torchinfo matplotlib seaborn opencv-python --quiet

In [None]:
# === Cell 2: Imports, device, seed, helpers ===
import os, random, time, math, copy
from glob import glob
import numpy as np
import pandas as pd
from PIL import Image
from tqdm import tqdm
import cv2
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision.models.detection import MaskRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.ops import MultiScaleRoIAlign
from torchvision.ops import FeaturePyramidNetwork

import matplotlib.pyplot as plt
import seaborn as sns

from thop import profile
from torchinfo import summary

# Device & reproducibility
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
print("Device:", DEVICE)

# EarlyStopping (same logic as your UNet++)
class EarlyStopping:
    def __init__(self, patience=10, min_delta=1e-4, restore_best=True, min_epochs=10):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best = restore_best
        self.best_loss = None
        self.counter = 0
        self.best_state = None
        self.min_epochs = min_epochs

    def __call__(self, epoch, current_loss, model):
        improved = (self.best_loss is None) or ((self.best_loss - current_loss) > self.min_delta)
        if improved:
            self.best_loss = current_loss
            self.counter = 0
            if self.restore_best:
                self.best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
        else:
            self.counter += 1

        if (epoch + 1) < self.min_epochs:
            return False

        if self.counter >= self.patience:
            if self.restore_best and self.best_state is not None:
                model.load_state_dict(self.best_state)
            return True
        return False

# Metric helpers (union predicted instances -> semantic mask; compute IoU/Dice/...)
@torch.no_grad()
def _to_binary_semantic_from_instances(outputs, image_size, score_thresh=0.5, mask_thresh=0.3):
    if outputs is None or len(outputs.get("scores", [])) == 0:
        return torch.zeros(image_size, dtype=torch.uint8)
    keep = outputs["scores"] >= score_thresh
    if keep.sum() == 0:
        return torch.zeros(image_size, dtype=torch.uint8)
    masks = outputs["masks"][keep]  # (K,1,H,W)
    bin_inst = (masks.squeeze(1) >= mask_thresh).to(torch.uint8)
    union = torch.any(bin_inst.bool(), dim=0).to(torch.uint8)
    return union

@torch.no_grad()
def batch_metrics_from_outputs(outputs_list, targets_list, score_thresh=0.5, mask_thresh=0.3):
    eps = 1e-8
    agg = dict(tp=0, fp=0, fn=0, tn=0)
    for outputs, target in zip(outputs_list, targets_list):
        # target -> binary semantic
        if isinstance(target, dict) and "masks" in target:
            if len(target["masks"]) == 0:
                # infer shape from outputs if possible
                if "masks" in outputs and outputs["masks"].numel() > 0:
                    tgt = torch.zeros_like(outputs["masks"][0,0]).to(torch.uint8)
                else:
                    # fallback zero (unknown size) -> skip
                    continue
            else:
                t_inst = target["masks"].to(torch.uint8)
                tgt = torch.any(t_inst.bool(), dim=0).to(torch.uint8)
        else:
            t = target
            if isinstance(t, torch.Tensor):
                if t.ndim == 3 and t.shape[0] == 1:
                    t = t[0]
                tgt = (t > 0).to(torch.uint8)
            else:
                continue
        H, W = tgt.shape[-2], tgt.shape[-1]
        pred = _to_binary_semantic_from_instances(outputs, (H, W), score_thresh, mask_thresh)
        p = pred.view(-1).cpu().numpy()
        tt = tgt.view(-1).cpu().numpy()
        tp = int(((p == 1) & (tt == 1)).sum())
        fp = int(((p == 1) & (tt == 0)).sum())
        fn = int(((p == 0) & (tt == 1)).sum())
        tn = int(((p == 0) & (tt == 0)).sum())
        agg["tp"] += tp; agg["fp"] += fp; agg["fn"] += fn; agg["tn"] += tn

    tp, fp, fn, tn = agg["tp"], agg["fp"], agg["fn"], agg["tn"]
    iou = tp / (tp + fp + fn + eps)
    dice = (2*tp) / (2*tp + fp + fn + eps)
    precision = tp / (tp + fp + eps) if (tp+fp)>0 else 0.0
    recall = tp / (tp + fn + eps) if (tp+fn)>0 else 0.0
    f1 = 2*precision*recall/(precision+recall+eps) if (precision+recall)>0 else 0.0
    acc = (tp + tn) / (tp + tn + fp + fn + eps)
    return dict(tp=int(tp), fp=int(fp), fn=int(fn), tn=int(tn),
                iou=float(iou), dice=float(dice),
                precision=float(precision), recall=float(recall),
                f1=float(f1), acc=float(acc))

In [None]:
# === Cell 3: Paths, params, ImageNet normalization ===
DATA_ROOT = '/kaggle/input/earthquakedatasetnew/earthquakeDataset'  # change if needed
IMG_SIZE = (256,256)   # (W,H)
TRAIN_BATCH = 6
VAL_BATCH = 2
TEST_BATCH = 1
NUM_CLASSES = 2  # background + damage

# Input mode: we'll use post-only RGB (ImageNet normalization)
INPUT_MODE = 'post'

IMAGE_NET_MEAN = [0.485, 0.456, 0.406]
IMAGE_NET_STD  = [0.229, 0.224, 0.225]

def pil_to_tensor_normalized(img_pil):
    img = np.array(img_pil).astype(np.float32) / 255.0
    img = (img - np.array(IMAGE_NET_MEAN)) / np.array(IMAGE_NET_STD)
    img = torch.from_numpy(img.transpose(2,0,1)).float()
    return img

In [None]:
# === Cell 4: Dataset that converts semantic mask -> instances (connected components) ===
from scipy import ndimage

class MaskRCNNDataset(Dataset):
    def __init__(self, root, split='train', input_mode='post', img_size=(256,256)):
        self.input_mode = input_mode
        self.img_size = img_size
        if split=='train':
            a_dir = os.path.join(root,'train','A_train_aug')
            b_dir = os.path.join(root,'train','B_train_aug')
            m_dir = os.path.join(root,'train','label_train_aug')
        elif split=='val':
            a_dir = os.path.join(root,'val','A_val')
            b_dir = os.path.join(root,'val','B_val')
            m_dir = os.path.join(root,'val','label_val')
        else:
            a_dir = os.path.join(root,'test','A_test')
            b_dir = os.path.join(root,'test','B_test')
            m_dir = os.path.join(root,'test','label_test')
        self.a_files = sorted([f for f in os.listdir(a_dir) if f.endswith('.png')])
        self.a_dir, self.b_dir, self.m_dir = a_dir, b_dir, m_dir

    def __len__(self):
        return len(self.a_files)

    def _make_input(self, a_pil, b_pil):
        a = a_pil.resize(self.img_size, resample=Image.BILINEAR)
        b = b_pil.resize(self.img_size, resample=Image.BILINEAR)
        if self.input_mode == 'post':
            return pil_to_tensor_normalized(b)
        elif self.input_mode == 'diff':
            a_np = np.array(a).astype(np.float32)/255.0
            b_np = np.array(b).astype(np.float32)/255.0
            diff = np.abs(b_np - a_np)
            # three channels: post_R, post_G, avg_abs_diff
            ch0 = b_np[:,:,0]; ch1 = b_np[:,:,1]; ch2 = diff.mean(axis=2)
            stacked = np.stack([ch0, ch1, ch2], axis=2)
            stacked = (stacked - np.array([0.485,0.456,0.485])) / np.array([0.229,0.224,0.229])
            return torch.from_numpy(stacked.transpose(2,0,1)).float()
        else:
            raise ValueError("Unsupported input_mode")

    def _masks_to_instances(self, mask_pil):
        mask_np = np.array(mask_pil.resize(self.img_size, resample=Image.NEAREST))
        mask_bin = (mask_np > 127).astype(np.uint8)
        labeled, ncomp = ndimage.label(mask_bin)
        if ncomp == 0:
            return [], []
        masks = []
        boxes = []
        for inst_id in range(1, ncomp+1):
            inst_mask = (labeled == inst_id).astype(np.uint8)
            ys, xs = np.where(inst_mask)
            if ys.size == 0 or xs.size == 0:
                continue
            y1, x1, y2, x2 = ys.min(), xs.min(), ys.max(), xs.max()
            # boxes in (x1,y1,x2,y2) order
            boxes.append([int(x1), int(y1), int(x2), int(y2)])
            masks.append(inst_mask)
        return masks, boxes

    def __getitem__(self, idx):
        name = self.a_files[idx]
        a_pil = Image.open(os.path.join(self.a_dir, name)).convert('RGB')
        b_pil = Image.open(os.path.join(self.b_dir, name)).convert('RGB')
        m_pil = Image.open(os.path.join(self.m_dir, name)).convert('L')
        image = self._make_input(a_pil, b_pil)  # 3,H,W

        masks, boxes = self._masks_to_instances(m_pil)
        target = {}
        if len(masks) == 0:
            target['boxes'] = torch.zeros((0,4), dtype=torch.float32)
            target['labels'] = torch.zeros((0,), dtype=torch.int64)
            target['masks'] = torch.zeros((0, self.img_size[1], self.img_size[0]), dtype=torch.uint8)
            target['area'] = torch.tensor([], dtype=torch.float32)
            target['iscrowd'] = torch.tensor([], dtype=torch.int64)
        else:
            masks_t = torch.as_tensor(np.stack(masks, axis=0), dtype=torch.uint8)  # (N,H,W)
            boxes_t = torch.as_tensor(boxes, dtype=torch.float32)  # (N,4)
            labels_t = torch.ones((masks_t.shape[0],), dtype=torch.int64)
            areas = ((boxes_t[:,3] - boxes_t[:,1]) * (boxes_t[:,2] - boxes_t[:,0])).to(torch.float32)
            iscrowd = torch.zeros((masks_t.shape[0],), dtype=torch.int64)
            target['boxes'] = boxes_t
            target['labels'] = labels_t
            target['masks'] = masks_t
            target['area'] = areas
            target['iscrowd'] = iscrowd

        target['image_id'] = torch.tensor([idx])
        return image, target

def collate_fn(batch):
    images, targets = list(zip(*batch))
    return list(images), list(targets)

# build datasets + loaders
train_ds = MaskRCNNDataset(DATA_ROOT, 'train', input_mode=INPUT_MODE, img_size=IMG_SIZE)
val_ds   = MaskRCNNDataset(DATA_ROOT, 'val',   input_mode=INPUT_MODE, img_size=IMG_SIZE)
test_ds  = MaskRCNNDataset(DATA_ROOT, 'test',  input_mode=INPUT_MODE, img_size=IMG_SIZE)

train_loader = DataLoader(train_ds, batch_size=TRAIN_BATCH, shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds,   batch_size=VAL_BATCH,   shuffle=False, num_workers=1, collate_fn=collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=TEST_BATCH,  shuffle=False, num_workers=1, collate_fn=collate_fn)

print(f"Train {len(train_ds)} | Val {len(val_ds)} | Test {len(test_ds)}")

In [None]:
# === Cell 5: Custom CNN backbone + FPN wrapper (no ResNet) ===
# We'll make a small ResNet-like stack producing 3 feature maps (C3, C4, C5),
# then pass them through torchvision.ops.FeaturePyramidNetwork to produce FPN outputs.

class ConvBlockBN(nn.Module):
    def __init__(self, in_ch, out_ch, kernel=3, stride=1, padding=1, dropout=0.0):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel, stride=stride, padding=padding, bias=False)
        self.bn = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
        nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu')
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.dropout(x)
        return x

class CustomBackboneWithFPN(nn.Module):
    def __init__(self, in_channels=3, base_filters=(64,128,256), fpn_out_channels=256, dropout_map=None):
        super().__init__()
        # dropout_map: dict for blocks e.g. {'c1':0.1,'c2':0.2,'c3':0.3}
        if dropout_map is None:
            dropout_map = {'c1':0.15, 'c2':0.25, 'c3':0.35, 'fpn':0.25}
        # bottom-up convs producing 3 levels
        self.stage1 = nn.Sequential(
            ConvBlockBN(in_channels, base_filters[0], dropout=dropout_map.get('c1',0.0)),
            ConvBlockBN(base_filters[0], base_filters[0], dropout=dropout_map.get('c1',0.0))
        )
        self.pool1 = nn.MaxPool2d(2)  # /2
        self.stage2 = nn.Sequential(
            ConvBlockBN(base_filters[0], base_filters[1], dropout=dropout_map.get('c2',0.0)),
            ConvBlockBN(base_filters[1], base_filters[1], dropout=dropout_map.get('c2',0.0))
        )
        self.pool2 = nn.MaxPool2d(2)  # /4
        self.stage3 = nn.Sequential(
            ConvBlockBN(base_filters[1], base_filters[2], dropout=dropout_map.get('c3',0.0)),
            ConvBlockBN(base_filters[2], base_filters[2], dropout=dropout_map.get('c3',0.0))
        )
        # produce feature maps C3, C4, C5 (we'll use stage1 output as C3 for simplicity)
        # create FPN
        in_channels_list = [base_filters[0], base_filters[1], base_filters[2]]
        self.fpn = FeaturePyramidNetwork(in_channels_list=in_channels_list, out_channels=fpn_out_channels)
        # optionally wrap fpn convs with dropout
        self.fpn_dropout = nn.Dropout2d(dropout_map.get('fpn', 0.0)) if dropout_map.get('fpn', 0.0) > 0 else nn.Identity()
        self.out_channels = fpn_out_channels

    def forward(self, x):
        c1 = self.stage1(x)      # (B, C1, H, W)
        p1 = self.pool1(c1)
        c2 = self.stage2(p1)     # (B, C2, H/2, W/2)
        p2 = self.pool2(c2)
        c3 = self.stage3(p2)     # (B, C3, H/4, W/4)
        # Build OrderedDict of feature maps expected by FPN
        feats = OrderedDict()
        feats['0'] = c1
        feats['1'] = c2
        feats['2'] = c3
        fpn_outs = self.fpn(feats)  # OrderedDict of p3,p4,p5...
        # apply dropout to each fpn output
        for k in list(fpn_outs.keys()):
            fpn_outs[k] = self.fpn_dropout(fpn_outs[k])
        return fpn_outs

# Map your very_aggressive_dropout to backbone levels
very_aggressive_dropout = {
    'encoder': [0.0, 0.15, 0.25, 0.4, 0.6],
    'decoder': [0.6, 0.5, 0.4, 0.3],
    'skip': 0.4,
    'final': 0.4
}
drop_map = {'c1': 0.15, 'c2':0.25, 'c3':0.4, 'fpn':0.35}

backbone = CustomBackboneWithFPN(in_channels=3, base_filters=(64,128,256), fpn_out_channels=256, dropout_map=drop_map)
backbone.to(DEVICE)

# quick param count for backbone+fpn
bp_params = sum(p.numel() for p in backbone.parameters())
print("Backbone+FPN params:", f"{bp_params:,}")

In [None]:
# === Cell 6: Build Mask R-CNN using our backbone (make anchors, RoIAlign, and mask/box head customizations) ===
# Anchor generator (scales and aspect ratios)
anchor_generator = AnchorGenerator(
    sizes=((32, 64, 128, 256, 512),), 
    aspect_ratios=((0.5, 1.0, 2.0),)
)

# RoI Align / RoI Pooler
roi_pooler = MultiScaleRoIAlign(
    featmap_names=['0','1','2'],  # keys emitted by our backbone's fpn (we used '0','1','2')
    output_size=7,
    sampling_ratio=2
)

mask_roi_pooler = MultiScaleRoIAlign(
    featmap_names=['0','1','2'],
    output_size=14,
    sampling_ratio=2
)

# Construct MaskRCNN model
model = MaskRCNN(
    backbone,
    num_classes=NUM_CLASSES,
    rpn_anchor_generator=anchor_generator,
    box_roi_pool=roi_pooler,
    mask_roi_pool=mask_roi_pooler
)

# Replace classification box predictor to insert dropout and match representation size (mapping from earlier)
in_features = model.roi_heads.box_predictor.cls_score.in_features

class CustomFastRCNNPredictor(nn.Module):
    def __init__(self, in_channels, num_classes, dropout_p=0.4):
        super().__init__()
        representation_size = 1024
        self.fc1 = nn.Linear(in_channels, representation_size)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout(dropout_p)
        self.cls_score = nn.Linear(representation_size, num_classes)
        self.bbox_pred = nn.Linear(representation_size, num_classes * 4)
        nn.init.normal_(self.fc1.weight, std=0.01)
        nn.init.normal_(self.cls_score.weight, std=0.01)
        nn.init.normal_(self.bbox_pred.weight, std=0.001)
        nn.init.constant_(self.fc1.bias, 0)
        nn.init.constant_(self.cls_score.bias, 0)
        nn.init.constant_(self.bbox_pred.bias, 0)
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        scores = self.cls_score(x)
        bbox_deltas = self.bbox_pred(x)
        return scores, bbox_deltas

model.roi_heads.box_predictor = CustomFastRCNNPredictor(in_features, NUM_CLASSES, dropout_p=0.4)

# Replace mask predictor: conv -> relu -> dropout -> conv logits
try:
    in_ch_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
except:
    in_ch_mask = 256
hidden_dim = 256
class CustomMaskPredictor(nn.Module):
    def __init__(self, in_channels, dim_reduced=hidden_dim, num_classes=NUM_CLASSES, dropout_p=0.4):
        super().__init__()
        self.conv5_mask = nn.Conv2d(in_channels, dim_reduced, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout2d(dropout_p)
        self.mask_fcn_logits = nn.Conv2d(dim_reduced, num_classes, kernel_size=1)
        for l in [self.conv5_mask, self.mask_fcn_logits]:
            nn.init.kaiming_normal_(l.weight, mode='fan_out', nonlinearity='relu')
            if l.bias is not None:
                nn.init.constant_(l.bias, 0)
    def forward(self, x):
        x = self.conv5_mask(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.mask_fcn_logits(x)
        return x

model.roi_heads.mask_predictor = CustomMaskPredictor(in_ch_mask, dim_reduced=hidden_dim, num_classes=NUM_CLASSES, dropout_p=0.4)

# Move model to device
model.to(DEVICE)

# Parameter count and approximate gfloPs profiling for backbone wrapper (thop)
total_params = sum(p.numel() for p in model.parameters())
print("Total Mask R-CNN params:", f"{total_params:,}")

In [None]:
# === Cell 7: Optimizer, scheduler, early stopping (same as UNet++) ===
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True, min_lr=1e-6)
early_stop = EarlyStopping(patience=10, min_delta=1e-4, min_epochs=10)

EPOCHS = 200
history = []
best_val_loss = float('inf')

for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0.0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} Train", leave=False)
    for images, targets in pbar:
        images = [img.to(DEVICE) for img in images]
        targets = [{k:v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k,v in t.items()} for t in targets]
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        optimizer.zero_grad()
        losses.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        epoch_loss += losses.item() * len(images)
        pbar.set_postfix(dict(loss=losses.item()))
    epoch_loss = epoch_loss / len(train_ds)

    # Validation: compute losses and predictions for metrics
    model.eval()
    val_loss = 0.0
    outputs_for_metrics = []
    targets_for_metrics = []
    with torch.no_grad():
        for images, targets in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} Val", leave=False):
            images = [img.to(DEVICE) for img in images]
            targets = [{k:v.to(DEVICE) if isinstance(v, torch.Tensor) else v for k,v in t.items()} for t in targets]
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            val_loss += losses.item() * len(images)

            preds = model(images)  # inference mode -> list of dicts
            preds_cpu = [{k:v.detach().cpu() if isinstance(v, torch.Tensor) else v for k,v in p.items()} for p in preds]
            targets_cpu = [{k:v.detach().cpu() if isinstance(v, torch.Tensor) else v for k,v in t.items()} for t in targets]
            outputs_for_metrics.extend(preds_cpu)
            targets_for_metrics.extend(targets_cpu)

    val_loss = val_loss / len(val_ds)
    mets = batch_metrics_from_outputs(outputs_for_metrics, targets_for_metrics, score_thresh=0.5, mask_thresh=0.3)

    scheduler.step(val_loss)
    history.append(dict(epoch=epoch+1, train_loss=epoch_loss, val_loss=val_loss,
                        IoU=mets['iou'], Dice=mets['dice'], Precision=mets['precision'],
                        Recall=mets['recall'], F1=mets['f1'], Accuracy=mets['acc']))

    print(f"Epoch {epoch+1}: TL {epoch_loss:.4f} VL {val_loss:.4f} IoU {mets['iou']:.4f} Dice {mets['dice']:.4f} F1 {mets['f1']:.4f} LR {optimizer.param_groups[0]['lr']:.2e}")

    # save best
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_maskrcnn_custombackbone.pth')

    if early_stop(epoch, val_loss, model):
        print(f"Early stopping at epoch {epoch+1}")
        break

# Save training history
pd.DataFrame(history).to_csv('training_history_maskrcnn_custombackbone.csv', index=False)
print("Training finished.")

In [None]:
# === Cell 8: Test evaluation (load best, compute metrics, save predictions) ===
# reload best
model.load_state_dict(torch.load('best_maskrcnn_custombackbone.pth', map_location=DEVICE))
model.eval()

agg = dict(tp=0,fp=0,fn=0,tn=0)
all_preds_semantic = []
with torch.no_grad():
    for images, targets in tqdm(test_loader, desc="Test", leave=False):
        images = [img.to(DEVICE) for img in images]
        targets_cpu = [{k:v for k,v in t.items()} for t in targets]  # keep on CPU for metric fn
        preds = model(images)
        preds_cpu = [{k:v.detach().cpu() if isinstance(v, torch.Tensor) else v for k,v in p.items()} for p in preds]
        mets = batch_metrics_from_outputs(preds_cpu, targets_cpu, score_thresh=0.5, mask_thresh=0.3)
        for k in agg: agg[k] += mets[k]
        for p in preds_cpu:
            H, W = IMG_SIZE[1], IMG_SIZE[0]
            pred_sem = _to_binary_semantic_from_instances(p, (H,W), score_thresh=0.5, mask_thresh=0.3)
            all_preds_semantic.append(pred_sem.unsqueeze(0))

if len(all_preds_semantic) > 0:
    all_preds_tensor = torch.cat(all_preds_semantic, dim=0)
else:
    all_preds_tensor = torch.zeros((0, IMG_SIZE[1], IMG_SIZE[0]))

tp,fp,fn,tn = agg['tp'],agg['fp'],agg['fn'],agg['tn']
eps=1e-8
iou = tp/(tp+fp+fn+eps)
dice = (2*tp)/(2*tp+fp+fn+eps)
precision = tp/(tp+fp+eps) if (tp+fp)>0 else 0
recall = tp/(tp+fn+eps) if (tp+fn)>0 else 0
f1 = 2*precision*recall/(precision+recall+eps) if (precision+recall)>0 else 0
acc = (tp+tn)/(tp+tn+fp+fn+eps)

cm = np.array([[tn, fp],[fn, tp]])
metrics = dict(IoU=iou, Dice=dice, Precision=precision, Recall=recall, F1=f1, Accuracy=acc, TP=tp, FP=fp, FN=fn, TN=tn)

print("\nTest Metrics (Mask R-CNN custom backbone):")
print(f"IoU: {iou:.4f}, Dice: {dice:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")

pd.DataFrame([metrics]).to_csv('test_metrics_maskrcnn_custombackbone.csv', index=False)
np.savetxt('confusion_matrix_maskrcnn_custombackbone.txt', cm, fmt='%d')

os.makedirs('test_predictions_maskrcnn_custombackbone', exist_ok=True)
for i in range(min(10, all_preds_tensor.shape[0])):
    img = (all_preds_tensor[i].numpy()*255).astype('uint8')
    Image.fromarray(img).save(f'test_predictions_maskrcnn_custombackbone/pred_{i}.png')
print("Saved some test prediction masks.")

In [None]:
# === Cell 9: Visualization (training curves) ===
hist_df = pd.read_csv('training_history_maskrcnn_custombackbone.csv')
fig, ((ax1, ax2, ax3),(ax4, ax5, ax6)) = plt.subplots(2,3, figsize=(16,8))

ax1.plot(hist_df['epoch'], hist_df['train_loss'], label='Train Loss'); ax1.plot(hist_df['epoch'], hist_df['val_loss'], label='Val Loss'); ax1.set_title('Loss Curves'); ax1.legend(); ax1.grid(True, alpha=0.3)
ax2.plot(hist_df['epoch'], hist_df['IoU']); ax2.set_title('Validation IoU'); ax3.plot(hist_df['epoch'], hist_df['Dice']); ax3.set_title('Validation Dice')
ax4.plot(hist_df['epoch'], hist_df['Precision'], label='Precision'); ax4.plot(hist_df['epoch'], hist_df['Recall'], label='Recall'); ax4.legend(); ax4.set_title('Precision & Recall')
ax5.plot(hist_df['epoch'], hist_df['F1'], label='F1'); ax5.plot(hist_df['epoch'], hist_df['Accuracy'], label='Accuracy'); ax5.legend(); ax5.set_title('F1 & Accuracy')
ax6.axis('off')
if len(hist_df)>0:
    best_epoch = hist_df.loc[hist_df['val_loss'].idxmin(), 'epoch']
    best_val = hist_df['val_loss'].min()
    best_iou = hist_df['IoU'].max()
    best_dice = hist_df['Dice'].max()
    summary_text = f"Total epochs: {len(hist_df)}\\nBest Epoch: {best_epoch}\\nVal Loss: {best_val:.4f}\\nIoU: {best_iou:.4f}\\nDice: {best_dice:.4f}"
    ax6.text(0.05, 0.5, summary_text, fontsize=10, fontfamily='monospace', bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray"))

plt.tight_layout()
plt.savefig('training_curves_maskrcnn_custombackbone.png', dpi=150, bbox_inches='tight')
plt.show()
print("Visualization saved.")

print("Total model params:", f"{total_params:,}")
if gflops is not None:
    print(f"Approx backbone+FPN GFLOPs for input {IMG_SIZE[1]}x{IMG_SIZE[0]}: {gflops:.3f} GFLOPs (approx).")
    print("Total model GFLOPs (including RPN & RoIHeads) will be higher and depends on proposals; this is an approximation for backbone+FPN.")
else:
    print("GFLOPs not available (profiling failed).")