PyTorch + SMp for deep learning

rasterio for reading image patches

numpy, random, pandas for arrays, randomness, and metrics

torchvision.transforms.functional for augmentations

sklearn.metrics for evaluation metrics

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import segmentation_models_pytorch as smp
import rasterio
from pathlib import Path
import numpy as np
import random
import torchvision.transforms.functional as TF
import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix


Defines directories for images, masks, model weights, and metrics.

Creates output folders if missing.

Sets batch size and number of epochs.

In [None]:
# Base paths
BASE_PATH = Path("/content/cafo_project")
IMG_DIR = BASE_PATH / "patches/images"
MASK_DIR = BASE_PATH / "patches/masks"

# Output paths
MODEL_OUTPUT_PATH = BASE_PATH / "weights/cafo_multi_patch.pt"
METRICS_OUTPUT_PATH = BASE_PATH / "results/training_metrics.csv"
MODEL_OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)
METRICS_OUTPUT_PATH.parent.mkdir(parents=True, exist_ok=True)

# Training parameters
BATCH_SIZE = 16
EPOCHS = 75


Reads image and mask patches.

Applies random flips and rotations for augmentation.

Converts images/masks to PyTorch tensors.

In [None]:
class PatchDataset(Dataset):
    def __init__(self, img_dir, mask_dir, augment=True):
        self.img_dir = img_dir
        self.mask_dir = mask_dir
        self.image_files = sorted(img_dir.glob("*.tif"))
        self.augment = augment

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        mask_path = self.mask_dir / img_path.name.replace(".tif", "_mask.tif")

        with rasterio.open(img_path) as img_src:
            img = img_src.read([1, 2, 3]) / 255.0
        with rasterio.open(mask_path) as mask_src:
            mask = mask_src.read(1)

        img = torch.tensor(img, dtype=torch.float32)
        mask = torch.tensor(mask, dtype=torch.float32).unsqueeze(0)  # add channel

        # Data augmentation
        if self.augment:
            if random.random() > 0.5:
                img = TF.hflip(img)
                mask = TF.hflip(mask)
            if random.random() > 0.5:
                img = TF.vflip(img)
                mask = TF.vflip(mask)
            if random.random() > 0.5:
                k = random.choice([1, 2, 3])
                img = torch.rot90(img, k, dims=[1, 2])
                mask = torch.rot90(mask, k, dims=[1, 2])

        return img, mask


Wraps dataset in a DataLoader for batching and shuffling.

In [None]:
dataset = PatchDataset(IMG_DIR, MASK_DIR, augment=True)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)


UNet with ResNet18 backbone pretrained on ImageNet.

Binary segmentation → use BCEWithLogitsLoss.

Optimizer → Adam.

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = smp.Unet(encoder_name="resnet18", encoder_weights="imagenet", in_channels=3, classes=1)
model = model.to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()


Training with metric calculation per epoch.

Stores metrics for later analysis.

In [None]:
print("\n🚀 Starting patch training loop with metrics tracking...")
model.train()
metrics_log = []

for epoch in range(EPOCHS):
    epoch_loss = 0.0
    all_preds, all_targets = [], []

    for images, masks in dataloader:
        images, masks = images.to(device), masks.to(device)

        outputs = model(images)
        loss = criterion(outputs, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        probs = torch.sigmoid(outputs)
        preds = (probs > 0.5).float()
        all_preds.append(preds.detach().cpu().numpy().ravel())
        all_targets.append(masks.detach().cpu().numpy().ravel())

    all_preds = np.concatenate(all_preds)
    all_targets = np.concatenate(all_targets)

    # Compute metrics
    acc = accuracy_score(all_targets, all_preds)
    prec = precision_score(all_targets, all_preds, zero_division=0)
    rec = recall_score(all_targets, all_preds, zero_division=0)
    f1 = f1_score(all_targets, all_preds, zero_division=0)
    auc = roc_auc_score(all_targets, all_preds)
    cm = confusion_matrix(all_targets, all_preds)

    avg_loss = epoch_loss / len(dataloader)
    print(f"\nEpoch {epoch + 1}/{EPOCHS} | Loss: {avg_loss:.4f} | Acc: {acc:.4f} | F1: {f1:.4f} | AUC: {auc:.4f}")

    metrics_log.append({
        "Epoch": epoch + 1,
        "Loss": avg_loss,
        "Accuracy": acc,
        "Precision": prec,
        "Recall": rec,
        "F1": f1,
        "AUC-ROC": auc,
        "TP": cm[1, 1] if cm.shape == (2,2) else 0,
        "FP": cm[0, 1] if cm.shape == (2,2) else 0,
        "FN": cm[1, 0] if cm.shape == (2,2) else 0,
        "TN": cm[0, 0] if cm.shape == (2,2) else 0,
    })


Saves trained model weights.

Saves metrics CSV for plotting or further analysis.

In [None]:
torch.save(model.state_dict(), MODEL_OUTPUT_PATH)
print(f"\n✅ Training complete. Model saved to {MODEL_OUTPUT_PATH}")

metrics_df = pd.DataFrame(metrics_log)
metrics_df.to_csv(METRICS_OUTPUT_PATH, index=False)
print(f"📄 Training metrics saved to {METRICS_OUTPUT_PATH}")
