In [None]:
import os, glob, random, json
from dataclasses import dataclass
from typing import List, Tuple

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# ----------------------------
# Reproducibility
# ----------------------------
def seed_everything(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

SEED = 24
seed_everything(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# ----------------------------
# Config
# ----------------------------
@dataclass
class CFG:
    DATA_ROOT: str = "/kaggle/input/sen2fire/Sen2Fire/Sen2Fire"
    SCENES: Tuple[str, ...] = ("scene1", "scene2", "scene3", "scene4")
    PATCH_EXT: str = ".npz"

    # Global split (VD default)
    USE_GLOBAL_SPLIT: bool = True
    GLOBAL_TRAIN_RATIO: float = 0.80
    GLOBAL_VAL_RATIO: float = 0.10
    GLOBAL_TEST_RATIO: float = 0.10
    KEEP_VAL_TEST_NATURAL: bool = True

    # Keys inside npz
    X_KEY: str = "image"     # (12,512,512)
    A_KEY: str = "aerosol"   # (512,512) -> becomes 1 channel
    Y_KEY: str = "label"     # (512,512)

    # Fire patch definition
    FIRE_PATCH_MIN_RATIO: float = 0

    # Controlled train pool (VD default)
    USE_CONTROLLED_POOL: bool = True
    POOL_KEEP_FIRE: int = -1
    NONFIRE_PER_FIRE: int = 3

    # Training config (RAM-safe)
    IN_CHANNELS: int = 13
    H: int = 512
    W: int = 512
    BATCH_SIZE: int = 2
    NUM_WORKERS: int = 2
    LR: float = 1e-4
    EPOCHS: int = 10

    VERBOSE: bool = True

cfg = CFG()

assert abs((cfg.GLOBAL_TRAIN_RATIO + cfg.GLOBAL_VAL_RATIO + cfg.GLOBAL_TEST_RATIO) - 1.0) < 1e-6
assert os.path.exists(cfg.DATA_ROOT), f"DATA_ROOT not found: {cfg.DATA_ROOT}"
print("DATA_ROOT OK:", cfg.DATA_ROOT)


Device: cuda
DATA_ROOT OK: /kaggle/input/sen2fire/Sen2Fire/Sen2Fire


# Fire Risk Mapping (Sen2Fire) — Notebook 2 (DeepLabV3+)

Same pipeline as Notebook 1 (UNet++):
- Global stratified split + controlled train pool
- Normalization (13-channel)
- Segmentation training
- t_high selection by validation mean Dice sweep
- Export Streamlit-ready model package

Export folder:
- /kaggle/working/model_packages/deeplabv3p_fire/


## Machine 1 — Patch Intake (NPZ → X, Y)

Inside each patch (.npz):
- `image`: (12,512,512) Sentinel-2 bands
- `aerosol`: (512,512) additional band
- `label`: (512,512) fire mask (0/1)

We construct:
- X: concat(image[12], aerosol[1]) → (13,512,512)
- Y: label → (1,512,512)

Note:
- Y is only used during training/validation to compute loss/metrics.
- During Streamlit inference, we only have X.


In [4]:
def list_scene_files(data_root: str, scene_name: str, ext: str = ".npz") -> List[str]:
    scene_dir = os.path.join(data_root, scene_name)
    return sorted(glob.glob(os.path.join(scene_dir, f"*{ext}")))

scene_to_files = {}
total = 0
for s in cfg.SCENES:
    files = list_scene_files(cfg.DATA_ROOT, s, cfg.PATCH_EXT)
    scene_to_files[s] = files
    total += len(files)
    print(f"{s}: {len(files)} patches")

print("Total patches:", total)
if total == 0:
    raise RuntimeError("No patch files found. Check DATA_ROOT / PATCH_EXT.")


scene1: 864 patches
scene2: 594 patches
scene3: 504 patches
scene4: 504 patches
Total patches: 2466


In [5]:
def inspect_npz(npz_path: str):
    with np.load(npz_path) as data:
        return {k: (data[k].shape, str(data[k].dtype)) for k in data.keys()}

sample_path = scene_to_files[cfg.SCENES[0]][0]
print("Sample file:", sample_path)
info = inspect_npz(sample_path)
for k, (shape, dtype) in info.items():
    print(f"  - {k:>10s}: shape={shape}, dtype={dtype}")

for rk in [cfg.X_KEY, cfg.A_KEY, cfg.Y_KEY]:
    if rk not in info:
        raise KeyError(f"Missing key '{rk}' in npz. Found keys: {list(info.keys())}")


Sample file: /kaggle/input/sen2fire/Sen2Fire/Sen2Fire/scene1/scene_1_patch_10_1.npz
  -      image: shape=(12, 512, 512), dtype=int16
  -    aerosol: shape=(512, 512), dtype=float32
  -      label: shape=(512, 512), dtype=uint8


In [6]:
rows = []
for s, files in scene_to_files.items():
    for p in files:
        rows.append({"scene": s, "path": p})
manifest = pd.DataFrame(rows)

def fire_ratio_from_path(npz_path: str) -> float:
    with np.load(npz_path) as data:
        y = data[cfg.Y_KEY]
        if y.ndim == 3 and y.shape[-1] == 1:
            y = y[..., 0]
        yb = (y > 0).astype(np.uint8)
        return float(yb.mean())

fire_ratios = []
has_fire = []
for p in manifest["path"].tolist():
    r = fire_ratio_from_path(p)
    fire_ratios.append(r)
    has_fire.append(1 if r > cfg.FIRE_PATCH_MIN_RATIO else 0)

manifest["fire_ratio"] = fire_ratios
manifest["has_fire"] = has_fire

print("\nFULL dataset has_fire distribution:")
print(manifest["has_fire"].value_counts().sort_index())
print("\nFULL dataset has_fire ratio:")
print(manifest["has_fire"].value_counts(normalize=True).sort_index())



FULL dataset has_fire distribution:
has_fire
0    2117
1     349
Name: count, dtype: int64

FULL dataset has_fire ratio:
has_fire
0    0.858475
1    0.141525
Name: proportion, dtype: float64


In [7]:
def stratified_split(df, train_ratio, val_ratio, seed=42):
    df = df.sample(frac=1.0, random_state=seed).reset_index(drop=True)
    parts = []
    for cls in [0, 1]:
        sub = df[df["has_fire"] == cls].copy()
        n = len(sub)
        n_train = int(round(n * train_ratio))
        n_val   = int(round(n * val_ratio))
        sub_train = sub.iloc[:n_train]
        sub_val   = sub.iloc[n_train:n_train+n_val]
        sub_test  = sub.iloc[n_train+n_val:]
        parts.append((sub_train, sub_val, sub_test))

    train_df = pd.concat([parts[0][0], parts[1][0]]).sample(frac=1.0, random_state=seed).reset_index(drop=True)
    val_df   = pd.concat([parts[0][1], parts[1][1]]).sample(frac=1.0, random_state=seed).reset_index(drop=True)
    test_df  = pd.concat([parts[0][2], parts[1][2]]).sample(frac=1.0, random_state=seed).reset_index(drop=True)
    return train_df, val_df, test_df

train_pool_df, val_df, test_df = stratified_split(
    manifest,
    train_ratio=cfg.GLOBAL_TRAIN_RATIO,
    val_ratio=cfg.GLOBAL_VAL_RATIO,
    seed=SEED
)

print("\nGLOBAL split counts:")
print("  Train pool:", len(train_pool_df))
print("  Val      :", len(val_df))
print("  Test     :", len(test_df))

# Controlled pool applies to TRAIN only
if cfg.USE_CONTROLLED_POOL:
    fire_df   = train_pool_df[train_pool_df["has_fire"] == 1].copy()
    nofire_df = train_pool_df[train_pool_df["has_fire"] == 0].copy()

    n_fire_total = len(fire_df)
    n_nofire_total = len(nofire_df)

    n_fire_keep = n_fire_total if cfg.POOL_KEEP_FIRE in (-1, None) else min(cfg.POOL_KEEP_FIRE, n_fire_total)
    fire_keep = fire_df.sample(n=n_fire_keep, random_state=SEED) if n_fire_keep > 0 else fire_df.iloc[:0]

    n_nofire_keep = min(n_nofire_total, n_fire_keep * int(cfg.NONFIRE_PER_FIRE))
    nofire_keep = nofire_df.sample(n=n_nofire_keep, random_state=SEED) if n_nofire_keep > 0 else nofire_df.iloc[:0]

    train_df = pd.concat([fire_keep, nofire_keep]).sample(frac=1.0, random_state=SEED).reset_index(drop=True)

    print("\nTRAIN controlled sampling:")
    print(f"  fire_total_in_pool   : {n_fire_total}")
    print(f"  nofire_total_in_pool : {n_nofire_total}")
    print(f"  fire_kept            : {len(fire_keep)}")
    print(f"  nofire_kept          : {len(nofire_keep)} (NONFIRE_PER_FIRE={cfg.NONFIRE_PER_FIRE})")
    print(f"  train_final          : {len(train_df)}")
else:
    train_df = train_pool_df.copy()

train_paths = train_df["path"].tolist()
val_paths   = val_df["path"].tolist()
test_paths  = test_df["path"].tolist()

# leakage check
assert len(set(train_paths)&set(val_paths))==0
assert len(set(train_paths)&set(test_paths))==0
assert len(set(val_paths)&set(test_paths))==0

print("\nFinal used split sizes:")
print("  train:", len(train_paths))
print("  val  :", len(val_paths))
print("  test :", len(test_paths))



GLOBAL split counts:
  Train pool: 1973
  Val      : 247
  Test     : 246

TRAIN controlled sampling:
  fire_total_in_pool   : 279
  nofire_total_in_pool : 1694
  fire_kept            : 279
  nofire_kept          : 837 (NONFIRE_PER_FIRE=3)
  train_final          : 1116

Final used split sizes:
  train: 1116
  val  : 247
  test : 246


## Machine 2 — Normalization (X only)

We compute mean/std per band (13 values each) from TRAIN ONLY:
X_norm[c] = (X[c] - mean[c]) / std[c]

Mini example:
- raw pixel = 820
- mean = 600
- std = 100
- norm = 2.2


In [8]:
class Sen2FireDataset(Dataset):
    def __init__(self, paths: List[str], with_label: bool = True):
        self.paths = paths
        self.with_label = with_label

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        with np.load(p) as d:
            img12 = d[cfg.X_KEY].astype(np.float32)      # (12,512,512)
            aer   = d[cfg.A_KEY].astype(np.float32)[None, ...]  # (1,512,512)
            x = np.concatenate([img12, aer], axis=0)     # (13,512,512)

            if self.with_label:
                y = d[cfg.Y_KEY]
                if y.ndim == 2:
                    y = y[None, ...]
                y = (y > 0).astype(np.float32)
                return torch.from_numpy(x), torch.from_numpy(y)
            else:
                return torch.from_numpy(x)


In [9]:
train_loader = DataLoader(Sen2FireDataset(train_paths, with_label=True),
                          batch_size=cfg.BATCH_SIZE, shuffle=True,
                          num_workers=cfg.NUM_WORKERS, pin_memory=True)

val_loader = DataLoader(Sen2FireDataset(val_paths, with_label=True),
                        batch_size=cfg.BATCH_SIZE, shuffle=False,
                        num_workers=cfg.NUM_WORKERS, pin_memory=True)

test_loader = DataLoader(Sen2FireDataset(test_paths, with_label=True),
                         batch_size=cfg.BATCH_SIZE, shuffle=False,
                         num_workers=cfg.NUM_WORKERS, pin_memory=True)


In [10]:
@torch.no_grad()
def compute_mean_std(loader, max_batches=50):
    mean = torch.zeros(cfg.IN_CHANNELS, device=DEVICE)
    var  = torch.zeros(cfg.IN_CHANNELS, device=DEVICE)
    n_batches = 0

    for bi, (x, y) in enumerate(loader):
        if bi >= max_batches:
            break
        x = x.to(DEVICE)  # (B,C,H,W)
        x_ = x.view(x.size(0), cfg.IN_CHANNELS, -1)
        mean += x_.mean(dim=(0,2))
        var  += x_.var(dim=(0,2), unbiased=False)
        n_batches += 1

    mean /= max(n_batches, 1)
    var  /= max(n_batches, 1)
    std = torch.sqrt(var + 1e-6)
    return mean.detach().cpu().tolist(), std.detach().cpu().tolist()

MEAN_13, STD_13 = compute_mean_std(train_loader, max_batches=50)
print("MEAN_13 len:", len(MEAN_13), "STD_13 len:", len(STD_13))


MEAN_13 len: 13 STD_13 len: 13


In [11]:
!pip -q install segmentation-models-pytorch


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [12]:
criterion = nn.BCEWithLogitsLoss()

def normalize_batch(x, mean_13, std_13):
    mean_t = torch.tensor(mean_13, device=x.device).view(1, cfg.IN_CHANNELS, 1, 1)
    std_t  = torch.tensor(std_13,  device=x.device).view(1, cfg.IN_CHANNELS, 1, 1)
    return (x - mean_t) / (std_t + 1e-6)

# TES MODEL: TRANSFORMER

In [16]:
from transformers import SegformerForSemanticSegmentation

def build_segformer(in_channels=13, num_classes=1, backbone="nvidia/segformer-b1-finetuned-ade-512-512"): #MiT-B2
    model = SegformerForSemanticSegmentation.from_pretrained(
        backbone,
        num_labels=num_classes,
        ignore_mismatched_sizes=True
    )

    # ganti input channel (default 3 → 13)
    model.segformer.encoder.patch_embeddings[0].proj = \
        torch.nn.Conv2d(
            in_channels,
            model.segformer.encoder.patch_embeddings[0].proj.out_channels,
            kernel_size=model.segformer.encoder.patch_embeddings[0].proj.kernel_size,
            stride=model.segformer.encoder.patch_embeddings[0].proj.stride,
            padding=model.segformer.encoder.patch_embeddings[0].proj.padding,
            bias=False
        )

    return model
    
model = build_segformer(cfg.IN_CHANNELS).to(DEVICE)
print(type(model).__name__)

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/54.9M [00:00<?, ?B/s]

Some weights of SegformerForSemanticSegmentation were not initialized from the model checkpoint at nvidia/segformer-b1-finetuned-ade-512-512 and are newly initialized because the shapes did not match:
- decode_head.classifier.weight: found shape torch.Size([150, 256, 1, 1]) in the checkpoint and torch.Size([1, 256, 1, 1]) in the model instantiated
- decode_head.classifier.bias: found shape torch.Size([150]) in the checkpoint and torch.Size([1]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


SegformerForSemanticSegmentation


model.safetensors:   0%|          | 0.00/54.9M [00:00<?, ?B/s]

In [26]:
from torch.amp import autocast, GradScaler
from sklearn.metrics import average_precision_score
import torch.nn.functional as F
import time

optimizer = torch.optim.Adam(model.parameters(), lr=cfg.LR)
scaler = GradScaler(enabled=(DEVICE.type == "cuda"))

def train_one_epoch(model, loader, thr=0.5, eps=1e-6):
    model.train()
    total_loss_sum = 0.0
    total_samples = 0

    total_tp = total_fp = total_fn = total_tn = 0.0
    all_probs, all_targets = [], []

    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        x = normalize_batch(x, MEAN_13, STD_13)

        optimizer.zero_grad(set_to_none=True)

        with autocast("cuda", enabled=(DEVICE.type == "cuda")):
            out = model(x)
            logits = out.logits
            logits = F.interpolate(
                logits,
                size=y.shape[-2:],
                mode="bilinear",
                align_corners=False
            )
            loss = criterion(logits, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        bs = x.size(0)
        total_loss_sum += loss.item() * bs
        total_samples += bs

        probs = torch.sigmoid(logits)
        preds = (probs >= thr).float()

        total_tp += (preds * y).sum().item()
        total_fp += (preds * (1 - y)).sum().item()
        total_fn += ((1 - preds) * y).sum().item()
        total_tn += ((1 - preds) * (1 - y)).sum().item()

        all_probs.append(probs.flatten().detach().cpu())
        all_targets.append(y.flatten().detach().cpu())

    precision = (total_tp + eps) / (total_tp + total_fp + eps)
    recall    = (total_tp + eps) / (total_tp + total_fn + eps)
    f1        = (2 * precision * recall) / (precision + recall + eps)
    acc       = (total_tp + total_tn + eps) / (
        total_tp + total_tn + total_fp + total_fn + eps
    )

    all_probs = torch.cat(all_probs).numpy()
    all_targets = torch.cat(all_targets).numpy()
    auc = average_precision_score(all_targets, all_probs)

    loss_global = total_loss_sum / total_samples

    return loss_global, precision, recall, f1, acc, auc


@torch.no_grad()
def validate_epoch(model, loader, thr=0.5, eps=1e-6):
    model.eval()
    total_loss_sum = 0.0
    total_samples = 0

    total_tp = total_fp = total_fn = total_tn = 0.0
    all_probs, all_targets = [], []

    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        x = normalize_batch(x, MEAN_13, STD_13)

        out = model(x)
        logits = out.logits
        logits = F.interpolate(
                logits,
                size=y.shape[-2:],
                mode="bilinear",
                align_corners=False
            )
        loss = criterion(logits, y)

        bs = x.size(0)
        total_loss_sum += loss.item() * bs
        total_samples += bs

        probs = torch.sigmoid(logits)
        preds = (probs >= thr).float()

        total_tp += (preds * y).sum().item()
        total_fp += (preds * (1 - y)).sum().item()
        total_fn += ((1 - preds) * y).sum().item()
        total_tn += ((1 - preds) * (1 - y)).sum().item()

        all_probs.append(probs.flatten().cpu())
        all_targets.append(y.flatten().cpu())

    precision = (total_tp + eps) / (total_tp + total_fp + eps)
    recall    = (total_tp + eps) / (total_tp + total_fn + eps)
    f1        = (2 * precision * recall) / (precision + recall + eps)
    acc       = (total_tp + total_tn + eps) / (
        total_tp + total_tn + total_fp + total_fn + eps
    )

    all_probs = torch.cat(all_probs).numpy()
    all_targets = torch.cat(all_targets).numpy()
    auc = average_precision_score(all_targets, all_probs)

    loss_global = total_loss_sum / total_samples

    return loss_global, precision, recall, f1, acc, auc

In [27]:
BEST_PATH = "/kaggle/working/seg-b1.pth"
best_val_f1 = -1

import pandas as pd
history = []

for epoch in range(1, cfg.EPOCHS + 1):
    t0 = time.time()

    tr_loss, tr_prec, tr_rec, tr_f1, tr_acc, tr_auc = train_one_epoch(model, train_loader, thr=0.5)
    va_loss, va_prec, va_rec, va_f1, va_acc, va_auc = validate_epoch(model, val_loader, thr=0.5)
    elapsed = time.time() - t0

    print(
        f"[unet_basic] Epoch {epoch:03d} | "
        f"tr_loss={tr_loss:.4f} tr_f1={tr_f1:.4f} tr_acc={tr_acc:.4f} tr_PRauc={tr_auc:.4f} || "
        f"va_loss={va_loss:.4f} va_f1={va_f1:.4f} va_acc={va_acc:.4f} va_PRauc={va_auc:.4f} | "
        f"{elapsed:.1f}s"
    )

    history.append({
        "epoch": epoch,
    
        "tr_loss": tr_loss,
        "tr_precision": tr_prec,
        "tr_recall": tr_rec,
        "tr_f1": tr_f1,
        "tr_acc": tr_acc,
        "tr_PRauc": tr_auc,
    
        "va_loss": va_loss,
        "va_precision": va_prec,
        "va_recall": va_rec,
        "va_f1": va_f1,
        "va_acc": va_acc,
        "va_PRauc": va_auc,
    
        "seconds": elapsed
    })

    if va_f1 > best_val_f1:
        torch.save({
            "model_state": model.state_dict(),
            "mean_13": MEAN_13,
            "std_13": STD_13,
            "epoch": epoch,
            "best_val_f1": best_val_f1,
            "va_auc": va_auc,
        }, BEST_PATH)


[unet_basic] Epoch 001 | tr_loss=0.2926 tr_f1=0.0010 tr_acc=0.9382 tr_PRauc=0.0551 || va_loss=0.2822 va_f1=0.0000 va_acc=0.9522 va_PRauc=0.1213 | 187.0s
[unet_basic] Epoch 002 | tr_loss=0.2200 tr_f1=0.0000 tr_acc=0.9411 tr_PRauc=0.0902 || va_loss=0.1836 va_f1=0.0000 va_acc=0.9522 va_PRauc=0.1023 | 179.6s
[unet_basic] Epoch 003 | tr_loss=0.2242 tr_f1=0.0000 tr_acc=0.9411 tr_PRauc=0.0789 || va_loss=0.2130 va_f1=0.0000 va_acc=0.9522 va_PRauc=0.1231 | 180.2s
[unet_basic] Epoch 004 | tr_loss=0.2260 tr_f1=0.0000 tr_acc=0.9411 tr_PRauc=0.0791 || va_loss=0.2155 va_f1=0.0000 va_acc=0.9522 va_PRauc=0.0573 | 181.5s
[unet_basic] Epoch 005 | tr_loss=0.2191 tr_f1=0.0000 tr_acc=0.9411 tr_PRauc=0.0905 || va_loss=0.1837 va_f1=0.0000 va_acc=0.9522 va_PRauc=0.0800 | 179.1s
[unet_basic] Epoch 006 | tr_loss=0.2090 tr_f1=0.0000 tr_acc=0.9411 tr_PRauc=0.1330 || va_loss=0.1727 va_f1=0.0000 va_acc=0.9522 va_PRauc=0.0906 | 177.0s


KeyboardInterrupt: 

- Bayangin: gambar mostly background, api cuma 1-2 pixel. Model prediksi 0 api → dice pixel-wise masih lumayan karena background cocok (misal 90% pixel benar). Tapi F1 object-wise → 0 karena dia miss all actual fire objects.

- Liat epoch 14–19: F1 mulai naik (0.38, 0.35, 0.26). Itu karena model mulai ngerasa beberapa fire patch. Dice kadang turun dikit (0.64) karena beberapa prediksi terlalu kecil/lebih besar → overlap ga maksimal, tapi F1 object-wise meningkat karena ada setidaknya beberapa detection bener.

- Misal ground truth fire cuma 10 pixel di gambar.
Model prediksi fire 10 pixel, tepat di lokasi yang sama → overlap 100% → dice tinggi, F1 juga tinggi.
Model prediksi fire 10 pixel tapi di tempat salah → overlap 0% → dice rendah, F1 rendah.
Model prediksi 20 pixel, 10 di tempat benar, 10 di tempat salah → overlap 50% → dice turun, F1 juga turun karena false positives.

In [None]:
df_history = pd.DataFrame(history)

df_history = df_history.round(4)
csv_path = "/kaggle/working/training_metrics_trans.csv"
df_history.to_csv(csv_path, index=False)
print("Saved to:", csv_path)
df_history

In [None]:

test_loss, test_prec, test_rec, test_f1, test_acc, test_auc = validate_epoch(
    model, test_loader
)

print(f"Test Loss   : {test_loss:.4f}")
print(f"Test Prec   : {test_prec:.4f}")
print(f"Test Recall : {test_rec:.4f}")
print(f"Test F1     : {test_f1:.4f}")
print(f"Test Acc    : {test_acc:.4f}")

# pred

In [None]:
class Sen2FireDataset(Dataset):
    def __init__(self, paths, with_label=True):
        self.paths = paths
        self.with_label = with_label

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        patch_name = os.path.basename(p)

        with np.load(p) as d:
            img12 = d[cfg.X_KEY].astype(np.float32)
            aer   = d[cfg.A_KEY].astype(np.float32)[None, ...]
            x = np.concatenate([img12, aer], axis=0)

            if self.with_label:
                y = d[cfg.Y_KEY]
                if y.ndim == 2:
                    y = y[None, ...]
                y = (y > 0).astype(np.float32)

                return torch.from_numpy(x), torch.from_numpy(y), patch_name
            else:
                return torch.from_numpy(x), patch_name


In [None]:
test_loader = DataLoader(Sen2FireDataset(test_paths, with_label=True),
                         batch_size=cfg.BATCH_SIZE, shuffle=False,
                         num_workers=cfg.NUM_WORKERS, pin_memory=True)


In [None]:
import matplotlib.pyplot as plt

def show_selected_patches(loader, selected_paths):
    # kumpulin semua sample dulu
    all_samples = {}
    for x, y, patch_name in loader:
        for i, p in enumerate(patch_name):
            all_samples[p] = {
                "img": x[i].cpu(),
                "mask_gt": y[i,0].cpu(),
            }

    fig, axes = plt.subplots(len(selected_paths), 3, figsize=(15, 5*len(selected_paths)))

    if len(selected_paths) == 1:
        axes = axes[None, :]  # konsisten

    for i, path in enumerate(selected_paths):
        sample = all_samples[path]

        img = sample["img"].permute(1,2,0).numpy()[..., :3]
        img_vis = (img - img.min()) / (img.max() - img.min())

        mask_gt = sample["mask_gt"].numpy()
        # pred bisa pake model langsung
        x_tensor = sample["img"].unsqueeze(0).to(DEVICE)
        x_tensor = normalize_batch(x_tensor, MEAN_13, STD_13)
        with torch.no_grad():
            mask_pred = (torch.sigmoid(model(x_tensor)) >= t_best)[0,0].cpu().numpy()

        # Original
        axes[i,0].imshow(img_vis)
        axes[i,0].set_title(f"{path}\nOriginal")
        axes[i,0].axis('off')

        # GT
        axes[i,1].imshow(mask_gt, cmap='Reds')
        axes[i,1].set_title("Ground Truth")
        axes[i,1].axis('off')

        # Pred overlay
        axes[i,2].imshow(img_vis)
        axes[i,2].imshow(mask_pred, cmap='Reds', alpha=0.4)
        axes[i,2].set_title("Prediction")
        axes[i,2].axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
selected_paths = [
    "scene_3_patch_4_11.npz",
    "scene_3_patch_5_35.npz",
    "scene_4_patch_3_19.npz",
    "scene_3_patch_5_12.npz",
    "scene_4_patch_3_21.npz",
    "scene_4_patch_18_13.npz",
    "scene_4_patch_13_12.npz",
]

show_selected_patches(test_loader, selected_paths)


In [None]:
selected_paths = [
    "scene_4_patch_4_14.npz",
    "scene_1_patch_20_9.npz",
    "scene_3_patch_3_28.npz",
    "scene_4_patch_2_14.npz",
    "scene_4_patch_10_12.npz",
]

show_selected_patches(test_loader, selected_paths)


In [None]:
paths = [
    "scene_4_patch_31_17.npz",
    "scene_1_patch_22_8.npz",
    "scene_1_patch_21_10.npz",
    "scene_3_patch_7_4.npz",
    "scene_4_patch_13_23.npz"
]

show_selected_patches(test_loader, paths)
