In [1]:
import os, glob, random, json
from dataclasses import dataclass
from typing import List, Tuple

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# ----------------------------
# Reproducibility
# ----------------------------
def seed_everything(seed: int = 42) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

SEED = 24
seed_everything(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# ----------------------------
# Config
# ----------------------------
@dataclass
class CFG:
    DATA_ROOT: str = "/kaggle/input/sen2fire/Sen2Fire/Sen2Fire"
    SCENES: Tuple[str, ...] = ("scene1", "scene2", "scene3", "scene4")
    PATCH_EXT: str = ".npz"

    # Global split (VD default)
    USE_GLOBAL_SPLIT: bool = True
    GLOBAL_TRAIN_RATIO: float = 0.80
    GLOBAL_VAL_RATIO: float = 0.10
    GLOBAL_TEST_RATIO: float = 0.10
    KEEP_VAL_TEST_NATURAL: bool = True

    # Keys inside npz
    X_KEY: str = "image"     # (12,512,512)
    A_KEY: str = "aerosol"   # (512,512) -> becomes 1 channel
    Y_KEY: str = "label"     # (512,512)

    # Fire patch definition
    FIRE_PATCH_MIN_RATIO: float = 0

    # Controlled train pool (VD default)
    USE_CONTROLLED_POOL: bool = True
    POOL_KEEP_FIRE: int = -1
    NONFIRE_PER_FIRE: int = 3

    # Training config (RAM-safe)
    IN_CHANNELS: int = 13
    H: int = 512
    W: int = 512
    BATCH_SIZE: int = 2
    NUM_WORKERS: int = 2
    LR: float = 1e-4
    EPOCHS: int = 20 

    VERBOSE: bool = True

cfg = CFG()

assert abs((cfg.GLOBAL_TRAIN_RATIO + cfg.GLOBAL_VAL_RATIO + cfg.GLOBAL_TEST_RATIO) - 1.0) < 1e-6
assert os.path.exists(cfg.DATA_ROOT), f"DATA_ROOT not found: {cfg.DATA_ROOT}"
print("DATA_ROOT OK:", cfg.DATA_ROOT)


Device: cuda
DATA_ROOT OK: /kaggle/input/sen2fire/Sen2Fire/Sen2Fire


In [2]:
def list_scene_files(data_root: str, scene_name: str, ext: str = ".npz") -> List[str]:
    scene_dir = os.path.join(data_root, scene_name)
    return sorted(glob.glob(os.path.join(scene_dir, f"*{ext}")))

scene_to_files = {}
total = 0
for s in cfg.SCENES:
    files = list_scene_files(cfg.DATA_ROOT, s, cfg.PATCH_EXT)
    scene_to_files[s] = files
    total += len(files)
    print(f"{s}: {len(files)} patches")

print("Total patches:", total)
if total == 0:
    raise RuntimeError("No patch files found. Check DATA_ROOT / PATCH_EXT.")


scene1: 864 patches
scene2: 594 patches
scene3: 504 patches
scene4: 504 patches
Total patches: 2466


In [3]:
def inspect_npz(npz_path: str):
    with np.load(npz_path) as data:
        return {k: (data[k].shape, str(data[k].dtype)) for k in data.keys()}

sample_path = scene_to_files[cfg.SCENES[0]][0]
print("Sample file:", sample_path)
info = inspect_npz(sample_path)
for k, (shape, dtype) in info.items():
    print(f"  - {k:>10s}: shape={shape}, dtype={dtype}")

for rk in [cfg.X_KEY, cfg.A_KEY, cfg.Y_KEY]:
    if rk not in info:
        raise KeyError(f"Missing key '{rk}' in npz. Found keys: {list(info.keys())}")


Sample file: /kaggle/input/sen2fire/Sen2Fire/Sen2Fire/scene1/scene_1_patch_10_1.npz
  -      image: shape=(12, 512, 512), dtype=int16
  -    aerosol: shape=(512, 512), dtype=float32
  -      label: shape=(512, 512), dtype=uint8


In [4]:
rows = []
for s, files in scene_to_files.items():
    for p in files:
        rows.append({"scene": s, "path": p})
manifest = pd.DataFrame(rows)

def fire_ratio_from_path(npz_path: str) -> float:
    with np.load(npz_path) as data:
        y = data[cfg.Y_KEY]
        if y.ndim == 3 and y.shape[-1] == 1:
            y = y[..., 0]
        yb = (y > 0).astype(np.uint8)
        return float(yb.mean())

fire_ratios = []
has_fire = []
for p in manifest["path"].tolist():
    r = fire_ratio_from_path(p)
    fire_ratios.append(r)
    has_fire.append(1 if r > cfg.FIRE_PATCH_MIN_RATIO else 0)

manifest["fire_ratio"] = fire_ratios
manifest["has_fire"] = has_fire

print("\nFULL dataset has_fire distribution:")
print(manifest["has_fire"].value_counts().sort_index())
print("\nFULL dataset has_fire ratio:")
print(manifest["has_fire"].value_counts(normalize=True).sort_index())



FULL dataset has_fire distribution:
has_fire
0    2117
1     349
Name: count, dtype: int64

FULL dataset has_fire ratio:
has_fire
0    0.858475
1    0.141525
Name: proportion, dtype: float64


# DATASET PREPARATION

In [5]:
def stratified_split(df, train_ratio, val_ratio, seed=42):
    df = df.sample(frac=1.0, random_state=seed).reset_index(drop=True)
    parts = []
    for cls in [0, 1]:
        sub = df[df["has_fire"] == cls].copy()
        n = len(sub)
        n_train = int(round(n * train_ratio))
        n_val   = int(round(n * val_ratio))
        sub_train = sub.iloc[:n_train]
        sub_val   = sub.iloc[n_train:n_train+n_val]
        sub_test  = sub.iloc[n_train+n_val:]
        parts.append((sub_train, sub_val, sub_test))

    train_df = pd.concat([parts[0][0], parts[1][0]]).sample(frac=1.0, random_state=seed).reset_index(drop=True)
    val_df   = pd.concat([parts[0][1], parts[1][1]]).sample(frac=1.0, random_state=seed).reset_index(drop=True)
    test_df  = pd.concat([parts[0][2], parts[1][2]]).sample(frac=1.0, random_state=seed).reset_index(drop=True)
    return train_df, val_df, test_df

train_pool_df, val_df, test_df = stratified_split(
    manifest,
    train_ratio=cfg.GLOBAL_TRAIN_RATIO,
    val_ratio=cfg.GLOBAL_VAL_RATIO,
    seed=SEED
)

print("\nGLOBAL split counts:")
print("  Train pool:", len(train_pool_df))
print("  Val      :", len(val_df))
print("  Test     :", len(test_df))

# Controlled pool applies to TRAIN only
if cfg.USE_CONTROLLED_POOL:
    fire_df   = train_pool_df[train_pool_df["has_fire"] == 1].copy()
    nofire_df = train_pool_df[train_pool_df["has_fire"] == 0].copy()

    n_fire_total = len(fire_df)
    n_nofire_total = len(nofire_df)

    n_fire_keep = n_fire_total if cfg.POOL_KEEP_FIRE in (-1, None) else min(cfg.POOL_KEEP_FIRE, n_fire_total)
    fire_keep = fire_df.sample(n=n_fire_keep, random_state=SEED) if n_fire_keep > 0 else fire_df.iloc[:0]

    n_nofire_keep = min(n_nofire_total, n_fire_keep * int(cfg.NONFIRE_PER_FIRE))
    nofire_keep = nofire_df.sample(n=n_nofire_keep, random_state=SEED) if n_nofire_keep > 0 else nofire_df.iloc[:0]

    train_df = pd.concat([fire_keep, nofire_keep]).sample(frac=1.0, random_state=SEED).reset_index(drop=True)

    print("\nTRAIN controlled sampling:")
    print(f"  fire_total_in_pool   : {n_fire_total}")
    print(f"  nofire_total_in_pool : {n_nofire_total}")
    print(f"  fire_kept            : {len(fire_keep)}")
    print(f"  nofire_kept          : {len(nofire_keep)} (NONFIRE_PER_FIRE={cfg.NONFIRE_PER_FIRE})")
    print(f"  train_final          : {len(train_df)}")
else:
    train_df = train_pool_df.copy()

train_paths = train_df["path"].tolist()
val_paths   = val_df["path"].tolist()
test_paths  = test_df["path"].tolist()

# leakage check
assert len(set(train_paths)&set(val_paths))==0
assert len(set(train_paths)&set(test_paths))==0
assert len(set(val_paths)&set(test_paths))==0

print("\nFinal used split sizes:")
print("  train:", len(train_paths))
print("  val  :", len(val_paths))
print("  test :", len(test_paths))



GLOBAL split counts:
  Train pool: 1973
  Val      : 247
  Test     : 246

TRAIN controlled sampling:
  fire_total_in_pool   : 279
  nofire_total_in_pool : 1694
  fire_kept            : 279
  nofire_kept          : 837 (NONFIRE_PER_FIRE=3)
  train_final          : 1116

Final used split sizes:
  train: 1116
  val  : 247
  test : 246


In [6]:
class Sen2FireDataset(Dataset):
    def __init__(self, paths: List[str], with_label: bool = True):
        self.paths = paths
        self.with_label = with_label

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        p = self.paths[idx]
        with np.load(p) as d:
            img12 = d[cfg.X_KEY].astype(np.float32)      # (12,512,512)
            aer   = d[cfg.A_KEY].astype(np.float32)[None, ...]  # (1,512,512)
            x = np.concatenate([img12, aer], axis=0)     # (13,512,512)

            if self.with_label:
                y = d[cfg.Y_KEY]
                if y.ndim == 2:
                    y = y[None, ...]
                y = (y > 0).astype(np.float32)
                return torch.from_numpy(x), torch.from_numpy(y)
            else:
                return torch.from_numpy(x)


In [7]:
train_loader = DataLoader(Sen2FireDataset(train_paths, with_label=True),
                          batch_size=cfg.BATCH_SIZE, shuffle=True,
                          num_workers=cfg.NUM_WORKERS, pin_memory=True)

val_loader = DataLoader(Sen2FireDataset(val_paths, with_label=True),
                        batch_size=cfg.BATCH_SIZE, shuffle=False,
                        num_workers=cfg.NUM_WORKERS, pin_memory=True)

test_loader = DataLoader(Sen2FireDataset(test_paths, with_label=True),
                         batch_size=cfg.BATCH_SIZE, shuffle=False,
                         num_workers=cfg.NUM_WORKERS, pin_memory=True)


In [8]:
@torch.no_grad()
def compute_mean_std(loader, max_batches=50):
    mean = torch.zeros(cfg.IN_CHANNELS, device=DEVICE)
    var  = torch.zeros(cfg.IN_CHANNELS, device=DEVICE)
    n_batches = 0

    for bi, (x, y) in enumerate(loader):
        if bi >= max_batches:
            break
        x = x.to(DEVICE)  # (B,C,H,W)
        x_ = x.view(x.size(0), cfg.IN_CHANNELS, -1)
        mean += x_.mean(dim=(0,2))
        var  += x_.var(dim=(0,2), unbiased=False)
        n_batches += 1

    mean /= max(n_batches, 1)
    var  /= max(n_batches, 1)
    std = torch.sqrt(var + 1e-6)
    return mean.detach().cpu().tolist(), std.detach().cpu().tolist()

MEAN_13, STD_13 = compute_mean_std(train_loader, max_batches=50)
print("MEAN_13 len:", len(MEAN_13), "STD_13 len:", len(STD_13))


MEAN_13 len: 13 STD_13 len: 13


# MODEL PREPARATION


In [9]:
!pip -q install segmentation-models-pytorch


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [10]:
def normalize_batch(x, mean_13, std_13):
    mean_t = torch.tensor(mean_13, device=x.device).view(1, cfg.IN_CHANNELS, 1, 1)
    std_t  = torch.tensor(std_13,  device=x.device).view(1, cfg.IN_CHANNELS, 1, 1)
    return (x - mean_t) / (std_t + 1e-6)

In [11]:
import segmentation_models_pytorch as smp

criterion = nn.BCEWithLogitsLoss()
def build_unet(in_channels=13):
    model = smp.Unet(
        encoder_name="efficientnet-b4",
        encoder_weights="imagenet",
        in_channels=in_channels,
        classes=1,
        activation=None
    )
    return model
    
model = build_unet(cfg.IN_CHANNELS).to(DEVICE)
print("Model:", type(model).__name__)



config.json:   0%|          | 0.00/106 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/77.9M [00:00<?, ?B/s]

Model: Unet


In [12]:
from torch.amp import autocast, GradScaler
from sklearn.metrics import average_precision_score
from sklearn.metrics import roc_auc_score
import time

optimizer = torch.optim.Adam(model.parameters(), lr=cfg.LR)
scaler = GradScaler(enabled=(DEVICE.type == "cuda"))

def train_one_epoch(model, loader, thr=0.5, eps=1e-6):
    model.train()
    total_loss_sum = 0.0
    total_samples = 0

    total_tp = total_fp = total_fn = total_tn = 0.0
    all_probs, all_targets = [], []

    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        x = normalize_batch(x, MEAN_13, STD_13)

        optimizer.zero_grad(set_to_none=True)

        with autocast("cuda", enabled=(DEVICE.type == "cuda")):
            logits = model(x)
            loss = criterion(logits, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        bs = x.size(0)
        total_loss_sum += loss.item() * bs
        total_samples += bs

        probs = torch.sigmoid(logits)
        preds = (probs >= thr).float()

        total_tp += (preds * y).sum().item()
        total_fp += (preds * (1 - y)).sum().item()
        total_fn += ((1 - preds) * y).sum().item()
        total_tn += ((1 - preds) * (1 - y)).sum().item()

        all_probs.append(probs.flatten().detach().cpu())
        all_targets.append(y.flatten().detach().cpu())

    precision = (total_tp + eps) / (total_tp + total_fp + eps)
    recall    = (total_tp + eps) / (total_tp + total_fn + eps)
    f1        = (2 * precision * recall) / (precision + recall + eps)
    acc       = (total_tp + total_tn + eps) / (
        total_tp + total_tn + total_fp + total_fn + eps
    )

    all_probs = torch.cat(all_probs).numpy()
    all_targets = torch.cat(all_targets).numpy()
    auc = average_precision_score(all_targets, all_probs)

    loss_global = total_loss_sum / total_samples

    return loss_global, precision, recall, f1, acc, auc


@torch.no_grad()
def validate_epoch(model, loader, thr=0.5, eps=1e-6):
    model.eval()
    total_loss_sum = 0.0
    total_samples = 0

    total_tp = total_fp = total_fn = total_tn = 0.0
    all_probs, all_targets = [], []

    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        x = normalize_batch(x, MEAN_13, STD_13)

        logits = model(x)
        loss = criterion(logits, y)

        bs = x.size(0)
        total_loss_sum += loss.item() * bs
        total_samples += bs

        probs = torch.sigmoid(logits)
        preds = (probs >= thr).float()

        total_tp += (preds * y).sum().item()
        total_fp += (preds * (1 - y)).sum().item()
        total_fn += ((1 - preds) * y).sum().item()
        total_tn += ((1 - preds) * (1 - y)).sum().item()

        all_probs.append(probs.flatten().cpu())
        all_targets.append(y.flatten().cpu())

    precision = (total_tp + eps) / (total_tp + total_fp + eps)
    recall    = (total_tp + eps) / (total_tp + total_fn + eps)
    f1        = (2 * precision * recall) / (precision + recall + eps)
    acc       = (total_tp + total_tn + eps) / (
        total_tp + total_tn + total_fp + total_fn + eps
    )

    all_probs = torch.cat(all_probs).numpy()
    all_targets = torch.cat(all_targets).numpy()
    auc = average_precision_score(all_targets, all_probs)

    loss_global = total_loss_sum / total_samples

    return loss_global, precision, recall, f1, acc, auc

# TRAINING

In [20]:
BEST_PATH = "/kaggle/working/efficientb4.pth"
best_val_f1 = -1

import pandas as pd
history = []

for epoch in range(1, cfg.EPOCHS + 1):
    t0 = time.time()

    tr_loss, tr_prec, tr_rec, tr_f1, tr_acc, tr_auc = train_one_epoch(model, train_loader)
    va_loss, va_prec, va_rec, va_f1, va_acc, va_auc = validate_epoch(model, val_loader, thr=0.5)
    elapsed = time.time() - t0

    print(
        f"[unet_basic] Epoch {epoch:03d} | "
        f"tr_loss={tr_loss:.4f} tr_f1={tr_f1:.4f} tr_acc={tr_acc:.4f} tr_PRauc={tr_auc:.4f} || "
        f"va_loss={va_loss:.4f} va_f1={va_f1:.4f} va_acc={va_acc:.4f} va_PRauc={va_auc:.4f} | "
        f"{elapsed:.1f}s"
    )

    history.append({
        "epoch": epoch,
    
        "tr_loss": tr_loss,
        "tr_precision": tr_prec,
        "tr_recall": tr_rec,
        "tr_f1": tr_f1,
        "tr_acc": tr_acc,
        "tr_PRauc": tr_auc,
    
        "va_loss": va_loss,
        "va_precision": va_prec,
        "va_recall": va_rec,
        "va_f1": va_f1,
        "va_acc": va_acc,
        "va_PRauc": va_auc,
    
        "seconds": elapsed
    })

    if va_f1 > best_val_f1:
        torch.save({
            "model_state": model.state_dict(),
            "mean_13": MEAN_13,
            "std_13": STD_13,
            "epoch": epoch,
            "best_val_f1": best_val_f1,
            "va_auc": va_auc,
        }, BEST_PATH)


[unet_basic] Epoch 001 | tr_loss=0.1889 tr_f1=0.1420 tr_acc=0.9428 tr_PRauc=0.2755 || va_loss=0.1381 va_f1=0.0651 va_acc=0.9519 va_PRauc=0.3289 | 326.0s
[unet_basic] Epoch 002 | tr_loss=0.1639 tr_f1=0.3168 tr_acc=0.9467 tr_PRauc=0.4100 || va_loss=0.1232 va_f1=0.0018 va_acc=0.9522 va_PRauc=0.4055 | 324.8s
[unet_basic] Epoch 003 | tr_loss=0.1356 tr_f1=0.4874 tr_acc=0.9551 tr_PRauc=0.5765 || va_loss=0.1477 va_f1=0.0000 va_acc=0.9522 va_PRauc=0.3395 | 325.0s
[unet_basic] Epoch 004 | tr_loss=0.1011 tr_f1=0.6632 tr_acc=0.9653 tr_PRauc=0.7327 || va_loss=0.1086 va_f1=0.0798 va_acc=0.9539 va_PRauc=0.6033 | 323.2s
[unet_basic] Epoch 005 | tr_loss=0.0822 tr_f1=0.7421 tr_acc=0.9721 tr_PRauc=0.8072 || va_loss=0.1484 va_f1=0.0000 va_acc=0.9522 va_PRauc=0.5431 | 323.6s
[unet_basic] Epoch 006 | tr_loss=0.0665 tr_f1=0.7983 tr_acc=0.9774 tr_PRauc=0.8634 || va_loss=0.1328 va_f1=0.0000 va_acc=0.9522 va_PRauc=0.6381 | 323.4s
[unet_basic] Epoch 007 | tr_loss=0.0569 tr_f1=0.8356 tr_acc=0.9813 tr_PRauc=0.8944

In [21]:
df_history = pd.DataFrame(history)

df_history = df_history.round(4)
csv_path = "/kaggle/working/training_metrics_eff1.csv"
df_history.to_csv(csv_path, index=False)
print("Saved to:", csv_path)
df_history

Saved to: /kaggle/working/training_metrics_eff1.csv


Unnamed: 0,epoch,tr_loss,tr_precision,tr_recall,tr_f1,tr_acc,tr_PRauc,va_loss,va_precision,va_recall,va_f1,va_acc,va_PRauc,seconds
0,1,0.1889,0.6055,0.0805,0.142,0.9428,0.2755,0.1381,0.4598,0.035,0.0651,0.9519,0.3289,326.0277
1,2,0.1639,0.6474,0.2097,0.3168,0.9467,0.41,0.1232,0.7221,0.0009,0.0018,0.9522,0.4055,324.7601
2,3,0.1356,0.7453,0.3621,0.4874,0.9551,0.5765,0.1477,1.0,0.0,0.0,0.9522,0.3395,324.951
3,4,0.1011,0.7751,0.5795,0.6632,0.9653,0.7327,0.1086,0.8705,0.0418,0.0798,0.9539,0.6033,323.1811
4,5,0.0822,0.813,0.6825,0.7421,0.9721,0.8072,0.1484,1.0,0.0,0.0,0.9522,0.5431,323.6256
5,6,0.0665,0.8429,0.7582,0.7983,0.9774,0.8634,0.1328,0.0,0.0,0.0,0.9522,0.6381,323.4198
6,7,0.0569,0.8666,0.8068,0.8356,0.9813,0.8944,0.1106,0.8246,0.1868,0.3046,0.9592,0.6277,322.1071
7,8,0.0497,0.8725,0.8349,0.8533,0.9831,0.9135,0.1192,0.9612,0.0966,0.1756,0.9566,0.7063,324.6199
8,9,0.0411,0.8916,0.869,0.8801,0.9861,0.9351,0.0982,0.8344,0.3501,0.4932,0.9656,0.6908,323.9097
9,10,0.034,0.9098,0.8882,0.8989,0.9882,0.9565,0.1098,0.8758,0.283,0.4278,0.9638,0.7042,320.1268


In [13]:
ckpt = torch.load(
    "/kaggle/input/efficient-b4/pytorch/default/1/efficientb4.pth",
    map_location=DEVICE,
    weights_only=False  # biar load semua, bukan cuma weights
)
model.load_state_dict(ckpt["model_state"])

<All keys matched successfully>

In [16]:
BEST_PATH = "/kaggle/working/efficientb4_pretrained.pth"
best_val_f1 = -1

import pandas as pd
history = []

for epoch in range(1, cfg.EPOCHS + 1):
    t0 = time.time()

    tr_loss, tr_prec, tr_rec, tr_f1, tr_acc, tr_auc = train_one_epoch(model, train_loader,thr=0.5)
    va_loss, va_prec, va_rec, va_f1, va_acc, va_auc = validate_epoch(model, val_loader, thr=0.5)
    elapsed = time.time() - t0

    print(
        f"[unet_basic] Epoch {epoch:03d} | "
        f"tr_loss={tr_loss:.4f} tr_f1={tr_f1:.4f} tr_acc={tr_acc:.4f} tr_PRauc={tr_auc:.4f} || "
        f"va_loss={va_loss:.4f} va_f1={va_f1:.4f} va_acc={va_acc:.4f} va_PRauc={va_auc:.4f} | "
        f"{elapsed:.1f}s"
    )

    history.append({
        "epoch": epoch,
    
        "tr_loss": tr_loss,
        "tr_precision": tr_prec,
        "tr_recall": tr_rec,
        "tr_f1": tr_f1,
        "tr_acc": tr_acc,
        "tr_PRauc": tr_auc,
    
        "va_loss": va_loss,
        "va_precision": va_prec,
        "va_recall": va_rec,
        "va_f1": va_f1,
        "va_acc": va_acc,
        "va_PRauc": va_auc,
    
        "seconds": elapsed
    })

    if va_f1 > best_val_f1:
        best_val_f1 = va_f1
        torch.save({
            "model_state": model.state_dict(),
            "mean_13": MEAN_13,
            "std_13": STD_13,
            "epoch": epoch,
            "best_val_f1": best_val_f1,
            "va_auc": va_auc,
        }, BEST_PATH)


[unet_basic] Epoch 001 | tr_loss=0.0198 tr_f1=0.9372 tr_acc=0.9926 tr_PRauc=0.9830 || va_loss=0.1082 va_f1=0.6155 va_acc=0.9716 va_PRauc=0.7554 | 228.1s
[unet_basic] Epoch 002 | tr_loss=0.0178 tr_f1=0.9434 tr_acc=0.9934 tr_PRauc=0.9857 || va_loss=0.1306 va_f1=0.5201 va_acc=0.9670 va_PRauc=0.7195 | 227.6s
[unet_basic] Epoch 003 | tr_loss=0.0184 tr_f1=0.9446 tr_acc=0.9935 tr_PRauc=0.9851 || va_loss=0.0923 va_f1=0.6457 va_acc=0.9731 va_PRauc=0.7824 | 231.4s
[unet_basic] Epoch 004 | tr_loss=0.0183 tr_f1=0.9443 tr_acc=0.9935 tr_PRauc=0.9856 || va_loss=0.0940 va_f1=0.6545 va_acc=0.9739 va_PRauc=0.7833 | 229.7s
[unet_basic] Epoch 005 | tr_loss=0.0160 tr_f1=0.9507 tr_acc=0.9942 tr_PRauc=0.9884 || va_loss=0.0937 va_f1=0.6320 va_acc=0.9721 va_PRauc=0.7763 | 229.8s
[unet_basic] Epoch 006 | tr_loss=0.0116 tr_f1=0.9626 tr_acc=0.9956 tr_PRauc=0.9941 || va_loss=0.1002 va_f1=0.6977 va_acc=0.9764 va_PRauc=0.8088 | 230.5s
[unet_basic] Epoch 007 | tr_loss=0.0096 tr_f1=0.9682 tr_acc=0.9963 tr_PRauc=0.9960

In [17]:
df_history = pd.DataFrame(history)

df_history = df_history.round(4)
csv_path = "/kaggle/working/training_metrics_eff_pretrained.csv"
df_history.to_csv(csv_path, index=False)
print("Saved to:", csv_path)
df_history

Saved to: /kaggle/working/training_metrics_eff_pretrained.csv


Unnamed: 0,epoch,tr_loss,tr_precision,tr_recall,tr_f1,tr_acc,tr_PRauc,va_loss,va_precision,va_recall,va_f1,va_acc,va_PRauc,seconds
0,1,0.0198,0.9423,0.9321,0.9372,0.9926,0.983,0.1082,0.8711,0.4758,0.6155,0.9716,0.7554,228.1279
1,2,0.0178,0.9456,0.9413,0.9434,0.9934,0.9857,0.1306,0.8531,0.3741,0.5201,0.967,0.7195,227.5578
2,3,0.0184,0.9465,0.9427,0.9446,0.9935,0.9851,0.0923,0.8715,0.5129,0.6457,0.9731,0.7824,231.383
3,4,0.0183,0.9496,0.939,0.9443,0.9935,0.9856,0.094,0.8929,0.5165,0.6545,0.9739,0.7833,229.6803
4,5,0.016,0.9544,0.947,0.9507,0.9942,0.9884,0.0937,0.8538,0.5017,0.632,0.9721,0.7763,229.8185
5,6,0.0116,0.9625,0.9627,0.9626,0.9956,0.9941,0.1002,0.9003,0.5695,0.6977,0.9764,0.8088,230.5392
6,7,0.0096,0.9682,0.9683,0.9682,0.9963,0.996,0.1393,0.9457,0.3756,0.5376,0.9691,0.7855,233.1362
7,8,0.0104,0.9709,0.9636,0.9673,0.9962,0.9953,0.1331,0.9453,0.4099,0.5719,0.9706,0.8041,232.679
8,9,0.0091,0.9706,0.9697,0.9702,0.9965,0.9963,0.1125,0.9062,0.4851,0.6319,0.973,0.8051,230.9003
9,10,0.0107,0.9678,0.9645,0.9661,0.996,0.9943,0.1341,0.8841,0.4684,0.6123,0.9716,0.7679,235.0706


In [21]:
t_best = 0.5
test_loss, test_prec, test_rec, test_f1, test_acc, test_auc = validate_epoch(
    model, test_loader, thr=t_best
)

print(f"Test Loss   : {test_loss:.4f}")
print(f"Test Prec   : {test_prec:.4f}")
print(f"Test Recall : {test_rec:.4f}")
print(f"Test F1     : {test_f1:.4f}")
print(f"Test Acc    : {test_acc:.4f}")
print(f"Test PRauc  : {test_auc:.4f}")
print(f"threshold   : {t_best:.2f}")

Test Loss   : 0.0532
Test Prec   : 0.9031
Test Recall : 0.6518
Test F1     : 0.7571
Test Acc    : 0.9899
Test PRauc  : 0.8185
threshold   : 0.50


In [22]:
@torch.no_grad()
def find_best_thr_global_f1(model, loader, grid=None, eps=1e-6):
    model.eval()
    if grid is None:
        grid = np.linspace(0.05, 0.95, 19)

    best_thr = None
    best_f1 = -1
    log = []

    for thr in grid:
        total_tp = total_fp = total_fn = 0.0

        for x, y in loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            x = normalize_batch(x, MEAN_13, STD_13)

            logits = model(x)
            probs = torch.sigmoid(logits)
            preds = (probs >= thr).float()

            total_tp += (preds * y).sum().item()
            total_fp += (preds * (1 - y)).sum().item()
            total_fn += ((1 - preds) * y).sum().item()

        precision = (total_tp + eps) / (total_tp + total_fp + eps)
        recall    = (total_tp + eps) / (total_tp + total_fn + eps)
        f1        = (2 * precision * recall) / (precision + recall + eps)

        log.append((float(thr), f1))

        if f1 > best_f1:
            best_f1 = f1
            best_thr = float(thr)

    return best_thr, best_f1, log

t_best, best_f1, thr_log = find_best_thr_global_f1(model, val_loader)

print("Chosen threshold:", t_best)
print("Best global F1:", best_f1)

print("Top-5 thresholds:")
for thr, f1 in sorted(thr_log, key=lambda x: x[1], reverse=True)[:5]:
    print(thr, f1)

Chosen threshold: 0.05
Best global F1: 0.7392240273337137
Top-5 thresholds:
0.05 0.7392240273337137
0.1 0.7151314054501176
0.15 0.6990625098157347
0.2 0.686029707327796
0.25 0.6739252907299398


In [20]:
t_best = 0.05
test_loss, test_prec, test_rec, test_f1, test_acc, test_auc = validate_epoch(
    model, test_loader, thr=t_best
)

print(f"Test Loss   : {test_loss:.4f}")
print(f"Test Prec   : {test_prec:.4f}")
print(f"Test Recall : {test_rec:.4f}")
print(f"Test F1     : {test_f1:.4f}")
print(f"Test Acc    : {test_acc:.4f}")
print(f"Test PRauc  : {test_auc:.4f}")
print(f"threshold   : {t_best:.2f}")

Test Loss   : 0.0532
Test Prec   : 0.8229
Test Recall : 0.7718
Test F1     : 0.7965
Test Acc    : 0.9904
Test PRauc  : 0.8185
threshold   : 0.05
