In [1]:
# ===========================
# Install & environment
# ===========================
!pip -q install monai nibabel scikit-learn timm einops tqdm

import os
os.environ["PYTORCH_DISABLE_DYNAMO"] = "1"
os.environ["TORCHDYNAMO_DISABLE"] = "1"

# ===========================
# Imports
# ===========================
import torch, numpy as np, random, math, warnings, json
from pathlib import Path
from google.colab import drive
drive.mount('/content/drive')

from monai.data import Dataset
from monai import transforms as T
from monai.config import print_config

import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from collections import defaultdict, Counter
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, classification_report, confusion_matrix

warnings.filterwarnings("ignore")

# ===========================
# Repro, device, paths
# ===========================
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DRIVE = Path("/content/drive/MyDrive")
DATA_ROOT = DRIVE / "OASIS3_Preprocessed_Data"   # expects AD/ and CN/
SAVE_DIR  = DRIVE / "OASIS3_Preprocessed_Data"
SAVE_DIR.mkdir(parents=True, exist_ok=True)

print("Torch:", torch.__version__, "| Device:", device)
print_config()
print("Data root:", DATA_ROOT)

# ====================
# CONFIG (tune here)
# ====================
# Phase-1 (baseline) — stays accuracy-friendly but still robust
DROPOUT_P1      = 0.00
LR_PHASE1       = 5e-5
EPOCHS_PHASE1   = 80
PATIENCE_P1     = 12
USE_CLASS_WEIGHT_P1 = True      # pos_weight for AD to help balance
MIXUP_ALPHA_P1  = 0.20          # mild mixup in P1

# Phase-2 (fine-tune) — *gentler* to recover overall accuracy while keeping BA
DROPOUT_P2      = 0.10
LR_PHASE2       = 5e-6          # smaller LR for stability
EPOCHS_PHASE2   = 20
PATIENCE_P2     = 8
USE_CLASS_WEIGHT_P2 = False     # turn OFF class weight in FT (improves accuracy)
MIXUP_ALPHA_P2  = 0.00          # turn OFF mixup in FT

# Optional: freeze backbone in fine-tune (only train classifier head)
FREEZE_BACKBONE_IN_FT = False   # set True to try even stabler FT (often +acc)

# Loss asymmetry (for Phase-1 when USE_CLASS_WEIGHT=True)
POS_WEIGHT_POWER = 1.25         # pos_weight = (n_cn/n_ad)**power

# Batching / data
BATCH_SIZE   = 4
NUM_WORKERS  = 2
PIN_MEMORY   = True
GRAD_ACCUM_TARGET = 8

# Split
TEST_FRACTION   = 0.15
VAL_FRACTION_OF_TRAINVAL = 0.20

# ===========================
# Data discovery
# ===========================
CLASS_TO_IDX = {"AD": 0, "CN": 1}
IDX_TO_CLASS = {v:k for k,v in CLASS_TO_IDX.items()}

def list_nii(d: Path):
    return sorted([p for p in d.rglob("*") if p.is_file() and p.name.lower().endswith((".nii",".nii.gz"))])

ad_files = list_nii(DATA_ROOT / "AD")
cn_files = list_nii(DATA_ROOT / "CN")
items = [{"image": str(p), "label": CLASS_TO_IDX["AD"]} for p in ad_files] + \
        [{"image": str(p), "label": CLASS_TO_IDX["CN"]} for p in cn_files]
assert len(items) > 0, "No NIfTI files found. Check /MyDrive/OASIS3_Preprocessed_Data/AD and /CN."

labels_all = np.array([it["label"] for it in items])
print(f"Found | AD={len(ad_files)} | CN={len(cn_files)} | Total={len(items)}")
u, c = np.unique(labels_all, return_counts=True)
print("Class counts:", {IDX_TO_CLASS[int(ui)]: int(ci) for ui,ci in zip(u,c)})

# ==================
# Transforms
# ==================
TARGET_SPACING = (1.0, 1.0, 1.0)
TARGET_SIZE    = (128, 160, 160)

train_tf = T.Compose([
    T.LoadImaged(keys=["image"]),
    T.EnsureChannelFirstd(keys=["image"]),
    T.Orientationd(keys=["image"], axcodes="RAS"),
    T.Spacingd(keys=["image"], pixdim=TARGET_SPACING, mode=("bilinear")),
    T.ScaleIntensityRangePercentilesd(keys=["image"], lower=0.5, upper=99.5, b_min=0.0, b_max=1.0, clip=True),
    T.CropForegroundd(keys=["image"], source_key="image"),
    T.ResizeWithPadOrCropd(keys=["image"], spatial_size=TARGET_SIZE),
    # mild augs (safe globally)
    T.RandFlipd(keys=["image"], prob=0.2, spatial_axis=[0]),
    T.RandFlipd(keys=["image"], prob=0.2, spatial_axis=[1]),
    T.RandFlipd(keys=["image"], prob=0.2, spatial_axis=[2]),
    T.RandAffined(keys=["image"], prob=0.10,
                  rotate_range=(np.deg2rad(8),)*3,
                  scale_range=(0.08,0.08,0.08),
                  translate_range=(4,4,4),
                  mode=("bilinear"), padding_mode="border"),
    T.EnsureTyped(keys=["image"], dtype=torch.float32),
    T.ToTensord(keys=["image"]),
])

val_tf = T.Compose([
    T.LoadImaged(keys=["image"]),
    T.EnsureChannelFirstd(keys=["image"]),
    T.Orientationd(keys=["image"], axcodes="RAS"),
    T.Spacingd(keys=["image"], pixdim=TARGET_SPACING, mode=("bilinear")),
    T.ScaleIntensityRangePercentilesd(keys=["image"], lower=0.5, upper=99.5, b_min=0.0, b_max=1.0, clip=True),
    T.CropForegroundd(keys=["image"], source_key="image"),
    T.ResizeWithPadOrCropd(keys=["image"], spatial_size=TARGET_SIZE),
    T.EnsureTyped(keys=["image"], dtype=torch.float32),
    T.ToTensord(keys=["image"]),
])

from torch.utils.data import Sampler, DataLoader

class BalancedBinaryBatchSampler(Sampler):
    """Balanced mini-batches over LOCAL indices: 50% AD (0), 50% CN (1)."""
    def __init__(self, indices_local, labels_local, batch_size):
        assert batch_size % 2 == 0
        self.batch_size = batch_size
        self.half = batch_size // 2
        idx_by_class = defaultdict(list)
        for i in indices_local:
            idx_by_class[int(labels_local[i])].append(i)
        self.idx0 = idx_by_class[0][:]
        self.idx1 = idx_by_class[1][:]
        if len(self.idx0)==0 or len(self.idx1)==0:
            raise ValueError("Both classes needed in the TRAIN split.")
        self.num_batches = math.ceil(max(len(self.idx0), len(self.idx1))*2 / self.batch_size)

    def __len__(self): return self.num_batches
    def __iter__(self):
        for _ in range(self.num_batches):
            b0 = [random.choice(self.idx0) for _ in range(self.half)]
            b1 = [random.choice(self.idx1) for _ in range(self.half)]
            batch = b0 + b1
            random.shuffle(batch)
            yield batch

def build_balanced_loaders_from_indices(idx_train_global, idx_val_global, batch_size=BATCH_SIZE):
    train_items_loc = [items[i] for i in idx_train_global]
    val_items_loc   = [items[i] for i in idx_val_global]
    train_ds_loc = Dataset(train_items_loc, transform=train_tf)
    val_ds_loc   = Dataset(val_items_loc,   transform=val_tf)

    train_local_indices = np.arange(len(train_items_loc))
    train_labels_local  = np.array([it["label"] for it in train_items_loc])
    train_sampler = BalancedBinaryBatchSampler(train_local_indices, train_labels_local, batch_size=batch_size)

    train_loader_loc = DataLoader(train_ds_loc, batch_sampler=train_sampler,
                                  num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, persistent_workers=False)
    val_loader_loc = DataLoader(val_ds_loc, batch_size=batch_size, shuffle=False,
                                num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, persistent_workers=False)
    return train_loader_loc, val_loader_loc, train_items_loc, val_items_loc

def build_plain_loader_from_indices(idx_global, batch_size=BATCH_SIZE):
    ds = Dataset([items[i] for i in idx_global], transform=val_tf)
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, persistent_workers=False)
    return loader, ds

# ==========================================
# Model (DenseNet with configurable dropout)
# ==========================================
from monai.networks.nets import DenseNet

class DenseNet3DHead(nn.Module):
    def __init__(self, out_ch=2, dropout_prob=0.0):
        super().__init__()
        self.backbone = DenseNet(
            spatial_dims=3, in_channels=1, out_channels=out_ch,
            init_features=64, growth_rate=32, block_config=(6,12,24,16),
            bn_size=4, act=("relu", {"inplace": False}),
            dropout_prob=dropout_prob
        )
    def forward(self, x): return self.backbone(x)

def build_binary_model(dropout_prob=0.0):
    model = DenseNet3DHead(out_ch=2, dropout_prob=dropout_prob).to(device)
    for m in model.modules():
        if isinstance(m, nn.ReLU): m.inplace = False
    return model

def unfreeze_all(model):
    for p in model.parameters(): p.requires_grad_(True)

def freeze_backbone_keep_classifier(model):
    """Freeze everything except the classifier head (MONAI DenseNet .backbone.class_layers.*)."""
    for n,p in model.named_parameters():
        p.requires_grad_(False)
    for n,p in model.named_parameters():
        if "backbone.class_layers" in n or "classifier" in n:
            p.requires_grad_(True)
    # sanity print
    n_tr = sum(p.numel() for p in model.parameters() if p.requires_grad)
    n_all = sum(p.numel() for p in model.parameters())
    print(f"FT trainable params: {n_tr}/{n_all} ({100*n_tr/n_all:.2f}%)")

class ModelEMA:
    def __init__(self, model, decay=0.999):
        import copy
        self.ema = copy.deepcopy(model).to(device)
        for p in self.ema.parameters(): p.requires_grad_(False)
        self.decay = decay
    @torch.no_grad()
    def update(self, model):
        d=self.decay; msd=model.state_dict()
        for k,v in self.ema.state_dict().items():
            if v.dtype.is_floating_point: v.copy_(v*d + msd[k]*(1-d))

def make_optim_sched(model, epochs=80, base_lr=5e-5, weight_decay=1e-5, warmup=5):
    opt = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()),
                      lr=base_lr, weight_decay=weight_decay)
    def lr_lambda(e):
        if e < warmup: return (e+1)/max(1,warmup)
        prog = (e-warmup)/max(1,(epochs-warmup))
        return 0.5*(1+math.cos(math.pi*min(prog,1.0)))
    sch = optim.lr_scheduler.LambdaLR(opt, lr_lambda=lr_lambda)
    return opt, sch

# MixUp (optional)
def mixup(x, y_ad, alpha=0.2):
    if alpha <= 0: return x, y_ad
    lam = np.random.beta(alpha, alpha)
    perm = torch.randperm(x.size(0), device=x.device)
    x_mix = lam * x + (1-lam) * x[perm]
    y_mix = lam * y_ad + (1-lam) * y_ad[perm]
    return x_mix, y_mix

# ----- TTA & single-logit helpers -----
def tta_logits(x, model_eval):
    outs=[]
    with torch.no_grad(), torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
        outs.append(model_eval(x))
        for dims in [(2,),(3,),(4,)]:
            outs.append(model_eval(torch.flip(x,dims=dims)))
        for a,b in [(2,3),(2,4),(3,4)]:
            for k in (1,2):
                outs.append(model_eval(torch.rot90(x,k=k,dims=(a,b))))
    return torch.stack(outs,0).mean(0)  # [N,2]

def two_to_single_logit(logits2):  # AD vs CN margin
    return logits2[:,0] - logits2[:,1]

@torch.no_grad()
def debug_val_bias_single(model_eval, loader):
    model_eval.eval()
    s=0.0; n=0
    for batch in loader:
        x = batch["image"].to(device, dtype=torch.float32)
        p_ad = torch.sigmoid(two_to_single_logit(tta_logits(x, model_eval)))
        s += p_ad.sum().item(); n += p_ad.numel()
    avg_ad = s/max(1,n)
    print(f"Val avg P(AD) ~= {avg_ad:.4f} | P(CN) ~= {1-avg_ad:.4f}")

@torch.no_grad()
def evaluate_single(model_eval, loader, sweep_thresholds=True, print_quantiles=True):
    model_eval.eval()
    p_ad_all, y_all = [], []
    for batch in loader:
        x = batch["image"].to(device, dtype=torch.float32, non_blocking=True)
        y = batch["label"].long().to(device, non_blocking=True)  # 0=AD, 1=CN
        p_ad = torch.sigmoid(two_to_single_logit(tta_logits(x, model_eval)))
        p_ad_all.append(p_ad.cpu().numpy()); y_all.append(y.cpu().numpy())
    p_ad = np.concatenate(p_ad_all); y = np.concatenate(y_all)

    preds = np.where(p_ad>=0.5, 0, 1)
    acc = accuracy_score(y, preds)
    f1  = f1_score(y, preds)
    auc = roc_auc_score(y, 1.0 - p_ad)

    print("\n=== Eval (BCE single-logit, th=0.50) ===")
    print(f"Acc:{acc:.4f} | F1(CN+):{f1:.4f} | AUC(CN prob):{auc:.4f}")
    print("Confusion:\n", confusion_matrix(y, preds))
    print(classification_report(y, preds, target_names=["AD","CN"]))
    if print_quantiles:
        qs = np.quantile(p_ad, [0.05,0.25,0.5,0.75,0.95]); print("P(AD) q=[5,25,50,75,95]:", np.round(qs,3))

    if sweep_thresholds:
        best, best_t = -1, 0.5
        for t in np.linspace(0.05,0.95,37):
            preds_t = np.where(p_ad>=t, 0, 1)
            f1_t = f1_score(y, preds_t)
            if f1_t > best: best, best_t = f1_t, t
        preds_best = np.where(p_ad>=best_t, 0, 1)
        print(f"--- Threshold sweep on P(AD) ---\nBest F1(CN+)={best:.4f} at th={best_t:.2f}")
        print(confusion_matrix(y, preds_best))
    return acc, f1, auc

def compute_balanced_accuracy(cm):
    tp_ad = cm[0,0]; fn_ad = cm[0,1]
    tn_ad = cm[1,1]; fp_ad = cm[1,0]
    tpr = tp_ad / max(1,(tp_ad+fn_ad))
    tnr = tn_ad / max(1,(tn_ad+fp_ad))
    return 0.5*(tpr+tnr)

@torch.no_grad()
def evaluate_with_threshold(model_eval, loader, threshold: float, title="FINAL"):
    """Evaluate at a chosen P(AD) threshold; prints Confusion + Balanced Accuracy."""
    model_eval.eval()
    p_ad_all, y_all = [], []
    for batch in loader:
        x = batch["image"].to(device, dtype=torch.float32, non_blocking=True)
        y = batch["label"].long().to(device, non_blocking=True)  # 0=AD, 1=CN
        p_ad = torch.sigmoid(two_to_single_logit(tta_logits(x, model_eval)))
        p_ad_all.append(p_ad.cpu().numpy()); y_all.append(y.cpu().numpy())
    p_ad = np.concatenate(p_ad_all); y = np.concatenate(y_all)

    preds = np.where(p_ad>=threshold, 0, 1)
    acc = accuracy_score(y, preds)
    cm  = confusion_matrix(y, preds, labels=[0,1])
    balacc = compute_balanced_accuracy(cm)

    print(f"\n=== {title} @ th={threshold:.2f} ===")
    print(f"Accuracy: {acc:.4f} | Balanced Acc: {balacc:.4f}")
    print("Confusion (rows=GT [AD,CN], cols=Pred [AD,CN]):\n", cm)
    print(classification_report(y, preds, target_names=["AD","CN"]))
    return acc, balacc, cm

# Threshold search: best Overall Accuracy and best Balanced Accuracy
def sweep_thresholds_all(p_ad, y, grid=np.linspace(0.05, 0.95, 37)):
    best = {
        "acc":(-1,None), "balacc":(-1,None), "f1_ad":(-1,None), "f1_cn":(-1,None)
    }
    for t in grid:
        preds = np.where(p_ad>=t, 0, 1)
        acc = accuracy_score(y, preds)
        cm  = confusion_matrix(y, preds, labels=[0,1])
        balacc = compute_balanced_accuracy(cm)
        from sklearn.metrics import f1_score
        f1_ad = f1_score(y==0, preds==0)
        f1_cn = f1_score(y==1, preds==1)
        if acc    > best["acc"][0]   : best["acc"]    = (acc, t)
        if balacc > best["balacc"][0]: best["balacc"] = (balacc, t)
        if f1_ad  > best["f1_ad"][0] : best["f1_ad"]  = (f1_ad, t)
        if f1_cn  > best["f1_cn"][0] : best["f1_cn"]  = (f1_cn, t)
    return best

@torch.no_grad()
def collect_p_ad_and_labels(model_eval, loader):
    p_ad_all, y_all = [], []
    for batch in loader:
        x = batch["image"].to(device, dtype=torch.float32, non_blocking=True)
        y = batch["label"].long().cpu().numpy()
        logits2 = tta_logits(x, model_eval)
        p_ad = torch.sigmoid(two_to_single_logit(logits2)).cpu().numpy()
        p_ad_all.append(p_ad); y_all.append(y)
    return np.concatenate(p_ad_all), np.concatenate(y_all)

# ================================================================
# Train loop (supports resume & configurable dropout / weighting)
# ================================================================
def train_binary_fold_balanced(idx_train, idx_val,
                               epochs=80,
                               base_lr=5e-5,
                               ema_decay=0.999,
                               batch_size=BATCH_SIZE,
                               grad_accum_target=8,
                               mixup_alpha=0.2,
                               use_class_weight=True,
                               pos_weight_power=1.25,
                               patience=12,
                               dropout_prob=0.0,
                               resume_path=None,
                               freeze_backbone=False,
                               save_name="best_binary_ad_cn_ema.pth"):

    # Build loaders (balanced over LOCAL indices)
    train_loader, val_loader, train_items_fold, _ = build_balanced_loaders_from_indices(
        idx_train, idx_val, batch_size=batch_size
    )

    # BCEWithLogits: AD positive (label 0)
    y_train = np.array([it["label"] for it in train_items_fold])  # 0=AD,1=CN
    n_ad = (y_train==0).sum(); n_cn=(y_train==1).sum()
    ratio = n_cn / max(1, n_ad)
    pos_weight_val = max(1.0, ratio**pos_weight_power)

    print(f"\nTrain size: {len(train_items_fold)} | Val size: {len(val_loader.dataset)}")
    print(f"Batch size: {batch_size} (grad accum to ~{grad_accum_target})")
    print(f"Dropout in DenseNet blocks: {dropout_prob}")
    print(f"Class ratio (CN/AD) ~ {ratio:.3f} -> pos_weight(AD) base={pos_weight_val:.3f} (if enabled)")

    model = build_binary_model(dropout_prob=dropout_prob); unfreeze_all(model)

    # Optional resume (useful for FT with new dropout)
    if resume_path is not None and Path(resume_path).exists():
        ckpt = torch.load(resume_path, map_location=device, weights_only=False)
        missing, unexpected = model.load_state_dict(ckpt["model"], strict=False)
        print(f"Resumed from: {resume_path}\n  Missing keys: {len(missing)} | Unexpected: {len(unexpected)}")

    # Optional freezing (FT only)
    if freeze_backbone:
        freeze_backbone_keep_classifier(model)

    # Build loss
    if use_class_weight:
        pos_weight = torch.tensor([pos_weight_val], device=device, dtype=torch.float32)
        bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight).to(device)
        print(f"Using class weight: pos_weight(AD)={pos_weight.item():.3f}")
    else:
        bce = nn.BCEWithLogitsLoss().to(device)
        print("Using unweighted BCE")

    opt, sch = make_optim_sched(model, epochs=epochs, base_lr=base_lr)
    ema = ModelEMA(model, decay=ema_decay)
    scaler = torch.amp.GradScaler('cuda', enabled=torch.cuda.is_available())

    best_key=-1; best_path=SAVE_DIR/save_name; no_improve=0
    accum_steps = max(1, grad_accum_target // batch_size)
    print(f"Grad accumulation steps: {accum_steps}")

    for epoch in range(1, epochs+1):
        model.train(); running=0.0; opt.zero_grad(set_to_none=True)
        print(f"\n--- Epoch {epoch}/{epochs} ---")
        pbar = tqdm(train_loader, desc="Train", leave=True, mininterval=0.2, dynamic_ncols=True)
        step_i = 0

        for batch in pbar:
            x = batch["image"].to(device, dtype=torch.float32, non_blocking=True)
            y = batch["label"].long().to(device, non_blocking=True)   # 0=AD,1=CN
            y_ad = (y==0).float().unsqueeze(1)                        # AD-positive

            # Optional MixUp
            if mixup_alpha > 0:
                x, y_ad = mixup(x, y_ad, alpha=mixup_alpha)

            with torch.amp.autocast('cuda', enabled=torch.cuda.is_available()):
                logits2 = model(x)                    # [N,2]
                single = two_to_single_logit(logits2) # [N]
                loss   = bce(single.unsqueeze(1), y_ad) / accum_steps

            scaler.scale(loss).backward()
            step_i += 1
            if step_i % accum_steps == 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(opt); scaler.update()
                opt.zero_grad(set_to_none=True)
                ema.update(model)

            running += float(loss.detach().cpu()) * x.size(0) * accum_steps
            pbar.set_postfix(loss=float(loss.detach().cpu())*accum_steps)

        sch.step()
        train_loss = running / len(train_loader.dataset)
        print(f"Train loss: {train_loss:.4f}")

        # Validation snapshot
        debug_val_bias_single(ema.ema, val_loader)
        acc,f1,auc = evaluate_single(ema.ema, val_loader, sweep_thresholds=True, print_quantiles=True)

        # Composite key (favor AUC, then F1, then Acc)
        key = (0.0 if np.isnan(auc) else auc) + 0.01*f1 + 0.001*acc
        print(f"Score key: {key:.4f}")

        if key>best_key:
            best_key=key; no_improve=0
            torch.save({"model":ema.ema.state_dict(),"epoch":epoch,
                        "val_acc":float(acc),"val_f1":float(f1),"val_auc":float(auc)}, best_path)
            print("✅ Saved new best.")
        else:
            no_improve+=1
            if no_improve>=patience:
                print("⏹️ Early stopping."); break

    # Final eval (reload best)
    ckpt=torch.load(best_path, map_location=device, weights_only=False)
    ema.ema.load_state_dict(ckpt["model"])
    print("\n=== BEST (reloaded) ===")
    evaluate_single(ema.ema, val_loader, sweep_thresholds=True)
    print("Model saved to:", best_path)
    return best_path

# ===========================
# Split: train / val / test
# ===========================
labels = np.array([it["label"] for it in items])
idx_all = np.arange(len(items))

# held-out TEST
idx_trainval, idx_test = train_test_split(idx_all, test_size=TEST_FRACTION, stratify=labels, random_state=SEED)
labels_trainval = labels[idx_trainval]
# train/val
idx_train, idx_val = train_test_split(idx_trainval, test_size=VAL_FRACTION_OF_TRAINVAL, stratify=labels_trainval, random_state=SEED)

print(f"\nSplit sizes -> Train: {len(idx_train)} | Val: {len(idx_val)} | Test: {len(idx_test)}")

# ===========================
# Phase-1 train (baseline)
# ===========================
PHASE1_CKPT = SAVE_DIR/"best_binary_ad_cn_ema.pth"
PHASE2_CKPT = SAVE_DIR/"best_binary_ad_cn_ema_ft.pth"

best_p1_path = train_binary_fold_balanced(
    idx_train, idx_val,
    epochs=EPOCHS_PHASE1,
    base_lr=LR_PHASE1,
    ema_decay=0.999,
    batch_size=BATCH_SIZE,
    grad_accum_target=GRAD_ACCUM_TARGET,
    mixup_alpha=MIXUP_ALPHA_P1,
    use_class_weight=USE_CLASS_WEIGHT_P1,
    pos_weight_power=POS_WEIGHT_POWER,
    patience=PATIENCE_P1,
    dropout_prob=DROPOUT_P1,
    resume_path=None,
    freeze_backbone=False,
    save_name=PHASE1_CKPT.name
)

# =====================================
# Threshold selection on VAL (Phase-1)
# =====================================
_, val_loader, _, _ = build_balanced_loaders_from_indices(idx_train, idx_val, batch_size=BATCH_SIZE)
model_p1 = build_binary_model(dropout_prob=DROPOUT_P1)
ckpt = torch.load(best_p1_path, map_location=device, weights_only=False)
model_p1.load_state_dict(ckpt["model"]); model_p1.eval()

p_ad_val_p1, y_val_p1 = collect_p_ad_and_labels(model_p1, val_loader)
best_p1 = sweep_thresholds_all(p_ad_val_p1, y_val_p1)
print("\nThreshold suggestions (VALIDATION, Phase-1):")
print(f"- Max Accuracy     -> t={best_p1['acc'][1]:.2f}, Acc={best_p1['acc'][0]:.3f}")
print(f"- Balanced Acc     -> t={best_p1['balacc'][1]:.2f}, BalAcc={best_p1['balacc'][0]:.3f}")

with open(SAVE_DIR/"deploy_threshold_phase1.json","w") as f:
    json.dump({"t_acc": float(best_p1['acc'][1]), "t_balacc": float(best_p1['balacc'][1])}, f)
print("Saved Phase-1 thresholds:", SAVE_DIR/"deploy_threshold_phase1.json")

# ================================================
# FINAL TEST (Phase-1) — evaluate both thresholds
# ================================================
test_loader, _ = build_plain_loader_from_indices(idx_test, batch_size=BATCH_SIZE)
evaluate_with_threshold(model_p1, test_loader, threshold=float(best_p1['acc'][1]), title="FINAL TEST (Phase-1, Max-Acc)")
evaluate_with_threshold(model_p1, test_loader, threshold=float(best_p1['balacc'][1]), title="FINAL TEST (Phase-1, Max-BalAcc)")

# ===========================
# Phase-2 Fine-Tune (gentle)
# ===========================
best_p2_path = train_binary_fold_balanced(
    idx_train, idx_val,
    epochs=EPOCHS_PHASE2,
    base_lr=LR_PHASE2,
    ema_decay=0.999,
    batch_size=BATCH_SIZE,
    grad_accum_target=GRAD_ACCUM_TARGET,
    mixup_alpha=MIXUP_ALPHA_P2,            # OFF in FT
    use_class_weight=USE_CLASS_WEIGHT_P2,  # OFF in FT
    pos_weight_power=POS_WEIGHT_POWER,     # ignored if not using class weight
    patience=PATIENCE_P2,
    dropout_prob=DROPOUT_P2,
    resume_path=best_p1_path,              # start from Phase-1 weights
    freeze_backbone=FREEZE_BACKBONE_IN_FT, # optionally freeze backbone
    save_name=PHASE2_CKPT.name
)

# Threshold selection on VAL (Phase-2) for BOTH metrics
model_p2 = build_binary_model(dropout_prob=DROPOUT_P2)
ckpt_ft = torch.load(best_p2_path, map_location=device, weights_only=False)
model_p2.load_state_dict(ckpt_ft["model"]); model_p2.eval()

p_ad_val_p2, y_val_p2 = collect_p_ad_and_labels(model_p2, val_loader)
best_p2 = sweep_thresholds_all(p_ad_val_p2, y_val_p2)
print("\nThreshold suggestions (VALIDATION, Phase-2/Fine-Tune):")
print(f"- Max Accuracy     -> t={best_p2['acc'][1]:.2f}, Acc={best_p2['acc'][0]:.3f}")
print(f"- Balanced Acc     -> t={best_p2['balacc'][1]:.2f}, BalAcc={best_p2['balacc'][0]:.3f}")

with open(SAVE_DIR/"deploy_threshold_phase2_ft.json","w") as f:
    json.dump({"t_acc": float(best_p2['acc'][1]), "t_balacc": float(best_p2['balacc'][1])}, f)
print("Saved Phase-2 thresholds:", SAVE_DIR/"deploy_threshold_phase2_ft.json")

# FINAL TEST (Phase-2) — evaluate both thresholds
evaluate_with_threshold(model_p2, test_loader, threshold=float(best_p2['acc'][1]), title="FINAL TEST (Phase-2, Max-Acc)")
evaluate_with_threshold(model_p2, test_loader, threshold=float(best_p2['balacc'][1]), title="FINAL TEST (Phase-2, Max-BalAcc)")

# ================================================
# Ensemble Phase-1 + Phase-2 (often best of both)
# ================================================
@torch.no_grad()
def evaluate_ensemble(models, loader, threshold, title="FINAL TEST (Ensemble)"):
    for m in models: m.eval()
    p_ad_all, y_all = [], []
    for batch in loader:
        x = batch["image"].to(device, dtype=torch.float32, non_blocking=True)
        y = batch["label"].long().cpu().numpy()
        # average logits across models with TTA
        logits = []
        for m in models:
            logits.append(tta_logits(x, m))
        logits = torch.stack(logits, 0).mean(0)
        p_ad = torch.sigmoid(two_to_single_logit(logits)).cpu().numpy()
        p_ad_all.append(p_ad); y_all.append(y)
    p_ad = np.concatenate(p_ad_all); y = np.concatenate(y_all)
    preds = np.where(p_ad>=threshold, 0, 1)
    acc = accuracy_score(y, preds)
    cm  = confusion_matrix(y, preds, labels=[0,1])
    balacc = compute_balanced_accuracy(cm)
    print(f"\n=== {title} @ th={threshold:.2f} ===")
    print(f"Accuracy: {acc:.4f} | Balanced Acc: {balacc:.4f}")
    print("Confusion (rows=GT [AD,CN], cols=Pred [AD,CN]):\n", cm)
    print(classification_report(y, preds, target_names=['AD','CN']))
    return acc, balacc, cm

# Build eval models
m1 = build_binary_model(dropout_prob=DROPOUT_P1)
m1.load_state_dict(torch.load(PHASE1_CKPT, map_location=device, weights_only=False)["model"]); m1.eval()
m2 = build_binary_model(dropout_prob=DROPOUT_P2)
m2.load_state_dict(torch.load(PHASE2_CKPT, map_location=device, weights_only=False)["model"]); m2.eval()

# Ensemble thresholds (sweep on VAL using ensemble)
@torch.no_grad()
def ensemble_val_thresholds(m1, m2, loader):
    p_ad_all, y_all = [], []
    for batch in loader:
        x = batch["image"].to(device, dtype=torch.float32, non_blocking=True)
        y = batch["label"].long().cpu().numpy()
        log1 = tta_logits(x, m1); log2 = tta_logits(x, m2)
        logits = (log1 + log2)/2
        p_ad = torch.sigmoid(two_to_single_logit(logits)).cpu().numpy()
        p_ad_all.append(p_ad); y_all.append(y)
    p_ad = np.concatenate(p_ad_all); y = np.concatenate(y_all)
    return sweep_thresholds_all(p_ad, y)

best_ens = ensemble_val_thresholds(m1, m2, val_loader)
print("\nThreshold suggestions (VALIDATION, Ensemble P1+P2):")
print(f"- Max Accuracy     -> t={best_ens['acc'][1]:.2f}, Acc={best_ens['acc'][0]:.3f}")
print(f"- Balanced Acc     -> t={best_ens['balacc'][1]:.2f}, BalAcc={best_ens['balacc'][0]:.3f}")

# FINAL TEST (Ensemble) — evaluate both thresholds
evaluate_ensemble([m1, m2], test_loader, threshold=float(best_ens['acc'][1]), title="FINAL TEST (Ensemble, Max-Acc)")
evaluate_ensemble([m1, m2], test_loader, threshold=float(best_ens['balacc'][1]), title="FINAL TEST (Ensemble, Max-BalAcc)")


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.7/2.7 MB[0m [31m35.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m766.6/766.6 MB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m125.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m99.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m57.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m9.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Train: 100%|██████████| 291/291 [03:10<00:00,  1.53it/s, loss=0.971]

Train loss: 1.2265





Val avg P(AD) ~= 0.4948 | P(CN) ~= 0.5052

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.6621 | F1(CN+):0.7967 | AUC(CN prob):0.3594
Confusion:
 [[  0  74]
 [  0 145]]
              precision    recall  f1-score   support

          AD       0.00      0.00      0.00        74
          CN       0.66      1.00      0.80       145

    accuracy                           0.66       219
   macro avg       0.33      0.50      0.40       219
weighted avg       0.44      0.66      0.53       219

P(AD) q=[5,25,50,75,95]: [0.493 0.494 0.495 0.496 0.496]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.7967 at th=0.50
[[  0  74]
 [  0 145]]
Score key: 0.3680
✅ Saved new best.

--- Epoch 2/80 ---


Train: 100%|██████████| 291/291 [02:27<00:00,  1.98it/s, loss=0.732]

Train loss: 1.0168





Val avg P(AD) ~= 0.5171 | P(CN) ~= 0.4829

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.3379 | F1(CN+):0.0000 | AUC(CN prob):0.4145
Confusion:
 [[ 74   0]
 [145   0]]
              precision    recall  f1-score   support

          AD       0.34      1.00      0.51        74
          CN       0.00      0.00      0.00       145

    accuracy                           0.34       219
   macro avg       0.17      0.50      0.25       219
weighted avg       0.11      0.34      0.17       219

P(AD) q=[5,25,50,75,95]: [0.515 0.517 0.517 0.518 0.518]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.7967 at th=0.53
[[  0  74]
 [  0 145]]
Score key: 0.4148
✅ Saved new best.

--- Epoch 3/80 ---


Train: 100%|██████████| 291/291 [02:18<00:00,  2.10it/s, loss=0.254]

Train loss: 1.0033





Val avg P(AD) ~= 0.5359 | P(CN) ~= 0.4641

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.3379 | F1(CN+):0.0000 | AUC(CN prob):0.3912
Confusion:
 [[ 74   0]
 [145   0]]
              precision    recall  f1-score   support

          AD       0.34      1.00      0.51        74
          CN       0.00      0.00      0.00       145

    accuracy                           0.34       219
   macro avg       0.17      0.50      0.25       219
weighted avg       0.11      0.34      0.17       219

P(AD) q=[5,25,50,75,95]: [0.535 0.536 0.536 0.536 0.537]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.7967 at th=0.55
[[  0  74]
 [  0 145]]
Score key: 0.3915

--- Epoch 4/80 ---


Train: 100%|██████████| 291/291 [01:52<00:00,  2.59it/s, loss=1.21]

Train loss: 1.0735





Val avg P(AD) ~= 0.5455 | P(CN) ~= 0.4545

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.3379 | F1(CN+):0.0000 | AUC(CN prob):0.6397
Confusion:
 [[ 74   0]
 [145   0]]
              precision    recall  f1-score   support

          AD       0.34      1.00      0.51        74
          CN       0.00      0.00      0.00       145

    accuracy                           0.34       219
   macro avg       0.17      0.50      0.25       219
weighted avg       0.11      0.34      0.17       219

P(AD) q=[5,25,50,75,95]: [0.544 0.545 0.545 0.546 0.546]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.7967 at th=0.55
[[  0  74]
 [  0 145]]
Score key: 0.6401
✅ Saved new best.

--- Epoch 5/80 ---


Train: 100%|██████████| 291/291 [01:52<00:00,  2.59it/s, loss=0.883]

Train loss: 1.0489





Val avg P(AD) ~= 0.5654 | P(CN) ~= 0.4346

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.3379 | F1(CN+):0.0000 | AUC(CN prob):0.7025
Confusion:
 [[ 74   0]
 [145   0]]
              precision    recall  f1-score   support

          AD       0.34      1.00      0.51        74
          CN       0.00      0.00      0.00       145

    accuracy                           0.34       219
   macro avg       0.17      0.50      0.25       219
weighted avg       0.11      0.34      0.17       219

P(AD) q=[5,25,50,75,95]: [0.564 0.565 0.565 0.566 0.566]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.7967 at th=0.57
[[  0  74]
 [  0 145]]
Score key: 0.7028
✅ Saved new best.

--- Epoch 6/80 ---


Train: 100%|██████████| 291/291 [01:53<00:00,  2.57it/s, loss=0.116]

Train loss: 0.9941





Val avg P(AD) ~= 0.5803 | P(CN) ~= 0.4197

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.3379 | F1(CN+):0.0000 | AUC(CN prob):0.7461
Confusion:
 [[ 74   0]
 [145   0]]
              precision    recall  f1-score   support

          AD       0.34      1.00      0.51        74
          CN       0.00      0.00      0.00       145

    accuracy                           0.34       219
   macro avg       0.17      0.50      0.25       219
weighted avg       0.11      0.34      0.17       219

P(AD) q=[5,25,50,75,95]: [0.579 0.58  0.58  0.581 0.582]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.7967 at th=0.60
[[  0  74]
 [  0 145]]
Score key: 0.7464
✅ Saved new best.

--- Epoch 7/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.61it/s, loss=0.541]

Train loss: 0.9607





Val avg P(AD) ~= 0.6232 | P(CN) ~= 0.3768

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.3379 | F1(CN+):0.0000 | AUC(CN prob):0.7710
Confusion:
 [[ 74   0]
 [145   0]]
              precision    recall  f1-score   support

          AD       0.34      1.00      0.51        74
          CN       0.00      0.00      0.00       145

    accuracy                           0.34       219
   macro avg       0.17      0.50      0.25       219
weighted avg       0.11      0.34      0.17       219

P(AD) q=[5,25,50,75,95]: [0.621 0.622 0.623 0.624 0.626]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8193 at th=0.62
[[ 23  51]
 [  9 136]]
Score key: 0.7714
✅ Saved new best.

--- Epoch 8/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.62it/s, loss=0.103]

Train loss: 0.9471





Val avg P(AD) ~= 0.6789 | P(CN) ~= 0.3211

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.3379 | F1(CN+):0.0000 | AUC(CN prob):0.7752
Confusion:
 [[ 74   0]
 [145   0]]
              precision    recall  f1-score   support

          AD       0.34      1.00      0.51        74
          CN       0.00      0.00      0.00       145

    accuracy                           0.34       219
   macro avg       0.17      0.50      0.25       219
weighted avg       0.11      0.34      0.17       219

P(AD) q=[5,25,50,75,95]: [0.676 0.677 0.679 0.68  0.683]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.7967 at th=0.70
[[  0  74]
 [  0 145]]
Score key: 0.7755
✅ Saved new best.

--- Epoch 9/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.61it/s, loss=0.176]

Train loss: 0.9074





Val avg P(AD) ~= 0.7085 | P(CN) ~= 0.2915

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.3379 | F1(CN+):0.0000 | AUC(CN prob):0.7765
Confusion:
 [[ 74   0]
 [145   0]]
              precision    recall  f1-score   support

          AD       0.34      1.00      0.51        74
          CN       0.00      0.00      0.00       145

    accuracy                           0.34       219
   macro avg       0.17      0.50      0.25       219
weighted avg       0.11      0.34      0.17       219

P(AD) q=[5,25,50,75,95]: [0.704 0.706 0.708 0.71  0.714]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.7967 at th=0.75
[[  0  74]
 [  0 145]]
Score key: 0.7768
✅ Saved new best.

--- Epoch 10/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.61it/s, loss=0.104]

Train loss: 0.9938





Val avg P(AD) ~= 0.7626 | P(CN) ~= 0.2374

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.3379 | F1(CN+):0.0000 | AUC(CN prob):0.7746
Confusion:
 [[ 74   0]
 [145   0]]
              precision    recall  f1-score   support

          AD       0.34      1.00      0.51        74
          CN       0.00      0.00      0.00       145

    accuracy                           0.34       219
   macro avg       0.17      0.50      0.25       219
weighted avg       0.11      0.34      0.17       219

P(AD) q=[5,25,50,75,95]: [0.756 0.76  0.762 0.765 0.77 ]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.7967 at th=0.80
[[  0  74]
 [  0 145]]
Score key: 0.7749

--- Epoch 11/80 ---


Train: 100%|██████████| 291/291 [01:49<00:00,  2.67it/s, loss=1.56]

Train loss: 0.9120





Val avg P(AD) ~= 0.7965 | P(CN) ~= 0.2035

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.3379 | F1(CN+):0.0000 | AUC(CN prob):0.7669
Confusion:
 [[ 74   0]
 [145   0]]
              precision    recall  f1-score   support

          AD       0.34      1.00      0.51        74
          CN       0.00      0.00      0.00       145

    accuracy                           0.34       219
   macro avg       0.17      0.50      0.25       219
weighted avg       0.11      0.34      0.17       219

P(AD) q=[5,25,50,75,95]: [0.788 0.792 0.796 0.8   0.807]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8026 at th=0.80
[[ 34  40]
 [ 21 124]]
Score key: 0.7673

--- Epoch 12/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.61it/s, loss=1.27]

Train loss: 0.9288





Val avg P(AD) ~= 0.8203 | P(CN) ~= 0.1797

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.3379 | F1(CN+):0.0000 | AUC(CN prob):0.7777
Confusion:
 [[ 74   0]
 [145   0]]
              precision    recall  f1-score   support

          AD       0.34      1.00      0.51        74
          CN       0.00      0.00      0.00       145

    accuracy                           0.34       219
   macro avg       0.17      0.50      0.25       219
weighted avg       0.11      0.34      0.17       219

P(AD) q=[5,25,50,75,95]: [0.809 0.815 0.82  0.825 0.834]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8155 at th=0.82
[[ 36  38]
 [ 19 126]]
Score key: 0.7780
✅ Saved new best.

--- Epoch 13/80 ---


Train: 100%|██████████| 291/291 [01:53<00:00,  2.57it/s, loss=1.16]

Train loss: 0.9638





Val avg P(AD) ~= 0.8308 | P(CN) ~= 0.1692

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.3379 | F1(CN+):0.0000 | AUC(CN prob):0.7860
Confusion:
 [[ 74   0]
 [145   0]]
              precision    recall  f1-score   support

          AD       0.34      1.00      0.51        74
          CN       0.00      0.00      0.00       145

    accuracy                           0.34       219
   macro avg       0.17      0.50      0.25       219
weighted avg       0.11      0.34      0.17       219

P(AD) q=[5,25,50,75,95]: [0.816 0.824 0.83  0.837 0.849]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.7967 at th=0.88
[[  0  74]
 [  0 145]]
Score key: 0.7863
✅ Saved new best.

--- Epoch 14/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.61it/s, loss=0.316]

Train loss: 0.9499





Val avg P(AD) ~= 0.8329 | P(CN) ~= 0.1671

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.3379 | F1(CN+):0.0000 | AUC(CN prob):0.7912
Confusion:
 [[ 74   0]
 [145   0]]
              precision    recall  f1-score   support

          AD       0.34      1.00      0.51        74
          CN       0.00      0.00      0.00       145

    accuracy                           0.34       219
   macro avg       0.17      0.50      0.25       219
weighted avg       0.11      0.34      0.17       219

P(AD) q=[5,25,50,75,95]: [0.812 0.823 0.831 0.841 0.858]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8249 at th=0.85
[[ 21  53]
 [  6 139]]
Score key: 0.7915
✅ Saved new best.

--- Epoch 15/80 ---


Train: 100%|██████████| 291/291 [01:50<00:00,  2.64it/s, loss=0.333]

Train loss: 0.8684





Val avg P(AD) ~= 0.8456 | P(CN) ~= 0.1544

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.3379 | F1(CN+):0.0000 | AUC(CN prob):0.7967
Confusion:
 [[ 74   0]
 [145   0]]
              precision    recall  f1-score   support

          AD       0.34      1.00      0.51        74
          CN       0.00      0.00      0.00       145

    accuracy                           0.34       219
   macro avg       0.17      0.50      0.25       219
weighted avg       0.11      0.34      0.17       219

P(AD) q=[5,25,50,75,95]: [0.818 0.833 0.844 0.857 0.878]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8012 at th=0.88
[[ 11  63]
 [  6 139]]
Score key: 0.7971
✅ Saved new best.

--- Epoch 16/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.61it/s, loss=0.108]

Train loss: 0.8801





Val avg P(AD) ~= 0.8456 | P(CN) ~= 0.1544

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.3379 | F1(CN+):0.0000 | AUC(CN prob):0.8055
Confusion:
 [[ 74   0]
 [145   0]]
              precision    recall  f1-score   support

          AD       0.34      1.00      0.51        74
          CN       0.00      0.00      0.00       145

    accuracy                           0.34       219
   macro avg       0.17      0.50      0.25       219
weighted avg       0.11      0.34      0.17       219

P(AD) q=[5,25,50,75,95]: [0.806 0.829 0.843 0.862 0.889]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8299 at th=0.88
[[ 23  51]
 [  6 139]]
Score key: 0.8058
✅ Saved new best.

--- Epoch 17/80 ---


Train: 100%|██████████| 291/291 [01:50<00:00,  2.63it/s, loss=0.593]

Train loss: 0.9221





Val avg P(AD) ~= 0.8431 | P(CN) ~= 0.1569

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.3379 | F1(CN+):0.0000 | AUC(CN prob):0.8146
Confusion:
 [[ 74   0]
 [145   0]]
              precision    recall  f1-score   support

          AD       0.34      1.00      0.51        74
          CN       0.00      0.00      0.00       145

    accuracy                           0.34       219
   macro avg       0.17      0.50      0.25       219
weighted avg       0.11      0.34      0.17       219

P(AD) q=[5,25,50,75,95]: [0.784 0.82  0.842 0.866 0.901]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8491 at th=0.88
[[ 36  38]
 [ 10 135]]
Score key: 0.8150
✅ Saved new best.

--- Epoch 18/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.60it/s, loss=1.02]

Train loss: 0.8997





Val avg P(AD) ~= 0.8613 | P(CN) ~= 0.1387

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.3379 | F1(CN+):0.0000 | AUC(CN prob):0.8186
Confusion:
 [[ 74   0]
 [145   0]]
              precision    recall  f1-score   support

          AD       0.34      1.00      0.51        74
          CN       0.00      0.00      0.00       145

    accuracy                           0.34       219
   macro avg       0.17      0.50      0.25       219
weighted avg       0.11      0.34      0.17       219

P(AD) q=[5,25,50,75,95]: [0.788 0.834 0.863 0.888 0.923]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8438 at th=0.90
[[ 34  40]
 [ 10 135]]
Score key: 0.8189
✅ Saved new best.

--- Epoch 19/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.62it/s, loss=0.0481]

Train loss: 0.8411





Val avg P(AD) ~= 0.8657 | P(CN) ~= 0.1343

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.3379 | F1(CN+):0.0000 | AUC(CN prob):0.8218
Confusion:
 [[ 74   0]
 [145   0]]
              precision    recall  f1-score   support

          AD       0.34      1.00      0.51        74
          CN       0.00      0.00      0.00       145

    accuracy                           0.34       219
   macro avg       0.17      0.50      0.25       219
weighted avg       0.11      0.34      0.17       219

P(AD) q=[5,25,50,75,95]: [0.768 0.833 0.872 0.901 0.936]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8487 at th=0.90
[[ 44  30]
 [ 16 129]]
Score key: 0.8221
✅ Saved new best.

--- Epoch 20/80 ---


Train: 100%|██████████| 291/291 [01:50<00:00,  2.63it/s, loss=0.522]

Train loss: 0.8557





Val avg P(AD) ~= 0.8789 | P(CN) ~= 0.1211

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.3379 | F1(CN+):0.0000 | AUC(CN prob):0.8270
Confusion:
 [[ 74   0]
 [145   0]]
              precision    recall  f1-score   support

          AD       0.34      1.00      0.51        74
          CN       0.00      0.00      0.00       145

    accuracy                           0.34       219
   macro avg       0.17      0.50      0.25       219
weighted avg       0.11      0.34      0.17       219

P(AD) q=[5,25,50,75,95]: [0.767 0.845 0.89  0.923 0.95 ]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8471 at th=0.92
[[ 38  36]
 [ 12 133]]
Score key: 0.8274
✅ Saved new best.

--- Epoch 21/80 ---


Train: 100%|██████████| 291/291 [01:49<00:00,  2.65it/s, loss=1.07]

Train loss: 0.9409





Val avg P(AD) ~= 0.8779 | P(CN) ~= 0.1221

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.3379 | F1(CN+):0.0000 | AUC(CN prob):0.8292
Confusion:
 [[ 74   0]
 [145   0]]
              precision    recall  f1-score   support

          AD       0.34      1.00      0.51        74
          CN       0.00      0.00      0.00       145

    accuracy                           0.34       219
   macro avg       0.17      0.50      0.25       219
weighted avg       0.11      0.34      0.17       219

P(AD) q=[5,25,50,75,95]: [0.729 0.835 0.898 0.936 0.958]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8485 at th=0.92
[[ 48  26]
 [ 19 126]]
Score key: 0.8296
✅ Saved new best.

--- Epoch 22/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.61it/s, loss=1.13]

Train loss: 0.8390





Val avg P(AD) ~= 0.8467 | P(CN) ~= 0.1533

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.3425 | F1(CN+):0.0137 | AUC(CN prob):0.8313
Confusion:
 [[ 74   0]
 [144   1]]
              precision    recall  f1-score   support

          AD       0.34      1.00      0.51        74
          CN       1.00      0.01      0.01       145

    accuracy                           0.34       219
   macro avg       0.67      0.50      0.26       219
weighted avg       0.78      0.34      0.18       219

P(AD) q=[5,25,50,75,95]: [0.62  0.776 0.882 0.936 0.96 ]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8505 at th=0.92
[[ 46  28]
 [ 17 128]]
Score key: 0.8317
✅ Saved new best.

--- Epoch 23/80 ---


Train: 100%|██████████| 291/291 [01:53<00:00,  2.57it/s, loss=0.078]

Train loss: 0.7999





Val avg P(AD) ~= 0.8365 | P(CN) ~= 0.1635

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.3699 | F1(CN+):0.0921 | AUC(CN prob):0.8329
Confusion:
 [[ 74   0]
 [138   7]]
              precision    recall  f1-score   support

          AD       0.35      1.00      0.52        74
          CN       1.00      0.05      0.09       145

    accuracy                           0.37       219
   macro avg       0.67      0.52      0.30       219
weighted avg       0.78      0.37      0.24       219

P(AD) q=[5,25,50,75,95]: [0.552 0.752 0.886 0.942 0.964]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8421 at th=0.95
[[ 32  42]
 [  9 136]]
Score key: 0.8342
✅ Saved new best.

--- Epoch 24/80 ---


Train: 100%|██████████| 291/291 [01:50<00:00,  2.63it/s, loss=2.94]

Train loss: 0.8066





Val avg P(AD) ~= 0.7855 | P(CN) ~= 0.2145

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.4475 | F1(CN+):0.2840 | AUC(CN prob):0.8356
Confusion:
 [[ 74   0]
 [121  24]]
              precision    recall  f1-score   support

          AD       0.38      1.00      0.55        74
          CN       1.00      0.17      0.28       145

    accuracy                           0.45       219
   macro avg       0.69      0.58      0.42       219
weighted avg       0.79      0.45      0.37       219

P(AD) q=[5,25,50,75,95]: [0.414 0.654 0.852 0.937 0.962]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8543 at th=0.92
[[ 46  28]
 [ 16 129]]
Score key: 0.8389
✅ Saved new best.

--- Epoch 25/80 ---


Train: 100%|██████████| 291/291 [01:52<00:00,  2.59it/s, loss=0.132]

Train loss: 0.8323





Val avg P(AD) ~= 0.7670 | P(CN) ~= 0.2330

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.4749 | F1(CN+):0.3575 | AUC(CN prob):0.8371
Confusion:
 [[ 72   2]
 [113  32]]
              precision    recall  f1-score   support

          AD       0.39      0.97      0.56        74
          CN       0.94      0.22      0.36       145

    accuracy                           0.47       219
   macro avg       0.67      0.60      0.46       219
weighted avg       0.75      0.47      0.42       219

P(AD) q=[5,25,50,75,95]: [0.349 0.601 0.851 0.94  0.964]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8552 at th=0.92
[[ 49  25]
 [ 18 127]]
Score key: 0.8411
✅ Saved new best.

--- Epoch 26/80 ---


Train: 100%|██████████| 291/291 [01:52<00:00,  2.60it/s, loss=0.0164]

Train loss: 0.7048





Val avg P(AD) ~= 0.7754 | P(CN) ~= 0.2246

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.4840 | F1(CN+):0.3757 | AUC(CN prob):0.8375
Confusion:
 [[ 72   2]
 [111  34]]
              precision    recall  f1-score   support

          AD       0.39      0.97      0.56        74
          CN       0.94      0.23      0.38       145

    accuracy                           0.48       219
   macro avg       0.67      0.60      0.47       219
weighted avg       0.76      0.48      0.44       219

P(AD) q=[5,25,50,75,95]: [0.324 0.623 0.876 0.946 0.969]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8526 at th=0.95
[[ 40  34]
 [ 12 133]]
Score key: 0.8418
✅ Saved new best.

--- Epoch 27/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.60it/s, loss=0.632]

Train loss: 0.8332





Val avg P(AD) ~= 0.7583 | P(CN) ~= 0.2417

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.5342 | F1(CN+):0.4688 | AUC(CN prob):0.8394
Confusion:
 [[ 72   2]
 [100  45]]
              precision    recall  f1-score   support

          AD       0.42      0.97      0.59        74
          CN       0.96      0.31      0.47       145

    accuracy                           0.53       219
   macro avg       0.69      0.64      0.53       219
weighted avg       0.78      0.53      0.51       219

P(AD) q=[5,25,50,75,95]: [0.274 0.58  0.875 0.947 0.972]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8526 at th=0.95
[[ 40  34]
 [ 12 133]]
Score key: 0.8446
✅ Saved new best.

--- Epoch 28/80 ---


Train: 100%|██████████| 291/291 [01:52<00:00,  2.59it/s, loss=0.883]

Train loss: 0.8293





Val avg P(AD) ~= 0.7596 | P(CN) ~= 0.2404

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.5342 | F1(CN+):0.4688 | AUC(CN prob):0.8378
Confusion:
 [[ 72   2]
 [100  45]]
              precision    recall  f1-score   support

          AD       0.42      0.97      0.59        74
          CN       0.96      0.31      0.47       145

    accuracy                           0.53       219
   macro avg       0.69      0.64      0.53       219
weighted avg       0.78      0.53      0.51       219

P(AD) q=[5,25,50,75,95]: [0.255 0.584 0.893 0.95  0.974]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8479 at th=0.95
[[ 41  33]
 [ 14 131]]
Score key: 0.8431

--- Epoch 29/80 ---


Train: 100%|██████████| 291/291 [01:50<00:00,  2.63it/s, loss=0.0592]

Train loss: 0.8510





Val avg P(AD) ~= 0.7575 | P(CN) ~= 0.2425

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.5479 | F1(CN+):0.4923 | AUC(CN prob):0.8379
Confusion:
 [[72  2]
 [97 48]]
              precision    recall  f1-score   support

          AD       0.43      0.97      0.59        74
          CN       0.96      0.33      0.49       145

    accuracy                           0.55       219
   macro avg       0.69      0.65      0.54       219
weighted avg       0.78      0.55      0.53       219

P(AD) q=[5,25,50,75,95]: [0.241 0.578 0.902 0.953 0.976]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8344 at th=0.95
[[ 43  31]
 [ 19 126]]
Score key: 0.8434

--- Epoch 30/80 ---


Train: 100%|██████████| 291/291 [01:49<00:00,  2.65it/s, loss=0.0316]

Train loss: 0.7703





Val avg P(AD) ~= 0.7539 | P(CN) ~= 0.2461

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.5571 | F1(CN+):0.5076 | AUC(CN prob):0.8352
Confusion:
 [[72  2]
 [95 50]]
              precision    recall  f1-score   support

          AD       0.43      0.97      0.60        74
          CN       0.96      0.34      0.51       145

    accuracy                           0.56       219
   macro avg       0.70      0.66      0.55       219
weighted avg       0.78      0.56      0.54       219

P(AD) q=[5,25,50,75,95]: [0.223 0.567 0.903 0.954 0.978]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8400 at th=0.95
[[ 45  29]
 [ 19 126]]
Score key: 0.8409

--- Epoch 31/80 ---


Train: 100%|██████████| 291/291 [01:49<00:00,  2.65it/s, loss=0.0814]

Train loss: 0.8308





Val avg P(AD) ~= 0.7571 | P(CN) ~= 0.2429

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.5616 | F1(CN+):0.5152 | AUC(CN prob):0.8319
Confusion:
 [[72  2]
 [94 51]]
              precision    recall  f1-score   support

          AD       0.43      0.97      0.60        74
          CN       0.96      0.35      0.52       145

    accuracy                           0.56       219
   macro avg       0.70      0.66      0.56       219
weighted avg       0.78      0.56      0.54       219

P(AD) q=[5,25,50,75,95]: [0.216 0.562 0.907 0.958 0.978]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8350 at th=0.95
[[ 46  28]
 [ 21 124]]
Score key: 0.8376

--- Epoch 32/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.60it/s, loss=0.972]

Train loss: 0.7291





Val avg P(AD) ~= 0.7376 | P(CN) ~= 0.2624

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.5708 | F1(CN+):0.5300 | AUC(CN prob):0.8315
Confusion:
 [[72  2]
 [92 53]]
              precision    recall  f1-score   support

          AD       0.44      0.97      0.61        74
          CN       0.96      0.37      0.53       145

    accuracy                           0.57       219
   macro avg       0.70      0.67      0.57       219
weighted avg       0.79      0.57      0.56       219

P(AD) q=[5,25,50,75,95]: [0.19  0.499 0.898 0.958 0.979]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8428 at th=0.95
[[ 46  28]
 [ 19 126]]
Score key: 0.8373

--- Epoch 33/80 ---


Train: 100%|██████████| 291/291 [01:49<00:00,  2.65it/s, loss=1.14]


Train loss: 0.7829
Val avg P(AD) ~= 0.7295 | P(CN) ~= 0.2705

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.5753 | F1(CN+):0.5373 | AUC(CN prob):0.8312
Confusion:
 [[72  2]
 [91 54]]
              precision    recall  f1-score   support

          AD       0.44      0.97      0.61        74
          CN       0.96      0.37      0.54       145

    accuracy                           0.58       219
   macro avg       0.70      0.67      0.57       219
weighted avg       0.79      0.58      0.56       219

P(AD) q=[5,25,50,75,95]: [0.174 0.475 0.9   0.959 0.979]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8428 at th=0.95
[[ 46  28]
 [ 19 126]]
Score key: 0.8372

--- Epoch 34/80 ---


Train: 100%|██████████| 291/291 [01:50<00:00,  2.63it/s, loss=0.537]

Train loss: 0.7011





Val avg P(AD) ~= 0.7136 | P(CN) ~= 0.2864

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.5936 | F1(CN+):0.5659 | AUC(CN prob):0.8346
Confusion:
 [[72  2]
 [87 58]]
              precision    recall  f1-score   support

          AD       0.45      0.97      0.62        74
          CN       0.97      0.40      0.57       145

    accuracy                           0.59       219
   macro avg       0.71      0.69      0.59       219
weighted avg       0.79      0.59      0.58       219

P(AD) q=[5,25,50,75,95]: [0.154 0.425 0.901 0.962 0.98 ]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8378 at th=0.95
[[ 47  27]
 [ 21 124]]
Score key: 0.8409

--- Epoch 35/80 ---


Train: 100%|██████████| 291/291 [01:50<00:00,  2.63it/s, loss=0.459]

Train loss: 0.6528





Val avg P(AD) ~= 0.6912 | P(CN) ~= 0.3088

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.6073 | F1(CN+):0.5865 | AUC(CN prob):0.8399
Confusion:
 [[72  2]
 [84 61]]
              precision    recall  f1-score   support

          AD       0.46      0.97      0.63        74
          CN       0.97      0.42      0.59       145

    accuracy                           0.61       219
   macro avg       0.71      0.70      0.61       219
weighted avg       0.80      0.61      0.60       219

P(AD) q=[5,25,50,75,95]: [0.127 0.372 0.892 0.961 0.982]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8435 at th=0.95
[[ 49  25]
 [ 21 124]]
Score key: 0.8464
✅ Saved new best.

--- Epoch 36/80 ---


Train: 100%|██████████| 291/291 [01:53<00:00,  2.57it/s, loss=0.578]

Train loss: 0.7525





Val avg P(AD) ~= 0.6588 | P(CN) ~= 0.3412

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.6438 | F1(CN+):0.6455 | AUC(CN prob):0.8455
Confusion:
 [[70  4]
 [74 71]]
              precision    recall  f1-score   support

          AD       0.49      0.95      0.64        74
          CN       0.95      0.49      0.65       145

    accuracy                           0.64       219
   macro avg       0.72      0.72      0.64       219
weighted avg       0.79      0.64      0.64       219

P(AD) q=[5,25,50,75,95]: [0.116 0.281 0.865 0.956 0.982]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8439 at th=0.95
[[ 45  29]
 [ 18 127]]
Score key: 0.8526
✅ Saved new best.

--- Epoch 37/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.61it/s, loss=3.3]

Train loss: 0.7487





Val avg P(AD) ~= 0.6323 | P(CN) ~= 0.3677

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.6667 | F1(CN+):0.6784 | AUC(CN prob):0.8486
Confusion:
 [[69  5]
 [68 77]]
              precision    recall  f1-score   support

          AD       0.50      0.93      0.65        74
          CN       0.94      0.53      0.68       145

    accuracy                           0.67       219
   macro avg       0.72      0.73      0.67       219
weighted avg       0.79      0.67      0.67       219

P(AD) q=[5,25,50,75,95]: [0.103 0.233 0.822 0.955 0.982]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8487 at th=0.95
[[ 44  30]
 [ 16 129]]
Score key: 0.8561
✅ Saved new best.

--- Epoch 38/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.60it/s, loss=0.17]

Train loss: 0.7326





Val avg P(AD) ~= 0.6188 | P(CN) ~= 0.3812

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.6758 | F1(CN+):0.6953 | AUC(CN prob):0.8522
Confusion:
 [[67  7]
 [64 81]]
              precision    recall  f1-score   support

          AD       0.51      0.91      0.65        74
          CN       0.92      0.56      0.70       145

    accuracy                           0.68       219
   macro avg       0.72      0.73      0.67       219
weighted avg       0.78      0.68      0.68       219

P(AD) q=[5,25,50,75,95]: [0.097 0.218 0.793 0.955 0.982]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8590 at th=0.95
[[ 45  29]
 [ 14 131]]
Score key: 0.8599
✅ Saved new best.

--- Epoch 39/80 ---


Train: 100%|██████████| 291/291 [01:53<00:00,  2.56it/s, loss=0.658]

Train loss: 0.6107





Val avg P(AD) ~= 0.5977 | P(CN) ~= 0.4023

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.6941 | F1(CN+):0.7197 | AUC(CN prob):0.8555
Confusion:
 [[66  8]
 [59 86]]
              precision    recall  f1-score   support

          AD       0.53      0.89      0.66        74
          CN       0.91      0.59      0.72       145

    accuracy                           0.69       219
   macro avg       0.72      0.74      0.69       219
weighted avg       0.78      0.69      0.70       219

P(AD) q=[5,25,50,75,95]: [0.088 0.181 0.712 0.956 0.983]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8590 at th=0.95
[[ 45  29]
 [ 14 131]]
Score key: 0.8634
✅ Saved new best.

--- Epoch 40/80 ---


Train: 100%|██████████| 291/291 [01:54<00:00,  2.54it/s, loss=0.689]

Train loss: 0.7024





Val avg P(AD) ~= 0.5668 | P(CN) ~= 0.4332

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7078 | F1(CN+):0.7355 | AUC(CN prob):0.8627
Confusion:
 [[66  8]
 [56 89]]
              precision    recall  f1-score   support

          AD       0.54      0.89      0.67        74
          CN       0.92      0.61      0.74       145

    accuracy                           0.71       219
   macro avg       0.73      0.75      0.70       219
weighted avg       0.79      0.71      0.71       219

P(AD) q=[5,25,50,75,95]: [0.073 0.153 0.652 0.95  0.983]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8673 at th=0.95
[[ 44  30]
 [ 11 134]]
Score key: 0.8708
✅ Saved new best.

--- Epoch 41/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.61it/s, loss=0.0544]

Train loss: 0.6076





Val avg P(AD) ~= 0.5347 | P(CN) ~= 0.4653

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7215 | F1(CN+):0.7530 | AUC(CN prob):0.8653
Confusion:
 [[65  9]
 [52 93]]
              precision    recall  f1-score   support

          AD       0.56      0.88      0.68        74
          CN       0.91      0.64      0.75       145

    accuracy                           0.72       219
   macro avg       0.73      0.76      0.72       219
weighted avg       0.79      0.72      0.73       219

P(AD) q=[5,25,50,75,95]: [0.06  0.122 0.572 0.945 0.983]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8630 at th=0.92
[[ 53  21]
 [ 19 126]]
Score key: 0.8735
✅ Saved new best.

--- Epoch 42/80 ---


Train: 100%|██████████| 291/291 [01:50<00:00,  2.63it/s, loss=0.0643]

Train loss: 0.6804





Val avg P(AD) ~= 0.5080 | P(CN) ~= 0.4920

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7489 | F1(CN+):0.7860 | AUC(CN prob):0.8658
Confusion:
 [[ 63  11]
 [ 44 101]]
              precision    recall  f1-score   support

          AD       0.59      0.85      0.70        74
          CN       0.90      0.70      0.79       145

    accuracy                           0.75       219
   macro avg       0.75      0.77      0.74       219
weighted avg       0.80      0.75      0.76       219

P(AD) q=[5,25,50,75,95]: [0.053 0.108 0.48  0.943 0.982]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8754 at th=0.92
[[ 52  22]
 [ 15 130]]
Score key: 0.8744
✅ Saved new best.

--- Epoch 43/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.60it/s, loss=0.709]

Train loss: 0.6029





Val avg P(AD) ~= 0.4893 | P(CN) ~= 0.5107

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7717 | F1(CN+):0.8106 | AUC(CN prob):0.8664
Confusion:
 [[ 62  12]
 [ 38 107]]
              precision    recall  f1-score   support

          AD       0.62      0.84      0.71        74
          CN       0.90      0.74      0.81       145

    accuracy                           0.77       219
   macro avg       0.76      0.79      0.76       219
weighted avg       0.80      0.77      0.78       219

P(AD) q=[5,25,50,75,95]: [0.048 0.102 0.443 0.943 0.982]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8754 at th=0.90
[[ 52  22]
 [ 15 130]]
Score key: 0.8752
✅ Saved new best.

--- Epoch 44/80 ---


Train: 100%|██████████| 291/291 [01:53<00:00,  2.56it/s, loss=0.0091]

Train loss: 0.5912





Val avg P(AD) ~= 0.4500 | P(CN) ~= 0.5500

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7945 | F1(CN+):0.8352 | AUC(CN prob):0.8669
Confusion:
 [[ 60  14]
 [ 31 114]]
              precision    recall  f1-score   support

          AD       0.66      0.81      0.73        74
          CN       0.89      0.79      0.84       145

    accuracy                           0.79       219
   macro avg       0.77      0.80      0.78       219
weighted avg       0.81      0.79      0.80       219

P(AD) q=[5,25,50,75,95]: [0.04  0.081 0.303 0.94  0.98 ]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8816 at th=0.92
[[ 49  25]
 [ 11 134]]
Score key: 0.8761
✅ Saved new best.

--- Epoch 45/80 ---


Train: 100%|██████████| 291/291 [01:52<00:00,  2.58it/s, loss=1.05]

Train loss: 0.6875





Val avg P(AD) ~= 0.4539 | P(CN) ~= 0.5461

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7991 | F1(CN+):0.8394 | AUC(CN prob):0.8676
Confusion:
 [[ 60  14]
 [ 30 115]]
              precision    recall  f1-score   support

          AD       0.67      0.81      0.73        74
          CN       0.89      0.79      0.84       145

    accuracy                           0.80       219
   macro avg       0.78      0.80      0.79       219
weighted avg       0.82      0.80      0.80       219

P(AD) q=[5,25,50,75,95]: [0.036 0.08  0.301 0.944 0.982]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8808 at th=0.92
[[ 50  24]
 [ 12 133]]
Score key: 0.8768
✅ Saved new best.

--- Epoch 46/80 ---


Train: 100%|██████████| 291/291 [01:50<00:00,  2.64it/s, loss=0.00546]

Train loss: 0.5937





Val avg P(AD) ~= 0.4351 | P(CN) ~= 0.5649

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7945 | F1(CN+):0.8364 | AUC(CN prob):0.8687
Confusion:
 [[ 59  15]
 [ 30 115]]
              precision    recall  f1-score   support

          AD       0.66      0.80      0.72        74
          CN       0.88      0.79      0.84       145

    accuracy                           0.79       219
   macro avg       0.77      0.80      0.78       219
weighted avg       0.81      0.79      0.80       219

P(AD) q=[5,25,50,75,95]: [0.032 0.068 0.256 0.938 0.981]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8824 at th=0.92
[[ 48  26]
 [ 10 135]]
Score key: 0.8778
✅ Saved new best.

--- Epoch 47/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.61it/s, loss=0.00891]

Train loss: 0.6058





Val avg P(AD) ~= 0.4278 | P(CN) ~= 0.5722

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7945 | F1(CN+):0.8364 | AUC(CN prob):0.8677
Confusion:
 [[ 59  15]
 [ 30 115]]
              precision    recall  f1-score   support

          AD       0.66      0.80      0.72        74
          CN       0.88      0.79      0.84       145

    accuracy                           0.79       219
   macro avg       0.77      0.80      0.78       219
weighted avg       0.81      0.79      0.80       219

P(AD) q=[5,25,50,75,95]: [0.03  0.062 0.212 0.937 0.981]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8797 at th=0.78
[[ 56  18]
 [ 17 128]]
Score key: 0.8768

--- Epoch 48/80 ---


Train: 100%|██████████| 291/291 [01:50<00:00,  2.62it/s, loss=0.00844]

Train loss: 0.5880





Val avg P(AD) ~= 0.4219 | P(CN) ~= 0.5781

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7854 | F1(CN+):0.8303 | AUC(CN prob):0.8679
Confusion:
 [[ 57  17]
 [ 30 115]]
              precision    recall  f1-score   support

          AD       0.66      0.77      0.71        74
          CN       0.87      0.79      0.83       145

    accuracy                           0.79       219
   macro avg       0.76      0.78      0.77       219
weighted avg       0.80      0.79      0.79       219

P(AD) q=[5,25,50,75,95]: [0.027 0.059 0.212 0.933 0.981]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8797 at th=0.78
[[ 56  18]
 [ 17 128]]
Score key: 0.8770

--- Epoch 49/80 ---


Train: 100%|██████████| 291/291 [01:49<00:00,  2.65it/s, loss=0.00972]

Train loss: 0.6216





Val avg P(AD) ~= 0.4260 | P(CN) ~= 0.5740

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7900 | F1(CN+):0.8333 | AUC(CN prob):0.8685
Confusion:
 [[ 58  16]
 [ 30 115]]
              precision    recall  f1-score   support

          AD       0.66      0.78      0.72        74
          CN       0.88      0.79      0.83       145

    accuracy                           0.79       219
   macro avg       0.77      0.79      0.77       219
weighted avg       0.80      0.79      0.79       219

P(AD) q=[5,25,50,75,95]: [0.026 0.058 0.212 0.935 0.983]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8787 at th=0.90
[[ 48  26]
 [ 11 134]]
Score key: 0.8776

--- Epoch 50/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.60it/s, loss=0.00437]

Train loss: 0.5296





Val avg P(AD) ~= 0.4259 | P(CN) ~= 0.5741

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7900 | F1(CN+):0.8333 | AUC(CN prob):0.8690
Confusion:
 [[ 58  16]
 [ 30 115]]
              precision    recall  f1-score   support

          AD       0.66      0.78      0.72        74
          CN       0.88      0.79      0.83       145

    accuracy                           0.79       219
   macro avg       0.77      0.79      0.77       219
weighted avg       0.80      0.79      0.79       219

P(AD) q=[5,25,50,75,95]: [0.024 0.054 0.22  0.936 0.983]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8816 at th=0.90
[[ 49  25]
 [ 11 134]]
Score key: 0.8781
✅ Saved new best.

--- Epoch 51/80 ---


Train: 100%|██████████| 291/291 [01:52<00:00,  2.59it/s, loss=0.441]

Train loss: 0.5522





Val avg P(AD) ~= 0.4240 | P(CN) ~= 0.5760

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7945 | F1(CN+):0.8375 | AUC(CN prob):0.8694
Confusion:
 [[ 58  16]
 [ 29 116]]
              precision    recall  f1-score   support

          AD       0.67      0.78      0.72        74
          CN       0.88      0.80      0.84       145

    accuracy                           0.79       219
   macro avg       0.77      0.79      0.78       219
weighted avg       0.81      0.79      0.80       219

P(AD) q=[5,25,50,75,95]: [0.023 0.052 0.208 0.938 0.983]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8845 at th=0.90
[[ 50  24]
 [ 11 134]]
Score key: 0.8786
✅ Saved new best.

--- Epoch 52/80 ---


Train: 100%|██████████| 291/291 [01:53<00:00,  2.56it/s, loss=0.0827]

Train loss: 0.5302





Val avg P(AD) ~= 0.4183 | P(CN) ~= 0.5817

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7991 | F1(CN+):0.8417 | AUC(CN prob):0.8692
Confusion:
 [[ 58  16]
 [ 28 117]]
              precision    recall  f1-score   support

          AD       0.67      0.78      0.72        74
          CN       0.88      0.81      0.84       145

    accuracy                           0.80       219
   macro avg       0.78      0.80      0.78       219
weighted avg       0.81      0.80      0.80       219

P(AD) q=[5,25,50,75,95]: [0.023 0.049 0.198 0.938 0.982]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8837 at th=0.88
[[ 51  23]
 [ 12 133]]
Score key: 0.8785

--- Epoch 53/80 ---


Train: 100%|██████████| 291/291 [01:50<00:00,  2.62it/s, loss=0.636]

Train loss: 0.5182





Val avg P(AD) ~= 0.4204 | P(CN) ~= 0.5796

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.8082 | F1(CN+):0.8489 | AUC(CN prob):0.8704
Confusion:
 [[ 59  15]
 [ 27 118]]
              precision    recall  f1-score   support

          AD       0.69      0.80      0.74        74
          CN       0.89      0.81      0.85       145

    accuracy                           0.81       219
   macro avg       0.79      0.81      0.79       219
weighted avg       0.82      0.81      0.81       219

P(AD) q=[5,25,50,75,95]: [0.021 0.048 0.206 0.938 0.982]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8837 at th=0.88
[[ 51  23]
 [ 12 133]]
Score key: 0.8797
✅ Saved new best.

--- Epoch 54/80 ---


Train: 100%|██████████| 291/291 [01:50<00:00,  2.62it/s, loss=0.0359]

Train loss: 0.5493





Val avg P(AD) ~= 0.4224 | P(CN) ~= 0.5776

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.8082 | F1(CN+):0.8489 | AUC(CN prob):0.8713
Confusion:
 [[ 59  15]
 [ 27 118]]
              precision    recall  f1-score   support

          AD       0.69      0.80      0.74        74
          CN       0.89      0.81      0.85       145

    accuracy                           0.81       219
   macro avg       0.79      0.81      0.79       219
weighted avg       0.82      0.81      0.81       219

P(AD) q=[5,25,50,75,95]: [0.02  0.047 0.213 0.94  0.983]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8816 at th=0.90
[[ 49  25]
 [ 11 134]]
Score key: 0.8806
✅ Saved new best.

--- Epoch 55/80 ---


Train: 100%|██████████| 291/291 [01:53<00:00,  2.57it/s, loss=0.594]

Train loss: 0.4701





Val avg P(AD) ~= 0.4208 | P(CN) ~= 0.5792

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.8082 | F1(CN+):0.8489 | AUC(CN prob):0.8731
Confusion:
 [[ 59  15]
 [ 27 118]]
              precision    recall  f1-score   support

          AD       0.69      0.80      0.74        74
          CN       0.89      0.81      0.85       145

    accuracy                           0.81       219
   macro avg       0.79      0.81      0.79       219
weighted avg       0.82      0.81      0.81       219

P(AD) q=[5,25,50,75,95]: [0.019 0.043 0.205 0.941 0.984]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8816 at th=0.90
[[ 49  25]
 [ 11 134]]
Score key: 0.8824
✅ Saved new best.

--- Epoch 56/80 ---


Train: 100%|██████████| 291/291 [01:50<00:00,  2.62it/s, loss=1.11]

Train loss: 0.5457





Val avg P(AD) ~= 0.4134 | P(CN) ~= 0.5866

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7991 | F1(CN+):0.8429 | AUC(CN prob):0.8759
Confusion:
 [[ 57  17]
 [ 27 118]]
              precision    recall  f1-score   support

          AD       0.68      0.77      0.72        74
          CN       0.87      0.81      0.84       145

    accuracy                           0.80       219
   macro avg       0.78      0.79      0.78       219
weighted avg       0.81      0.80      0.80       219

P(AD) q=[5,25,50,75,95]: [0.018 0.04  0.199 0.936 0.984]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8845 at th=0.90
[[ 50  24]
 [ 11 134]]
Score key: 0.8851
✅ Saved new best.

--- Epoch 57/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.62it/s, loss=0.0304]

Train loss: 0.5351





Val avg P(AD) ~= 0.4060 | P(CN) ~= 0.5940

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7991 | F1(CN+):0.8429 | AUC(CN prob):0.8763
Confusion:
 [[ 57  17]
 [ 27 118]]
              precision    recall  f1-score   support

          AD       0.68      0.77      0.72        74
          CN       0.87      0.81      0.84       145

    accuracy                           0.80       219
   macro avg       0.78      0.79      0.78       219
weighted avg       0.81      0.80      0.80       219

P(AD) q=[5,25,50,75,95]: [0.018 0.038 0.163 0.935 0.983]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8814 at th=0.80
[[ 54  20]
 [ 15 130]]
Score key: 0.8856
✅ Saved new best.

--- Epoch 58/80 ---


Train: 100%|██████████| 291/291 [01:52<00:00,  2.59it/s, loss=0.254]

Train loss: 0.5500





Val avg P(AD) ~= 0.3991 | P(CN) ~= 0.6009

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.8082 | F1(CN+):0.8500 | AUC(CN prob):0.8774
Confusion:
 [[ 58  16]
 [ 26 119]]
              precision    recall  f1-score   support

          AD       0.69      0.78      0.73        74
          CN       0.88      0.82      0.85       145

    accuracy                           0.81       219
   macro avg       0.79      0.80      0.79       219
weighted avg       0.82      0.81      0.81       219

P(AD) q=[5,25,50,75,95]: [0.018 0.037 0.155 0.924 0.983]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8784 at th=0.78
[[ 53  21]
 [ 15 130]]
Score key: 0.8868
✅ Saved new best.

--- Epoch 59/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.61it/s, loss=0.00471]

Train loss: 0.4423





Val avg P(AD) ~= 0.3929 | P(CN) ~= 0.6071

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.8037 | F1(CN+):0.8470 | AUC(CN prob):0.8765
Confusion:
 [[ 57  17]
 [ 26 119]]
              precision    recall  f1-score   support

          AD       0.69      0.77      0.73        74
          CN       0.88      0.82      0.85       145

    accuracy                           0.80       219
   macro avg       0.78      0.80      0.79       219
weighted avg       0.81      0.80      0.81       219

P(AD) q=[5,25,50,75,95]: [0.017 0.034 0.136 0.92  0.983]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8808 at th=0.85
[[ 50  24]
 [ 12 133]]
Score key: 0.8857

--- Epoch 60/80 ---


Train: 100%|██████████| 291/291 [01:50<00:00,  2.63it/s, loss=0.548]

Train loss: 0.5215





Val avg P(AD) ~= 0.3884 | P(CN) ~= 0.6116

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.8037 | F1(CN+):0.8481 | AUC(CN prob):0.8756
Confusion:
 [[ 56  18]
 [ 25 120]]
              precision    recall  f1-score   support

          AD       0.69      0.76      0.72        74
          CN       0.87      0.83      0.85       145

    accuracy                           0.80       219
   macro avg       0.78      0.79      0.79       219
weighted avg       0.81      0.80      0.81       219

P(AD) q=[5,25,50,75,95]: [0.016 0.033 0.13  0.909 0.982]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8787 at th=0.88
[[ 48  26]
 [ 11 134]]
Score key: 0.8849

--- Epoch 61/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.60it/s, loss=0.00589]

Train loss: 0.4922





Val avg P(AD) ~= 0.3832 | P(CN) ~= 0.6168

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.8037 | F1(CN+):0.8481 | AUC(CN prob):0.8741
Confusion:
 [[ 56  18]
 [ 25 120]]
              precision    recall  f1-score   support

          AD       0.69      0.76      0.72        74
          CN       0.87      0.83      0.85       145

    accuracy                           0.80       219
   macro avg       0.78      0.79      0.79       219
weighted avg       0.81      0.80      0.81       219

P(AD) q=[5,25,50,75,95]: [0.016 0.03  0.118 0.904 0.982]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8787 at th=0.88
[[ 48  26]
 [ 11 134]]
Score key: 0.8834

--- Epoch 62/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.61it/s, loss=0.00811]


Train loss: 0.4646
Val avg P(AD) ~= 0.3797 | P(CN) ~= 0.6203

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.8082 | F1(CN+):0.8521 | AUC(CN prob):0.8733
Confusion:
 [[ 56  18]
 [ 24 121]]
              precision    recall  f1-score   support

          AD       0.70      0.76      0.73        74
          CN       0.87      0.83      0.85       145

    accuracy                           0.81       219
   macro avg       0.79      0.80      0.79       219
weighted avg       0.81      0.81      0.81       219

P(AD) q=[5,25,50,75,95]: [0.015 0.029 0.109 0.898 0.981]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8803 at th=0.90
[[ 46  28]
 [  9 136]]
Score key: 0.8826

--- Epoch 63/80 ---


Train: 100%|██████████| 291/291 [01:52<00:00,  2.58it/s, loss=0.164]

Train loss: 0.4861





Val avg P(AD) ~= 0.3733 | P(CN) ~= 0.6267

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.8128 | F1(CN+):0.8561 | AUC(CN prob):0.8703
Confusion:
 [[ 56  18]
 [ 23 122]]
              precision    recall  f1-score   support

          AD       0.71      0.76      0.73        74
          CN       0.87      0.84      0.86       145

    accuracy                           0.81       219
   macro avg       0.79      0.80      0.79       219
weighted avg       0.82      0.81      0.81       219

P(AD) q=[5,25,50,75,95]: [0.014 0.028 0.097 0.89  0.98 ]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8860 at th=0.88
[[ 48  26]
 [  9 136]]
Score key: 0.8796

--- Epoch 64/80 ---


Train: 100%|██████████| 291/291 [01:52<00:00,  2.60it/s, loss=0.00831]

Train loss: 0.3675





Val avg P(AD) ~= 0.3691 | P(CN) ~= 0.6309

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.8174 | F1(CN+):0.8601 | AUC(CN prob):0.8692
Confusion:
 [[ 56  18]
 [ 22 123]]
              precision    recall  f1-score   support

          AD       0.72      0.76      0.74        74
          CN       0.87      0.85      0.86       145

    accuracy                           0.82       219
   macro avg       0.80      0.80      0.80       219
weighted avg       0.82      0.82      0.82       219

P(AD) q=[5,25,50,75,95]: [0.014 0.026 0.09  0.883 0.979]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8860 at th=0.88
[[ 48  26]
 [  9 136]]
Score key: 0.8786

--- Epoch 65/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.61it/s, loss=0.19]

Train loss: 0.4445





Val avg P(AD) ~= 0.3628 | P(CN) ~= 0.6372

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.8219 | F1(CN+):0.8641 | AUC(CN prob):0.8673
Confusion:
 [[ 56  18]
 [ 21 124]]
              precision    recall  f1-score   support

          AD       0.73      0.76      0.74        74
          CN       0.87      0.86      0.86       145

    accuracy                           0.82       219
   macro avg       0.80      0.81      0.80       219
weighted avg       0.82      0.82      0.82       219

P(AD) q=[5,25,50,75,95]: [0.012 0.024 0.073 0.875 0.978]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8852 at th=0.82
[[ 49  25]
 [ 10 135]]
Score key: 0.8768

--- Epoch 66/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.60it/s, loss=0.674]

Train loss: 0.3965





Val avg P(AD) ~= 0.3594 | P(CN) ~= 0.6406

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.8219 | F1(CN+):0.8641 | AUC(CN prob):0.8674
Confusion:
 [[ 56  18]
 [ 21 124]]
              precision    recall  f1-score   support

          AD       0.73      0.76      0.74        74
          CN       0.87      0.86      0.86       145

    accuracy                           0.82       219
   macro avg       0.80      0.81      0.80       219
weighted avg       0.82      0.82      0.82       219

P(AD) q=[5,25,50,75,95]: [0.012 0.022 0.069 0.874 0.978]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8852 at th=0.82
[[ 49  25]
 [ 10 135]]
Score key: 0.8769

--- Epoch 67/80 ---


Train: 100%|██████████| 291/291 [01:52<00:00,  2.58it/s, loss=0.00674]

Train loss: 0.4281





Val avg P(AD) ~= 0.3544 | P(CN) ~= 0.6456

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.8265 | F1(CN+):0.8681 | AUC(CN prob):0.8672
Confusion:
 [[ 56  18]
 [ 20 125]]
              precision    recall  f1-score   support

          AD       0.74      0.76      0.75        74
          CN       0.87      0.86      0.87       145

    accuracy                           0.83       219
   macro avg       0.81      0.81      0.81       219
weighted avg       0.83      0.83      0.83       219

P(AD) q=[5,25,50,75,95]: [0.011 0.021 0.063 0.874 0.98 ]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8852 at th=0.82
[[ 49  25]
 [ 10 135]]
Score key: 0.8767

--- Epoch 68/80 ---


Train: 100%|██████████| 291/291 [01:54<00:00,  2.55it/s, loss=0.517]

Train loss: 0.4164





Val avg P(AD) ~= 0.3549 | P(CN) ~= 0.6451

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.8265 | F1(CN+):0.8681 | AUC(CN prob):0.8672
Confusion:
 [[ 56  18]
 [ 20 125]]
              precision    recall  f1-score   support

          AD       0.74      0.76      0.75        74
          CN       0.87      0.86      0.87       145

    accuracy                           0.83       219
   macro avg       0.81      0.81      0.81       219
weighted avg       0.83      0.83      0.83       219

P(AD) q=[5,25,50,75,95]: [0.01  0.02  0.062 0.884 0.981]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8852 at th=0.82
[[ 49  25]
 [ 10 135]]
Score key: 0.8767

--- Epoch 69/80 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.62it/s, loss=0.00386]

Train loss: 0.3992





Val avg P(AD) ~= 0.3525 | P(CN) ~= 0.6475

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.8265 | F1(CN+):0.8681 | AUC(CN prob):0.8668
Confusion:
 [[ 56  18]
 [ 20 125]]
              precision    recall  f1-score   support

          AD       0.74      0.76      0.75        74
          CN       0.87      0.86      0.87       145

    accuracy                           0.83       219
   macro avg       0.81      0.81      0.81       219
weighted avg       0.83      0.83      0.83       219

P(AD) q=[5,25,50,75,95]: [0.01  0.02  0.06  0.889 0.982]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8852 at th=0.82
[[ 49  25]
 [ 10 135]]
Score key: 0.8763

--- Epoch 70/80 ---


Train: 100%|██████████| 291/291 [01:50<00:00,  2.64it/s, loss=0.0518]

Train loss: 0.4367





Val avg P(AD) ~= 0.3509 | P(CN) ~= 0.6491

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.8265 | F1(CN+):0.8681 | AUC(CN prob):0.8669
Confusion:
 [[ 56  18]
 [ 20 125]]
              precision    recall  f1-score   support

          AD       0.74      0.76      0.75        74
          CN       0.87      0.86      0.87       145

    accuracy                           0.83       219
   macro avg       0.81      0.81      0.81       219
weighted avg       0.83      0.83      0.83       219

P(AD) q=[5,25,50,75,95]: [0.009 0.019 0.062 0.892 0.981]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8852 at th=0.80
[[ 49  25]
 [ 10 135]]
Score key: 0.8764
⏹️ Early stopping.

=== BEST (reloaded) ===

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.8082 | F1(CN+):0.8500 | AUC(CN prob):0.8774
Confusion:
 [[ 58  16]
 [ 26 119]]
              precision    recall  f1-score   support

          AD       0.69      0.78      0.73        74
          CN       0.88      0.82      0.85       145

    acc

Train: 100%|██████████| 291/291 [01:50<00:00,  2.64it/s, loss=2.02]

Train loss: 1.1847





Val avg P(AD) ~= 0.4376 | P(CN) ~= 0.5624

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.8128 | F1(CN+):0.8520 | AUC(CN prob):0.8763
Confusion:
 [[ 60  14]
 [ 27 118]]
              precision    recall  f1-score   support

          AD       0.69      0.81      0.75        74
          CN       0.89      0.81      0.85       145

    accuracy                           0.81       219
   macro avg       0.79      0.81      0.80       219
weighted avg       0.82      0.81      0.82       219

P(AD) q=[5,25,50,75,95]: [0.024 0.056 0.285 0.939 0.983]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8808 at th=0.90
[[ 50  24]
 [ 12 133]]
Score key: 0.8857
✅ Saved new best.

--- Epoch 2/20 ---


Train: 100%|██████████| 291/291 [01:50<00:00,  2.63it/s, loss=0.314]


Train loss: 1.0624
Val avg P(AD) ~= 0.4672 | P(CN) ~= 0.5328

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7763 | F1(CN+):0.8165 | AUC(CN prob):0.8719
Confusion:
 [[ 61  13]
 [ 36 109]]
              precision    recall  f1-score   support

          AD       0.63      0.82      0.71        74
          CN       0.89      0.75      0.82       145

    accuracy                           0.78       219
   macro avg       0.76      0.79      0.76       219
weighted avg       0.80      0.78      0.78       219

P(AD) q=[5,25,50,75,95]: [0.032 0.074 0.362 0.943 0.982]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8704 at th=0.92
[[ 49  25]
 [ 14 131]]
Score key: 0.8808

--- Epoch 3/20 ---


Train: 100%|██████████| 291/291 [01:49<00:00,  2.67it/s, loss=0.195]

Train loss: 0.7399





Val avg P(AD) ~= 0.4893 | P(CN) ~= 0.5107

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7671 | F1(CN+):0.8061 | AUC(CN prob):0.8675
Confusion:
 [[ 62  12]
 [ 39 106]]
              precision    recall  f1-score   support

          AD       0.61      0.84      0.71        74
          CN       0.90      0.73      0.81       145

    accuracy                           0.77       219
   macro avg       0.76      0.78      0.76       219
weighted avg       0.80      0.77      0.77       219

P(AD) q=[5,25,50,75,95]: [0.044 0.1   0.441 0.943 0.982]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8649 at th=0.90
[[ 51  23]
 [ 17 128]]
Score key: 0.8763

--- Epoch 4/20 ---


Train: 100%|██████████| 291/291 [01:49<00:00,  2.65it/s, loss=0.0638]

Train loss: 0.7258





Val avg P(AD) ~= 0.5112 | P(CN) ~= 0.4888

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7306 | F1(CN+):0.7686 | AUC(CN prob):0.8659
Confusion:
 [[62 12]
 [47 98]]
              precision    recall  f1-score   support

          AD       0.57      0.84      0.68        74
          CN       0.89      0.68      0.77       145

    accuracy                           0.73       219
   macro avg       0.73      0.76      0.72       219
weighted avg       0.78      0.73      0.74       219

P(AD) q=[5,25,50,75,95]: [0.057 0.128 0.481 0.94  0.981]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8649 at th=0.90
[[ 51  23]
 [ 17 128]]
Score key: 0.8744

--- Epoch 5/20 ---


Train: 100%|██████████| 291/291 [01:49<00:00,  2.66it/s, loss=0.811]

Train loss: 0.5728





Val avg P(AD) ~= 0.5243 | P(CN) ~= 0.4757

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7169 | F1(CN+):0.7540 | AUC(CN prob):0.8665
Confusion:
 [[62 12]
 [50 95]]
              precision    recall  f1-score   support

          AD       0.55      0.84      0.67        74
          CN       0.89      0.66      0.75       145

    accuracy                           0.72       219
   macro avg       0.72      0.75      0.71       219
weighted avg       0.77      0.72      0.72       219

P(AD) q=[5,25,50,75,95]: [0.065 0.151 0.521 0.937 0.981]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8684 at th=0.92
[[ 47  27]
 [ 13 132]]
Score key: 0.8748

--- Epoch 6/20 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.62it/s, loss=0.594]

Train loss: 0.6689





Val avg P(AD) ~= 0.5293 | P(CN) ~= 0.4707

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7215 | F1(CN+):0.7570 | AUC(CN prob):0.8646
Confusion:
 [[63 11]
 [50 95]]
              precision    recall  f1-score   support

          AD       0.56      0.85      0.67        74
          CN       0.90      0.66      0.76       145

    accuracy                           0.72       219
   macro avg       0.73      0.75      0.72       219
weighted avg       0.78      0.72      0.73       219

P(AD) q=[5,25,50,75,95]: [0.069 0.156 0.527 0.933 0.979]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8684 at th=0.92
[[ 47  27]
 [ 13 132]]
Score key: 0.8729

--- Epoch 7/20 ---


Train: 100%|██████████| 291/291 [01:49<00:00,  2.66it/s, loss=0.0765]

Train loss: 0.6507





Val avg P(AD) ~= 0.5236 | P(CN) ~= 0.4764

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7306 | F1(CN+):0.7686 | AUC(CN prob):0.8621
Confusion:
 [[62 12]
 [47 98]]
              precision    recall  f1-score   support

          AD       0.57      0.84      0.68        74
          CN       0.89      0.68      0.77       145

    accuracy                           0.73       219
   macro avg       0.73      0.76      0.72       219
weighted avg       0.78      0.73      0.74       219

P(AD) q=[5,25,50,75,95]: [0.066 0.146 0.498 0.931 0.979]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8713 at th=0.92
[[ 48  26]
 [ 13 132]]
Score key: 0.8705

--- Epoch 8/20 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.61it/s, loss=0.0101]

Train loss: 0.6603





Val avg P(AD) ~= 0.5087 | P(CN) ~= 0.4913

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7397 | F1(CN+):0.7799 | AUC(CN prob):0.8613
Confusion:
 [[ 61  13]
 [ 44 101]]
              precision    recall  f1-score   support

          AD       0.58      0.82      0.68        74
          CN       0.89      0.70      0.78       145

    accuracy                           0.74       219
   macro avg       0.73      0.76      0.73       219
weighted avg       0.78      0.74      0.75       219

P(AD) q=[5,25,50,75,95]: [0.058 0.125 0.458 0.929 0.979]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8656 at th=0.92
[[ 46  28]
 [ 13 132]]
Score key: 0.8699

--- Epoch 9/20 ---


Train: 100%|██████████| 291/291 [01:51<00:00,  2.61it/s, loss=0.0402]


Train loss: 0.6276
Val avg P(AD) ~= 0.4974 | P(CN) ~= 0.5026

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.7489 | F1(CN+):0.7893 | AUC(CN prob):0.8603
Confusion:
 [[ 61  13]
 [ 42 103]]
              precision    recall  f1-score   support

          AD       0.59      0.82      0.69        74
          CN       0.89      0.71      0.79       145

    accuracy                           0.75       219
   macro avg       0.74      0.77      0.74       219
weighted avg       0.79      0.75      0.76       219

P(AD) q=[5,25,50,75,95]: [0.053 0.109 0.423 0.926 0.98 ]
--- Threshold sweep on P(AD) ---
Best F1(CN+)=0.8644 at th=0.95
[[ 39  35]
 [  8 137]]
Score key: 0.8690
⏹️ Early stopping.

=== BEST (reloaded) ===

=== Eval (BCE single-logit, th=0.50) ===
Acc:0.8128 | F1(CN+):0.8520 | AUC(CN prob):0.8763
Confusion:
 [[ 60  14]
 [ 27 118]]
              precision    recall  f1-score   support

          AD       0.69      0.81      0.75        74
          CN       0.89      0.81      0.85

(0.8350515463917526,
 np.float64(0.8416219439475254),
 array([[ 56,   9],
        [ 23, 106]]))