In [3]:
import os, glob, random, json
from dataclasses import dataclass
from typing import List, Tuple

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# ----------------------------
# Reproducibility
# ----------------------------
def seed_everything(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

SEED = 24
seed_everything(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# ----------------------------
# Config
# ----------------------------
@dataclass
class CFG:
    DATA_ROOT: str = "/kaggle/input/sen2fire/Sen2Fire/Sen2Fire"
    SCENES: Tuple[str, ...] = ("scene1", "scene2", "scene3", "scene4")
    PATCH_EXT: str = ".npz"

    # Global split (VD default)
    USE_GLOBAL_SPLIT: bool = True
    GLOBAL_TRAIN_RATIO: float = 0.80
    GLOBAL_VAL_RATIO: float = 0.10
    GLOBAL_TEST_RATIO: float = 0.10
    KEEP_VAL_TEST_NATURAL: bool = True

    # Keys inside npz
    X_KEY: str = "image"     # (12,512,512)
    A_KEY: str = "aerosol"   # (512,512) -> becomes 1 channel
    Y_KEY: str = "label"     # (512,512)

    # Fire patch definition
    FIRE_PATCH_MIN_RATIO: float = 0

    # Controlled train pool (VD default)
    USE_CONTROLLED_POOL: bool = True
    POOL_KEEP_FIRE: int = -1
    NONFIRE_PER_FIRE: int = 3

    # Training config (RAM-safe)
    IN_CHANNELS: int = 13
    H: int = 512
    W: int = 512
    BATCH_SIZE: int = 2
    NUM_WORKERS: int = 2
    LR: float = 1e-4
    EPOCHS: int = 20

    VERBOSE: bool = True

cfg = CFG()

assert abs((cfg.GLOBAL_TRAIN_RATIO + cfg.GLOBAL_VAL_RATIO + cfg.GLOBAL_TEST_RATIO) - 1.0) < 1e-6
assert os.path.exists(cfg.DATA_ROOT), f"DATA_ROOT not found: {cfg.DATA_ROOT}"
print("DATA_ROOT OK:", cfg.DATA_ROOT)


Device: cuda
DATA_ROOT OK: /kaggle/input/sen2fire/Sen2Fire/Sen2Fire


# Fire Risk Mapping (Sen2Fire) — Notebook 2 (DeepLabV3+)

Same pipeline as Notebook 1 (UNet++):
- Global stratified split + controlled train pool
- Normalization (13-channel)
- Segmentation training
- t_high selection by validation mean Dice sweep
- Export Streamlit-ready model package

Export folder:
- /kaggle/working/model_packages/deeplabv3p_fire/


## Machine 1 — Patch Intake (NPZ → X, Y)

Inside each patch (.npz):
- `image`: (12,512,512) Sentinel-2 bands
- `aerosol`: (512,512) additional band
- `label`: (512,512) fire mask (0/1)

We construct:
- X: concat(image[12], aerosol[1]) → (13,512,512)
- Y: label → (1,512,512)

Note:
- Y is only used during training/validation to compute loss/metrics.
- During Streamlit inference, we only have X.


In [4]:
def list_scene_files(data_root: str, scene_name: str, ext: str = ".npz") -> List[str]:
    scene_dir = os.path.join(data_root, scene_name)
    return sorted(glob.glob(os.path.join(scene_dir, f"*{ext}")))

scene_to_files = {}
total = 0
for s in cfg.SCENES:
    files = list_scene_files(cfg.DATA_ROOT, s, cfg.PATCH_EXT)
    scene_to_files[s] = files
    total += len(files)
    print(f"{s}: {len(files)} patches")

print("Total patches:", total)
if total == 0:
    raise RuntimeError("No patch files found. Check DATA_ROOT / PATCH_EXT.")


scene1: 864 patches
scene2: 594 patches
scene3: 504 patches
scene4: 504 patches
Total patches: 2466


In [5]:
def inspect_npz(npz_path: str):
    with np.load(npz_path) as data:
        return {k: (data[k].shape, str(data[k].dtype)) for k in data.keys()}

sample_path = scene_to_files[cfg.SCENES[0]][0]
print("Sample file:", sample_path)
info = inspect_npz(sample_path)
for k, (shape, dtype) in info.items():
    print(f"  - {k:>10s}: shape={shape}, dtype={dtype}")

for rk in [cfg.X_KEY, cfg.A_KEY, cfg.Y_KEY]:
    if rk not in info:
        raise KeyError(f"Missing key '{rk}' in npz. Found keys: {list(info.keys())}")


Sample file: /kaggle/input/sen2fire/Sen2Fire/Sen2Fire/scene1/scene_1_patch_10_1.npz
  -      image: shape=(12, 512, 512), dtype=int16
  -    aerosol: shape=(512, 512), dtype=float32
  -      label: shape=(512, 512), dtype=uint8


In [6]:
rows = []
for s, files in scene_to_files.items():
    for p in files:
        rows.append({"scene": s, "path": p})
manifest = pd.DataFrame(rows)

def fire_ratio_from_path(npz_path: str) -> float:
    with np.load(npz_path) as data:
        y = data[cfg.Y_KEY]
        if y.ndim == 3 and y.shape[-1] == 1:
            y = y[..., 0]
        yb = (y > 0).astype(np.uint8)
        return float(yb.mean())

fire_ratios = []
has_fire = []
for p in manifest["path"].tolist():
    r = fire_ratio_from_path(p)
    fire_ratios.append(r)
    has_fire.append(1 if r > cfg.FIRE_PATCH_MIN_RATIO else 0)

manifest["fire_ratio"] = fire_ratios
manifest["has_fire"] = has_fire

print("\nFULL dataset has_fire distribution:")
print(manifest["has_fire"].value_counts().sort_index())
print("\nFULL dataset has_fire ratio:")
print(manifest["has_fire"].value_counts(normalize=True).sort_index())



FULL dataset has_fire distribution:
has_fire
0    2117
1     349
Name: count, dtype: int64

FULL dataset has_fire ratio:
has_fire
0    0.858475
1    0.141525
Name: proportion, dtype: float64


In [7]:
def stratified_split(df, train_ratio, val_ratio, seed=42):
    df = df.sample(frac=1.0, random_state=seed).reset_index(drop=True)
    parts = []
    for cls in [0, 1]:
        sub = df[df["has_fire"] == cls].copy()
        n = len(sub)
        n_train = int(round(n * train_ratio))
        n_val   = int(round(n * val_ratio))
        sub_train = sub.iloc[:n_train]
        sub_val   = sub.iloc[n_train:n_train+n_val]
        sub_test  = sub.iloc[n_train+n_val:]
        parts.append((sub_train, sub_val, sub_test))

    train_df = pd.concat([parts[0][0], parts[1][0]]).sample(frac=1.0, random_state=seed).reset_index(drop=True)
    val_df   = pd.concat([parts[0][1], parts[1][1]]).sample(frac=1.0, random_state=seed).reset_index(drop=True)
    test_df  = pd.concat([parts[0][2], parts[1][2]]).sample(frac=1.0, random_state=seed).reset_index(drop=True)
    return train_df, val_df, test_df

train_pool_df, val_df, test_df = stratified_split(
    manifest,
    train_ratio=cfg.GLOBAL_TRAIN_RATIO,
    val_ratio=cfg.GLOBAL_VAL_RATIO,
    seed=SEED
)

print("\nGLOBAL split counts:")
print("  Train pool:", len(train_pool_df))
print("  Val      :", len(val_df))
print("  Test     :", len(test_df))

# Controlled pool applies to TRAIN only
if cfg.USE_CONTROLLED_POOL:
    fire_df   = train_pool_df[train_pool_df["has_fire"] == 1].copy()
    nofire_df = train_pool_df[train_pool_df["has_fire"] == 0].copy()

    n_fire_total = len(fire_df)
    n_nofire_total = len(nofire_df)

    n_fire_keep = n_fire_total if cfg.POOL_KEEP_FIRE in (-1, None) else min(cfg.POOL_KEEP_FIRE, n_fire_total)
    fire_keep = fire_df.sample(n=n_fire_keep, random_state=SEED) if n_fire_keep > 0 else fire_df.iloc[:0]

    n_nofire_keep = min(n_nofire_total, n_fire_keep * int(cfg.NONFIRE_PER_FIRE))
    nofire_keep = nofire_df.sample(n=n_nofire_keep, random_state=SEED) if n_nofire_keep > 0 else nofire_df.iloc[:0]

    train_df = pd.concat([fire_keep, nofire_keep]).sample(frac=1.0, random_state=SEED).reset_index(drop=True)

    print("\nTRAIN controlled sampling:")
    print(f"  fire_total_in_pool   : {n_fire_total}")
    print(f"  nofire_total_in_pool : {n_nofire_total}")
    print(f"  fire_kept            : {len(fire_keep)}")
    print(f"  nofire_kept          : {len(nofire_keep)} (NONFIRE_PER_FIRE={cfg.NONFIRE_PER_FIRE})")
    print(f"  train_final          : {len(train_df)}")
else:
    train_df = train_pool_df.copy()

train_paths = train_df["path"].tolist()
val_paths   = val_df["path"].tolist()
test_paths  = test_df["path"].tolist()

# leakage check
assert len(set(train_paths)&set(val_paths))==0
assert len(set(train_paths)&set(test_paths))==0
assert len(set(val_paths)&set(test_paths))==0

print("\nFinal used split sizes:")
print("  train:", len(train_paths))
print("  val  :", len(val_paths))
print("  test :", len(test_paths))



GLOBAL split counts:
  Train pool: 1973
  Val      : 247
  Test     : 246

TRAIN controlled sampling:
  fire_total_in_pool   : 279
  nofire_total_in_pool : 1694
  fire_kept            : 279
  nofire_kept          : 837 (NONFIRE_PER_FIRE=3)
  train_final          : 1116

Final used split sizes:
  train: 1116
  val  : 247
  test : 246


## Machine 2 — Normalization (X only)

We compute mean/std per band (13 values each) from TRAIN ONLY:
X_norm[c] = (X[c] - mean[c]) / std[c]

Mini example:
- raw pixel = 820
- mean = 600
- std = 100
- norm = 2.2


In [8]:
class Sen2FireDataset(Dataset):
    def __init__(self, paths: List[str], with_label: bool = True):
        self.paths = paths
        self.with_label = with_label

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        with np.load(p) as d:
            img12 = d[cfg.X_KEY].astype(np.float32)      # (12,512,512)
            aer   = d[cfg.A_KEY].astype(np.float32)[None, ...]  # (1,512,512)
            x = np.concatenate([img12, aer], axis=0)     # (13,512,512)

            if self.with_label:
                y = d[cfg.Y_KEY]
                if y.ndim == 2:
                    y = y[None, ...]
                y = (y > 0).astype(np.float32)
                return torch.from_numpy(x), torch.from_numpy(y)
            else:
                return torch.from_numpy(x)


In [9]:
train_loader = DataLoader(Sen2FireDataset(train_paths, with_label=True),
                          batch_size=cfg.BATCH_SIZE, shuffle=True,
                          num_workers=cfg.NUM_WORKERS, pin_memory=True)

val_loader = DataLoader(Sen2FireDataset(val_paths, with_label=True),
                        batch_size=cfg.BATCH_SIZE, shuffle=False,
                        num_workers=cfg.NUM_WORKERS, pin_memory=True)

test_loader = DataLoader(Sen2FireDataset(test_paths, with_label=True),
                         batch_size=cfg.BATCH_SIZE, shuffle=False,
                         num_workers=cfg.NUM_WORKERS, pin_memory=True)


In [10]:
@torch.no_grad()
def compute_mean_std(loader, max_batches=50):
    mean = torch.zeros(cfg.IN_CHANNELS, device=DEVICE)
    var  = torch.zeros(cfg.IN_CHANNELS, device=DEVICE)
    n_batches = 0

    for bi, (x, y) in enumerate(loader):
        if bi >= max_batches:
            break
        x = x.to(DEVICE)  # (B,C,H,W)
        x_ = x.view(x.size(0), cfg.IN_CHANNELS, -1)
        mean += x_.mean(dim=(0,2))
        var  += x_.var(dim=(0,2), unbiased=False)
        n_batches += 1

    mean /= max(n_batches, 1)
    var  /= max(n_batches, 1)
    std = torch.sqrt(var + 1e-6)
    return mean.detach().cpu().tolist(), std.detach().cpu().tolist()

MEAN_13, STD_13 = compute_mean_std(train_loader, max_batches=50)
print("MEAN_13 len:", len(MEAN_13), "STD_13 len:", len(STD_13))


MEAN_13 len: 13 STD_13 len: 13


In [11]:
!pip -q install segmentation-models-pytorch


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [12]:
criterion = nn.BCEWithLogitsLoss()

def normalize_batch(x, mean_13, std_13):
    mean_t = torch.tensor(mean_13, device=x.device).view(1, cfg.IN_CHANNELS, 1, 1)
    std_t  = torch.tensor(std_13,  device=x.device).view(1, cfg.IN_CHANNELS, 1, 1)
    return (x - mean_t) / (std_t + 1e-6)

# DEEPLAB

In [13]:
import segmentation_models_pytorch as smp
import torch.nn as nn

criterion = nn.BCEWithLogitsLoss()

def build_deeplab(in_channels=13):
    model = smp.DeepLabV3Plus(
        encoder_name="efficientnet-b4", 
        encoder_weights="imagenet",
        in_channels=in_channels,
        classes=1,
        activation=None
    )
    return model

model = build_deeplab(cfg.IN_CHANNELS).to(DEVICE)
print("Model:", type(model).__name__)



config.json:   0%|          | 0.00/106 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/77.9M [00:00<?, ?B/s]

Model: DeepLabV3Plus


In [13]:
from torch.amp import autocast, GradScaler
from sklearn.metrics import average_precision_score
import time

optimizer = torch.optim.Adam(model.parameters(), lr=cfg.LR)
scaler = GradScaler(enabled=(DEVICE.type == "cuda"))

def train_one_epoch(model, loader, thr=0.5, eps=1e-6):
    model.train()
    total_loss_sum = 0.0
    total_samples = 0

    total_tp = total_fp = total_fn = total_tn = 0.0
    all_probs, all_targets = [], []

    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        x = normalize_batch(x, MEAN_13, STD_13)

        optimizer.zero_grad(set_to_none=True)

        with autocast("cuda", enabled=(DEVICE.type == "cuda")):
            logits = model(x)
            loss = criterion(logits, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        bs = x.size(0)
        total_loss_sum += loss.item() * bs
        total_samples += bs

        probs = torch.sigmoid(logits)
        preds = (probs >= thr).float()

        total_tp += (preds * y).sum().item()
        total_fp += (preds * (1 - y)).sum().item()
        total_fn += ((1 - preds) * y).sum().item()
        total_tn += ((1 - preds) * (1 - y)).sum().item()

        all_probs.append(probs.flatten().detach().cpu())
        all_targets.append(y.flatten().detach().cpu())

    precision = (total_tp + eps) / (total_tp + total_fp + eps)
    recall    = (total_tp + eps) / (total_tp + total_fn + eps)
    f1        = (2 * precision * recall) / (precision + recall + eps)
    acc       = (total_tp + total_tn + eps) / (
        total_tp + total_tn + total_fp + total_fn + eps
    )

    all_probs = torch.cat(all_probs).numpy()
    all_targets = torch.cat(all_targets).numpy()
    auc = average_precision_score(all_targets, all_probs)

    loss_global = total_loss_sum / total_samples

    return loss_global, precision, recall, f1, acc, auc


@torch.no_grad()
def validate_epoch(model, loader, thr=0.5, eps=1e-6):
    model.eval()
    total_loss_sum = 0.0
    total_samples = 0

    total_tp = total_fp = total_fn = total_tn = 0.0
    all_probs, all_targets = [], []

    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        x = normalize_batch(x, MEAN_13, STD_13)

        logits = model(x)
        loss = criterion(logits, y)

        bs = x.size(0)
        total_loss_sum += loss.item() * bs
        total_samples += bs

        probs = torch.sigmoid(logits)
        preds = (probs >= thr).float()

        total_tp += (preds * y).sum().item()
        total_fp += (preds * (1 - y)).sum().item()
        total_fn += ((1 - preds) * y).sum().item()
        total_tn += ((1 - preds) * (1 - y)).sum().item()

        all_probs.append(probs.flatten().cpu())
        all_targets.append(y.flatten().cpu())

    precision = (total_tp + eps) / (total_tp + total_fp + eps)
    recall    = (total_tp + eps) / (total_tp + total_fn + eps)
    f1        = (2 * precision * recall) / (precision + recall + eps)
    acc       = (total_tp + total_tn + eps) / (
        total_tp + total_tn + total_fp + total_fn + eps
    )

    all_probs = torch.cat(all_probs).numpy()
    all_targets = torch.cat(all_targets).numpy()
    auc = average_precision_score(all_targets, all_probs)

    loss_global = total_loss_sum / total_samples

    return loss_global, precision, recall, f1, acc, auc

In [None]:
BEST_PATH = "/kaggle/working/deeplab-b4.pth"
best_val_f1 = -1

import pandas as pd
history = []

for epoch in range(1, cfg.EPOCHS + 1):
    t0 = time.time()

    tr_loss, tr_prec, tr_rec, tr_f1, tr_acc, tr_auc = train_one_epoch(model, train_loader)
    va_loss, va_prec, va_rec, va_f1, va_acc, va_auc = validate_epoch(model, val_loader, thr=0.5)
    elapsed = time.time() - t0

    print(
        f"[unet_basic] Epoch {epoch:03d} | "
        f"tr_loss={tr_loss:.4f} tr_f1={tr_f1:.4f} tr_acc={tr_acc:.4f} tr_PRauc={tr_auc:.4f} || "
        f"va_loss={va_loss:.4f} va_f1={va_f1:.4f} va_acc={va_acc:.4f} va_PRauc={va_auc:.4f} | "
        f"{elapsed:.1f}s"
    )

    history.append({
        "epoch": epoch,
    
        "tr_loss": tr_loss,
        "tr_precision": tr_prec,
        "tr_recall": tr_rec,
        "tr_f1": tr_f1,
        "tr_acc": tr_acc,
        "tr_PRauc": tr_auc,
    
        "va_loss": va_loss,
        "va_precision": va_prec,
        "va_recall": va_rec,
        "va_f1": va_f1,
        "va_acc": va_acc,
        "va_PRauc": va_auc,
    
        "seconds": elapsed
    })

    if va_f1 > best_val_f1:
        torch.save({
            "model_state": model.state_dict(),
            "mean_13": MEAN_13,
            "std_13": STD_13,
            "epoch": epoch,
            "best_val_f1": best_val_f1,
            "va_auc": va_auc,
        }, BEST_PATH)

[unet_basic] Epoch 001 | tr_loss=0.2787 tr_f1=0.0174 tr_acc=0.9346 tr_PRauc=0.0732 || va_loss=0.1667 va_f1=0.0040 va_acc=0.9520 va_PRauc=0.2708 | 400.6s
[unet_basic] Epoch 002 | tr_loss=0.1921 tr_f1=0.0322 tr_acc=0.9413 tr_PRauc=0.2466 || va_loss=0.1187 va_f1=0.3036 va_acc=0.9570 va_PRauc=0.4522 | 398.2s
[unet_basic] Epoch 003 | tr_loss=0.1787 tr_f1=0.1253 tr_acc=0.9422 tr_PRauc=0.3084 || va_loss=0.1315 va_f1=0.2604 va_acc=0.9537 va_PRauc=0.3854 | 397.5s
[unet_basic] Epoch 004 | tr_loss=0.1535 tr_f1=0.3014 tr_acc=0.9461 tr_PRauc=0.4321 || va_loss=0.1246 va_f1=0.3589 va_acc=0.9567 va_PRauc=0.4470 | 397.5s
[unet_basic] Epoch 005 | tr_loss=0.1289 tr_f1=0.4866 tr_acc=0.9543 tr_PRauc=0.5793 || va_loss=0.1167 va_f1=0.3626 va_acc=0.9570 va_PRauc=0.4905 | 395.9s
[unet_basic] Epoch 006 | tr_loss=0.1054 tr_f1=0.6396 tr_acc=0.9621 tr_PRauc=0.6848 || va_loss=0.1271 va_f1=0.4148 va_acc=0.9545 va_PRauc=0.4399 | 395.6s
[unet_basic] Epoch 007 | tr_loss=0.0982 tr_f1=0.6698 tr_acc=0.9665 tr_PRauc=0.7423

In [17]:
df_history = pd.DataFrame(history)

df_history = df_history.round(4)
csv_path = "/kaggle/working/training_metrics_eff1.csv"
df_history.to_csv(csv_path, index=False)
print("Saved to:", csv_path)
df_history

Saved to: /kaggle/working/training_metrics_eff1.csv


Unnamed: 0,epoch,tr_loss,tr_precision,tr_recall,tr_f1,tr_acc,tr_PRauc,va_loss,va_precision,va_recall,va_f1,va_acc,va_PRauc,seconds
0,1,0.2787,0.0753,0.0098,0.0174,0.9346,0.0732,0.1667,0.2716,0.002,0.004,0.952,0.2708,400.5822
1,2,0.1921,0.5447,0.0166,0.0322,0.9413,0.2466,0.1187,0.6742,0.1959,0.3036,0.957,0.4522,398.1763
2,3,0.1787,0.5722,0.0704,0.1253,0.9422,0.3084,0.1315,0.5531,0.1703,0.2604,0.9537,0.3854,397.495
3,4,0.1535,0.6351,0.1975,0.3014,0.9461,0.4321,0.1246,0.6158,0.2533,0.3589,0.9567,0.447,397.5215
4,5,0.1289,0.7196,0.3676,0.4866,0.9543,0.5793,0.1167,0.6239,0.2556,0.3626,0.957,0.4905,395.8892
5,6,0.1054,0.7263,0.5714,0.6396,0.9621,0.6848,0.1271,0.5392,0.3371,0.4148,0.9545,0.4399,395.5895
6,7,0.0982,0.7982,0.577,0.6698,0.9665,0.7423,0.1028,0.5714,0.6553,0.6105,0.96,0.5962,395.6905
7,8,0.0741,0.8252,0.733,0.7764,0.9751,0.8325,0.124,0.7076,0.1581,0.2585,0.9566,0.5015,395.3123
8,9,0.06,0.8587,0.7899,0.8228,0.98,0.8809,0.1388,0.5553,0.122,0.2001,0.9533,0.3915,395.4704
9,10,0.0485,0.8676,0.8366,0.8518,0.9829,0.9145,0.1009,0.8302,0.3444,0.4869,0.9653,0.6664,393.8328


In [15]:
@torch.no_grad()
def find_best_thr_global_f1(model, loader, grid=None, eps=1e-6):
    model.eval()
    if grid is None:
        grid = np.linspace(0.05, 0.95, 19)

    best_thr = None
    best_f1 = -1
    log = []

    for thr in grid:
        total_tp = total_fp = total_fn = 0.0

        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            x = normalize_batch(x, MEAN_13, STD_13)

            logits = model(x)
            probs = torch.sigmoid(logits)
            preds = (probs >= thr).float()

            total_tp += (preds * y).sum().item()
            total_fp += (preds * (1 - y)).sum().item()
            total_fn += ((1 - preds) * y).sum().item()

        precision = (total_tp + eps) / (total_tp + total_fp + eps)
        recall    = (total_tp + eps) / (total_tp + total_fn + eps)
        f1        = (2 * precision * recall) / (precision + recall + eps)

        log.append((float(thr), f1))

        if f1 > best_f1:
            best_f1 = f1
            best_thr = float(thr)

    return best_thr, best_f1, log

t_best, best_f1, thr_log = find_best_thr_global_f1(model, val_loader)

print("Chosen threshold:", t_best)
print("Best global F1:", best_f1)

print("Top-5 thresholds:")
for thr, f1 in sorted(thr_log, key=lambda x: x[1], reverse=True)[:5]:
    print(thr, f1)

Chosen threshold: 0.05
Best global F1: 0.521013327039667
Top-5 thresholds:
0.05 0.521013327039667
0.1 0.4976343547287141
0.15 0.480017006184065
0.2 0.4646291311364695
0.25 0.4503687547885739


In [16]:
test_loss, test_prec, test_rec, test_f1, test_acc, test_auc = validate_epoch(
    model, test_loader, thr=t_best
)

print(f"Test Loss   : {test_loss:.4f}")
print(f"Test Prec   : {test_prec:.4f}")
print(f"Test Recall : {test_rec:.4f}")
print(f"Test F1     : {test_f1:.4f}")
print(f"Test Acc    : {test_acc:.4f}")
print(f"Test PRauc  : {test_auc:.4f}")

Test Loss   : 0.0674
Test Prec   : 0.4625
Test Recall : 0.6308
Test F1     : 0.5337
Test Acc    : 0.9732
Test PRauc    : 0.5996


# RECALL MODEL

In [14]:
ckpt = torch.load(
    "/kaggle/input/deeplab-b4/pytorch/default/1/deeplab-b4.pth",
    map_location=DEVICE,
    weights_only=False
)
model.load_state_dict(ckpt["model_state"])

<All keys matched successfully>

In [None]:
cfg.EPOCHS = 20

In [16]:
from torch.amp import autocast, GradScaler
from sklearn.metrics import average_precision_score
from sklearn.metrics import roc_auc_score
import time

optimizer = torch.optim.Adam(model.parameters(), lr=cfg.LR)
scaler = GradScaler(enabled=(DEVICE.type == "cuda"))

def train_one_epoch(model, loader, thr=0.05, eps=1e-6):
    model.train()
    total_loss_sum = 0.0
    total_samples = 0

    total_tp = total_fp = total_fn = total_tn = 0.0
    all_probs, all_targets = [], []

    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        x = normalize_batch(x, MEAN_13, STD_13)

        optimizer.zero_grad(set_to_none=True)

        with autocast("cuda", enabled=(DEVICE.type == "cuda")):
            logits = model(x)
            loss = criterion(logits, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        bs = x.size(0)
        total_loss_sum += loss.item() * bs
        total_samples += bs

        probs = torch.sigmoid(logits)
        preds = (probs >= thr).float()

        total_tp += (preds * y).sum().item()
        total_fp += (preds * (1 - y)).sum().item()
        total_fn += ((1 - preds) * y).sum().item()
        total_tn += ((1 - preds) * (1 - y)).sum().item()

        all_probs.append(probs.flatten().detach().cpu())
        all_targets.append(y.flatten().detach().cpu())

    precision = (total_tp + eps) / (total_tp + total_fp + eps)
    recall    = (total_tp + eps) / (total_tp + total_fn + eps)
    f1        = (2 * precision * recall) / (precision + recall + eps)
    acc       = (total_tp + total_tn + eps) / (
        total_tp + total_tn + total_fp + total_fn + eps
    )

    all_probs = torch.cat(all_probs).numpy()
    all_targets = torch.cat(all_targets).numpy()
    auc = average_precision_score(all_targets, all_probs)

    loss_global = total_loss_sum / total_samples

    return loss_global, precision, recall, f1, acc, auc


@torch.no_grad()
def validate_epoch(model, loader, thr=0.05, eps=1e-6):
    model.eval()
    total_loss_sum = 0.0
    total_samples = 0

    total_tp = total_fp = total_fn = total_tn = 0.0
    all_probs, all_targets = [], []

    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        x = normalize_batch(x, MEAN_13, STD_13)

        logits = model(x)
        loss = criterion(logits, y)

        bs = x.size(0)
        total_loss_sum += loss.item() * bs
        total_samples += bs

        probs = torch.sigmoid(logits)
        preds = (probs >= thr).float()

        total_tp += (preds * y).sum().item()
        total_fp += (preds * (1 - y)).sum().item()
        total_fn += ((1 - preds) * y).sum().item()
        total_tn += ((1 - preds) * (1 - y)).sum().item()

        all_probs.append(probs.flatten().cpu())
        all_targets.append(y.flatten().cpu())

    precision = (total_tp + eps) / (total_tp + total_fp + eps)
    recall    = (total_tp + eps) / (total_tp + total_fn + eps)
    f1        = (2 * precision * recall) / (precision + recall + eps)
    acc       = (total_tp + total_tn + eps) / (
        total_tp + total_tn + total_fp + total_fn + eps
    )

    all_probs = torch.cat(all_probs).numpy()
    all_targets = torch.cat(all_targets).numpy()
    auc = average_precision_score(all_targets, all_probs)

    loss_global = total_loss_sum / total_samples

    return loss_global, precision, recall, f1, acc, auc

In [17]:
BEST_PATH = "/kaggle/working/deeplab_pretrained.pth"
best_val_f1 = -1

import pandas as pd
history = []

for epoch in range(1, cfg.EPOCHS + 1):
    t0 = time.time()

    tr_loss, tr_prec, tr_rec, tr_f1, tr_acc, tr_auc = train_one_epoch(model, train_loader)
    va_loss, va_prec, va_rec, va_f1, va_acc, va_auc = validate_epoch(model, val_loader)
    elapsed = time.time() - t0

    print(
        f"[unet_basic] Epoch {epoch:03d} | "
        f"tr_loss={tr_loss:.4f} tr_f1={tr_f1:.4f} tr_acc={tr_acc:.4f} tr_PRauc={tr_auc:.4f} || "
        f"va_loss={va_loss:.4f} va_f1={va_f1:.4f} va_acc={va_acc:.4f} va_PRauc={va_auc:.4f} | "
        f"{elapsed:.1f}s"
    )

    history.append({
        "epoch": epoch,
    
        "tr_loss": tr_loss,
        "tr_precision": tr_prec,
        "tr_recall": tr_rec,
        "tr_f1": tr_f1,
        "tr_acc": tr_acc,
        "tr_PRauc": tr_auc,
    
        "va_loss": va_loss,
        "va_precision": va_prec,
        "va_recall": va_rec,
        "va_f1": va_f1,
        "va_acc": va_acc,
        "va_PRauc": va_auc,
    
        "seconds": elapsed
    })

    if va_f1 > best_val_f1:
        torch.save({
            "model_state": model.state_dict(),
            "mean_13": MEAN_13,
            "std_13": STD_13,
            "epoch": epoch,
            "best_val_f1": best_val_f1,
            "va_auc": va_auc,
        }, BEST_PATH)


[unet_basic] Epoch 001 | tr_loss=0.0282 tr_f1=0.8703 tr_acc=0.9830 tr_PRauc=0.9715 || va_loss=0.1165 va_f1=0.6579 va_acc=0.9662 va_PRauc=0.7152 | 392.1s
[unet_basic] Epoch 002 | tr_loss=0.0220 tr_f1=0.8763 tr_acc=0.9836 tr_PRauc=0.9805 || va_loss=0.1282 va_f1=0.6580 va_acc=0.9676 va_PRauc=0.7085 | 391.5s
[unet_basic] Epoch 003 | tr_loss=0.0241 tr_f1=0.8520 tr_acc=0.9799 tr_PRauc=0.9769 || va_loss=0.1134 va_f1=0.6196 va_acc=0.9588 va_PRauc=0.6562 | 394.7s
[unet_basic] Epoch 004 | tr_loss=0.0235 tr_f1=0.8442 tr_acc=0.9785 tr_PRauc=0.9780 || va_loss=0.1180 va_f1=0.6726 va_acc=0.9689 va_PRauc=0.7331 | 393.0s
[unet_basic] Epoch 005 | tr_loss=0.0148 tr_f1=0.9081 tr_acc=0.9882 tr_PRauc=0.9906 || va_loss=0.1761 va_f1=0.5800 va_acc=0.9689 va_PRauc=0.7194 | 393.4s
[unet_basic] Epoch 006 | tr_loss=0.0124 tr_f1=0.9206 tr_acc=0.9899 tr_PRauc=0.9935 || va_loss=0.1515 va_f1=0.6198 va_acc=0.9686 va_PRauc=0.7009 | 392.6s
[unet_basic] Epoch 007 | tr_loss=0.0136 tr_f1=0.9147 tr_acc=0.9891 tr_PRauc=0.9922

In [18]:
df_history = pd.DataFrame(history)

df_history = df_history.round(4)
csv_path = "/kaggle/working/training_metrics_eff1.csv"
df_history.to_csv(csv_path, index=False)
print("Saved to:", csv_path)
df_history

Saved to: /kaggle/working/training_metrics_eff1.csv


Unnamed: 0,epoch,tr_loss,tr_precision,tr_recall,tr_f1,tr_acc,tr_PRauc,va_loss,va_precision,va_recall,va_f1,va_acc,va_PRauc,seconds
0,1,0.0282,0.7888,0.9706,0.8703,0.983,0.9715,0.1165,0.6376,0.6795,0.6579,0.9662,0.7152,392.0918
1,2,0.022,0.7874,0.9877,0.8763,0.9836,0.9805,0.1282,0.6646,0.6515,0.658,0.9676,0.7085,391.4926
2,3,0.0241,0.7516,0.9834,0.852,0.9799,0.9769,0.1134,0.5544,0.7023,0.6196,0.9588,0.6562,394.6945
3,4,0.0235,0.7375,0.9871,0.8442,0.9785,0.978,0.118,0.677,0.6682,0.6726,0.9689,0.7331,393.028
4,5,0.0148,0.836,0.9937,0.9081,0.9882,0.9906,0.1761,0.8169,0.4496,0.58,0.9689,0.7194,393.3645
5,6,0.0124,0.8567,0.9949,0.9206,0.9899,0.9935,0.1515,0.737,0.5347,0.6198,0.9686,0.7009,392.6385
6,7,0.0136,0.8469,0.9943,0.9147,0.9891,0.9922,0.1451,0.8007,0.5604,0.6594,0.9723,0.7504,395.6485
7,8,0.0117,0.8671,0.9962,0.9271,0.9908,0.994,0.242,0.8893,0.2849,0.4315,0.9641,0.7093,396.6417
8,9,0.0112,0.8715,0.9952,0.9292,0.9911,0.9946,0.2883,0.7727,0.1885,0.303,0.9585,0.4987,395.3366
9,10,0.011,0.8748,0.9952,0.9312,0.9913,0.9946,0.1712,0.7476,0.55,0.6337,0.9696,0.6912,397.3984


In [19]:
t_best=0.05
test_loss, test_prec, test_rec, test_f1, test_acc, test_auc = validate_epoch(
    model, test_loader, thr=t_best
)

print(f"Test Loss   : {test_loss:.4f}")
print(f"Test Prec   : {test_prec:.4f}")
print(f"Test Recall : {test_rec:.4f}")
print(f"Test F1     : {test_f1:.4f}")
print(f"Test Acc    : {test_acc:.4f}")
print(f"Test PRauc  : {test_auc:.4f}")

Test Loss   : 0.0693
Test Prec   : 0.5757
Test Recall : 0.6979
Test F1     : 0.6309
Test Acc    : 0.9802
Test PRauc  : 0.6572
