# Tripleknock Revision — Pair-disjoint Cross-validation (E. coli iML1515)

本 Notebook **只做双基因 pair 不重复（Pair-disjoint）** 的 5-fold 交叉验证，用来回应 Reviewer 关于 **gene-pair leakage** 的质疑。

✅ 核心思想（Strict pair-disjoint / 严格双基因不重复）  
对每个 gene pair（AB）分配一个 fold：`fold(pair)=hash(pair) % K`  
一个三基因组合 (A,B,C) 只在以下条件成立时被允许进入某个 fold：

- `fold(AB) == fold(AC) == fold(BC)`

否则该 triple 会被丢弃（跨 fold 的 triple 直接丢弃）。

这样做的好处：

- ✅ Train/Test 之间 **不会共享任何 gene pair**
- ✅ 不会出现 “ABC 导致 AB/AC/BC 分到不同 fold 的矛盾” —— 因为这种 triple 会直接被过滤掉

---

## 事先准备好的对象
- `two_mer_dict`：每个基因的 400 维 2-mer 特征向量（代码里已生成）
- `ae1_2({g1,g2,g3})`：输出 (3,400) tensor（代码里已定义并测试成功）
- `device`：GPU device


In [1]:
# =========================
# Part 0: Imports + Global Config
# =========================
import os, time, random, traceback, math, gc, hashlib
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim

from sklearn.metrics import (
    roc_auc_score, confusion_matrix, classification_report
)
from sklearn.model_selection import train_test_split


# ---- device ----
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device:', device)

# =========================
# Adjustable parameters / 可调参数
# =========================

# ---- Data files ----
FILE_PARTS = [
    '/data1/xpgeng/cross_pathogen/FBA/iML1515_parts/iML1515-1.csv',
    # '/data1/xpgeng/cross_pathogen/FBA/iML1515_parts/iML1515-2.csv',
]

# ---- Reproducibility ----
SEED = 42
N_FOLDS = 5
PAIR_HASH_SEED = 2026

# ---- Streaming CSV read ----
CSV_CHUNK_SIZE = 200_000

# ---- Build fold pools (sampling) ----
POOL_TARGET_PER_FOLD = 400_000                             ### change!!! 1

# ---- Sampling sizes inside each CV fold ----
# We use 5 folds; each fold contains 200k samples (total 1,000,000).
# In each CV iteration: 1 fold is Test (=200k), remaining 4 folds form the TrainPool (=800k).
# Then we split TrainPool into Train/Val by VAL_FRACTION (default 10%).
TRAIN_SIZE_PER_FOLD = None      # None => use full TrainPool
VAL_FRACTION        = 0.10      # 10% of TrainPool as validation
TEST_SIZE_PER_FOLD  = 400_000   # should match POOL_TARGET_PER_FOLD                            ### change!!! 2

USE_STRATIFIED_SAMPLE = True

# ---- Chunk training to avoid GPU OOM ----
TRAIN_CHUNK_SIZE = 20_000
BATCH_SIZE = 512

# ---- Optimization ----
LR = 5e-4
WEIGHT_DECAY = 1e-3
MAX_EPOCHS = 12
PATIENCE = 3
MIN_DELTA = 5e-4

# ---- Model architecture (unchanged; only dropout adjustable) ----
DROPOUT_RATE = 0.5

# ---- Feature options ----
REST_SCALE = 0.1
NORM_MODE = 'block'   # 'block' | 'per_sample' | 'none'

# ---- Threshold search options ----
THRESH_METRIC = 'youden'
MIN_PRECISION_POS = None
THRESH_MIN  = 0.05
THRESH_MAX  = 0.95
THRESH_STEP = 0.005

# ---- Plot saving ----
SAVE_PLOTS = True
PLOT_PREFIX = "pair_gene_disjoint_cv_iML1515-1-2000k_data"                             ### change!!! 3

# -------------------------
# Logging
# -------------------------
workdir = os.getcwd()
log_file = os.path.join(workdir, f'{PLOT_PREFIX}_log.txt')
err_file = os.path.join(workdir, f'{PLOT_PREFIX}_err.txt')

def log_print(msg: str):
    ts = time.strftime('%Y-%m-%d %H:%M:%S')
    line = f'[{ts}] {msg}'
    with open(log_file, 'a') as f:
        f.write(line + "\n")
    print(line)

def log_error(msg: str):
    ts = time.strftime('%Y-%m-%d %H:%M:%S')
    line = f'[{ts}] {msg}'
    with open(err_file, 'a') as f:
        f.write(line + "\n")
    print(line)

# ---- Reproducibility ----
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

log_print('Config loaded.')

device: cuda:0
[2026-02-15 15:43:37] Config loaded.


## Helper utilities / 工具函数

包含：

- `stratified_subsample()`：分层抽样保持 0/1 比例  
- `find_best_threshold()`：在 val 上找阈值（youden / balanced_acc / f1_pos）  
- `stable_pair_fold()`：为每个 pair 赋予 fold（稳定、可复现，不依赖 Python 内置 hash）


In [2]:
# =========================
# Part 1: Utilities
# =========================

def stratified_subsample(df: pd.DataFrame, n: int, seed: int = 0):
    # Simple stratified sampling to keep y=1 ratio similar
    if n is None or n <= 0 or n >= len(df):
        return df.reset_index(drop=True)

    rng = np.random.default_rng(seed)
    pos = df[df['y'] == 1]
    neg = df[df['y'] == 0]

    if len(pos) == 0 or len(neg) == 0:
        return df.sample(n=n, random_state=seed).reset_index(drop=True)

    pos_n = int(n * (len(pos) / len(df)))
    neg_n = n - pos_n

    pos_idx = rng.choice(pos.index.to_numpy(), size=min(pos_n, len(pos)), replace=False)
    neg_idx = rng.choice(neg.index.to_numpy(), size=min(neg_n, len(neg)), replace=False)

    out = pd.concat([df.loc[pos_idx], df.loc[neg_idx]], axis=0).sample(frac=1.0, random_state=seed)
    out = out.reset_index(drop=True)

    if len(out) < n:
        remain = df.drop(out.index, errors='ignore')
        extra = remain.sample(n=min(n-len(out), len(remain)), random_state=seed)
        out = pd.concat([out, extra], axis=0).sample(frac=1.0, random_state=seed).reset_index(drop=True)

    return out

def find_best_threshold(
    y_true,
    y_prob,
    metric="youden",
    min_precision_pos=None,
    t_min=0.01,
    t_max=0.99,
    t_step=0.01,
):
    y_true = np.asarray(y_true).astype(int)
    y_prob = np.asarray(y_prob).astype(float)

    best = None

    for thr in np.arange(t_min, t_max + 1e-12, t_step):
        y_hat = (y_prob >= thr).astype(int)
        tn, fp, fn, tp = confusion_matrix(y_true, y_hat, labels=[0, 1]).ravel()

        tpr = tp / (tp + fn + 1e-12)
        fpr = fp / (fp + tn + 1e-12)
        tnr = tn / (tn + fp + 1e-12)

        precision_pos = tp / (tp + fp + 1e-12)
        recall_pos = tpr
        f1_pos = 2 * precision_pos * recall_pos / (precision_pos + recall_pos + 1e-12)

        if min_precision_pos is not None and precision_pos < min_precision_pos:
            continue

        if metric == "youden":
            score = tpr - fpr
        elif metric == "balanced_acc":
            score = 0.5 * (tpr + tnr)
        elif metric == "f1_pos":
            score = f1_pos
        else:
            score = tpr - fpr

        if (best is None) or (score > best["score"]):
            best = {
                "threshold": float(thr),
                "score": float(score),
                "f1_pos": float(f1_pos),
                "precision_pos": float(precision_pos),
                "recall_pos": float(recall_pos),
                "tn": int(tn),
                "fp": int(fp),
                "fn": int(fn),
                "tp": int(tp),
            }

    return best

# ---- Stable pair -> fold assignment ----
_pair_cache = {}

def stable_pair_fold(gA: str, gB: str, K: int = 5, seed: int = 2026) -> int:
    a = str(gA).strip()
    b = str(gB).strip()
    if a > b:
        a, b = b, a

    key = (a, b)
    if key in _pair_cache:
        return _pair_cache[key]

    s = f"{a}|{b}|{seed}".encode("utf-8")
    h = hashlib.blake2b(s, digest_size=8).digest()   # stable 64-bit
    v = int.from_bytes(h, byteorder="little", signed=False)
    fold = int(v % K)
    _pair_cache[key] = fold
    return fold

## Feature preparation / 特征准备

这部分直接复用原来的 AE 与 2-mer 构建逻辑（保持不变）。


In [3]:
# =========================
# Part 2: Build two_mer_dict (2-mer features) + load Autoencoders + define ae1_2
# =========================

# ---- (A) Build two_mer_dict from FASTA ----
# 如果你已经有 two_mer_dict，可以把 BUILD_2MER=False，然后跳过。
BUILD_2MER = True

if BUILD_2MER:
    from Bio import SeqIO
    from collections import Counter

    fasta_path = '/data1/xpgeng/cross_pathogen/autoencoder/E.coli.tag_seq.fasta'

    def read_fasta(fp):
        gene_sequence_dict = {}
        for record in SeqIO.parse(fp, 'fasta'):
            gene_sequence_dict[record.id] = str(record.seq)
        return gene_sequence_dict

    gene_sequence_dict = read_fasta(fasta_path)
    all_genes = set(gene_sequence_dict.keys())

    print('Total genes in FASTA:', len(all_genes))
    print('Example:', list(gene_sequence_dict.items())[:1])

    standard_amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
    all_2mers = [a + b for a in standard_amino_acids for b in standard_amino_acids]
    two_mer_index = {two_mer: idx for idx, two_mer in enumerate(all_2mers)}

    two_mer_dict = {}

    for gene, sequence in tqdm(gene_sequence_dict.items(), desc='Building two_mer_dict'):
        sequence = ''.join([aa for aa in sequence if aa in standard_amino_acids])

        if len(sequence) < 2:
            two_mer_dict[gene] = np.zeros(400, dtype=np.float32)
            continue

        two_mer_counts = Counter(sequence[i:i+2] for i in range(len(sequence)-1))
        total_two_mers = sum(two_mer_counts.values())

        feature_vector = np.zeros(400, dtype=np.float32)
        for two_mer, count in two_mer_counts.items():
            idx = two_mer_index.get(two_mer)
            if idx is not None:
                feature_vector[idx] = count / total_two_mers

        two_mer_dict[gene] = feature_vector

    for gene, vec in list(two_mer_dict.items())[:3]:
        print(gene, vec[:10])

# ---- (B) Load Autoencoders and define ae1_2 ----
# 你可以保留你的架构和权重加载方式（与原来一致）

LOAD_AUTOENCODERS = True

if LOAD_AUTOENCODERS:

    class Autoencoder(torch.nn.Module):
        def __init__(self):
            super(Autoencoder, self).__init__()
            self.encoder = torch.nn.Sequential(
                torch.nn.Linear(400, 256),
                torch.nn.ReLU(),
                torch.nn.Dropout(0.35),
                torch.nn.Linear(256, 128),
                torch.nn.ReLU(),
                torch.nn.Dropout(0.35),
                torch.nn.Linear(128, 3),
            )
            self.decoder = torch.nn.Sequential(
                torch.nn.Linear(3, 128),
                torch.nn.ReLU(),
                torch.nn.Linear(128, 256),
                torch.nn.ReLU(),
                torch.nn.Linear(256, 400),
            )

        def forward(self, x):
            encoded = self.encoder(x)
            decoded = self.decoder(encoded)
            return decoded


    class Autoencoder2(torch.nn.Module):
        def __init__(self):
            super(Autoencoder2, self).__init__()
            self.encoder = torch.nn.Sequential(
                torch.nn.Linear(4304, 3000),
                torch.nn.ReLU(),
                torch.nn.Dropout(0.2),
                torch.nn.Linear(3000, 1000),
                torch.nn.ReLU(),
                torch.nn.Dropout(0.3),
                torch.nn.Linear(1000, 400),
            )
            self.decoder = torch.nn.Sequential(
                torch.nn.Linear(400, 1000),
                torch.nn.ReLU(),
                torch.nn.Linear(1000, 3000),
                torch.nn.ReLU(),
                torch.nn.Linear(3000, 4304),
            )

        def forward(self, x):
            encoded = self.encoder(x)
            decoded = self.decoder(encoded)
            return decoded


    model = Autoencoder().to(device)
    model.load_state_dict(torch.load('/data1/xpgeng/cross_pathogen/autoencoder/ae1_all_data_training.pth', map_location=device))
    model.eval()

    model2 = Autoencoder2().to(device)
    model2.load_state_dict(torch.load('/data1/xpgeng/cross_pathogen/autoencoder/ae2_all_data_training.pth', map_location=device))
    model2.eval()


@torch.no_grad()
def ae1_2(three_genes):
    # 输入：set({g1,g2,g3})
    # 输出：(3,400) tensor
    rest_genes = list(all_genes - three_genes)
    inputs = np.vstack([two_mer_dict[gene] for gene in rest_genes]).astype(np.float32)

    zeros_400 = np.zeros((2, 400), dtype=np.float32)
    inputs = np.vstack([inputs, zeros_400])

    inputs = torch.tensor(inputs).to(device)
    inputs = model.encoder(inputs)
    inputs = inputs.cpu().detach().numpy().T
    inputs = torch.tensor(inputs).to(device)
    outputs = model2.encoder(inputs)

    return outputs

log_print('two_mer_dict & ae1_2 ready.')


Total genes in FASTA: 4305
Example: [('b0001', 'MKRISTTITTTITITTGNGAG')]


Building two_mer_dict: 100%|█████████████████████████████████████████████| 4305/4305 [00:00<00:00, 7640.64it/s]


b0001 [0.   0.   0.   0.   0.   0.05 0.   0.   0.   0.  ]
b0002 [0.01587302 0.001221   0.00854701 0.01098901 0.002442   0.00854701
 0.         0.003663   0.00610501 0.00732601]
b0003 [0.01294498 0.00647249 0.00647249 0.01294498 0.         0.00647249
 0.00323625 0.00323625 0.00323625 0.01294498]
[2026-02-15 15:44:09] two_mer_dict & ae1_2 ready.


# Part A — Pair-disjoint CV (Strict unseen gene-pairs)

In [4]:
# =========================
# Part A1: Read CSV(s) and build Pair-disjoint fold pools (streaming)
# =========================

def build_pair_disjoint_fold_pools(
    file_parts,
    K=5,
    pair_seed=2026,
    pool_target=200_000,
    chunk_size=200_000,
):
    pools = [[] for _ in range(K)]
    kept = 0
    dropped = 0

    def pools_full():
        return all(len(p) >= pool_target for p in pools)

    log_print("Reading CSV parts (streaming) and building pair-disjoint pools...")

    for fp in file_parts:
        log_print(f"File: {fp}")
        reader = pd.read_csv(
            fp,
            header=None,
            names=['g1','g2','g3','y'],
            chunksize=chunk_size,
            dtype={0:str, 1:str, 2:str, 3:'int8'},
            engine='c',
            low_memory=False
        )

        for chunk_id, dfc in enumerate(reader, start=1):
            dfc['g1'] = dfc['g1'].astype(str).str.strip()
            dfc['g2'] = dfc['g2'].astype(str).str.strip()
            dfc['g3'] = dfc['g3'].astype(str).str.strip()
            dfc['y']  = dfc['y'].astype('int8')

            for g1, g2, g3, y in dfc.itertuples(index=False, name=None):
                f12 = stable_pair_fold(g1, g2, K=K, seed=pair_seed)
                f13 = stable_pair_fold(g1, g3, K=K, seed=pair_seed)
                if f12 != f13:
                    dropped += 1
                    continue
                f23 = stable_pair_fold(g2, g3, K=K, seed=pair_seed)
                if f12 != f23:
                    dropped += 1
                    continue

                fold = f12
                if len(pools[fold]) < pool_target:
                    pools[fold].append((g1, g2, g3, int(y)))
                    kept += 1

                if pools_full():
                    break

            if chunk_id % 5 == 0:
                sizes = [len(p) for p in pools]
                log_print(f"  chunk={chunk_id} | kept={kept:,} dropped={dropped:,} | pool_sizes={sizes}")

            if pools_full():
                log_print("All fold pools reached target, stop reading more data.")
                break

        if pools_full():
            break

    fold_dfs = []
    for i in range(K):
        dfi = pd.DataFrame(pools[i], columns=['g1','g2','g3','y'])
        fold_dfs.append(dfi)

    log_print(f"Pair-disjoint pools built. kept={kept:,}, dropped={dropped:,}")
    log_print("Final fold pool sizes: " + str([len(x) for x in fold_dfs]))
    return fold_dfs

pair_fold_pools = build_pair_disjoint_fold_pools(
    FILE_PARTS,
    K=N_FOLDS,
    pair_seed=PAIR_HASH_SEED,
    pool_target=POOL_TARGET_PER_FOLD,
    chunk_size=CSV_CHUNK_SIZE,
)

for i, dfi in enumerate(pair_fold_pools):
    if len(dfi) > 0:
        log_print(f"[FoldPool {i}] y=1 ratio = {dfi['y'].mean():.4f}")

[2026-02-15 15:45:55] Reading CSV parts (streaming) and building pair-disjoint pools...
[2026-02-15 15:45:55] File: /data1/xpgeng/cross_pathogen/FBA/iML1515_parts/iML1515-1.csv
[2026-02-15 15:45:59]   chunk=5 | kept=40,301 dropped=959,699 | pool_sizes=[7975, 8166, 8017, 8075, 8068]
[2026-02-15 15:46:03]   chunk=10 | kept=80,126 dropped=1,919,874 | pool_sizes=[15962, 16241, 15987, 15958, 15978]
[2026-02-15 15:46:06]   chunk=15 | kept=119,709 dropped=2,880,291 | pool_sizes=[23889, 24216, 23936, 23764, 23904]
[2026-02-15 15:46:09]   chunk=20 | kept=159,653 dropped=3,840,347 | pool_sizes=[31785, 32450, 31881, 31607, 31930]
[2026-02-15 15:46:12]   chunk=25 | kept=199,725 dropped=4,800,275 | pool_sizes=[39711, 40559, 39897, 39564, 39994]
[2026-02-15 15:46:16]   chunk=30 | kept=239,780 dropped=5,760,220 | pool_sizes=[47631, 48583, 47849, 47653, 48064]
[2026-02-15 15:46:19]   chunk=35 | kept=279,996 dropped=6,720,004 | pool_sizes=[55447, 56712, 56076, 55677, 56084]
[2026-02-15 15:46:22]   chun

## Part A2 — Train MLP + record train/val loss and val AUC

In [None]:
# =========================
# Part A2: Feature builder + MLP + CV training (Pair-disjoint)
# =========================

def zscore(vec, eps=1e-8):
    return (vec - vec.mean()) / (vec.std() + eps)

@torch.no_grad()
def build_feature_vector(g1, g2, g3, rest_scale=0.05, norm_mode='block'):
    three = np.array([two_mer_dict[g] for g in [g1, g2, g3]], dtype=np.float32).flatten()
    rest = ae1_2({g1, g2, g3}).detach().cpu().numpy().astype(np.float32).flatten()

    if norm_mode == 'block':
        three = zscore(three)
        rest  = zscore(rest)

    feat = np.concatenate([three, rest * rest_scale], axis=0).astype(np.float32)

    if norm_mode == 'per_sample':
        feat = zscore(feat)

    return feat

def build_XY_from_df(dfx: pd.DataFrame, rest_scale=0.05, norm_mode='block', desc='Build XY'):
    triples = dfx[['g1','g2','g3']].values.tolist()
    y = dfx['y'].values.astype(np.int64)

    X = np.zeros((len(triples), 2400), dtype=np.float32)
    for i, (g1, g2, g3) in enumerate(tqdm(triples, desc=desc, leave=False)):
        X[i] = build_feature_vector(g1, g2, g3, rest_scale=rest_scale, norm_mode=norm_mode)

    return X, y

class MLP(nn.Module):
    def __init__(self, input_size=2400, hidden_size1=512, hidden_size2=256, output_size=1, dropout_rate=0.5):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.dropout2 = nn.Dropout(dropout_rate)
        self.fc3 = nn.Linear(hidden_size2, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout1(x)
        x = torch.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return self.sigmoid(x)

def train_one_epoch_streaming(
    model, optimizer, criterion,
    df_train_pool: pd.DataFrame,
    batch_size=512,
    train_chunk_size=20000,
    rest_scale=0.05,
    norm_mode='block',
    fold=0,
    epoch=1,
):
    model.train()

    idx = np.arange(len(df_train_pool))
    rng = np.random.default_rng(SEED + fold * 100000 + epoch)
    rng.shuffle(idx)

    total_loss = 0.0
    seen = 0

    for start in range(0, len(idx), train_chunk_size):
        chunk_idx = idx[start:start+train_chunk_size]
        dfx = df_train_pool.iloc[chunk_idx].reset_index(drop=True)

        Xc, yc = build_XY_from_df(
            dfx, rest_scale=rest_scale, norm_mode=norm_mode,
            desc=f"Train chunk {start//train_chunk_size+1}"
        )

        ds = torch.utils.data.TensorDataset(
            torch.from_numpy(Xc),
            torch.from_numpy(yc.astype(np.float32))
        )
        dl = torch.utils.data.DataLoader(
            ds,
            batch_size=batch_size,
            shuffle=True,
            drop_last=False,
            num_workers=0,
            pin_memory=torch.cuda.is_available()
        )

        for xb, yb in dl:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)
            pred = model(xb).view(-1)
            loss = criterion(pred, yb)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * len(yb)
            seen += len(yb)

        del Xc, yc, ds, dl, dfx
        torch.cuda.empty_cache()

    return total_loss / max(seen, 1)

@torch.no_grad()
def eval_loss_and_auc(model, X_np, y_np, criterion, batch_size=2048):
    model.eval()
    y_np = np.asarray(y_np).astype(np.int64)

    total_loss = 0.0
    seen = 0
    probs_all = []

    for i in range(0, len(X_np), batch_size):
        xb = torch.tensor(X_np[i:i+batch_size], dtype=torch.float32).to(device)
        yb = torch.tensor(y_np[i:i+batch_size], dtype=torch.float32).to(device)

        pb = model(xb).view(-1)
        loss = criterion(pb, yb)

        total_loss += loss.item() * len(yb)
        seen += len(yb)
        probs_all.append(pb.detach().cpu().numpy())

    probs_all = np.concatenate(probs_all, axis=0)
    auc = roc_auc_score(y_np, probs_all)
    return total_loss / max(seen, 1), auc, probs_all

def stratified_train_val_split(df: pd.DataFrame, val_frac: float = 0.10, seed: int = 0):
    """Split df into train/val with stratification on y and no sample overlap."""
    if val_frac <= 0:
        return df.reset_index(drop=True), df.iloc[0:0].copy()
    if val_frac >= 1:
        return df.iloc[0:0].copy(), df.reset_index(drop=True)

    # if only one class exists, fall back to random split without stratify
    y_unique = df['y'].nunique(dropna=False)
    if y_unique < 2:
        rng = np.random.default_rng(seed)
        idx = np.arange(len(df))
        rng.shuffle(idx)
        n_val = int(round(len(df) * val_frac))
        val_idx = idx[:n_val]
        train_idx = idx[n_val:]
        df_train = df.iloc[train_idx].reset_index(drop=True)
        df_val = df.iloc[val_idx].reset_index(drop=True)
        return df_train, df_val

    train_idx, val_idx = train_test_split(
        df.index.values,
        test_size=val_frac,
        random_state=seed,
        shuffle=True,
        stratify=df['y']
    )
    df_train = df.loc[train_idx].reset_index(drop=True)
    df_val   = df.loc[val_idx].reset_index(drop=True)
    return df_train, df_val

def run_pair_disjoint_cv(pair_fold_pools):
    aucs = []

    for fold in range(N_FOLDS):
        log_print("\n" + "="*18 + f" Pair Fold {fold} " + "="*18)

        df_test_pool = pair_fold_pools[fold].reset_index(drop=True)
        df_train_pool = pd.concat(
            [pair_fold_pools[i] for i in range(N_FOLDS) if i != fold],
            axis=0
        ).reset_index(drop=True)

        log_print(f"Train pool={len(df_train_pool):,} | Test pool={len(df_test_pool):,}")

        # -----------------------------
        # Train/Val/Test construction
        # -----------------------------
        if TRAIN_SIZE_PER_FOLD is None:
            df_train_all = df_train_pool
        else:
            df_train_all = stratified_subsample(df_train_pool, TRAIN_SIZE_PER_FOLD, seed=SEED + fold)

        # Split TrainPool -> Train / Val (10% by default), no overlap and stratified on y.
        df_train, df_val = stratified_train_val_split(
            df_train_all,
            val_frac=VAL_FRACTION,
            seed=SEED + fold + 100
        )

        # Test = one fold (200k). Keep stratified_subsample for safety if pool > TEST_SIZE_PER_FOLD.
        df_test = stratified_subsample(
            df_test_pool,
            TEST_SIZE_PER_FOLD,
            seed=SEED + fold + 200
        )

        log_print(f"Train={len(df_train):,}, Val={len(df_val):,}, Test={len(df_test):,}")
        log_print(
            f"y_train mean={df_train.y.mean():.4f}, "
            f"y_val mean={df_val.y.mean():.4f}, "
            f"y_test mean={df_test.y.mean():.4f}"
        )

        # -----------------------------
        # Build features for VAL/TEST once per fold
        # -----------------------------
        X_val, y_val = build_XY_from_df(
            df_val, rest_scale=REST_SCALE, norm_mode=NORM_MODE,
            desc=f"Fold {fold} VAL"
        )
        X_test, y_test = build_XY_from_df(
            df_test, rest_scale=REST_SCALE, norm_mode=NORM_MODE,
            desc=f"Fold {fold} TEST"
        )

        # -----------------------------
        # Model / Optimizer / Loss
        # -----------------------------
        model_mlp = MLP(dropout_rate=DROPOUT_RATE).to(device)
        optimizer = optim.AdamW(model_mlp.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
        criterion = nn.BCELoss()

        best_auc = -1
        best_state = None
        best_epoch = 0
        patience_left = PATIENCE

        history = {"train_loss": [], "val_loss": [], "val_auc": []}

        for epoch in range(1, MAX_EPOCHS + 1):
            train_loss = train_one_epoch_streaming(
                model_mlp, optimizer, criterion,
                df_train,
                batch_size=BATCH_SIZE,
                train_chunk_size=TRAIN_CHUNK_SIZE,
                rest_scale=REST_SCALE,
                norm_mode=NORM_MODE,
                fold=fold,
                epoch=epoch,
            )

            val_loss, val_auc, y_val_prob = eval_loss_and_auc(model_mlp, X_val, y_val, criterion)

            history["train_loss"].append(train_loss)
            history["val_loss"].append(val_loss)
            history["val_auc"].append(val_auc)

            log_print(f"[Fold {fold}] Epoch {epoch}: train_loss={train_loss:.6f}, val_loss={val_loss:.6f}, val_auc={val_auc:.6f}")

            if val_auc > best_auc + MIN_DELTA:
                best_auc = val_auc
                best_epoch = epoch
                best_state = {k: v.detach().cpu().clone() for k, v in model_mlp.state_dict().items()}
                patience_left = PATIENCE
            else:
                patience_left -= 1
                if patience_left <= 0:
                    log_print(f"[Fold {fold}] Early stop at epoch {epoch} (best_val_auc={best_auc:.6f} @epoch {best_epoch})")
                    break

        if best_state is not None:
            model_mlp.load_state_dict(best_state)

        val_loss_best, val_auc_best, y_val_prob = eval_loss_and_auc(model_mlp, X_val, y_val, criterion)

        best_t = find_best_threshold(
            y_true=y_val,
            y_prob=y_val_prob,
            metric=THRESH_METRIC,
            min_precision_pos=MIN_PRECISION_POS,
            t_min=THRESH_MIN,
            t_max=THRESH_MAX,
            t_step=THRESH_STEP,
        )

        if best_t is None:
            best_thr = 0.5
            log_print(f"[Fold {fold}] No valid threshold found, fallback thr=0.5")
        else:
            best_thr = best_t["threshold"]
            log_print(f"[Fold {fold}] Best threshold from VAL = {best_thr:.3f} | metric={THRESH_METRIC}")
            log_print(f"[Fold {fold}] VAL best stats = {best_t}")

        test_loss, test_auc, y_test_prob = eval_loss_and_auc(model_mlp, X_test, y_test, criterion)
        aucs.append(test_auc)

        y_pred = (y_test_prob >= best_thr).astype(int)
        cm = confusion_matrix(y_test, y_pred, labels=[0, 1])
        rep = classification_report(y_test, y_pred, digits=4)

        log_print(f"[Fold {fold}] ✅ TEST AUC = {test_auc:.6f}")
        log_print(f"[Fold {fold}] TEST loss = {test_loss:.6f}")
        log_print(f"[Fold {fold}] Confusion Matrix [[TN,FP],[FN,TP]]:\n{cm}")
        log_print(f"[Fold {fold}] Report:\n{rep}")

        if SAVE_PLOTS:
            import matplotlib.pyplot as plt

            plt.figure()
            plt.plot(history["train_loss"], label="train_loss")
            plt.plot(history["val_loss"], label="val_loss")
            plt.legend()
            plt.title(f"Pair Fold {fold} Loss Curves")
            plt.savefig(f"{PLOT_PREFIX}_fold{fold}_LOSS.png", dpi=600, bbox_inches="tight")
            plt.show()

            plt.figure()
            plt.plot(history["val_auc"], label="val_auc")
            plt.legend()
            plt.title(f"Pair Fold {fold} Validation AUC")
            plt.savefig(f"{PLOT_PREFIX}_fold{fold}_AUC.png", dpi=600, bbox_inches="tight")
            plt.show()

    log_print("\n" + "="*20 + " Pair-disjoint CV Summary " + "="*20)
    log_print(f"AUCs per fold: {aucs}")
    log_print(f"Mean AUC = {np.mean(aucs):.6f}, Std = {np.std(aucs):.6f}")

run_pair_disjoint_cv(pair_fold_pools)

[2026-02-15 15:49:30] 
[2026-02-15 15:49:30] Train pool=1,600,000 | Test pool=400,000
[2026-02-15 15:49:31] Train=1,440,000, Val=160,000, Test=400,000
[2026-02-15 15:49:31] y_train mean=0.3399, y_val mean=0.3399, y_test mean=0.3378


Fold 0 VAL:  12%|██████▋                                                | 19496/160000 [04:10<27:37, 84.79it/s]