In [None]:
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple, Dict

import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split

import os


def set_seed(seed: int = 42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


@dataclass
class TrainConfig:
    data_dir: Path = Path("data/preprocessed")
    file_ext: str = ".npz"  # or ".pt"
    region_filter: str = "RegionA"  # files containing this substring
    input_channels: int = 6  # 5 or 6 (if NDSI used)
    num_classes: int = 2
    epochs: int = 50
    batch_size: int = 4
    lr: float = 1e-3
    weight_decay: float = 1e-4
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    val_fraction: float = 0.2
    num_workers: int = 2
    save_dir: Path = Path("checkpoints")


cfg = TrainConfig()
cfg.save_dir.mkdir(parents=True, exist_ok=True)
set_seed(42)
print(cfg)


In [None]:
class GlacierTensorDataset(Dataset):
    def __init__(self, directory: Path, file_ext: str, region_filter: str | None = None):
        self.paths = sorted([p for p in directory.glob(f"*{file_ext}") if (region_filter is None or region_filter in p.name)])
        if len(self.paths) == 0:
            raise FileNotFoundError(f"No files found in {directory} with ext {file_ext} and filter {region_filter}")

    def __len__(self) -> int:
        return len(self.paths)

    def __getitem__(self, idx: int):
        path = self.paths[idx]
        if path.suffix == ".npz":
            data = np.load(path)
            image = data["image"].astype(np.float32)
            mask = data["mask"].astype(np.uint8)
        elif path.suffix == ".pt":
            data = torch.load(path)
            image = data["image"].numpy().astype(np.float32)
            mask = data["mask"].numpy().astype(np.uint8)
        else:
            raise ValueError("Unsupported extension: " + path.suffix)
        # Convert to torch tensors
        x = torch.from_numpy(image)
        y = torch.from_numpy(mask).long()
        return x, y


In [None]:
# Split into train/val (Region A subset)
full_ds = GlacierTensorDataset(cfg.data_dir, cfg.file_ext, cfg.region_filter)
if len(full_ds) == 0:
    print(f"Warning: No files found in {cfg.data_dir} with extension {cfg.file_ext} and filter {cfg.region_filter}")
else:
    sample_x, _ = full_ds[0]
    assert sample_x.shape[0] == cfg.input_channels, \
        f"Mismatch: data has {sample_x.shape[0]} channels but cfg.input_channels={cfg.input_channels}. Update config."

val_len = int(len(full_ds) * cfg.val_fraction)
train_len = len(full_ds) - val_len
train_ds, val_ds = random_split(full_ds, [train_len, val_len])

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers)
val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)

print(f"Train: {len(train_ds)} | Val: {len(val_ds)}")


In [None]:
# Minimal U-Net adapted to variable input channels
import torch.nn.functional as F

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, dropout=0.2):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
        self.dropout = nn.Dropout2d(p=dropout)
    def forward(self, x):
        x = F.relu(self.dropout(self.conv1(x)))
        x = F.relu(self.conv2(x))
        return x

class UpBlock(nn.Module):
    def __init__(self, in_ch, out_ch, dropout=0.2):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
        self.conv = ConvBlock(in_ch, out_ch, dropout)
    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([skip, x], dim=1)
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels: int, num_classes: int = 2, base_ch: int = 32, depth: int = 4, dropout: float = 0.2):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        ch_in = in_channels
        ch_out = base_ch
        for _ in range(depth):
            self.downs.append(ConvBlock(ch_in, ch_out, dropout))
            ch_in, ch_out = ch_out, ch_out * 2
        self.pool = nn.MaxPool2d(2,2)
        self.mid = ConvBlock(ch_in, ch_out, dropout)
        ch_in, ch_out = ch_out, ch_out // 2
        for _ in range(depth):
            self.ups.append(UpBlock(ch_in, ch_out, dropout))
            ch_in, ch_out = ch_out, ch_out // 2
        self.seg = nn.Conv2d(ch_in * 2, num_classes, kernel_size=1)
    def forward(self, x):
        skips = []
        for block in self.downs:
            x = block(x)
            skips.append(x)
            x = self.pool(x)
        x = self.mid(x)
        for block in self.ups:
            x = block(x, skips.pop())
        x = self.seg(x)
        return x


In [None]:
# Losses and metrics

def dice_coefficient(prob: torch.Tensor, target: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    # prob: (N,1,H,W) or (N,2,H,W) logits later sigmoid/softmax
    if prob.shape[1] == 2:
        prob_bin = torch.softmax(prob, dim=1)[:, 1]
    else:
        prob_bin = torch.sigmoid(prob[:, 0])
    target = target.float()
    intersection = (prob_bin * target).sum(dim=(1,2))
    union = prob_bin.sum(dim=(1,2)) + target.sum(dim=(1,2))
    dice = (2 * intersection + eps) / (union + eps)
    return dice.mean()


def iou_score(prob: torch.Tensor, target: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    if prob.shape[1] == 2:
        prob_bin = (torch.softmax(prob, dim=1)[:, 1] > 0.5).float()
    else:
        prob_bin = (torch.sigmoid(prob[:, 0]) > 0.5).float()
    target = target.float()
    intersection = (prob_bin * target).sum(dim=(1,2))
    union = (prob_bin + target).clamp(0,1).sum(dim=(1,2))
    return ((intersection + eps) / (union + eps)).mean()


def matthews_corrcoef(prob: torch.Tensor, target: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    if prob.shape[1] == 2:
        pred = (torch.softmax(prob, dim=1)[:, 1] > 0.5).float()
    else:
        pred = (torch.sigmoid(prob[:, 0]) > 0.5).float()
    target = target.float()
    tp = (pred * target).sum(dim=(1,2))
    tn = ((1 - pred) * (1 - target)).sum(dim=(1,2))
    fp = (pred * (1 - target)).sum(dim=(1,2))
    fn = ((1 - pred) * target).sum(dim=(1,2))
    numerator = (tp * tn - fp * fn)
    denominator = torch.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) + eps)
    return (numerator / (denominator + eps)).mean()


class DiceBCELoss(nn.Module):
    def __init__(self, weight_dice: float = 0.5, weight_bce: float = 0.5):
        super().__init__()
        self.weight_dice = weight_dice
        self.weight_bce = weight_bce
        self.bce = nn.BCEWithLogitsLoss()
        self.ce = nn.CrossEntropyLoss()
    def forward(self, logits: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        if logits.shape[1] == 2:
            # use CE for 2-class logits
            ce_loss = self.ce(logits, target)
            # compute dice on foreground probs
            probs = torch.softmax(logits, dim=1)[:, 1:2]
            target_f = target.float().unsqueeze(1)
            intersection = (probs * target_f).sum(dim=(1,2,3))
            union = probs.sum(dim=(1,2,3)) + target_f.sum(dim=(1,2,3))
            dice = (2 * intersection + 1e-6) / (union + 1e-6)
            dice_loss = 1 - dice.mean()
            return self.weight_dice * dice_loss + self.weight_bce * ce_loss
        else:
            # single-channel logits
            bce_loss = self.bce(logits[:, 0], target.float())
            probs = torch.sigmoid(logits[:, 0:1])
            target_f = target.float().unsqueeze(1)
            intersection = (probs * target_f).sum(dim=(1,2,3))
            union = probs.sum(dim=(1,2,3)) + target_f.sum(dim=(1,2,3))
            dice = (2 * intersection + 1e-6) / (union + 1e-6)
            dice_loss = 1 - dice.mean()
            return self.weight_dice * dice_loss + self.weight_bce * bce_loss


In [None]:
# Training and validation loops

def run_epoch(model: nn.Module, loader: DataLoader, optimizer: torch.optim.Optimizer | None, device: str, criterion: nn.Module) -> Dict[str, float]:
    train = optimizer is not None
    model.train(train)
    total_loss = 0.0
    total_dice = 0.0
    total_iou = 0.0
    total_mcc = 0.0
    count = 0
    for x, y in loader:
        x = x.to(device)
        y = y.to(device)
        logits = model(x)
        loss = criterion(logits, y)
        if train:
            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
        with torch.no_grad():
            total_loss += float(loss.detach().cpu().item())
            total_dice += float(dice_coefficient(logits.detach(), y).cpu().item())
            total_iou += float(iou_score(logits.detach(), y).cpu().item())
            total_mcc += float(matthews_corrcoef(logits.detach(), y).cpu().item())
            count += 1
    return {
        "loss": total_loss / max(1, count),
        "dice": total_dice / max(1, count),
        "iou": total_iou / max(1, count),
        "mcc": total_mcc / max(1, count),
    }


model = UNet(in_channels=cfg.input_channels, num_classes=2).to(cfg.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
criterion = DiceBCELoss()

best_mcc = -1.0
best_path = cfg.save_dir / "best_mcc.pt"

for epoch in range(1, cfg.epochs + 1):
    train_metrics = run_epoch(model, train_loader, optimizer, cfg.device, criterion)
    val_metrics = run_epoch(model, val_loader, None, cfg.device, criterion)
    print(f"Epoch {epoch:03d} | train loss {train_metrics['loss']:.4f} dice {train_metrics['dice']:.4f} iou {train_metrics['iou']:.4f} mcc {train_metrics['mcc']:.4f} | val loss {val_metrics['loss']:.4f} dice {val_metrics['dice']:.4f} iou {val_metrics['iou']:.4f} mcc {val_metrics['mcc']:.4f}")
    if val_metrics["mcc"] > best_mcc:
        best_mcc = val_metrics["mcc"]
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "val_metrics": val_metrics,
            "config": vars(cfg)
        }, best_path)
        print(f"Saved new best checkpoint (MCC={best_mcc:.4f}) -> {best_path}")
