#Training already completed.  

This notebook documents the training setup and logs.


In [None]:
!pip -q install rasterio segmentation-models-pytorch==0.3.3 torchmetrics==1.3.0 albumentations opencv-python scipy tifffile


[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/58.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.8/58.8 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m68.5/68.5 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
Reason for being yanked: <none given>[0m[33m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m106.7/106.7 kB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m840.2/840.2 kB[0m [31m62.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m80.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for efficientnet-pytorch (setup.py) ... [?25l[?25hdone
  Building wheel for pretrainedmodels (setup.py)

In [None]:
# Core
import os, json, time, random
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm

# Imaging / Geo
import rasterio
from rasterio.plot import show
from tqdm import tqdm
import tifffile
from scipy.ndimage import gaussian_filter
from PIL import Image

# DL
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Augmentation
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Models
import segmentation_models_pytorch as smp
import torchvision
from torchvision.models.segmentation import DeepLabV3_ResNet50_Weights

# Metrics
import time
from torchmetrics.classification import MulticlassF1Score, MulticlassJaccardIndex
from torchmetrics.functional.classification import multiclass_accuracy
from torchmetrics.classification import (
    MulticlassF1Score,
    MulticlassJaccardIndex,
    MulticlassPrecision,
    MulticlassRecall,
    MulticlassAccuracy
)

# Reproducibility + device
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)


Device: cuda


**Train/Val/Test Split**

In [None]:
# ---- Paths ----
TILES_BASE = Path("/content/drive/MyDrive/ResearchProject (1)/dataset_tiles_512")
IMG_DIR = TILES_BASE / "images"
MSK_DIR = TILES_BASE / "masks"

assert IMG_DIR.exists(), f"Missing: {IMG_DIR}"
assert MSK_DIR.exists(), f"Missing: {MSK_DIR}"

# ---- Pairs ----
img_files = sorted(IMG_DIR.glob("*.png"))

pairs = []
missing = 0
for img_fp in img_files:
    msk_fp = MSK_DIR / f"{img_fp.stem}_mask.png"
    if not msk_fp.exists():
        missing += 1
        continue
    pairs.append((img_fp, msk_fp))

print("Num image tiles found:", len(img_files))
print("Paired tiles:", len(pairs))
print("Missing masks:", missing)

# ---- Split train/val/test ----
SEED = 42
random.seed(SEED)

idx = list(range(len(pairs)))
random.shuffle(idx)

train_ratio, val_ratio = 0.80, 0.10
n = len(idx)
train_end = int(train_ratio * n)
val_end   = int((train_ratio + val_ratio) * n)

train_idx = idx[:train_end]
val_idx   = idx[train_end:val_end]
test_idx  = idx[val_end:]

print(f"Split sizes -> train: {len(train_idx)}, val: {len(val_idx)}, test: {len(test_idx)}")


Num image tiles found: 1215
Paired tiles: 1215
Missing masks: 0
Split sizes -> train: 972, val: 121, test: 122


**Dataset Class + Data Augmentations + Data Loader**

In [None]:
# ---- Augmentations ----
# Train gets augmentation; Val/Test do not (only normalize + to tensor)
train_tf = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.10, rotate_limit=15, p=0.5),
    A.RandomBrightnessContrast(p=0.3),
    A.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)),  # keeps your existing [0,1] scale logic simple
    ToTensorV2()
])

eval_tf = A.Compose([
    A.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)),
    ToTensorV2()
])


# ---- Dataset ----
class TileSegDataset(Dataset):
    """
    pairs: list[(img_fp, msk_fp)]
    indices: list[int] selecting which pairs belong to this split
    transform: Albumentations Compose (expects image HWC, mask HW)
    """
    def __init__(self, pairs, indices, transform=None):
        self.pairs = pairs
        self.indices = indices
        self.transform = transform

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        img_fp, msk_fp = self.pairs[self.indices[i]]

        # RGB image uint8 (H,W,3)
        image = np.array(Image.open(img_fp).convert("RGB"), dtype=np.uint8)

        # mask uint8 (H,W) with labels {0,1,2,3}
        mask = np.array(Image.open(msk_fp), dtype=np.uint8)
        if mask.ndim == 3:          # safety (if saved as RGB accidentally)
            mask = mask[..., 0]

        # apply augmentations (mask stays integer labels)
        if self.transform is not None:
            out = self.transform(image=image, mask=mask)
            image = out["image"]                # torch float32 (3,H,W)
            mask  = out["mask"].long()          # torch int64 (H,W)
        else:
            # fallback (no albumentations): manual conversion
            image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
            mask  = torch.from_numpy(mask).long()

        return image, mask


# ---- Datasets + Loaders ----
BATCH_SIZE = 4

train_ds = TileSegDataset(pairs, train_idx, transform=train_tf)
val_ds   = TileSegDataset(pairs, val_idx,   transform=eval_tf)
test_ds  = TileSegDataset(pairs, test_idx,  transform=eval_tf)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

# ---- Quick sanity check ----
x, y = next(iter(train_loader))
print("Batch images:", x.shape, x.dtype)   # (B,3,512,512), float32
print("Batch masks :", y.shape, y.dtype)   # (B,512,512), int64
print("Unique labels in batch:", torch.unique(y))


  original_init(self, **validated_kwargs)


Batch images: torch.Size([4, 3, 512, 512]) torch.float32
Batch masks : torch.Size([4, 512, 512]) torch.int64
Unique labels in batch: tensor([0, 1, 2, 3])


**Metrics (F1,IoU,Precision/Recall per class, Pixel Accuracy)**

In [None]:
NUM_CLASSES = 4

def build_metrics(device):
    metrics = {
        "f1_macro": MulticlassF1Score(num_classes=NUM_CLASSES, average="macro").to(device),
        "f1_per_class": MulticlassF1Score(num_classes=NUM_CLASSES, average=None).to(device),

        "iou_macro": MulticlassJaccardIndex(num_classes=NUM_CLASSES, average="macro").to(device),
        "iou_per_class": MulticlassJaccardIndex(num_classes=NUM_CLASSES, average=None).to(device),

        "prec_per_class": MulticlassPrecision(num_classes=NUM_CLASSES, average=None).to(device),
        "rec_per_class": MulticlassRecall(num_classes=NUM_CLASSES, average=None).to(device),

        "pixel_acc": MulticlassAccuracy(num_classes=NUM_CLASSES, average="micro").to(device),
    }
    return metrics

@torch.no_grad()
def reset_metrics(metrics: dict):
    for m in metrics.values():
        m.reset()

@torch.no_grad()
def update_metrics(metrics: dict, preds: torch.Tensor, targets: torch.Tensor):
    """
    preds: (B,H,W) int64
    targets: (B,H,W) int64
    """
    for m in metrics.values():
        m.update(preds, targets)

@torch.no_grad()
def compute_metrics(metrics: dict):
    out = {}
    for k, m in metrics.items():
        val = m.compute()
        out[k] = val.detach().cpu()
    return out

**Training Loop(Epochs)**

In [None]:
def run_one_epoch(model, loader, optimizer, device, train=True):
    if train:
        model.train()
    else:
        model.eval()

    metrics = build_metrics(device)
    reset_metrics(metrics)

    total_loss = 0.0
    n_batches = 0

    start = time.perf_counter()

    for x, y in loader:
        x = x.to(device, non_blocking=True)   # (B,3,512,512)
        y = y.to(device, non_blocking=True)   # (B,512,512)

        if train:
            optimizer.zero_grad(set_to_none=True)

        with torch.set_grad_enabled(train):
            logits = model(x)                 # expected (B,C,H,W)
            # Some torchvision seg models return dict: {"out": ...}
            if isinstance(logits, dict):
                logits = logits["out"]

            loss = combined_loss(logits, y)

            if train:
                loss.backward()
                optimizer.step()

        total_loss += loss.item()
        n_batches += 1

        preds = torch.argmax(logits, dim=1)   # (B,H,W)
        update_metrics(metrics, preds, y)

    elapsed = time.perf_counter() - start

    metrics_out = compute_metrics(metrics)
    avg_loss = total_loss / max(1, n_batches)

    return avg_loss, metrics_out, elapsed


In [None]:
def train_model(
    model,
    train_loader,
    val_loader,
    model_name: str,
    device,
    lr=1e-4,
    weight_decay=1e-4,
    epochs=10,
    results_dir="/content/drive/MyDrive/ResearchProject (1)/model_results"
):
    results_dir = Path(results_dir)
    results_dir.mkdir(parents=True, exist_ok=True)

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

    history = []
    best = {
        "epoch": None,
        "val_f1_macro": -1.0,
        "path_ckpt": None,
        "path_metrics": None
    }

    for epoch in range(1, epochs + 1):
        tr_loss, tr_metrics, tr_time = run_one_epoch(model, train_loader, optimizer, device, train=True)
        va_loss, va_metrics, va_time = run_one_epoch(model, val_loader, optimizer, device, train=False)

        # choose best by val mean F1 (macro)
        val_f1 = float(va_metrics["f1_macro"].item())

        row = {
            "epoch": epoch,
            "train_loss": tr_loss,
            "val_loss": va_loss,
            "train_time_sec": tr_time,
            "val_time_sec": va_time,

            "train_f1_macro": float(tr_metrics["f1_macro"].item()),
            "val_f1_macro": val_f1,

            "train_iou_macro": float(tr_metrics["iou_macro"].item()),
            "val_iou_macro": float(va_metrics["iou_macro"].item()),

            "val_pixel_acc": float(va_metrics["pixel_acc"].item()),

            "val_f1_per_class": va_metrics["f1_per_class"].tolist(),
            "val_iou_per_class": va_metrics["iou_per_class"].tolist(),
            "val_prec_per_class": va_metrics["prec_per_class"].tolist(),
            "val_rec_per_class": va_metrics["rec_per_class"].tolist(),
        }
        history.append(row)

        print(
            f"[{model_name}] Epoch {epoch:02d}/{epochs} | "
            f"val_f1={row['val_f1_macro']:.4f} val_iou={row['val_iou_macro']:.4f} "
            f"loss={row['val_loss']:.4f} | "
            f"time(train/val)={tr_time:.1f}s/{va_time:.1f}s"
        )

        if val_f1 > best["val_f1_macro"]:
            best["val_f1_macro"] = val_f1
            best["epoch"] = epoch

            ckpt_path = results_dir / f"{model_name}_best.pt"
            metrics_path = results_dir / f"{model_name}_best_metrics.json"

            torch.save(model.state_dict(), ckpt_path)
            with open(metrics_path, "w") as f:
                json.dump({"best": row, "history": history}, f, indent=2)

            best["path_ckpt"] = str(ckpt_path)
            best["path_metrics"] = str(metrics_path)

    print("\nBest epoch:", best["epoch"], "| Best val mean F1:", best["val_f1_macro"])
    print("Saved checkpoint:", best["path_ckpt"])
    print("Saved metrics:", best["path_metrics"])

    return best, history


**Loss Function (CE + Dice)**

In [None]:
ce_loss = nn.CrossEntropyLoss()

def dice_loss(logits, targets, num_classes=4, eps=1e-6):
    """
    logits: (B, C, H, W)
    targets: (B, H, W) with class labels
    """
    probs = F.softmax(logits, dim=1)              # (B,C,H,W)
    targets_oh = F.one_hot(targets, num_classes)  # (B,H,W,C)
    targets_oh = targets_oh.permute(0, 3, 1, 2).float()

    dims = (0, 2, 3)
    intersection = torch.sum(probs * targets_oh, dims)
    union = torch.sum(probs + targets_oh, dims)

    dice = (2 * intersection + eps) / (union + eps)
    return 1 - dice.mean()

def combined_loss(logits, targets):
    return ce_loss(logits, targets) + dice_loss(logits, targets)

**U-Net(SMP) Model**

In [None]:
MODEL_NAME = "FullUNet_23L"

ENCODER = "resnet34"
ENCODER_WEIGHTS = "imagenet"  # good default

model_unet = smp.Unet(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS,
    in_channels=3,
    classes=4,
    activation=None
).to(DEVICE)

print("Model:", MODEL_NAME)
print("Encoder:", ENCODER)

# ---- Train and save BEST epoch (by val mean F1) ----
best_unet, hist_unet = train_model(
    model=model_unet,
    train_loader=train_loader,
    val_loader=val_loader,
    model_name=MODEL_NAME,
    device=DEVICE,
    lr=1e-4,
    weight_decay=1e-4,
    epochs=10,
    results_dir="/content/drive/MyDrive/ResearchProject (1)/model_results"
)


Model: FullUNet_23L
Encoder: resnet34
[FullUNet_23L] Epoch 01/10 | val_f1=0.6065 val_iou=0.5108 loss=1.0501 | time(train/val)=14.9s/1.8s
[FullUNet_23L] Epoch 02/10 | val_f1=0.5336 val_iou=0.4338 loss=1.2259 | time(train/val)=14.0s/1.7s
[FullUNet_23L] Epoch 03/10 | val_f1=0.5764 val_iou=0.4830 loss=0.9898 | time(train/val)=13.8s/1.8s
[FullUNet_23L] Epoch 04/10 | val_f1=0.7262 val_iou=0.6129 loss=0.8165 | time(train/val)=14.3s/1.8s
[FullUNet_23L] Epoch 05/10 | val_f1=0.7974 val_iou=0.6816 loss=0.7413 | time(train/val)=15.9s/1.8s
[FullUNet_23L] Epoch 06/10 | val_f1=0.7798 val_iou=0.6675 loss=0.7330 | time(train/val)=16.0s/1.8s
[FullUNet_23L] Epoch 07/10 | val_f1=0.7371 val_iou=0.6168 loss=0.7934 | time(train/val)=14.6s/1.7s
[FullUNet_23L] Epoch 08/10 | val_f1=0.7973 val_iou=0.6840 loss=0.6991 | time(train/val)=13.9s/1.7s
[FullUNet_23L] Epoch 09/10 | val_f1=0.7604 val_iou=0.6375 loss=0.7617 | time(train/val)=13.9s/1.8s
[FullUNet_23L] Epoch 10/10 | val_f1=0.7089 val_iou=0.5948 loss=0.8060 |

**DeepLabV3-ResNet50 Model**

In [None]:
MODEL_NAME = "DeepLabV3_ResNet50"
NUM_CLASSES = 4

# Load pretrained backbone weights
model = torchvision.models.segmentation.deeplabv3_resnet50(
    weights=torchvision.models.segmentation.DeepLabV3_ResNet50_Weights.DEFAULT
)

# Replaced classifier head so output channels = num classes
in_ch = model.classifier[-1].in_channels
model.classifier[-1] = nn.Conv2d(in_ch, NUM_CLASSES, kernel_size=1)

# If aux classifier exists, match it too (safe)
if model.aux_classifier is not None:
    in_ch_aux = model.aux_classifier[-1].in_channels
    model.aux_classifier[-1] = nn.Conv2d(in_ch_aux, NUM_CLASSES, kernel_size=1)

model = model.to(DEVICE)

print("Model:", MODEL_NAME)
print("Output head channels:", NUM_CLASSES)


Downloading: "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth" to /root/.cache/torch/hub/checkpoints/deeplabv3_resnet50_coco-cd0a2569.pth


100%|██████████| 161M/161M [00:01<00:00, 162MB/s]


Model: DeepLabV3_ResNet50
Output head channels: 4


**Train DeepLabV3-ResNet50 Model**

In [None]:
best_deeplab, hist_deeplab = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    model_name=MODEL_NAME,
    device=DEVICE,
    lr=1e-4,
    weight_decay=1e-4,
    epochs=10,
    results_dir="/content/drive/MyDrive/ResearchProject (1)/model_results"
)


[DeepLabV3_ResNet50] Epoch 01/10 | val_f1=0.5693 val_iou=0.4734 loss=1.1078 | time(train/val)=630.4s/72.5s
[DeepLabV3_ResNet50] Epoch 02/10 | val_f1=0.6750 val_iou=0.5803 loss=0.8432 | time(train/val)=24.2s/1.7s
[DeepLabV3_ResNet50] Epoch 03/10 | val_f1=0.5671 val_iou=0.4651 loss=1.1047 | time(train/val)=28.1s/1.6s
[DeepLabV3_ResNet50] Epoch 04/10 | val_f1=0.7805 val_iou=0.6594 loss=0.7797 | time(train/val)=24.3s/1.7s
[DeepLabV3_ResNet50] Epoch 05/10 | val_f1=0.7585 val_iou=0.6296 loss=0.8537 | time(train/val)=28.1s/1.7s
[DeepLabV3_ResNet50] Epoch 06/10 | val_f1=0.7083 val_iou=0.5641 loss=0.9607 | time(train/val)=24.3s/1.7s
[DeepLabV3_ResNet50] Epoch 07/10 | val_f1=0.8105 val_iou=0.6943 loss=0.6937 | time(train/val)=24.5s/1.7s
[DeepLabV3_ResNet50] Epoch 08/10 | val_f1=0.7248 val_iou=0.5793 loss=0.9011 | time(train/val)=28.3s/1.7s
[DeepLabV3_ResNet50] Epoch 09/10 | val_f1=0.6324 val_iou=0.5001 loss=1.0434 | time(train/val)=24.4s/1.7s
[DeepLabV3_ResNet50] Epoch 10/10 | val_f1=0.8052 val_

**U-Net + SCSE (ResNet-34 encoder) Model**

In [None]:
MODEL_NAME = "UNet_SCSE_ResNet34"
ENCODER = "resnet34"
ENCODER_WEIGHTS = "imagenet"

model_scas = smp.Unet(
    encoder_name=ENCODER,
    encoder_weights=ENCODER_WEIGHTS,
    in_channels=3,
    classes=4,
    activation=None,
    decoder_attention_type="scse"  #attention
).to(DEVICE)

print("Model:", MODEL_NAME)
print("Encoder:", ENCODER)
print("Decoder attention:", "scse")


Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth


100%|██████████| 83.3M/83.3M [00:00<00:00, 229MB/s]


Model: UNet_SCSE_ResNet34
Encoder: resnet34
Decoder attention: scse


**Train U-Net + SCSE (ResNet-34 encoder)**

In [None]:
best_scas, hist_scas = train_model(
    model=model_scas,
    train_loader=train_loader,
    val_loader=val_loader,
    model_name=MODEL_NAME,
    device=DEVICE,
    lr=1e-4,
    weight_decay=1e-4,
    epochs=10,
    results_dir="/content/drive/MyDrive/ResearchProject (1)/model_results"
)


[UNet_SCSE_ResNet34] Epoch 01/10 | val_f1=0.6051 val_iou=0.5125 loss=1.0524 | time(train/val)=15.7s/1.9s
[UNet_SCSE_ResNet34] Epoch 02/10 | val_f1=0.6269 val_iou=0.5411 loss=0.9240 | time(train/val)=15.6s/1.8s
[UNet_SCSE_ResNet34] Epoch 03/10 | val_f1=0.5980 val_iou=0.5056 loss=0.9414 | time(train/val)=17.4s/1.7s
[UNet_SCSE_ResNet34] Epoch 04/10 | val_f1=0.6613 val_iou=0.5924 loss=0.7540 | time(train/val)=15.7s/1.8s
[UNet_SCSE_ResNet34] Epoch 05/10 | val_f1=0.7093 val_iou=0.6050 loss=0.7941 | time(train/val)=17.6s/1.7s
[UNet_SCSE_ResNet34] Epoch 06/10 | val_f1=0.7496 val_iou=0.6386 loss=0.7511 | time(train/val)=17.5s/1.7s
[UNet_SCSE_ResNet34] Epoch 07/10 | val_f1=0.7669 val_iou=0.6460 loss=0.8039 | time(train/val)=17.4s/1.7s
[UNet_SCSE_ResNet34] Epoch 08/10 | val_f1=0.7836 val_iou=0.6633 loss=0.7398 | time(train/val)=17.3s/1.7s
[UNet_SCSE_ResNet34] Epoch 09/10 | val_f1=0.6568 val_iou=0.5417 loss=1.0166 | time(train/val)=17.3s/1.7s
[UNet_SCSE_ResNet34] Epoch 10/10 | val_f1=0.7885 val_io

**SPANetFull (ResNet50 backbone) Model**

In [None]:
# ============================================================
# SPANetFull Model (ResNet50 backbone) + Train
# ============================================================

MODEL_NAME = "SPANetFull_ResNet50"
print("Model:", MODEL_NAME)

# ----------------------------
# ResNet50 encoder (low + high)
# ----------------------------
class ResNet50Encoder(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        weights = ResNet50_Weights.DEFAULT if pretrained else None
        base = resnet50(weights=weights)

        self.stem = nn.Sequential(
            base.conv1,
            base.bn1,
            base.relu,
            base.maxpool,   # -> H/4, W/4
        )
        self.layer1 = base.layer1   # -> H/4,  W/4,  C=256
        self.layer2 = base.layer2   # -> H/8,  W/8,  C=512   (low)
        self.layer3 = base.layer3   # -> H/16, W/16, C=1024
        self.layer4 = base.layer4   # -> H/32, W/32, C=2048  (high)

    def forward(self, x):
        x = self.stem(x)
        x = self.layer1(x)
        low = self.layer2(x)         # (B,512,H/8,W/8)
        x = self.layer3(low)
        high = self.layer4(x)        # (B,2048,H/32,W/32)
        return low, high


# ---------------------------------------------
# SPAM block (successive pooling attention)
# ---------------------------------------------
class SPAMBlock(nn.Module):
    def __init__(self, in_channels, pool_sizes=(1, 2, 4), reduction=4):
        super().__init__()
        self.pool_sizes = pool_sizes
        mid_channels = max(in_channels // reduction, 1)

        self.conv_reduce = nn.Conv2d(in_channels * len(pool_sizes), mid_channels, kernel_size=1, bias=False)
        self.bn_reduce = nn.BatchNorm2d(mid_channels)
        self.relu = nn.ReLU(inplace=True)

        self.conv_attn = nn.Conv2d(mid_channels, in_channels, kernel_size=1, bias=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        B, C, H, W = x.shape
        pooled_list = []
        for ps in self.pool_sizes:
            p = F.adaptive_avg_pool2d(x, output_size=(ps, ps))
            p = F.interpolate(p, size=(H, W), mode="bilinear", align_corners=False)
            pooled_list.append(p)

        multi_scale = torch.cat(pooled_list, dim=1)              # (B, C*len(pool_sizes), H, W)
        h = self.relu(self.bn_reduce(self.conv_reduce(multi_scale)))
        attn = self.sigmoid(self.conv_attn(h))                   # (B, C, H, W)
        return x * attn


# ---------------------------------------------
# Feature Fusion Module (high -> gate low)
# ---------------------------------------------
class FeatureFusionModule(nn.Module):
    def __init__(self, low_channels, high_channels):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        mid_channels = max(low_channels // 4, 1)

        self.conv1 = nn.Conv2d(high_channels, mid_channels, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(mid_channels)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(mid_channels, low_channels, kernel_size=1, bias=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, low, high):
        B, C_l, H_l, W_l = low.shape

        h = self.pool(high)                     # (B, C_h, 1, 1)
        h = self.relu(self.bn1(self.conv1(h)))  # (B, mid, 1, 1)
        h = self.sigmoid(self.conv2(h))         # (B, C_l, 1, 1)

        attn = h.expand(-1, -1, H_l, W_l)
        fused = low * attn
        return fused


# ---------------------------------------------
# Full SPANet-style segmentation network
# ---------------------------------------------
class SPANetFull(nn.Module):
    def __init__(self, num_classes=4, pretrained_backbone=True):
        super().__init__()
        self.encoder = ResNet50Encoder(pretrained=pretrained_backbone)

        self.low_ch = 512
        self.high_ch = 2048

        self.spam_low = SPAMBlock(self.low_ch)
        self.spam_high = SPAMBlock(self.high_ch)

        self.ffm = FeatureFusionModule(self.low_ch, self.high_ch)

        # simple decoder
        self.dec_conv1 = nn.Conv2d(self.low_ch, 256, kernel_size=3, padding=1)
        self.dec_bn1 = nn.BatchNorm2d(256)
        self.dec_conv2 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.dec_bn2 = nn.BatchNorm2d(128)

        self.classifier = nn.Conv2d(128, num_classes, kernel_size=1)

    def forward(self, x):
        H, W = x.shape[-2:]
        low, high = self.encoder(x)

        low_enh = self.spam_low(low)
        high_enh = self.spam_high(high)

        fused_low = self.ffm(low_enh, high_enh)         # (B,512,H/8,W/8)

        y = self.dec_bn1(F.relu(self.dec_conv1(fused_low)))
        y = F.interpolate(y, scale_factor=2, mode="bilinear", align_corners=False)  # -> H/4

        y = self.dec_bn2(F.relu(self.dec_conv2(y)))
        y = F.interpolate(y, size=(H, W), mode="bilinear", align_corners=False)     # -> H,W

        logits = self.classifier(y)  # (B, num_classes, H, W)
        return logits


Model: SPANetFull_ResNet50


**Train SPANetFull (ResNet50 backbone) Model**

In [None]:
model_spanet = SPANetFull(num_classes=4, pretrained_backbone=True).to(DEVICE)

best_spanet, hist_spanet = train_model(
    model=model_spanet,
    train_loader=train_loader,
    val_loader=val_loader,
    model_name=MODEL_NAME,
    device=DEVICE,
    lr=1e-4,
    weight_decay=1e-4,
    epochs=10,
    results_dir="/content/drive/MyDrive/ResearchProject (1)/model_results"
)

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 240MB/s]


[SPANetFull_ResNet50] Epoch 01/10 | val_f1=0.7411 val_iou=0.6224 loss=0.9439 | time(train/val)=26.6s/1.8s
[SPANetFull_ResNet50] Epoch 02/10 | val_f1=0.6191 val_iou=0.5049 loss=1.0190 | time(train/val)=26.4s/1.8s
[SPANetFull_ResNet50] Epoch 03/10 | val_f1=0.7708 val_iou=0.6572 loss=0.8157 | time(train/val)=26.2s/1.8s
[SPANetFull_ResNet50] Epoch 04/10 | val_f1=0.6364 val_iou=0.5196 loss=1.0142 | time(train/val)=28.7s/1.7s
[SPANetFull_ResNet50] Epoch 05/10 | val_f1=0.7829 val_iou=0.6647 loss=0.8126 | time(train/val)=26.4s/1.7s
[SPANetFull_ResNet50] Epoch 06/10 | val_f1=0.6893 val_iou=0.5709 loss=0.9100 | time(train/val)=28.5s/1.7s
[SPANetFull_ResNet50] Epoch 07/10 | val_f1=0.7318 val_iou=0.6152 loss=0.8365 | time(train/val)=26.4s/1.7s
[SPANetFull_ResNet50] Epoch 08/10 | val_f1=0.7837 val_iou=0.6579 loss=0.8530 | time(train/val)=26.2s/1.7s
[SPANetFull_ResNet50] Epoch 09/10 | val_f1=0.7901 val_iou=0.6755 loss=0.7504 | time(train/val)=28.6s/1.7s
[SPANetFull_ResNet50] Epoch 10/10 | val_f1=0.8