# U-Net Baseline

Fair comparison setup: shared splits, patch size (64x64), normalization, and random seeds with FCEF baseline.

In [1]:
import importlib
import sys
if 'src.config' in sys.modules:
    importlib.reload(sys.modules['src.config'])

In [2]:
import sys
import random

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from pathlib import Path
import torch.nn.functional as F
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
import wandb
from tqdm import tqdm

root = Path().resolve().parents[0]
sys.path.append(str(root))

from src.config import SENTINEL_DIR, MASK_DIR
from src.data.sentinel_habloss_dataset import SentinelHablossPatchDataset
from src.data.splits import get_splits, get_ref_ids_from_directory
from src.data.transform import compute_normalization_stats

KeyboardInterrupt: 

## Set Random Seeds

In [None]:
RANDOM_SEED = 42

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print(f"All random seeds set to {RANDOM_SEED}")

## Setup Device

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


## Train/Val/Test Splits

Shared splits with FCEF: 70/15/15, random_state=42

In [None]:
all_ref_ids = get_ref_ids_from_directory(SENTINEL_DIR)
print(f"Total reference IDs found: {len(all_ref_ids)}")

train_ref_ids, val_ref_ids, test_ref_ids = get_splits(
    all_ref_ids,
    train_ratio=0.7,
    val_ratio=0.15,
    test_ratio=0.15,
    random_state=RANDOM_SEED,
)

print(f"\nTrain tiles: {len(train_ref_ids)} (~{100*len(train_ref_ids)/len(all_ref_ids):.0f}%)")
print(f"Val tiles: {len(val_ref_ids)} (~{100*len(val_ref_ids)/len(all_ref_ids):.0f}%)")
print(f"Test tiles: {len(test_ref_ids)} (~{100*len(test_ref_ids)/len(all_ref_ids):.0f}%)")
print(f"\nExample train ID: {train_ref_ids[0]}")
print(f"\n✓ Using SHARED splits with FCEF baseline (random_state={RANDOM_SEED})")

Total Sentinel files found: 34

Train files: 23 (~68%)
Val files: 5 (~15%)
Test files: 6 (~18%)

Example train file: a16-15317118163363_45-7859705066069_RGBNIRRSWIRQ_Mosaic.tif


## Create Datasets

Shared normalization: scale by 10000 + mean/std standardization (training set only)
Patch size: 64x64

In [None]:
PATCH_SIZE = 64

temp_train_ds = SentinelHablossPatchDataset(
    SENTINEL_DIR, MASK_DIR, 
    patch_size=PATCH_SIZE,
    patches_per_image=5,
    mean=None,
    std=None, 
    augment=False,
    ref_ids=train_ref_ids
)

print("Estimating per-channel mean and std from training data...")
mean, std = compute_normalization_stats(temp_train_ds, num_samples=2000)
print(f"✓ Computed normalization stats: {len(mean)} channels")
print(f"  Mean (first 5): {[f'{m:.4f}' for m in mean[:5]]}")
print(f"  Std (first 5): {[f'{s:.4f}' for s in std[:5]]}")

# Create actual datasets with shared normalization
train_ds = SentinelHablossPatchDataset(
    SENTINEL_DIR, MASK_DIR, 
    patch_size=PATCH_SIZE,
    patches_per_image=20, 
    mean=mean, 
    std=std, 
    augment=True,
    ref_ids=train_ref_ids
)

val_ds = SentinelHablossPatchDataset(
    SENTINEL_DIR, MASK_DIR, 
    patch_size=PATCH_SIZE,
    patches_per_image=10, 
    mean=mean, 
    std=std, 
    augment=False,
    ref_ids=val_ref_ids
)

test_ds = SentinelHablossPatchDataset(
    SENTINEL_DIR, MASK_DIR, 
    patch_size=PATCH_SIZE,
    patches_per_image=10, 
    mean=mean, 
    std=std, 
    augment=False,
    ref_ids=test_ref_ids
)

# Worker init function for reproducible shuffling
def worker_init_fn(worker_id):
    worker_seed = RANDOM_SEED + worker_id
    np.random.seed(worker_seed)
    random.seed(worker_seed)

# Create dataloaders with reproducible shuffling
train_loader = DataLoader(
    train_ds, 
    batch_size=8, 
    shuffle=True, 
    num_workers=0,
    worker_init_fn=worker_init_fn,
    generator=torch.Generator().manual_seed(RANDOM_SEED)
)
val_loader = DataLoader(val_ds, batch_size=8, shuffle=False, num_workers=0)
test_loader = DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=0)

print(f"\n✓ Datasets created with SHARED normalization and patch_size={PATCH_SIZE}")
print(f"Training patches: {len(train_ds)} (from {len(train_ref_ids)} tiles)")
print(f"Validation patches: {len(val_ds)} (from {len(val_ref_ids)} tiles)")
print(f"Test patches: {len(test_ds)} (from {len(test_ref_ids)} tiles)")
print(f"Number of input channels: {train_ds.num_bands}")

Estimating per-channel mean and std from training data...
Computed mean: 126 values
Computed std: 126 values

Training patches: 460 (from 23 files)
Validation patches: 50 (from 5 files)
Test patches: 60 (from 6 files)
Number of input channels: 126


## Build Model

In [None]:
num_input_channels = train_ds.num_bands

model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=126,
    classes=2
).to(device)

print(f"Model created with {num_input_channels} input channels and moved to {device}")

Model created with 126 input channels and moved to cpu


## Loss and Optimizer

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

## Metrics Functions

In [None]:
def compute_confusion_binary(y_pred, y_true, positive_class=1):
    """
    Compute confusion matrix for binary classification.
    y_pred, y_true: (B, H, W) with 0/1 labels
    returns TP, FP, TN, FN as scalars
    """
    y_pred = (y_pred == positive_class)
    y_true = (y_true == positive_class)

    tp = (y_pred & y_true).sum().item()
    fp = (y_pred & ~y_true).sum().item()
    tn = (~y_pred & ~y_true).sum().item()
    fn = (~y_pred & y_true).sum().item()
    return tp, fp, tn, fn

def compute_metrics_from_confusion(tp, fp, tn, fn, eps=1e-8):
    """
    Compute metrics from confusion matrix values.
    Returns: dict with accuracy, precision, recall, f1, iou
    """
    accuracy  = (tp + tn) / (tp + tn + fp + fn + eps)
    precision = tp / (tp + fp + eps)
    recall    = tp / (tp + fn + eps)
    f1        = 2 * precision * recall / (precision + recall + eps)
    iou       = tp / (tp + fp + fn + eps)  # IoU for the positive class
    return {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "iou": iou,
    }

## Training Functions

In [None]:
def train_one_epoch(loader):
    model.train()
    total_loss = 0.0

    for imgs, masks in loader:
        imgs = imgs.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        
        # Use automatic mixed precision for faster training
        with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
            logits = model(imgs)
            loss = loss_fn(logits, masks)

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)

    avg_loss = total_loss / len(loader.dataset)
    return avg_loss


def validate(loader):
    model.eval()
    total_loss = 0.0
    sum_tp = sum_fp = sum_tn = sum_fn = 0

    with torch.no_grad():
        for imgs, masks in loader:
            imgs = imgs.to(device)
            masks = masks.to(device)

            with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu'):
                logits = model(imgs)
                loss = loss_fn(logits, masks)
            
            total_loss += loss.item() * imgs.size(0)

            pred = torch.argmax(logits, dim=1)
            tp, fp, tn, fn = compute_confusion_binary(pred, masks, positive_class=1)
            sum_tp += tp
            sum_fp += fp
            sum_tn += tn
            sum_fn += fn

    avg_loss = total_loss / len(loader.dataset)
    metrics = compute_metrics_from_confusion(sum_tp, sum_fp, sum_tn, sum_fn)
    
    return avg_loss, metrics

## Initialize WandB

In [None]:
wandb.init(
    project="smp_unet",
    entity="nina_prosjektoppgave",
    config={
        "model": "Unet",
        "encoder": "resnet34",
        "encoder_weights": "imagenet",
        "in_channels": train_ds.num_bands,
        "classes": 2,
        "learning_rate": 1e-3,
        "batch_size": 8,
        "patch_size": PATCH_SIZE,
        "epochs": 10,
        "train_patches_per_image": 20,
        "val_patches_per_image": 10,
        "test_patches_per_image": 10,
        "train_ref_ids": len(train_ref_ids),
        "val_ref_ids": len(val_ref_ids),
        "test_ref_ids": len(test_ref_ids),
        "augmentation": True,
        "normalization": "scale_10000_plus_standardize",
        "random_seed": RANDOM_SEED,
        "train_ratio": 0.7,
        "val_ratio": 0.15,
        "test_ratio": 0.15,
        "fair_comparison": "shared_splits_normalization_patch_size_with_FCEF",
    },
)

wandb.watch(model, log="all", log_freq=100)

[34m[1mwandb[0m: Currently logged in as: [33mceciliamoller[0m ([33mnina_prosjektoppgave[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Training Loop

In [None]:
def log_examples(images, masks, preds, step, phase="train"):
    preds_class = preds.argmax(dim=1)
    
    rgb_imgs = images[:, :3, :, :].clone()
    for i in range(3):
        min_val = rgb_imgs[:, i, :, :].min()
        max_val = rgb_imgs[:, i, :, :].max()
        if max_val > min_val:
            rgb_imgs[:, i, :, :] = (rgb_imgs[:, i, :, :] - min_val) / (max_val - min_val)
    
    wandb_images = []
    for i in range(min(4, images.size(0))):
        wandb_images.append(
            wandb.Image(
                rgb_imgs[i].cpu(),
                masks={
                    "ground_truth": {"mask_data": masks[i].cpu().numpy(), "class_labels": {0: "background", 1: "land-take"}},
                    "prediction": {"mask_data": preds_class[i].cpu().numpy(), "class_labels": {0: "background", 1: "land-take"}},
                },
            )
        )
    
    wandb.log({f"{phase}_examples": wandb_images}, step=step)

In [None]:
train_losses = []
val_losses = []
val_ious = []
val_f1s = []

for epoch in range(10):
    train_loss = train_one_epoch(train_loader)
    train_losses.append(train_loss)
    
    # Validation with full metrics
    val_loss, val_metrics = validate(val_loader)
    val_losses.append(val_loss)
    val_ious.append(val_metrics['iou'])
    val_f1s.append(val_metrics['f1'])
    
    # Log to W&B
    wandb.log({
        "epoch": epoch + 1,
        "train_loss": train_loss,
        "val_loss": val_loss,
        "val_iou": val_metrics['iou'],
        "val_f1": val_metrics['f1'],
        "val_precision": val_metrics['precision'],
        "val_recall": val_metrics['recall'],
        "val_accuracy": val_metrics['accuracy'],
    })
    
    # Log example predictions every 2 epochs
    if (epoch + 1) % 2 == 0:
        model.eval()
        with torch.no_grad():
            val_imgs, val_masks = next(iter(val_loader))
            val_imgs = val_imgs.to(device)
            val_preds = model(val_imgs)
            log_examples(val_imgs, val_masks, val_preds, step=epoch + 1, phase="val")
    
    # Print concise epoch summary
    print(
        f"Epoch {epoch+1}: "
        f"train_loss={train_loss:.4f} "
        f"val_loss={val_loss:.4f} | "
        f"IoU={val_metrics['iou']:.4f} "
        f"F1={val_metrics['f1']:.4f} "
        f"Prec={val_metrics['precision']:.4f} "
        f"Rec={val_metrics['recall']:.4f} "
        f"Acc={val_metrics['accuracy']:.4f}"
    )

# Finish WandB run
wandb.finish()

print(f"\nTraining Complete!")
print(f"Final Validation Metrics:")
print(f"  Loss: {val_losses[-1]:.4f}")
print(f"  IoU: {val_ious[-1]:.4f}")
print(f"  F1: {val_f1s[-1]:.4f}")

KeyboardInterrupt: 

Exception ignored in: 'rasterio._env.log_error'
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.12/lib/python3.12/logging/__init__.py", line 1529, in info
    def info(self, msg, *args, **kwargs):

KeyboardInterrupt: 


## Visualize Training

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

epochs = range(1, len(train_losses) + 1)

axes[0].plot(epochs, train_losses, 'b-o', label='Training Loss', linewidth=2, markersize=6)
axes[0].plot(epochs, val_losses, 'r-s', label='Validation Loss', linewidth=2, markersize=6)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Plot 2: IoU
axes[1].plot(epochs, val_ious, 'g-^', label='Validation IoU', linewidth=2, markersize=6)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('IoU', fontsize=12)
axes[1].set_title('Validation IoU', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

# Plot 3: F1 Score
axes[2].plot(epochs, val_f1s, 'm-d', label='Validation F1', linewidth=2, markersize=6)
axes[2].set_xlabel('Epoch', fontsize=12)
axes[2].set_ylabel('F1 Score', fontsize=12)
axes[2].set_title('Validation F1 Score', fontsize=14, fontweight='bold')
axes[2].legend(fontsize=11)
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nFinal Results:")
print(f"Training Loss: {train_losses[-1]:.4f}")
print(f"Validation Loss: {val_losses[-1]:.4f}")
print(f"Validation IoU: {val_ious[-1]:.4f}")
print(f"Validation F1: {val_f1s[-1]:.4f}")

## Test Set Evaluation

In [None]:
test_loss, test_metrics = validate(test_loader)

print(f"Test Set Results:")
print(f"  Loss: {test_loss:.4f}")
print(f"  IoU: {test_metrics['iou']:.4f}")
print(f"  F1: {test_metrics['f1']:.4f}")
print(f"  Precision: {test_metrics['precision']:.4f}")
print(f"  Recall: {test_metrics['recall']:.4f}")
print(f"  Accuracy: {test_metrics['accuracy']:.4f}")