# Glandu segmentācija ar U-Net un FPN (segmentation_models_pytorch)

Šajā notebook'ā mēs:

- ielādēsim glandu segmentācijas datu kopu ar `GlandSegmentationDataset`,
- izveidosim treniņa un validācijas `DataLoader`,
- trenēsim divus modeļus no `segmentation_models_pytorch`:
  - U-Net
  - FPN
- salīdzināsim to kvalitāti:
  - loss,
  - IoU / Dice,
  - vizuālie rezultāti (attēls + maska).

In [None]:
from pathlib import Path
from typing import Dict, List, Tuple

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader, random_split

import segmentation_models_pytorch as smp

from gland_dataset import GlandSegmentationDataset  

In [None]:
GLAND_ROOT = Path("data/gland_seg")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
# Izveidojam visu training kopu
full_train_ds = GlandSegmentationDataset(GLAND_ROOT, split="training", transform=None)

len(full_train_ds)

In [None]:
# Sadalām training kopu train/val (piem., 80%/20%)
val_fraction = 0.2
val_size = int(len(full_train_ds) * val_fraction)
train_size = len(full_train_ds) - val_size

train_ds, val_ds = random_split(full_train_ds, [train_size, val_size])

len(train_ds), len(val_ds)

In [None]:
# DataLoaderi
batch_size = 4  # var koriģēt atkarībā no GPU/CPU iespējām

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2)

In [None]:
# Loss funkcijas (binary segmentācijai)
loss_fn = smp.losses.DiceLoss(mode="binary")  # var arī kombinēt ar BCE

# Metriku aprēķinam
iou_metric = smp.utils.metrics.IoU(threshold=0.5)   # IoU binārai segmentācijai
fscore_metric = smp.utils.metrics.Fscore(threshold=0.5)  # tas būtībā ir Dice/F1

In [None]:
def train_one_epoch(model, loader, optimizer, device):
    model.train()
    epoch_loss = 0.0
    epoch_iou = 0.0
    epoch_dice = 0.0
    n_batches = 0

    for imgs, masks, _ in loader:
        imgs = imgs.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()

        # Modeļa output: loģit vai probability, atkarībā no SMP iestatījumiem.
        # Noklusēti SMP dod loģitus (bez sigmoid) binary gadījumā.
        logits = model(imgs)

        loss = loss_fn(logits, masks)
        loss.backward()
        optimizer.step()

        # Metriku aprēķins
        with torch.no_grad():
            probs = torch.sigmoid(logits)
            iou = iou_metric(probs, masks)
            dice = fscore_metric(probs, masks)

        epoch_loss += loss.item()
        epoch_iou += iou.item()
        epoch_dice += dice.item()
        n_batches += 1

    return epoch_loss / n_batches, epoch_iou / n_batches, epoch_dice / n_batches

In [None]:
def validate_one_epoch(model, loader, device):
    model.eval()
    epoch_loss = 0.0
    epoch_iou = 0.0
    epoch_dice = 0.0
    n_batches = 0

    with torch.no_grad():
        for imgs, masks, _ in loader:
            imgs = imgs.to(device)
            masks = masks.to(device)

            logits = model(imgs)
            loss = loss_fn(logits, masks)

            probs = torch.sigmoid(logits)
            iou = iou_metric(probs, masks)
            dice = fscore_metric(probs, masks)

            epoch_loss += loss.item()
            epoch_iou += iou.item()
            epoch_dice += dice.item()
            n_batches += 1

    return epoch_loss / n_batches, epoch_iou / n_batches, epoch_dice / n_batches

In [None]:
def train_model(
    model,
    train_loader,
    val_loader,
    device,
    epochs: int = 10,
    lr: float = 1e-3,
) -> Dict[str, List[float]]:
    """
    Treniņa cikls, kas atgriež metriku vēsturi.
    """
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    history = {
        "train_loss": [],
        "train_iou": [],
        "train_dice": [],
        "val_loss": [],
        "val_iou": [],
        "val_dice": [],
    }

    for epoch in range(1, epochs + 1):
        train_loss, train_iou, train_dice = train_one_epoch(model, train_loader, optimizer, device)
        val_loss, val_iou, val_dice = validate_one_epoch(model, val_loader, device)

        history["train_loss"].append(train_loss)
        history["train_iou"].append(train_iou)
        history["train_dice"].append(train_dice)
        history["val_loss"].append(val_loss)
        history["val_iou"].append(val_iou)
        history["val_dice"].append(val_dice)

        print(
            f"Epoch {epoch:02d}: "
            f"train_loss={train_loss:.4f}, train_IoU={train_iou:.4f}, train_Dice={train_dice:.4f} | "
            f"val_loss={val_loss:.4f}, val_IoU={val_iou:.4f}, val_Dice={val_dice:.4f}"
        )

    return model, history

In [None]:
# Vienota encoder izvēle abiem modeļiem
ENCODER_NAME = "resnet34"   # Kā alternatīvu var izmantot "efficientnet-b0"
ENCODER_WEIGHTS = "imagenet"

unet_model = smp.Unet(                  # smp.Unet vietā var izmantot smp.FPN vai smp.DeepLabV3
    encoder_name=ENCODER_NAME,
    encoder_weights=ENCODER_WEIGHTS,
    in_channels=3,
    classes=1,          # bināra segmentācija
)

unet_epochs = 10  # vari mainīt pēc resursiem
unet_model, unet_history = train_model(
    unet_model,
    train_loader,
    val_loader,
    device,
    epochs=unet_epochs,
    lr=1e-3,
)

In [None]:
def plot_history(history_unet, history_fpn, metric_key: str, title: str):
    epochs_unet = range(1, len(history_unet[metric_key]) + 1)
    epochs_fpn = range(1, len(history_fpn[metric_key]) + 1)

    plt.figure(figsize=(6, 4))
    plt.plot(epochs_unet, history_unet[metric_key], "-o", label=f"U-Net {metric_key}")
    plt.plot(epochs_fpn, history_fpn[metric_key], "-o", label=f"FPN {metric_key}")
    plt.xlabel("Epoch")
    plt.ylabel(metric_key)
    plt.title(title)
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

In [None]:
plot_history(unet_history, fpn_history, "val_loss", "Validācijas loss: U-Net vs FPN")
plot_history(unet_history, fpn_history, "val_iou", "Validācijas IoU: U-Net vs FPN")
plot_history(unet_history, fpn_history, "val_dice", "Validācijas Dice (F1): U-Net vs FPN")

In [None]:
# Paņemsim dažus piemērus no val_loader
val_batch = next(iter(val_loader))

imgs, masks, fnames = val_batch
imgs = imgs.to(device)
masks = masks.to(device)

with torch.no_grad():
    unet_logits = unet_model(imgs)
    fpn_logits = fpn_model(imgs)

unet_probs = torch.sigmoid(unet_logits)
fpn_probs = torch.sigmoid(fpn_logits)

# Binarizējam ar threshold 0.5
unet_pred = (unet_probs > 0.5).float()
fpn_pred = (fpn_probs > 0.5).float()

In [None]:
# Vizualizēsim 3 paraugus (vai mazāk, ja batch mazāks)
n_show = min(3, imgs.size(0))

for i in range(n_show):
    img_np = imgs[i].detach().cpu().permute(1, 2, 0).numpy()  # (H,W,3)
    gt_np = masks[i].detach().cpu().squeeze(0).numpy()       # (H,W)
    unet_np = unet_pred[i].detach().cpu().squeeze(0).numpy()
    fpn_np = fpn_pred[i].detach().cpu().squeeze(0).numpy()

    fig, axs = plt.subplots(1, 4, figsize=(16, 4))
    axs[0].imshow(img_np)
    axs[0].set_title(f"{fnames[i]} - attēls")
    axs[0].axis("off")

    axs[1].imshow(gt_np, cmap="gray")
    axs[1].set_title("Ground truth maska")
    axs[1].axis("off")

    axs[2].imshow(unet_np, cmap="gray")
    axs[2].set_title("U-Net prognoze")
    axs[2].axis("off")

    axs[3].imshow(fpn_np, cmap="gray")
    axs[3].set_title("FPN prognoze")
    axs[3].axis("off")

    plt.tight_layout()
    plt.show()

In [None]:
MODELS_DIR = Path("models")
MODELS_DIR.mkdir(exist_ok=True, parents=True)

unet_path = MODELS_DIR / "gland_unet.pth"
fpn_path = MODELS_DIR / "gland_fpn.pth"

torch.save(unet_model.state_dict(), unet_path)
torch.save(fpn_model.state_dict(), fpn_path)

unet_path, fpn_path

## Kopsavilkums

- Uz viena un tā paša datu formāta un `GlandSegmentationDataset` mēs:

  - uztrenējām **U-Net** modeli glandu segmentācijai,
  - uztrenējām **FPN** modeli ar to pašu encoderi (`resnet34`).

- Salīdzinājām:

  - **loss** (train/val),
  - **IoU** un **Dice (F1)**,
  - vizuālo kvalitāti uz dažiem validācijas attēliem.