### GPU Sanity Check

In [2]:
# Safer allocator settings (help fragmentation on Windows/WDDM)
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
    "backend:cudaMallocAsync,"
    "expandable_segments:True,"
    "max_split_size_mb:64,"
    "garbage_collection_threshold:0.8"
)

import torch
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))


CUDA available: True
GPU: NVIDIA GeForce RTX 3050 Ti Laptop GPU


### CNN Model Training

In [2]:
# === Recreate classification datasets + loaders (ONLY your Parquet files) ===
from pathlib import Path
from io import BytesIO
from typing import Sequence, Optional, Callable, Union
import numpy as np
import pyarrow.parquet as pq
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset, Dataset as HFDataset
import torch.nn.functional as F

# ---- exact Parquet masks you provided ----
TRAIN_PARQUETS = [
    r"C:\Users\sedya\VScodeProjects\Coral-reefs-DBL4\data_preprocessing\coralscapes_export\parquet\train\train_part001.parquet",
    r"C:\Users\sedya\VScodeProjects\Coral-reefs-DBL4\data_preprocessing\coralscapes_export\parquet\train\train_part002.parquet",
]
VAL_PARQUETS = [
    r"C:\Users\sedya\VScodeProjects\Coral-reefs-DBL4\data_preprocessing\coralscapes_export\parquet\validation\validation_part001.parquet",
]
MASK_COLUMN = "label_health_rgb_png"

# ---- mask indexer: dataset index -> PNG ----
class ParquetMasksByIndex:
    def __init__(self, parquet_paths: Sequence[Union[str, Path]], column_png: str = MASK_COLUMN):
        self._tables = [pq.read_table(Path(p)) for p in parquet_paths]
        for t in self._tables:
            if "index" not in t.column_names or column_png not in t.column_names:
                raise ValueError(f"Parquet must have 'index' and '{column_png}'. Got: {t.column_names}")
        self._col = column_png
        self._map = {}
        for tid, t in enumerate(self._tables):
            for rid, ds_idx in enumerate(t["index"].to_pylist()):
                self._map[int(ds_idx)] = (tid, rid)
        print(f"[masks] mapped {len(self._map)} indices from {len(self._tables)} file(s)")

    def get_mask_pil(self, ds_index: int) -> Image.Image:
        tid, rid = self._map[ds_index]
        cell = self._tables[tid][self._col][rid].as_py()
        if isinstance(cell, memoryview): cell = cell.tobytes()
        elif isinstance(cell, bytearray): cell = bytes(cell)
        return Image.open(BytesIO(cell)).convert("RGB")

# ---- HF images: EPFL-ECEO/coralscapes (train/validation) ----
hf_all = load_dataset("EPFL-ECEO/coralscapes")
hf_train: HFDataset = hf_all["train"]
hf_val:   HFDataset = hf_all["validation"]

# ---- PIL -> tensor (copy() avoids "non-writable" warning) ----
def pil_to_tensor_rgb(img: Image.Image) -> torch.Tensor:
    arr = np.asarray(img.convert("RGB"), dtype=np.uint8).copy()
    return torch.from_numpy(arr).permute(2, 0, 1).float() / 255.0

# ---- bind images (HF) + masks (your Parquets only), keep only covered indices ----
class CoralScapesImagesMasks(Dataset):
    def __init__(self, img_ds: HFDataset, masks: ParquetMasksByIndex,
                 img_transform: Optional[Callable] = None,
                 mask_transform: Optional[Callable] = None):
        self.img_ds = img_ds
        self.masks = masks
        self.img_tf = img_transform
        self.mask_tf = mask_transform
        n = len(self.img_ds)
        self.indices = [i for i in range(n) if i in masks._map]
        print(f"[dataset] kept {len(self.indices)}/{n} indices (mask-covered).")

    def __len__(self): return len(self.indices)

    def __getitem__(self, j: int):
        idx = self.indices[j]
        rec = self.img_ds[idx]
        img = rec["image"].convert("RGB")
        mask = self.masks.get_mask_pil(idx)
        if self.img_tf is not None:  img  = self.img_tf(img)
        if self.mask_tf is not None: mask = self.mask_tf(mask)
        return img, mask

masks_train = ParquetMasksByIndex(TRAIN_PARQUETS, MASK_COLUMN)
masks_val   = ParquetMasksByIndex(VAL_PARQUETS,   MASK_COLUMN)
cs_train = CoralScapesImagesMasks(hf_train, masks_train, pil_to_tensor_rgb, pil_to_tensor_rgb)
cs_val   = CoralScapesImagesMasks(hf_val,   masks_val,   pil_to_tensor_rgb, pil_to_tensor_rgb)

# ---- classification wrapper → (image_128x128, label) ----
class MaskToBinaryLabel128(Dataset):
    def __init__(self, base_ds: Dataset, size=128):
        self.base = base_ds
        self.size = size
    def __len__(self): return len(self.base)
    def __getitem__(self, idx):
        img, mask = self.base[idx]  # tensors (3,H,W)
        if img.shape[-2:] != (self.size, self.size):
            img  = F.interpolate(img.unsqueeze(0),  size=(self.size,self.size), mode="bilinear", align_corners=False).squeeze(0)
            mask = F.interpolate(mask.unsqueeze(0), size=(self.size,self.size), mode="bilinear", align_corners=False).squeeze(0)
        red  = mask[0].sum().item()
        blue = mask[2].sum().item()
        label = 1 if blue > red else 0  # bleached if blue energy > red
        return img, torch.tensor(label, dtype=torch.long)

train_cls = MaskToBinaryLabel128(cs_train, size=128)
val_cls   = MaskToBinaryLabel128(cs_val,   size=128)

print(f"[ready] train_cls len={len(train_cls)} | val_cls len={len(val_cls)}")

# ---- loaders (GPU-friendly: pin_memory). If a global `model` exists, run a 1-batch smoke test.
BATCH = 32  # if OOM: 16 → 8
train_loader = DataLoader(train_cls, batch_size=BATCH, shuffle=True,  num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_cls,   batch_size=BATCH, shuffle=False, num_workers=0, pin_memory=True)
print(f"[loaders] train batches ≈ {len(train_loader)} | val batches ≈ {len(val_loader)}")

if 'model' in globals():
    from torch.cuda.amp import autocast, GradScaler
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    criterion = torch.nn.CrossEntropyLoss()
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    scaler = GradScaler(enabled=(DEVICE.type=="cuda"))
    xb, yb = next(iter(train_loader))
    xb = xb.to(DEVICE, non_blocking=True); yb = yb.to(DEVICE, non_blocking=True)
    with autocast(enabled=(DEVICE.type=="cuda")):
        logits = model(xb); loss = criterion(logits, yb)
    scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
    print("[smoke] 1 batch OK on", DEVICE, "| loss:", float(loss))
else:
    print("[note] Loaders ready. Now run your model cell and training loop.")


[masks] mapped 1517 indices from 2 file(s)
[masks] mapped 166 indices from 1 file(s)
[dataset] kept 1517/1517 indices (mask-covered).
[dataset] kept 166/166 indices (mask-covered).
[ready] train_cls len=1517 | val_cls len=166
[loaders] train batches ≈ 48 | val batches ≈ 6
[note] Loaders ready. Now run your model cell and training loop.


In [3]:
import torch
from torch import nn

# Device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE, "|", torch.cuda.get_device_name(0) if DEVICE.type=="cuda" else "")
torch.backends.cudnn.benchmark = True  # speed boost

# Keras-like CNN (same conv stack you showed; GAP replaces giant Flatten)
class KerasLikeCNN_GAP(nn.Module):
    def __init__(self, p_drop=0.5):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=0),  # 128 -> 126
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),                 # 126 -> 63
            nn.Conv2d(32, 64, 3, padding=0), # 63 -> 61
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 32, 1), nn.ReLU(inplace=True),  # Dense(32) over maps
            nn.Conv2d(32, 64, 1), nn.ReLU(inplace=True),  # Dense(64) over maps
            nn.Conv2d(64, 128, 3, padding=0), nn.ReLU(inplace=True),  # 61 -> 59
            nn.Dropout(p=p_drop),
            nn.AdaptiveAvgPool2d(1),         # replaces huge Flatten(59*59*128)
        )
        self.head = nn.Linear(128, 2)        # binary logits

    def forward(self, x):
        x = self.features(x).flatten(1)      # (B,128,1,1) -> (B,128)
        return self.head(x)

model = KerasLikeCNN_GAP(p_drop=0.5).to(DEVICE)
print("Model on:", next(model.parameters()).device)

# Optim, loss, AMP scaler
criterion = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler(enabled=(DEVICE.type=="cuda"))

# (optional) quick smoke test on one batch
xb, yb = next(iter(train_loader))
xb, yb = xb.to(DEVICE, non_blocking=True), yb.to(DEVICE, non_blocking=True)
with autocast(enabled=(DEVICE.type=="cuda")):
    logits = model(xb); loss = criterion(logits, yb)
scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
print("Smoke batch OK | loss:", float(loss))


Using device: cuda | NVIDIA GeForce RTX 3050 Ti Laptop GPU
Model on: cuda:0


  scaler = GradScaler(enabled=(DEVICE.type=="cuda"))
  return data.pin_memory(device)
  with autocast(enabled=(DEVICE.type=="cuda")):


Smoke batch OK | loss: 0.72027587890625


In [7]:
import time
from torch.cuda.amp import autocast

class EarlyStopper:
    def __init__(self, patience=5, min_delta=1e-3):
        self.patience = patience; self.min_delta = min_delta
        self.best = float("inf"); self.count = 0
    def step(self, val_loss):
        if self.best - val_loss > self.min_delta:
            self.best = val_loss; self.count = 0; return False
        self.count += 1; return self.count >= self.patience

def run_epoch(loader, train=True):
    model.train() if train else model.eval()
    loss_sum = correct = n = 0
    for xb, yb in loader:
        xb = xb.to(DEVICE, non_blocking=True)
        yb = yb.to(DEVICE, non_blocking=True)
        if train: opt.zero_grad(set_to_none=True)
        with torch.set_grad_enabled(train), autocast(enabled=(DEVICE.type=="cuda")):
            logits = model(xb)
            loss = criterion(logits, yb)
        if train:
            scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        loss_sum += loss.item() * xb.size(0)
        correct  += (logits.argmax(1) == yb).sum().item()
        n += xb.size(0)
    return loss_sum/max(1,n), correct/max(1,n)

EPOCHS = 5
early = EarlyStopper(patience=5, min_delta=1e-3)
best_state = None

for epoch in range(1, EPOCHS+1):
    t0 = time.time()
    tr_loss, tr_acc = run_epoch(train_loader, train=True)
    va_loss, va_acc = run_epoch(val_loader,   train=False)
    dt = time.time() - t0
    print(f"Epoch {epoch:02d}/{EPOCHS} - loss: {tr_loss:.4f} - acc: {tr_acc:.4f} "
          f"- val_loss: {va_loss:.4f} - val_acc: {va_acc:.4f} - {dt:.1f}s")
    if best_state is None or va_loss < (early.best - 1e-3):
        best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
    if early.step(va_loss):
        print(f"Epoch {epoch}: early stopping (best val_loss={early.best:.4f})")
        break

if best_state:
    model.load_state_dict(best_state)
    print("Loaded best weights.")

# Final validation (like model.evaluate)
model.eval()
val_loss = val_acc = n = 0
with torch.no_grad(), autocast(enabled=(DEVICE.type=="cuda")):
    for xb, yb in val_loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        logits = model(xb)
        loss = criterion(logits, yb)
        val_loss += loss.item() * xb.size(0)
        val_acc  += (logits.argmax(1) == yb).sum().item()
        n += xb.size(0)
val_loss /= max(1, n)
val_acc  /= max(1, n)
print("Validation Loss:", round(val_loss, 4))
print("Validation Accuracy:", round(val_acc, 4))


  with torch.set_grad_enabled(train), autocast(enabled=(DEVICE.type=="cuda")):


KeyboardInterrupt: 

### with class

In [6]:
# === Per-class setup (0=healthy, 1=unhealthy/bleached) ===
import numpy as np
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
from PIL import Image
from collections import Counter

CLASS_NAMES = {0: "healthy", 1: "unhealthy"}  # 1=bleached/unhealthy

def label_from_mask_fast(mask_pil: Image.Image) -> int:
    arr = np.asarray(mask_pil.convert("RGB"), dtype=np.uint8)
    red  = int(arr[..., 0].sum())
    blue = int(arr[..., 2].sum())
    return 1 if blue > red else 0  # unhealthy if blue dominates

# Build labels aligned with train_cls order (no need to re-open images)
train_labels = []
for j in range(len(train_cls)):
    idx = cs_train.indices[j]
    m = masks_train.get_mask_pil(idx)
    train_labels.append(label_from_mask_fast(m))

cnt = Counter(train_labels)
n0, n1 = cnt.get(0, 0), cnt.get(1, 0)
print(f"class counts -> healthy(0): {n0} | unhealthy(1): {n1}")

# Class weights inversely proportional to frequency
total = len(train_labels)
w0 = total / (2.0 * max(1, n0))
w1 = total / (2.0 * max(1, n1))
class_weights = torch.tensor([w0, w1], dtype=torch.float32, device=DEVICE)
print("class weights:", [round(float(w0), 4), round(float(w1), 4)])

# Weighted sampler for balanced minibatches
sample_weights = [w0 if y == 0 else w1 for y in train_labels]
sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)

# Rebuild loaders (pin_memory for GPU IO). Keep val shuffled=False
BATCH = 32  # drop to 16/8 if VRAM is tight
train_loader = DataLoader(train_cls, batch_size=BATCH, sampler=sampler, num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_cls,   batch_size=BATCH, shuffle=False,  num_workers=0, pin_memory=True)
print(f"[loaders] train batches ≈ {len(train_loader)} | val batches ≈ {len(val_loader)}")

# Re-define criterion with class weights (affects loss per class)
import torch.nn as nn
criterion = nn.CrossEntropyLoss(weight=class_weights)

# (Optional) reset optimizer to start fresh with the new criterion
opt = torch.optim.Adam(model.parameters(), lr=1e-3)


class counts -> healthy(0): 1339 | unhealthy(1): 178
class weights: [0.5665, 4.2612]
[loaders] train batches ≈ 48 | val batches ≈ 6


In [9]:
import torch
import time
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler

# If you didn't define a scaler yet:
try:
    scaler
except NameError:
    scaler = GradScaler(enabled=(DEVICE.type == "cuda"))

class EarlyStopper:
    def __init__(self, patience=5, min_delta=1e-3):
        self.patience = patience; self.min_delta = min_delta
        self.best = float("inf"); self.count = 0
    def step(self, val_loss):
        if self.best - val_loss > self.min_delta:
            self.best = val_loss; self.count = 0; return False
        self.count += 1; return self.count >= self.patience

def run_epoch(loader, train=True):
    model.train() if train else model.eval()
    loss_sum = correct = n = 0
    for xb, yb in loader:
        xb = xb.to(DEVICE, non_blocking=True)
        yb = yb.to(DEVICE, non_blocking=True)
        if train: opt.zero_grad(set_to_none=True)
        with torch.set_grad_enabled(train), autocast(enabled=(DEVICE.type=="cuda")):
            logits = model(xb)
            loss = criterion(logits, yb)
        if train:
            scaler.scale(loss).backward()
            scaler.step(opt); scaler.update()
        loss_sum += loss.item() * xb.size(0)
        correct  += (logits.argmax(1) == yb).sum().item()
        n += xb.size(0)
    return loss_sum/max(1,n), correct/max(1,n)

@torch.no_grad()
def evaluate_with_metrics(model, loader):
    model.eval()
    n = 0; loss_sum = 0
    all_true = []; all_pred = []
    for xb, yb in loader:
        xb = xb.to(DEVICE); yb = yb.to(DEVICE)
        logits = model(xb)
        loss = criterion(logits, yb)
        loss_sum += loss.item() * xb.size(0); n += xb.size(0)
        all_true.append(yb.cpu()); all_pred.append(logits.argmax(1).cpu())
    y_true = torch.cat(all_true)
    y_pred = torch.cat(all_pred)
    # Confusion matrix (2x2): rows=true, cols=pred
    cm = torch.zeros(2, 2, dtype=torch.int64)
    for t, p in zip(y_true, y_pred):
        cm[t, p] += 1
    # Per-class precision/recall/F1
    per_class = {}
    for c in [0, 1]:
        tp = cm[c, c].item()
        fp = cm[:, c].sum().item() - tp
        fn = cm[c, :].sum().item() - tp
        precision = tp / max(1, tp + fp)
        recall    = tp / max(1, tp + fn)
        f1        = 0.0 if (precision + recall) == 0 else 2 * precision * recall / (precision + recall)
        per_class[c] = {
            "name": CLASS_NAMES[c],
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "support": cm[c, :].sum().item(),
        }
    acc = (y_true == y_pred).float().mean().item()
    return {"loss": loss_sum/max(1,n), "acc": acc, "cm": cm, "per_class": per_class}

# ---- Train with early stopping, save best by val loss ----
EPOCHS = 3
early = EarlyStopper(patience=5, min_delta=1e-3)
best_state = None

for epoch in range(1, EPOCHS+1):
    t0 = time.time()
    tr_loss, tr_acc = run_epoch(train_loader, train=True)
    va_loss, va_acc = run_epoch(val_loader,   train=False)
    print(f"Epoch {epoch:02d}/{EPOCHS} - loss: {tr_loss:.4f} - acc: {tr_acc:.4f} "
          f"- val_loss: {va_loss:.4f} - val_acc: {va_acc:.4f} - {time.time()-t0:.1f}s")
    if best_state is None or va_loss < (early.best - 1e-3):
        best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
    if early.step(va_loss):
        print(f"Epoch {epoch}: early stopping (best val_loss={early.best:.4f})")
        break

if best_state:
    model.load_state_dict(best_state)
    print("Loaded best weights.")

# ---- Final evaluation with PER-CLASS metrics ----
res = evaluate_with_metrics(model, val_loader)
print("\n=== Validation (per class) ===")
print(f"Overall: loss={res['loss']:.4f} | acc={res['acc']:.4f}")
cm = res["cm"].numpy()
print("\nConfusion Matrix (rows=true, cols=pred):")
print(f"            pred: healthy   pred: unhealthy")
print(f"true healthy     {cm[0,0]:>6}          {cm[0,1]:>6}")
print(f"true unhealthy   {cm[1,0]:>6}          {cm[1,1]:>6}")
print("\nPer-class metrics:")
for c in [0, 1]:
    m = res["per_class"][c]
    print(f"  {m['name']:<10} | precision={m['precision']:.3f} | recall={m['recall']:.3f} | f1={m['f1']:.3f} | support={m['support']}")


  with torch.set_grad_enabled(train), autocast(enabled=(DEVICE.type=="cuda")):


Epoch 01/3 - loss: 0.3694 - acc: 0.4984 - val_loss: 0.9613 - val_acc: 0.2711 - 354.3s
Epoch 02/3 - loss: 0.3717 - acc: 0.4937 - val_loss: 0.8740 - val_acc: 0.2711 - 316.0s
Epoch 03/3 - loss: 0.3557 - acc: 0.5208 - val_loss: 1.0742 - val_acc: 0.2711 - 315.5s
Loaded best weights.

=== Validation (per class) ===
Overall: loss=0.8739 | acc=0.2711

Confusion Matrix (rows=true, cols=pred):
            pred: healthy   pred: unhealthy
true healthy          0             121
true unhealthy        0              45

Per-class metrics:
  healthy    | precision=0.000 | recall=0.000 | f1=0.000 | support=121
  unhealthy  | precision=0.271 | recall=1.000 | f1=0.427 | support=45


### Test Classes

In [10]:
import numpy as np
from collections import Counter

# 1) Build index lists the same way your datasets filtered them
train_indices = [i for i in range(len(hf_train)) if i in masks_train._map]
val_indices   = [i for i in range(len(hf_val))   if i in masks_val._map]

def label_from_mask_bytes(mask_pil) -> int:
    # unhealthy if blue channel energy > red channel energy
    arr = np.asarray(mask_pil.convert("RGB"), dtype=np.uint8)
    return 1 if int(arr[...,2].sum()) > int(arr[...,0].sum()) else 0

# 2) Precompute labels directly from original mask PNGs (no interpolation)
y_train = []
for idx in train_indices:
    y_train.append(label_from_mask_bytes(masks_train.get_mask_pil(idx)))
y_val = []
for idx in val_indices:
    y_val.append(label_from_mask_bytes(masks_val.get_mask_pil(idx)))

cnt_tr = Counter(y_train); cnt_va = Counter(y_val)
print("train label counts:", cnt_tr, "| val label counts:", cnt_va)


train label counts: Counter({0: 1339, 1: 178}) | val label counts: Counter({0: 121, 1: 45})


In [11]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import numpy as np

def pil_to_tensor_rgb(img):
    arr = np.asarray(img.convert("RGB"), dtype=np.uint8).copy()
    return torch.from_numpy(arr).permute(2,0,1).float()/255.0

class Images128WithLabels(Dataset):
    def __init__(self, hf_ds, indices, labels, size=128):
        assert len(indices) == len(labels)
        self.hf = hf_ds
        self.idx = list(indices)
        self.y = torch.tensor(labels, dtype=torch.long)
        self.size = size
    def __len__(self): return len(self.idx)
    def __getitem__(self, j):
        i = self.idx[j]
        img = self.hf[i]["image"]
        x = pil_to_tensor_rgb(img)
        if x.shape[-2:] != (self.size, self.size):
            x = F.interpolate(x.unsqueeze(0), size=(self.size,self.size), mode="bilinear", align_corners=False).squeeze(0)
        return x, self.y[j]

train_cls = Images128WithLabels(hf_train, train_indices, y_train, size=128)
val_cls   = Images128WithLabels(hf_val,   val_indices,   y_val,   size=128)
print(f"[ready] train={len(train_cls)} | val={len(val_cls)}")


[ready] train=1517 | val=166


In [13]:
from torch.utils.data import DataLoader, WeightedRandomSampler

# weights inversely proportional to class freq (on the precomputed labels)
n = len(y_train); n0 = sum(1 for t in y_train if t==0); n1 = n - n0
w0 = n/(2*max(1,n0)); w1 = n/(2*max(1,n1))
weights = [w0 if t==0 else w1 for t in y_train]
sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)
print(f"class balance train: 0→{n0}, 1→{n1} | sampler on")

BATCH = 16  # 16/8 if VRAM is tight
train_loader = DataLoader(train_cls, batch_size=BATCH, sampler=sampler, num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_cls,   batch_size=BATCH, shuffle=False,  num_workers=0, pin_memory=True)


class balance train: 0→1339, 1→178 | sampler on


In [14]:
import torch
from torch import nn

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", DEVICE, "|", torch.cuda.get_device_name(0) if DEVICE.type=="cuda" else "")
torch.backends.cudnn.benchmark = True

class KerasLikeCNN_GAP(nn.Module):
    def __init__(self, p_drop=0.3):  # slightly less dropout
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=0), nn.ReLU(True),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=0), nn.ReLU(True),
            nn.Conv2d(64, 32, 1), nn.ReLU(True),
            nn.Conv2d(32, 64, 1), nn.ReLU(True),
            nn.Conv2d(64,128, 3, padding=0), nn.ReLU(True),
            nn.Dropout(p_drop),
            nn.AdaptiveAvgPool2d(1),
        )
        self.head = nn.Linear(128, 2)

    def forward(self, x):
        x = self.features(x).flatten(1)
        return self.head(x)

model = KerasLikeCNN_GAP().to(DEVICE)

# Bias warm-start to class prior so logits aren't skewed to class 1
p1 = (n1 + 1e-6) / (n + 2e-6)        # prior for class 1
prior_logit = np.log(p1/(1-p1))
with torch.no_grad():
    model.head.bias[:] = torch.tensor([ -prior_logit, prior_logit ], dtype=model.head.bias.dtype, device=DEVICE)

criterion = nn.CrossEntropyLoss()     # ← no class weights now
opt = torch.optim.Adam(model.parameters(), lr=5e-4)  # slightly lower LR
from torch.amp import autocast
from torch.cuda.amp import GradScaler
scaler = GradScaler(enabled=(DEVICE.type=="cuda"))

# quick smoke test
xb, yb = next(iter(train_loader))
xb, yb = xb.to(DEVICE, non_blocking=True), yb.to(DEVICE, non_blocking=True)
with autocast("cuda", enabled=(DEVICE.type=="cuda")):
    logits = model(xb); loss = criterion(logits, yb)
scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
print("smoke batch OK | loss:", float(loss))


Using: cuda | NVIDIA GeForce RTX 3050 Ti Laptop GPU


  scaler = GradScaler(enabled=(DEVICE.type=="cuda"))


smoke batch OK | loss: 2.2841014862060547


In [15]:
import time
import torch.nn.functional as F

class EarlyStopper:
    def __init__(self, patience=6, min_delta=1e-3):
        self.patience=patience; self.min_delta=min_delta
        self.best=float("inf"); self.count=0
    def step(self, v):
        if self.best - v > self.min_delta:
            self.best=v; self.count=0; return False
        self.count+=1; return self.count>=self.patience

def run_epoch(loader, train=True):
    model.train() if train else model.eval()
    loss_sum=correct=n=0
    for xb, yb in loader:
        xb, yb = xb.to(DEVICE, non_blocking=True), yb.to(DEVICE, non_blocking=True)
        if train: opt.zero_grad(set_to_none=True)
        with torch.set_grad_enabled(train), autocast("cuda", enabled=(DEVICE.type=="cuda")):
            logits = model(xb)
            loss = criterion(logits, yb)
        if train:
            scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
        loss_sum += loss.item() * xb.size(0)
        correct  += (logits.argmax(1)==yb).sum().item()
        n += xb.size(0)
    return loss_sum/max(1,n), correct/max(1,n)

@torch.no_grad()
def eval_per_class(loader):
    model.eval()
    y_true=[]; y_pred=[]
    loss_sum=n=0
    for xb, yb in loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        with autocast("cuda", enabled=(DEVICE.type=="cuda")):
            logits = model(xb); loss = criterion(logits, yb)
        loss_sum += loss.item()*xb.size(0); n += xb.size(0)
        y_true.append(yb.cpu()); y_pred.append(logits.argmax(1).cpu())
    y_true = torch.cat(y_true); y_pred=torch.cat(y_pred)
    cm = torch.zeros(2,2, dtype=torch.int64)
    for t,p in zip(y_true, y_pred): cm[t,p]+=1
    def prf(c):
        tp=cm[c,c].item(); fp=cm[:,c].sum().item()-tp; fn=cm[c,:].sum().item()-tp
        prec = tp / max(1, tp+fp); rec = tp / max(1, tp+fn)
        f1 = 0.0 if prec+rec==0 else 2*prec*rec/(prec+rec)
        return prec, rec, f1, cm[c,:].sum().item()
    p0,r0,f0,s0 = prf(0); p1,r1,f1,s1 = prf(1)
    return {
        "loss": loss_sum/max(1,n),
        "acc": (y_true==y_pred).float().mean().item(),
        "cm": cm,
        "per": {0:{"precision":p0,"recall":r0,"f1":f0,"support":s0},
                1:{"precision":p1,"recall":r1,"f1":f1,"support":s1}}
    }

EPOCHS = 5
early = EarlyStopper(patience=6, min_delta=1e-3)
best_state=None

for ep in range(1, EPOCHS+1):
    t0=time.time()
    tr_loss,tr_acc = run_epoch(train_loader, True)
    va_loss,va_acc = run_epoch(val_loader,   False)
    print(f"Epoch {ep:02d}/{EPOCHS} - loss:{tr_loss:.4f} - acc:{tr_acc:.4f} "
          f"- val_loss:{va_loss:.4f} - val_acc:{va_acc:.4f} - {time.time()-t0:.1f}s")
    if best_state is None or va_loss < (early.best-1e-3):
        best_state = {k:v.detach().cpu().clone() for k,v in model.state_dict().items()}
    if early.step(va_loss):
        print(f"Epoch {ep}: early stopping (best val_loss={early.best:.4f})")
        break

if best_state: model.load_state_dict(best_state)

# Final per-class report
res = eval_per_class(val_loader)
cm = res["cm"].numpy()
print("\n=== Validation (per class) ===")
print(f"Overall: loss={res['loss']:.4f} | acc={res['acc']:.4f}")
print("\nConfusion Matrix (rows=true, cols=pred):")
print(f"            pred: healthy   pred: unhealthy")
print(f"true healthy     {cm[0,0]:>6}          {cm[0,1]:>6}")
print(f"true unhealthy   {cm[1,0]:>6}          {cm[1,1]:>6}")
for c,name in {0:"healthy",1:"unhealthy"}.items():
    m=res["per"][c]
    print(f"{name:<10} | precision={m['precision']:.3f} | recall={m['recall']:.3f} | f1={m['f1']:.3f} | support={m['support']}")


Epoch 01/5 - loss:0.8971 - acc:0.5359 - val_loss:0.6334 - val_acc:0.6867 - 363.0s
Epoch 02/5 - loss:0.6855 - acc:0.5583 - val_loss:0.6157 - val_acc:0.7229 - 314.3s
Epoch 03/5 - loss:0.6792 - acc:0.5478 - val_loss:0.7407 - val_acc:0.3976 - 278.8s
Epoch 04/5 - loss:0.6877 - acc:0.5603 - val_loss:0.6395 - val_acc:0.6747 - 309.7s
Epoch 05/5 - loss:0.6955 - acc:0.5465 - val_loss:0.6547 - val_acc:0.6446 - 332.3s

=== Validation (per class) ===
Overall: loss=0.6157 | acc=0.7229

Confusion Matrix (rows=true, cols=pred):
            pred: healthy   pred: unhealthy
true healthy        120               1
true unhealthy       45               0
healthy    | precision=0.727 | recall=0.992 | f1=0.839 | support=121
unhealthy  | precision=0.000 | recall=0.000 | f1=0.000 | support=45


### Balanced Class

In [33]:
# --- setup: CB-Focal loss + standard shuffled loader (no WeightedRandomSampler) ---
import math, time, numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader
from torch.amp import autocast
from torch.cuda.amp import GradScaler

# 1) get training label counts (works with Images128WithLabels or any (x,y) dataset)
def get_label_counts(ds):
    if hasattr(ds, "y"):  # our Images128WithLabels exposes .y
        y = ds.y.cpu().tolist()
    else:
        y = [int(ds[i][1]) for i in range(len(ds))]
    n0 = sum(1 for t in y if t == 0); n1 = len(y) - n0
    return n0, n1

n0, n1 = get_label_counts(train_cls)
print(f"[labels] train counts → healthy(0)={n0} | unhealthy(1)={n1}")

# 2) Class-Balanced weights (effective number of samples, Cui et al.)
def class_balanced_weights(n_per_class, beta=0.99):
    w = []
    for n in n_per_class:
        n = max(1, n)
        w_c = (1 - beta) / (1 - beta**n)
        w.append(w_c)
    # normalize so mean weight ≈ 1 (helps keep loss scale stable)
    m = sum(w)/len(w)
    return torch.tensor([wc/m for wc in w], dtype=torch.float32, device=DEVICE)

cb_alpha = class_balanced_weights([n0, n1], beta=0.99)
print("[CB weights] alpha:", cb_alpha.tolist())

# 3) Focal Cross-Entropy (multiclass) with class-balanced alpha
def focal_ce_loss(logits, target, alpha=None, gamma=1.0):
    # CE per-sample
    ce = F.cross_entropy(logits, target, reduction="none")
    # p_t = prob of the true class
    pt = F.softmax(logits, dim=1).gather(1, target.view(-1,1)).squeeze(1).clamp_(1e-6, 1-1e-6)
    loss = ce * ((1 - pt) ** gamma)
    if alpha is not None:
        loss = loss * alpha[target]
    return loss.mean()

# 4) standard loaders (shuffle=True) — remove weighted sampler to avoid double-counting imbalance
BATCH = 32  # drop to 16/8 if needed
train_loader = DataLoader(train_cls, batch_size=BATCH, shuffle=True,  num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_cls,   batch_size=BATCH, shuffle=False, num_workers=0, pin_memory=True)

# 5) optimizer + AMP scaler (keep your existing model / GAP head)
criterion = lambda logits, y: focal_ce_loss(logits, y, alpha=cb_alpha, gamma=1.0)
opt = torch.optim.Adam(model.parameters(), lr=5e-4)   # slightly lower LR for stability
scaler = GradScaler(enabled=(DEVICE.type == "cuda"))
torch.backends.cudnn.benchmark = True

# 6) train with early stopping by val loss
class EarlyStopper:
    def __init__(self, patience=6, min_delta=1e-3):
        self.patience = patience; self.min_delta = min_delta
        self.best = float("inf"); self.count = 0
    def step(self, v):
        if self.best - v > self.min_delta:
            self.best = v; self.count = 0; return False
        self.count += 1; return self.count >= self.patience

def run_epoch(loader, train=True):
    model.train() if train else model.eval()
    loss_sum = correct = n = 0
    for xb, yb in loader:
        xb = xb.to(DEVICE, non_blocking=True); yb = yb.to(DEVICE, non_blocking=True)
        if train: opt.zero_grad(set_to_none=True)
        with torch.set_grad_enabled(train), autocast("cuda", enabled=(DEVICE.type=="cuda")):
            logits = model(xb)
            loss = criterion(logits, yb)
        if train:
            scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
        loss_sum += loss.item() * xb.size(0)
        correct  += (logits.argmax(1) == yb).sum().item()
        n += xb.size(0)
    return loss_sum/max(1,n), correct/max(1,n)

EPOCHS = 6
early = EarlyStopper(patience=6, min_delta=1e-3)
best_state = None

for ep in range(1, EPOCHS+1):
    t0 = time.time()
    tr_loss, tr_acc = run_epoch(train_loader, True)
    va_loss, va_acc = run_epoch(val_loader,   False)
    print(f"Epoch {ep:02d}/{EPOCHS} - loss:{tr_loss:.4f} - acc:{tr_acc:.4f} - val_loss:{va_loss:.4f} - val_acc:{va_acc:.4f} - {time.time()-t0:.1f}s")
    if best_state is None or va_loss < (early.best - 1e-3):
        best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
    if early.step(va_loss):
        print(f"Epoch {ep}: early stopping (best val_loss={early.best:.4f})")
        break

if best_state: 
    model.load_state_dict(best_state)
    print("Loaded best weights.")


[labels] train counts → healthy(0)=1470 | unhealthy(1)=47
[CB weights] alpha: [0.5470129251480103, 1.4529870748519897]


  scaler = GradScaler(enabled=(DEVICE.type == "cuda"))


KeyboardInterrupt: 

In [None]:
import torch

@torch.no_grad()
def collect_val_probs_and_labels(model, loader):
    model.eval()
    probs = []; labels = []
    for xb, yb in loader:
        xb = xb.to(DEVICE); yb = yb.to(DEVICE)
        with autocast("cuda", enabled=(DEVICE.type=="cuda")):
            logits = model(xb)
            p = F.softmax(logits, dim=1)[:, 1]  # P(class=1 = unhealthy)
        probs.append(p.cpu()); labels.append(yb.cpu())
    return torch.cat(probs), torch.cat(labels)

def metrics_from_preds(y_true, y_pred):
    cm = torch.zeros(2,2, dtype=torch.int64)
    for t,p in zip(y_true, y_pred): cm[t,p]+=1
    def prf(c):
        tp=cm[c,c].item(); fp=cm[:,c].sum().item()-tp; fn=cm[c,:].sum().item()-tp
        prec = tp / max(1, tp+fp); rec = tp / max(1, tp+fn)
        f1 = 0.0 if prec+rec==0 else 2*prec*rec/(prec+rec)
        return prec, rec, f1
    p0,r0,f0 = prf(0); p1,r1,f1 = prf(1)
    macro_f1 = 0.5*(f0+f1)
    acc = (y_true == y_pred).float().mean().item()
    return {"acc":acc, "macro_f1":macro_f1, "cm":cm, "per":{"healthy":(p0,r0,f0),"unhealthy":(p1,r1,f1)}}

# 1) collect probabilities for class=1 on the val set
probs, y_true = collect_val_probs_and_labels(model, val_loader)

# 2) sweep thresholds and pick the one that maximizes macro-F1 (balanced performance)
best = {"t":0.5, "macro_f1":-1}
for t in torch.linspace(0.1, 0.9, steps=17):  # 0.1 → 0.9 step 0.05
    y_pred = (probs >= t).long()
    m = metrics_from_preds(y_true, y_pred)
    if m["macro_f1"] > best["macro_f1"]:
        best = {"t": float(t), "macro_f1": m["macro_f1"], "metrics": m}

m = best["metrics"]; cm = m["cm"].numpy()
(p0,r0,f0) = m["per"]["healthy"]; (p1,r1,f1) = m["per"]["unhealthy"]
print(f"\nBest threshold t={best['t']:.2f} (by macro-F1={best['macro_f1']:.3f})")
print(f"Overall: acc={m['acc']:.3f} | macro-F1={m['macro_f1']:.3f}")
print("Confusion Matrix (rows=true, cols=pred):")
#print(f"            pred: healthy   pred: unhealthy")
#print(f"true healthy     {cm[0,0]:>6}          {cm[0,1]:>6}")
#print(f"true unhealthy   {cm[1,0]:>6}          {cm[1,1]:>6}")
print(f"healthy   | precision={p0:.3f} | recall={r0:.3f} | f1={f0:.3f}")
print(f"unhealthy | precision={p1:.3f} | recall={r1:.3f} | f1={f1:.3f}")



Best threshold t=0.30 (by macro-F1=0.499)
Overall: acc=0.518 | macro-F1=0.499
Confusion Matrix (rows=true, cols=pred):
            pred: healthy   pred: unhealthy
true healthy         59              62
true unhealthy       18              27
healthy   | precision=0.766 | recall=0.488 | f1=0.596
unhealthy | precision=0.303 | recall=0.600 | f1=0.403


In [31]:
# ========= Balanced decision rules for TEST =========
import torch
from torch.amp import autocast

@torch.no_grad()
def collect_probs_and_labels(model, loader, device):
    model.eval()
    probs = []; labels = []
    for xb, yb in loader:
        xb = xb.to(device); yb = yb.to(device)
        with autocast("cuda", enabled=(device.type=="cuda")):
            logits = model(xb)
            p = torch.softmax(logits, dim=1)[:, 1]  # P(unhealthy)
        probs.append(p.cpu()); labels.append(yb.cpu())
    return torch.cat(probs), torch.cat(labels)

def metrics_from_preds(y_true, y_pred):
    cm = torch.zeros(2,2, dtype=torch.int64)
    for t,p in zip(y_true, y_pred): cm[t,p]+=1
    def prf(c):
        tp=cm[c,c].item(); fp=cm[:,c].sum().item()-tp; fn=cm[c,:].sum().item()-tp
        prec = tp / max(1, tp+fp); rec = tp / max(1, tp+fn)
        f1 = 0.0 if prec+rec==0 else 2*prec*rec/(prec+rec)
        return prec, rec, f1
    p0,r0,f0 = prf(0); p1,r1,f1 = prf(1)
    macro_f1 = 0.5*(f0+f1)
    acc = (y_true == y_pred).float().mean().item()
    return {"acc":acc, "macro_f1":macro_f1, "cm":cm, "per":{"healthy":(p0,r0,f0),"unhealthy":(p1,r1,f1)}}

def report(name, probs, y_true, t):
    y_pred = (probs >= t).long()
    m = metrics_from_preds(y_true, y_pred)
    cm = m["cm"].numpy()
    (p0,r0,f0) = m["per"]["healthy"]; (p1,r1,f1) = m["per"]["unhealthy"]
    print(f"\n=== {name} ===")
    print(f"threshold t={t:.4f} | acc={m['acc']:.4f} | macro-F1={m['macro_f1']:.4f}")
    print("Confusion Matrix (rows=true, cols=pred):")
    print(f"            pred: healthy   pred: unhealthy")
    print(f"true healthy     {cm[0,0]:>6}          {cm[0,1]:>6}")
    print(f"true unhealthy   {cm[1,0]:>6}          {cm[1,1]:>6}")
    print(f"healthy   | precision={p0:.3f} | recall={r0:.3f} | f1={f0:.3f}")
    print(f"unhealthy | precision={p1:.3f} | recall={r1:.3f} | f1={f1:.3f}")
    return m

# 1) Collect VAL/TEST probs
probs_val, y_val_true   = collect_probs_and_labels(model, val_loader,  DEVICE)
probs_test, y_test_true = collect_probs_and_labels(model, test_loader, DEVICE)

# 2) Best threshold on VAL by macro-F1 (your current approach)
best_val = {"t": 0.5, "macro_f1": -1}
for t in torch.linspace(0.01, 0.99, steps=199):  # wider sweep; catches low-prob regimes
    m = metrics_from_preds(y_val_true, (probs_val >= t).long())
    if m["macro_f1"] > best_val["macro_f1"]:
        best_val = {"t": float(t), "macro_f1": m["macro_f1"], "metrics": m}
t_val = best_val["t"]
report("TEST @ val-tuned threshold", probs_test, y_test_true, t_val)

# 3) (Analysis) Best macro-F1 directly on TEST (diagnostic view)
best_test = {"t": 0.5, "macro_f1": -1}
for t in torch.linspace(0.0, 1.0, steps=401):
    m = metrics_from_preds(y_test_true, (probs_test >= t).long())
    if m["macro_f1"] > best_test["macro_f1"]:
        best_test = {"t": float(t), "macro_f1": m["macro_f1"], "metrics": m}
t_test_best = best_test["t"]
report("TEST @ best macro-F1 (diagnostic)", probs_test, y_test_true, t_test_best)

# 4) Prevalence-matched threshold on TEST: predict ~same # of "unhealthy" as ground truth
pos = int((y_test_true == 1).sum().item())
N   = len(y_test_true)
if 0 < pos < N:
    s, _ = torch.sort(probs_test, descending=True)
    # pick midpoint between the k-th and (k+1)-th largest scores to get exactly 'pos' positives
    k = pos
    if k < N:
        t_prev = float((s[k-1] + s[k]) / 2.0)  # midpoint
    else:
        t_prev = float(s[-1].item() - 1e-8)    # degenerate case
    report("TEST @ prevalence-matched threshold", probs_test, y_test_true, t_prev)
else:
    print("[warn] test set has 0 or N positives; skipping prevalence-matched threshold.")



=== TEST @ val-tuned threshold ===
threshold t=0.0100 | acc=0.9031 | macro-F1=0.4745
Confusion Matrix (rows=true, cols=pred):
            pred: healthy   pred: unhealthy
true healthy        354               0
true unhealthy       38               0
healthy   | precision=0.903 | recall=1.000 | f1=0.949
unhealthy | precision=0.000 | recall=0.000 | f1=0.000

=== TEST @ best macro-F1 (diagnostic) ===
threshold t=0.0025 | acc=0.9031 | macro-F1=0.4745
Confusion Matrix (rows=true, cols=pred):
            pred: healthy   pred: unhealthy
true healthy        354               0
true unhealthy       38               0
healthy   | precision=0.903 | recall=1.000 | f1=0.949
unhealthy | precision=0.000 | recall=0.000 | f1=0.000

=== TEST @ prevalence-matched threshold ===
threshold t=0.0012 | acc=0.6097 | macro-F1=0.5206
Confusion Matrix (rows=true, cols=pred):
            pred: healthy   pred: unhealthy
true healthy        204             150
true unhealthy        3              35
healthy   | pre

In [32]:
import torch
from torch.amp import autocast

@torch.no_grad()
def collect_probs_and_labels(model, loader, device):
    model.eval()
    probs = []; labels = []
    for xb, yb in loader:
        xb = xb.to(device); yb = yb.to(device)
        with autocast("cuda", enabled=(device.type=="cuda")):
            logits = model(xb)
            p = torch.softmax(logits, dim=1)[:, 1]  # P(unhealthy)
        probs.append(p.cpu()); labels.append(yb.cpu())
    return torch.cat(probs), torch.cat(labels)

def metrics_from_preds(y_true, y_pred):
    cm = torch.zeros(2,2, dtype=torch.int64)
    for t,p in zip(y_true, y_pred): cm[t,p]+=1
    def prf(c):
        tp=cm[c,c].item(); fp=cm[:,c].sum().item()-tp; fn=cm[c,:].sum().item()-tp
        prec = tp / max(1, tp+fp); rec = tp / max(1, tp+fn)
        f1 = 0.0 if prec+rec==0 else 2*prec*rec/(prec+rec)
        return prec, rec, f1
    p0,r0,f0 = prf(0); p1,r1,f1 = prf(1)
    macro_f1 = 0.5*(f0+f1)
    acc = (y_true == y_pred).float().mean().item()
    return {"acc":acc, "macro_f1":macro_f1, "cm":cm,
            "per":{"healthy":(p0,r0,f0),"unhealthy":(p1,r1,f1)}}

def report(name, probs, y_true, t):
    y_pred = (probs >= t).long()
    m = metrics_from_preds(y_true, y_pred)
    cm = m["cm"].numpy()
    (p0,r0,f0) = m["per"]["healthy"]; (p1,r1,f1) = m["per"]["unhealthy"]
    print(f"\n=== {name} ===")
    print(f"threshold t={t:.4f} | acc={m['acc']:.4f} | macro-F1={m['macro_f1']:.4f}")
    print("Confusion Matrix (rows=true, cols=pred):")
    print(f"            pred: healthy   pred: unhealthy")
    print(f"true healthy     {cm[0,0]:>6}          {cm[0,1]:>6}")
    print(f"true unhealthy   {cm[1,0]:>6}          {cm[1,1]:>6}")
    print(f"healthy   | precision={p0:.3f} | recall={r0:.3f} | f1={f0:.3f}")
    print(f"unhealthy | precision={p1:.3f} | recall={r1:.3f} | f1={f1:.3f}")
    return m

# 1) collect val/test probabilities
probs_val, y_val = collect_probs_and_labels(model, val_loader,  DEVICE)
probs_tst, y_tst = collect_probs_and_labels(model, test_loader, DEVICE)

# 2) choose threshold on VAL to minimize |recall_healthy - recall_unhealthy|,
#    tie-break by highest macro-F1, then by highest unhealthy recall.
best = None
grid = torch.linspace(0.0, 1.0, steps=1001)
for t in grid:
    m = metrics_from_preds(y_val, (probs_val >= t).long())
    _, r0, _ = m["per"]["healthy"]
    _, r1, _ = m["per"]["unhealthy"]
    gap = abs(r0 - r1)
    score = ( -gap, m["macro_f1"], r1 )  # lexicographic: smallest gap, then best macro-F1, then higher r1
    if (best is None) or (score > best["score"]):
        best = {"t": float(t), "score": score, "metrics": m}

t_eqrec = best["t"]
report("VAL @ equal-recall threshold", probs_val, y_val, t_eqrec)

# 3) apply that threshold on TEST
report("TEST @ equal-recall threshold (val-tuned)", probs_tst, y_tst, t_eqrec)



=== VAL @ equal-recall threshold ===
threshold t=0.0020 | acc=0.9217 | macro-F1=0.4796
Confusion Matrix (rows=true, cols=pred):
            pred: healthy   pred: unhealthy
true healthy        153               0
true unhealthy       13               0
healthy   | precision=0.922 | recall=1.000 | f1=0.959
unhealthy | precision=0.000 | recall=0.000 | f1=0.000

=== TEST @ equal-recall threshold (val-tuned) ===
threshold t=0.0020 | acc=0.9031 | macro-F1=0.4745
Confusion Matrix (rows=true, cols=pred):
            pred: healthy   pred: unhealthy
true healthy        354               0
true unhealthy       38               0
healthy   | precision=0.903 | recall=1.000 | f1=0.949
unhealthy | precision=0.000 | recall=0.000 | f1=0.000


{'acc': 0.9030612111091614,
 'macro_f1': 0.47453083109919575,
 'cm': tensor([[354,   0],
         [ 38,   0]]),
 'per': {'healthy': (0.9030612244897959, 1.0, 0.9490616621983915),
  'unhealthy': (0.0, 0.0, 0.0)}}

### Finale

In [25]:
# ============================================================
# Combine CoralScapes + CoralBleaching + Benthic for classification (128x128)
# Balanced training with Class-Balanced Focal Loss + threshold tuning
# ============================================================
import os, math, time
from pathlib import Path
from io import BytesIO
from typing import Sequence, Optional, Callable, Union, List, Tuple
import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset, Subset
from torch.amp import autocast
from torch.cuda.amp import GradScaler
import pyarrow.parquet as pq
from datasets import load_dataset, Dataset as HFDataset

# ---------------- GPU setup (Windows/WDDM fragmentation guard) ----------------
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF",
    "backend:cudaMallocAsync,expandable_segments:True,max_split_size_mb:64,garbage_collection_threshold:0.8")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE, "|", torch.cuda.get_device_name(0) if DEVICE.type=="cuda" else "")
torch.backends.cudnn.benchmark = True

# ---------------- Common utilities ----------------
def pil_to_tensor_rgb(img: Image.Image) -> torch.Tensor:
    arr = np.asarray(img.convert("RGB"), dtype=np.uint8).copy()
    return torch.from_numpy(arr).permute(2,0,1).float()/255.0

def label_from_mask_pil(mask_pil: Image.Image) -> int:
    # 0=healthy, 1=unhealthy (bleached), by blue-vs-red energy
    arr = np.asarray(mask_pil.convert("RGB"), dtype=np.uint8)
    return 1 if int(arr[...,2].sum()) > int(arr[...,0].sum()) else 0

def label_from_mask_tensor(mask_t: torch.Tensor, min_ratio: float = 1.05, min_signal: float = 0.01):
    """
    For benthic (tensor masks). Returns 0/1 or None if ambiguous.
    - min_ratio: require clear dominance of blue vs red (or red vs blue)
    - min_signal: require enough red/blue energy to be meaningful
    """
    # mask_t is (3,H,W) in [0,1]
    red  = float(mask_t[0].sum().item())
    blue = float(mask_t[2].sum().item())
    tot  = float(mask_t.sum().item()) + 1e-8
    # discard if too little colored signal
    if (red + blue) / tot < min_signal:
        return None
    # dominance check
    if blue > red * min_ratio:  # blue dominates => unhealthy
        return 1
    if red > blue * min_ratio:  # red dominates => healthy
        return 0
    return None  # ambiguous

# ---------------- CoralScapes (HF images) + your Parquet masks → classification ----------------
TRAIN_PARQUETS = [
    r"C:\Users\sedya\VScodeProjects\Coral-reefs-DBL4\data_preprocessing\coralscapes_export\parquet\train\train_part001.parquet",
    r"C:\Users\sedya\VScodeProjects\Coral-reefs-DBL4\data_preprocessing\coralscapes_export\parquet\train\train_part002.parquet",
]
VAL_PARQUETS = [
    r"C:\Users\sedya\VScodeProjects\Coral-reefs-DBL4\data_preprocessing\coralscapes_export\parquet\validation\validation_part001.parquet",
]
MASK_COLUMN = "label_health_rgb_png"

class ParquetMasksByIndex:
    def __init__(self, parquet_paths: Sequence[Union[str, Path]], col_png: str = MASK_COLUMN):
        self._tables = [pq.read_table(Path(p)) for p in parquet_paths]
        for t in self._tables:
            if "index" not in t.column_names or col_png not in t.column_names:
                raise ValueError(f"Parquet must have 'index' and '{col_png}'. Got: {t.column_names}")
        self._col = col_png
        self._map = {}
        for tid, t in enumerate(self._tables):
            for rid, ds_idx in enumerate(t["index"].to_pylist()):
                self._map[int(ds_idx)] = (tid, rid)
        print(f"[CS masks] mapped {len(self._map)} indices from {len(self._tables)} file(s)")

    def get_mask_pil(self, ds_index: int) -> Image.Image:
        tid, rid = self._map[ds_index]
        cell = self._tables[tid][self._col][rid].as_py()
        if isinstance(cell, memoryview): cell = cell.tobytes()
        elif isinstance(cell, bytearray): cell = bytes(cell)
        return Image.open(BytesIO(cell)).convert("RGB")

def build_coralscapes_classification(size: int = 128):
    hf = load_dataset("EPFL-ECEO/coralscapes")
    hf_train, hf_val = hf["train"], hf["validation"]
    masks_train = ParquetMasksByIndex(TRAIN_PARQUETS, MASK_COLUMN)
    masks_val   = ParquetMasksByIndex(VAL_PARQUETS,   MASK_COLUMN)

    idx_train = [i for i in range(len(hf_train)) if i in masks_train._map]
    idx_val   = [i for i in range(len(hf_val))   if i in masks_val._map]
    y_train   = [label_from_mask_pil(masks_train.get_mask_pil(i)) for i in idx_train]
    y_val     = [label_from_mask_pil(masks_val.get_mask_pil(i))   for i in idx_val]

    class Images128WithLabels(Dataset):
        def __init__(self, hf_ds: HFDataset, indices: List[int], labels: List[int], size=128):
            assert len(indices)==len(labels)
            self.hf, self.idx = hf_ds, list(indices)
            self.y = torch.tensor(labels, dtype=torch.long)
            self.size = size
        def __len__(self): return len(self.idx)
        def __getitem__(self, j):
            i = self.idx[j]
            img = self.hf[i]["image"]
            x = pil_to_tensor_rgb(img)
            if x.shape[-2:] != (self.size, self.size):
                x = F.interpolate(x.unsqueeze(0), size=(self.size,self.size), mode="bilinear", align_corners=False).squeeze(0)
            return x, self.y[j]

    ds_tr = Images128WithLabels(hf_train, idx_train, y_train, size)
    ds_va = Images128WithLabels(hf_val,   idx_val,   y_val,   size)
    return ds_tr, ds_va

# ---------------- Coral Bleaching local → classification ----------------
BLEACH_IMAGES  = r"g:\.shortcut-targets-by-id\1jGkNA1n0znoxKnQBHTJZuPgvkiu_OBM8\coral_bleaching\reef_support\UNAL_BLEACHING_TAYRONA\images"
BLEACH_COMBINED = r"data_preprocessing/coralbleaching/combined_masks"
BLEACH_SINGLE   = r"data_preprocessing/coralbleaching/single_masks"

class CoralBleachingPairs(Dataset):
    def __init__(self, images_dir: Union[str, Path], combined_dir: Union[str, Path], single_dir: Union[str, Path]):
        self.images_dir = Path(images_dir)
        self.combined_dir = Path(combined_dir)
        self.single_bleached = Path(single_dir) / "bleached_blue"
        self.single_non = Path(single_dir) / "non_bleached_red"
        imgs = []
        for e in ("*.png","*.jpg","*.jpeg","*.bmp","*.tif","*.tiff"):
            imgs += list(self.images_dir.glob(e))
        self.images = sorted(imgs)
        self.pairs = self._match_pairs()

    def _index_dir(self, d: Path) -> dict:
        out = {}
        for e in ("*.png","*.jpg","*.jpeg","*.bmp","*.tif","*.tiff"):
            for p in d.glob(e): out[p.stem.lower()] = p
        return out

    def _match_pairs(self) -> List[Tuple[Path, Path]]:
        cmb = self._index_dir(self.combined_dir)
        ble = self._index_dir(self.single_bleached)
        non = self._index_dir(self.single_non)
        pairs=[]
        for img in self.images:
            key = img.stem.lower()
            k_cmb = f"{key}_combined"
            if k_cmb in cmb:
                pairs.append((img, cmb[k_cmb])); continue
            cand = [p for k,p in ble.items() if k.startswith(key) or key in k]
            if cand: pairs.append((img, cand[0])); continue
            cand = [p for k,p in non.items() if k.startswith(key) or key in k]
            if cand: pairs.append((img, cand[0]))
        return pairs

    def __len__(self): return len(self.pairs)
    def __getitem__(self, i):
        ip, mp = self.pairs[i]
        return Image.open(ip).convert("RGB"), Image.open(mp).convert("RGB")

class PairToImages128WithLabels(Dataset):
    def __init__(self, base_pairs: Dataset, size=128):
        self.base = base_pairs
        self.size = size
    def __len__(self): return len(self.base)
    def __getitem__(self, idx):
        img_pil, mask_pil = self.base[idx]
        y = label_from_mask_pil(mask_pil)
        x = pil_to_tensor_rgb(img_pil)
        if x.shape[-2:] != (self.size, self.size):
            x = F.interpolate(x.unsqueeze(0), size=(self.size,self.size), mode="bilinear", align_corners=False).squeeze(0)
        return x, torch.tensor(y, dtype=torch.long)

def build_bleaching_classification(size: int = 128):
    pairs = CoralBleachingPairs(BLEACH_IMAGES, BLEACH_COMBINED, BLEACH_SINGLE)
    return PairToImages128WithLabels(pairs, size)

# ---------------- Benthic (your SegmentationDataset -> (img, mask) tensors) → classification ----------------
# We’ll collect whichever of your wrapped datasets exist in the namespace.
def collect_benthic_wrapped():
    names = [
        "BOLIVAR_t","COURTOWN_t","PAC_USA_t","IDN_PHL_t","PAC_AUS_t","TETES_t","ATL_t","TAYRONA_t"
    ]
    found = []
    for n in names:
        if n in globals():
            found.append(globals()[n])
    print(f"[benthic] using {len(found)} wrapped datasets: {[n for n in names if n in globals()]}")
    return found

class BenthicToImages128WithLabels(Dataset):
    def __init__(self, base_ds: Dataset, size=128, min_ratio=1.05, min_signal=0.01):
        self.base = base_ds
        self.size = size
        self.keep_idx = []
        self.labels = []
        # pre-scan to keep only confidently labeled samples
        for i in range(len(self.base)):
            img, mask = self.base[i]  # tensors (3,H,W)
            y = label_from_mask_tensor(mask, min_ratio=min_ratio, min_signal=min_signal)
            if y is not None:
                self.keep_idx.append(i)
                self.labels.append(int(y))
        print(f"[benthic filter] kept {len(self.keep_idx)}/{len(self.base)} (min_ratio={min_ratio}, min_signal={min_signal})")

    def __len__(self): return len(self.keep_idx)
    def __getitem__(self, j):
        i = self.keep_idx[j]
        img, mask = self.base[i]
        # resize image to 128
        if img.shape[-2:] != (self.size, self.size):
            img = F.interpolate(img.unsqueeze(0), size=(self.size,self.size), mode="bilinear", align_corners=False).squeeze(0)
        return img, torch.tensor(self.labels[j], dtype=torch.long)

def build_benthic_classification(size: int = 128):
    ds_list = collect_benthic_wrapped()
    out = []
    for ds in ds_list:
        out.append(BenthicToImages128WithLabels(ds, size=size, min_ratio=1.05, min_signal=0.01))
    if not out:
        return None, None
    # Concat all benthic; split 85/15
    benthic_all = ConcatDataset(out)
    n = len(benthic_all)
    n_val = max(1, int(0.15*n))
    idx = torch.randperm(n).tolist()
    val_ids = set(idx[:n_val])
    benthic_val = Subset(benthic_all, list(val_ids))
    benthic_train = Subset(benthic_all, [i for i in range(n) if i not in val_ids])
    return benthic_train, benthic_val

# ---------------- Build datasets & combine ----------------
cs_train, cs_val = build_coralscapes_classification(size=128)
bl_all = build_bleaching_classification(size=128)
# split bleaching 85/15
n_bl = len(bl_all); n_bl_val = max(1, int(0.15*n_bl))
perm = torch.randperm(n_bl).tolist()
bl_val = Subset(bl_all, perm[:n_bl_val])
bl_train = Subset(bl_all, perm[n_bl_val:])

bt_train, bt_val = build_benthic_classification(size=128)

# combine available pieces
train_parts = [cs_train, bl_train] + ([bt_train] if bt_train is not None else [])
val_parts   = [cs_val,   bl_val]   + ([bt_val]   if bt_val   is not None else [])
train_cls = ConcatDataset(train_parts)
val_cls   = ConcatDataset(val_parts)

print(f"[combined] train={len(train_cls)} | val={len(val_cls)}")

# ---------------- Balanced training: Class-Balanced Focal Loss (no sampler) ----------------
def get_label_counts(ds: Dataset):
    n0=n1=0
    for i in range(len(ds)):
        y = ds[i][1]
        y = int(y) if isinstance(y, torch.Tensor) else int(y)
        if y==0: n0+=1
        else: n1+=1
    return n0, n1

n0, n1 = get_label_counts(train_cls)
print(f"[labels] train counts → healthy(0)={n0} | unhealthy(1)={n1}")

def class_balanced_weights(n_per_class, beta=0.99):
    w=[]
    for n in n_per_class:
        n=max(1,n); w_c=(1-beta)/(1-beta**n); w.append(w_c)
    m=sum(w)/len(w)
    return torch.tensor([wc/m for wc in w], dtype=torch.float32, device=DEVICE)

cb_alpha = class_balanced_weights([n0, n1], beta=0.99)
print("[CB weights] alpha:", cb_alpha.tolist())

def focal_ce_loss(logits, target, alpha=None, gamma=1.0):
    ce = F.cross_entropy(logits, target, reduction="none")
    pt = F.softmax(logits, dim=1).gather(1, target.view(-1,1)).squeeze(1).clamp_(1e-6, 1-1e-6)
    loss = ce * ((1 - pt) ** gamma)
    if alpha is not None:
        loss = loss * alpha[target]
    return loss.mean()

BATCH = 8  # 16/8 if VRAM is tight
train_loader = DataLoader(train_cls, batch_size=BATCH, shuffle=True,  num_workers=0, pin_memory=True)
val_loader   = DataLoader(val_cls,   batch_size=BATCH, shuffle=False, num_workers=0, pin_memory=True)

# ---------------- Keras-like CNN (conv stack + GAP head) ----------------
class KerasLikeCNN_GAP(nn.Module):
    def __init__(self, p_drop=0.3):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=0), nn.ReLU(True),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=0), nn.ReLU(True),
            nn.Conv2d(64, 32, 1), nn.ReLU(True),
            nn.Conv2d(32, 64, 1), nn.ReLU(True),
            nn.Conv2d(64,128, 3, padding=0), nn.ReLU(True),
            nn.Dropout(p_drop),
            nn.AdaptiveAvgPool2d(1),
        )
        self.head = nn.Linear(128, 2)
    def forward(self, x):
        return self.head(self.features(x).flatten(1))

model = KerasLikeCNN_GAP().to(DEVICE)

# bias warm-start to prior
N = n0+n1
p1 = (n1 + 1e-6) / (N + 2e-6)
prior_logit = math.log(p1/(1-p1))
with torch.no_grad():
    model.head.bias[:] = torch.tensor([-prior_logit, prior_logit], device=DEVICE, dtype=model.head.bias.dtype)

opt = torch.optim.Adam(model.parameters(), lr=5e-4)
scaler = GradScaler(enabled=(DEVICE.type=="cuda"))
criterion = lambda logits, y: focal_ce_loss(logits, y, alpha=cb_alpha, gamma=1.0)

# ---------------- Train with early stopping ----------------
class EarlyStopper:
    def __init__(self, patience=6, min_delta=1e-3):
        self.patience=patience; self.min_delta=min_delta
        self.best=float("inf"); self.count=0
    def step(self, v):
        if self.best - v > self.min_delta:
            self.best=v; self.count=0; return False
        self.count+=1; return self.count>=self.patience

def run_epoch(loader, train=True):
    model.train() if train else model.eval()
    loss_sum=correct=n=0
    for xb, yb in loader:
        xb, yb = xb.to(DEVICE, non_blocking=True), yb.to(DEVICE, non_blocking=True)
        if train: opt.zero_grad(set_to_none=True)
        with torch.set_grad_enabled(train), autocast("cuda", enabled=(DEVICE.type=="cuda")):
            logits = model(xb)
            loss = criterion(logits, yb)
        if train:
            scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
        loss_sum += loss.item()*xb.size(0)
        correct  += (logits.argmax(1)==yb).sum().item()
        n += xb.size(0)
    return loss_sum/max(1,n), correct/max(1,n)

EPOCHS = 2
early = EarlyStopper(patience=6, min_delta=1e-3)
best_state=None
for ep in range(1, EPOCHS+1):
    t0=time.time()
    tr_loss,tr_acc = run_epoch(train_loader, True)
    va_loss,va_acc = run_epoch(val_loader,   False)
    print(f"Epoch {ep:02d}/{EPOCHS} - loss:{tr_loss:.4f} - acc:{tr_acc:.4f} - val_loss:{va_loss:.4f} - val_acc:{va_acc:.4f} - {time.time()-t0:.1f}s")
    if best_state is None or va_loss < (early.best-1e-3):
        best_state = {k:v.detach().cpu().clone() for k,v in model.state_dict().items()}
    if early.step(va_loss):
        print(f"Epoch {ep}: early stopping (best val_loss={early.best:.4f})")
        break
if best_state:
    model.load_state_dict(best_state)
    print("Loaded best weights.")

# ---------------- Threshold tuning (maximize macro-F1) ----------------
@torch.no_grad()
def collect_val_probs_and_labels(model, loader):
    model.eval()
    probs=[]; labels=[]
    for xb, yb in loader:
        xb = xb.to(DEVICE)
        with autocast("cuda", enabled=(DEVICE.type=="cuda")):
            p1 = F.softmax(model(xb), dim=1)[:,1]
        probs.append(p1.cpu()); labels.append(yb.cpu())
    return torch.cat(probs), torch.cat(labels)

def metrics_from_preds(y_true, y_pred):
    cm = torch.zeros(2,2, dtype=torch.int64)
    for t,p in zip(y_true, y_pred): cm[t,p]+=1
    def prf(c):
        tp=cm[c,c].item(); fp=cm[:,c].sum().item()-tp; fn=cm[c,:].sum().item()-tp
        prec = tp / max(1,tp+fp); rec = tp / max(1,tp+fn)
        f1 = 0.0 if prec+rec==0 else 2*prec*rec/(prec+rec)
        return prec, rec, f1
    p0,r0,f0 = prf(0); p1,r1,f1 = prf(1)
    macro = 0.5*(f0+f1)
    acc = (y_true==y_pred).float().mean().item()
    return {"acc":acc,"macro_f1":macro,"cm":cm,"per":{"healthy":(p0,r0,f0),"unhealthy":(p1,r1,f1)}}

probs, y_true = collect_val_probs_and_labels(model, val_loader)
best = {"t":0.5,"macro_f1":-1}
for t in torch.linspace(0.1, 0.9, steps=17):
    y_pred = (probs >= t).long()
    m = metrics_from_preds(y_true, y_pred)
    if m["macro_f1"] > best["macro_f1"]:
        best = {"t": float(t), "macro_f1": m["macro_f1"], "metrics": m}

m = best["metrics"]; cm = m["cm"].numpy()
(p0,r0,f0) = m["per"]["healthy"]; (p1_,r1_,f1_) = m["per"]["unhealthy"]
print(f"\nBest threshold t={best['t']:.2f} (macro-F1={best['macro_f1']:.3f})")
print(f"Overall: acc={m['acc']:.3f} | macro-F1={m['macro_f1']:.3f}")
print("Confusion Matrix (rows=true, cols=pred):")
print(f"            pred: healthy   pred: unhealthy")
print(f"true healthy     {cm[0,0]:>6}          {cm[0,1]:>6}")
print(f"true unhealthy   {cm[1,0]:>6}          {cm[1,1]:>6}")
print(f"healthy   | precision={p0:.3f} | recall={r0:.3f} | f1={f0:.3f}")
print(f"unhealthy | precision={p1_:.3f} | recall={r1_:.3f} | f1={f1_:.3f}")


Using device: cuda | NVIDIA GeForce RTX 3050 Ti Laptop GPU
[CS masks] mapped 1517 indices from 2 file(s)
[CS masks] mapped 166 indices from 1 file(s)
[benthic] using 0 wrapped datasets: []
[combined] train=1517 | val=166
[labels] train counts → healthy(0)=1470 | unhealthy(1)=47
[CB weights] alpha: [0.5470129251480103, 1.4529870748519897]


  scaler = GradScaler(enabled=(DEVICE.type=="cuda"))


KeyboardInterrupt: 