In [None]:
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import List, Tuple, Dict

import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

import os

# Reuse metrics from 02_train; redefining here for standalone

def dice_coefficient(prob: torch.Tensor, target: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    if prob.shape[1] == 2:
        prob_bin = torch.softmax(prob, dim=1)[:, 1]
    else:
        prob_bin = torch.sigmoid(prob[:, 0])
    target = target.float()
    intersection = (prob_bin * target).sum(dim=(1,2))
    union = prob_bin.sum(dim=(1,2)) + target.sum(dim=(1,2))
    dice = (2 * intersection + eps) / (union + eps)
    return dice.mean()


def iou_score(prob: torch.Tensor, target: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    if prob.shape[1] == 2:
        prob_bin = (torch.softmax(prob, dim=1)[:, 1] > 0.5).float()
    else:
        prob_bin = (torch.sigmoid(prob[:, 0]) > 0.5).float()
    target = target.float()
    intersection = (prob_bin * target).sum(dim=(1,2))
    union = (prob_bin + target).clamp(0,1).sum(dim=(1,2))
    return ((intersection + eps) / (union + eps)).mean()


def matthews_corrcoef(prob: torch.Tensor, target: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    if prob.shape[1] == 2:
        pred = (torch.softmax(prob, dim=1)[:, 1] > 0.5).float()
    else:
        pred = (torch.sigmoid(prob[:, 0]) > 0.5).float()
    target = target.float()
    tp = (pred * target).sum(dim=(1,2))
    tn = ((1 - pred) * (1 - target)).sum(dim=(1,2))
    fp = (pred * (1 - target)).sum(dim=(1,2))
    fn = ((1 - pred) * target).sum(dim=(1,2))
    numerator = (tp * tn - fp * fn)
    denominator = torch.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn) + eps)
    return (numerator / (denominator + eps)).mean()


@dataclass
class InferConfig:
    data_dir: Path = Path("data/preprocessed")
    region_filter: str = "RegionB"  # unseen region
    file_ext: str = ".npz"  # or ".pt"
    input_channels: int = 6  # must match training
    num_classes: int = 2
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    checkpoint_path: Path = Path("checkpoints/best_mcc.pt")
    results_dir: Path = Path("results")
    threshold: float = 0.5


cfg = InferConfig()
cfg.results_dir.mkdir(parents=True, exist_ok=True)
print(cfg)


In [None]:
class RegionDataset(Dataset):
    def __init__(self, directory: Path, file_ext: str, region_filter: str):
        self.paths = sorted([p for p in directory.glob(f"*{file_ext}") if region_filter in p.name])
        if len(self.paths) == 0:
            raise FileNotFoundError(f"No files for region {region_filter} in {directory}")
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx: int):
        path = self.paths[idx]
        if path.suffix == ".npz":
            data = np.load(path)
            image = data["image"].astype(np.float32)
            mask = data["mask"].astype(np.uint8)
        else:
            data = torch.load(path)
            image = data["image"].numpy().astype(np.float32)
            mask = data["mask"].numpy().astype(np.uint8)
        x = torch.from_numpy(image)
        y = torch.from_numpy(mask).long()
        return x, y, path.name


In [None]:
# Define UNet to match training
import torch.nn.functional as F

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, dropout=0.2):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)
        self.dropout = nn.Dropout2d(p=dropout)
    def forward(self, x):
        x = F.relu(self.dropout(self.conv1(x)))
        x = F.relu(self.conv2(x))
        return x

class UpBlock(nn.Module):
    def __init__(self, in_ch, out_ch, dropout=0.2):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
        self.conv = ConvBlock(in_ch, out_ch, dropout)
    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([skip, x], dim=1)
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, in_channels: int, num_classes: int = 2, base_ch: int = 32, depth: int = 4, dropout: float = 0.2):
        super().__init__()
        self.downs = nn.ModuleList()
        self.ups = nn.ModuleList()
        ch_in = in_channels
        ch_out = base_ch
        for _ in range(depth):
            self.downs.append(ConvBlock(ch_in, ch_out, dropout))
            ch_in, ch_out = ch_out, ch_out * 2
        self.pool = nn.MaxPool2d(2,2)
        self.mid = ConvBlock(ch_in, ch_out, dropout)
        ch_in, ch_out = ch_out, ch_out // 2
        for _ in range(depth):
            self.ups.append(UpBlock(ch_in, ch_out, dropout))
            ch_in, ch_out = ch_out, ch_out // 2
        self.seg = nn.Conv2d(ch_in * 2, num_classes, kernel_size=1)
    def forward(self, x):
        skips = []
        for block in self.downs:
            x = block(x)
            skips.append(x)
            x = self.pool(x)
        x = self.mid(x)
        for block in self.ups:
            x = block(x, skips.pop())
        x = self.seg(x)
        return x


In [None]:
# Load checkpoint and dataset
cfg = cfg  # keep reference

region_ds = RegionDataset(cfg.data_dir, cfg.file_ext, cfg.region_filter)
if len(region_ds) == 0:
    print(f"Warning: No files found in {cfg.data_dir} with extension {cfg.file_ext} and filter {cfg.region_filter}")
else:
    sample_x, _, _ = region_ds[0]
    assert sample_x.shape[0] == cfg.input_channels, \
        f"Mismatch: data has {sample_x.shape[0]} channels but cfg.input_channels={cfg.input_channels}. Update config."

model = UNet(in_channels=cfg.input_channels, num_classes=2)
state = torch.load(cfg.checkpoint_path, map_location=cfg.device)
model.load_state_dict(state["model_state"]) if isinstance(state, dict) and "model_state" in state else model.load_state_dict(state)
model = model.to(cfg.device).eval()

loader = DataLoader(region_ds, batch_size=1, shuffle=False)
print(f"Loaded {len(region_ds)} tiles for {cfg.region_filter}")


In [None]:
# Tile-wise inference and optional stitching

def infer_tile(x: torch.Tensor, model: nn.Module, device: str) -> torch.Tensor:
    with torch.no_grad():
        logits = model(x.to(device))
    return logits.cpu()

all_metrics = {"dice": [], "iou": [], "mcc": []}
threshold = cfg.threshold

for x, y, name in loader:
    logits = infer_tile(x, model, cfg.device)
    # compute metrics per tile
    dice = dice_coefficient(logits, y)
    iou = iou_score(logits, y)
    mcc = matthews_corrcoef(logits, y)
    all_metrics["dice"].append(float(dice))
    all_metrics["iou"].append(float(iou))
    all_metrics["mcc"].append(float(mcc))

print({k: float(np.mean(v)) for k, v in all_metrics.items()})


In [None]:
# Visualization: RGB overlay with predicted mask

def overlay_rgb_mask(x: torch.Tensor, logits: torch.Tensor, threshold: float = 0.5) -> np.ndarray:
    # x: (1,C,H,W), logits: (1,2,H,W) or (1,1,H,W)
    x_np = x.squeeze(0).cpu().numpy()
    # channels: [B,G,R,SWIR,THERM(,NDSI)]
    blue = x_np[0]; green = x_np[1]; red = x_np[2]
    # denormalized visualization: scale to 0..1 per band for display
    def scale01(a):
        a = a - a.min()
        d = a.max() - a.min() + 1e-6
        return (a / d).clip(0,1)
    rgb = np.stack([scale01(red), scale01(green), scale01(blue)], axis=-1)  # H,W,3 in 0..1
    if logits.shape[1] == 2:
        prob = torch.softmax(logits, dim=1)[:, 1].squeeze(0).cpu().numpy()
    else:
        prob = torch.sigmoid(logits[:, 0]).squeeze(0).cpu().numpy()
    mask = (prob > threshold).astype(np.float32)
    overlay = rgb.copy()
    overlay[..., 0] = np.maximum(overlay[..., 0], mask)  # enhance red channel
    return (overlay * 255).astype(np.uint8)

for x, y, name in loader:
    logits = infer_tile(x, model, cfg.device)
    img = overlay_rgb_mask(x, logits, threshold)
    out_path = cfg.results_dir / f"overlay_{name}.png"
    plt.imsave(out_path.as_posix(), img)

print(f"Saved overlays to {cfg.results_dir}")
