# Tripleknock Revision Notebook (E. coli iML1515)




### data preparation on 5000 samples

In [1]:
'''
import pandas as pd
from pathlib import Path

infile = Path("/data1/xpgeng/cross_pathogen/sci_rep_revision_20260113/Baseline/triples_no_essential_123630-0_123630-1.csv")
outfile = infile.with_name(infile.stem + ".sample2000.csv")

SEED = 2026  # 固定随机种子，保证可复现；想每次不同就删掉 random_state

df = pd.read_csv(infile)
df_small = df.sample(n=2000, random_state=SEED)

df_small.to_csv(outfile, index=False)
print("Wrote:", outfile, "rows:", len(df_small))
'''

Wrote: /data1/xpgeng/cross_pathogen/sci_rep_revision_20260113/Baseline/triples_no_essential_123630-0_123630-1.sample2000.csv rows: 2000


In [1]:
# =========================
# Part 0: Imports + Global Config
# =========================
import os, time, random, traceback
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim

from sklearn.metrics import (
    roc_auc_score, confusion_matrix, classification_report,
    f1_score, precision_score, recall_score
)

# ---- device ----
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
print('device:', device)

# =========================
# Adjustable parameters / 可调参数
# =========================

# ---- Data files ----
# Two balanced datasets:
#  - triples_no_essential_123630-0_123630-1.csv
#  - triples_with_essential_123630-0_123630-1.csv
# We will CONCAT them, then do random 5-fold CV (ALLOW gene repetition across folds).
DATA_FILE_0 = '/data1/xpgeng/cross_pathogen/sci_rep_revision_20260113/Baseline/triples_no_essential_123630-0_123630-1.csv'
#DATA_FILE_1 = '/data1/xpgeng/cross_pathogen/sci_rep_revision_20260113/Baseline/triples_with_essential_0-0_123630-1.csv'

#FILE_PARTS = [DATA_FILE_0, DATA_FILE_1]
FILE_PARTS = [DATA_FILE_0]

# ---- Reproducibility ----
SEED = 42
N_FOLDS = 5

# ---- Sampling sizes (fold pools) ----
# Goal: total 1,000,000 samples for 5-fold CV (200k per fold).
# In each CV iteration: Test = 200k (1 fold), TrainPool = 800k (other 4 folds).
# Then split TrainPool into Train/Val with VAL_FRACTION (default 10%).
FOLD_SIZE_PER_FOLD = None     # unused in RANDOM split (we use full fold size)
TRAIN_SIZE_PER_FOLD = None       # None => use full TrainPool (default 800k)
VAL_FRACTION        = None #0.10       # 10% of TrainPool
TEST_SIZE_PER_FOLD  = None     # unused in RANDOM split

# (Deprecated sizes kept for backward-compat; not used when VAL_FRACTION is enabled)
VAL_SIZE_PER_FOLD   = None

# ---- Chunk training to avoid GPU OOM ----
# 如果训练集超过2万：分chunk训练，每次最多训练2万样本，然后继续下一批。
TRAIN_CHUNK_SIZE = 20000
BATCH_SIZE = 512

# ---- Optimization ----
LR = 5e-4
WEIGHT_DECAY = 1e-3
MAX_EPOCHS = 3
PATIENCE = 2
MIN_DELTA = 5e-4
DROPOUT_RATE = 0.5 

# ---- Feature options ----
REST_SCALE = 0.1   # try: 0.0 / 0.02 / 0.05 / 0.1 / 0.2
NORM_MODE = 'block' # 'block' (recommended) | 'per_sample' | 'none'

# ---- Threshold search options ----
THRESH_METRIC = 'youden'  # 'youden' | 'balanced_acc' | 'f1_pos'
MIN_PRECISION_POS = None  # e.g. 0.45 or 0.5 to force better precision for y=1
THRESH_MIN  = 0.05
THRESH_MAX  = 0.95
THRESH_STEP = 0.005

# ---- Baseline experiment sizes ----
BASELINE_TRAIN_SIZE = 50000

# ---- Pair-disjoint switch ----
RUN_PAIR_DISJOINT = False

# -------------------------
# Logging
# -------------------------
workdir = os.getcwd()
OUTPUT_DIR = os.path.join(workdir, 'cv_results_baseline_compare_no_essential_data')
os.makedirs(OUTPUT_DIR, exist_ok=True)
log_file = os.path.join(workdir, 'baseline_tripleknock_lethalrules_no_essential_data_random_5fold_mlp_512_256_cv_log.txt')          ### 3
err_file = os.path.join(workdir, 'baseline_tripleknock_lethalrules_no_essential_data_random_5fold_mlp_512_256_cv_err.txt')          ### 4

def log_print(msg: str):
    ts = time.strftime('%Y-%m-%d %H:%M:%S')
    line = f'[{ts}] {msg}'
    with open(log_file, 'a') as f:
        f.write(line + "\n")
    print(line)

def log_error(msg: str):
    ts = time.strftime('%Y-%m-%d %H:%M:%S')
    line = f'[{ts}] {msg}'
    with open(err_file, 'a') as f:
        f.write(line + "\n")
    print(line)

# ---- Reproducibility ----
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

log_print('Config loaded.')


device: cuda:2
[2026-02-20 21:20:16] Config loaded.


In [2]:
# =========================
# Baseline: Essential-gene presence rule
# pred_baseline = 1 if any of (g1,g2,g3) is an essential gene, else 0
# Reviewer baseline: single-gene lethality -> triple lethality
# =========================

essential_gene = ['b0003', 'b0004', 'b0025', 'b0029', 'b0031', 'b0052', 'b0054', 'b0071', 'b0072', 'b0074', 'b0084', 'b0085', 'b0086', 'b0087', 'b0088', 'b0089', 'b0090', 'b0091', 'b0096', 'b0103', 'b0109', 'b0131', 'b0133', 'b0134', 'b0142', 'b0154', 'b0159', 'b0166', 'b0173', 'b0174', 'b0175', 'b0179', 'b0180', 'b0181', 'b0182', 'b0185', 'b0242', 'b0243', 'b0369', 'b0386', 'b0414', 'b0415', 'b0417', 'b0420', 'b0421', 'b0423', 'b0522', 'b0523', 'b0524', 'b0635', 'b0639', 'b0641', 'b0720', 'b0750', 'b0774', 'b0775', 'b0776', 'b0777', 'b0778', 'b0908', 'b0914', 'b0915', 'b0918', 'b1062', 'b1069', 'b1091', 'b1092', 'b1093', 'b1094', 'b1098', 'b1131', 'b1136', 'b1208', 'b1210', 'b1215', 'b1260', 'b1261', 'b1262', 'b1263', 'b1264', 'b1277', 'b1281', 'b1288', 'b1662', 'b1693', 'b1740', 'b1812', 'b2019', 'b2020', 'b2021', 'b2022', 'b2023', 'b2024', 'b2025', 'b2026', 'b2103', 'b2153', 'b2312', 'b2315', 'b2316', 'b2323', 'b2329', 'b2400', 'b2472', 'b2476', 'b2478', 'b2499', 'b2507', 'b2515', 'b2530', 'b2557', 'b2564', 'b2574', 'b2585', 'b2599', 'b2600', 'b2615', 'b2687', 'b2746', 'b2747', 'b2750', 'b2751', 'b2752', 'b2762', 'b2763', 'b2764', 'b2780', 'b2818', 'b2827', 'b2838', 'b2942', 'b3018', 'b3040', 'b3041', 'b3058', 'b3172', 'b3176', 'b3177', 'b3187', 'b3189', 'b3196', 'b3198', 'b3199', 'b3200', 'b3201', 'b3255', 'b3256', 'b3360', 'b3368', 'b3389', 'b3412', 'b3433', 'b3607', 'b3633', 'b3634', 'b3639', 'b3642', 'b3648', 'b3729', 'b3730', 'b3770', 'b3771', 'b3774', 'b3804', 'b3805', 'b3809', 'b3843', 'b3850', 'b3870', 'b3939', 'b3941', 'b3957', 'b3958', 'b3959', 'b3960', 'b3967', 'b3972', 'b3974', 'b3990', 'b3991', 'b3992', 'b3993', 'b3994', 'b3997', 'b4005', 'b4006', 'b4013', 'b4040', 'b4160', 'b4177', 'b4214', 'b4245', 'b4261', 'b4262', 'b4407', 's0001']

essential_set = set(essential_gene)
print('Essential genes loaded:', len(essential_set))


Essential genes loaded: 196


## Helper utilities / 工具函数

这一节包含：
- gene / triple 排序（保证 `g1 < g2 < g3`，避免重复）
- stratified sampling（可选）
- threshold search（Youden / balanced accuracy / f1_pos）
- evaluation report（AUC + confusion matrix + classification report）

This section includes sorting helpers, sampling helpers, threshold search, and evaluation utilities.


In [3]:
# =========================
# Part 1: Utilities
# =========================

def gene_key(g: str):
    """Sort key: b0002 < b0003 < ..."""
    g = str(g).strip()
    prefix = ''.join([c for c in g if not c.isdigit()])
    digits = ''.join([c for c in g if c.isdigit()])
    return (prefix, int(digits) if digits else -1)


def sort_triple(g1, g2, g3):
    return tuple(sorted([str(g1).strip(), str(g2).strip(), str(g3).strip()], key=gene_key))


def stratified_subsample(df: pd.DataFrame, n: int, seed: int = 42):
    """Subsample with roughly preserved label ratio."""
    if n >= len(df):
        return df.copy()
    rng = np.random.default_rng(seed)
    pos = df[df['y'] == 1]
    neg = df[df['y'] == 0]
    pos_n = int(n * pos.shape[0] / df.shape[0])
    neg_n = n - pos_n
    pos_idx = rng.choice(pos.index.to_numpy(), size=min(pos_n, len(pos)), replace=False)
    neg_idx = rng.choice(neg.index.to_numpy(), size=min(neg_n, len(neg)), replace=False)
    out = pd.concat([df.loc[pos_idx], df.loc[neg_idx]], axis=0)
    out = out.sample(frac=1.0, random_state=seed).reset_index(drop=True)
    return out


def find_best_threshold(y_true, y_prob,
                        metric='youden',
                        min_precision_pos=None,
                        t_min=0.05, t_max=0.95, t_step=0.005):
    """Search threshold on VAL set.

    metric:
      - 'youden': maximize TPR - FPR (robust to class ratio shift)
      - 'balanced_acc': maximize (TPR + TNR)/2
      - 'f1_pos': maximize F1 for positive class (often pushes threshold too low)

    min_precision_pos:
      - if not None, only consider thresholds with precision_pos >= this value
    """
    best = None
    t_values = np.arange(t_min, t_max + 1e-12, t_step)

    for t in t_values:
        y_pred = (y_prob >= t).astype(int)
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()

        tpr = tp / (tp + fn + 1e-12)
        tnr = tn / (tn + fp + 1e-12)
        fpr = fp / (fp + tn + 1e-12)

        precision_pos = tp / (tp + fp + 1e-12)
        recall_pos = tpr
        f1_pos = 2 * precision_pos * recall_pos / (precision_pos + recall_pos + 1e-12)

        if min_precision_pos is not None and precision_pos < min_precision_pos:
            continue

        if metric == 'youden':
            score = tpr - fpr
        elif metric == 'balanced_acc':
            score = 0.5 * (tpr + tnr)
        elif metric == 'f1_pos':
            score = f1_pos
        else:
            score = tpr - fpr

        cand = {
            'threshold': float(t),
            'score': float(score),
            'f1_pos': float(f1_pos),
            'precision_pos': float(precision_pos),
            'recall_pos': float(recall_pos),
            'tn': int(tn), 'fp': int(fp), 'fn': int(fn), 'tp': int(tp),
        }

        if (best is None) or (cand['score'] > best['score']):
            best = cand

    return best


def eval_binary(y_true, y_prob, threshold=0.5, prefix=''):
    """Return AUC + confusion matrix + report."""
    auc = roc_auc_score(y_true, y_prob)
    y_pred = (y_prob >= threshold).astype(int)
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    report = classification_report(y_true, y_pred, digits=4)
    return auc, cm, report


## Feature preparation / 特征准备

已经有的 `two_mer_dict` 和 `ae1_2`。

- `two_mer_dict[gene]` gives a 400-d vector for each gene (2-mer frequency).
- `ae1_2({g1,g2,g3})` returns (3,400) tensor for the rest-of-genome compressed features.

下面是**可直接运行**的实现（与当前版本保持一致），只需要修改路径。



In [4]:
# =========================
# Part 2: Build two_mer_dict (2-mer features) + load Autoencoders + define ae1_2
# =========================

# ---- (A) Build two_mer_dict from FASTA ----
# 如果你已经有 two_mer_dict，可以把 BUILD_2MER=False，然后跳过。
BUILD_2MER = True

if BUILD_2MER:
    from Bio import SeqIO
    from collections import Counter

    fasta_path = '/data1/xpgeng/cross_pathogen/autoencoder/E.coli.tag_seq.fasta'

    def read_fasta(fp):
        gene_sequence_dict = {}
        for record in SeqIO.parse(fp, 'fasta'):
            gene_sequence_dict[record.id] = str(record.seq)
        return gene_sequence_dict

    gene_sequence_dict = read_fasta(fasta_path)
    all_genes = set(gene_sequence_dict.keys())

    print('Total genes in FASTA:', len(all_genes))
    print('Example:', list(gene_sequence_dict.items())[:1])

    standard_amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
    all_2mers = [a + b for a in standard_amino_acids for b in standard_amino_acids]
    two_mer_index = {two_mer: idx for idx, two_mer in enumerate(all_2mers)}

    two_mer_dict = {}

    for gene, sequence in tqdm(gene_sequence_dict.items(), desc='Building two_mer_dict'):
        sequence = ''.join([aa for aa in sequence if aa in standard_amino_acids])

        if len(sequence) < 2:
            two_mer_dict[gene] = np.zeros(400, dtype=np.float32)
            continue

        two_mer_counts = Counter(sequence[i:i+2] for i in range(len(sequence)-1))
        total_two_mers = sum(two_mer_counts.values())

        feature_vector = np.zeros(400, dtype=np.float32)
        for two_mer, count in two_mer_counts.items():
            idx = two_mer_index.get(two_mer)
            if idx is not None:
                feature_vector[idx] = count / total_two_mers

        two_mer_dict[gene] = feature_vector

    for gene, vec in list(two_mer_dict.items())[:3]:
        print(gene, vec[:10])

# ---- (B) Load Autoencoders and define ae1_2 ----
# 你可以保留你的架构和权重加载方式（与原来一致）

LOAD_AUTOENCODERS = True

if LOAD_AUTOENCODERS:

    class Autoencoder(torch.nn.Module):
        def __init__(self):
            super(Autoencoder, self).__init__()
            self.encoder = torch.nn.Sequential(
                torch.nn.Linear(400, 256),
                torch.nn.ReLU(),
                torch.nn.Dropout(0.35),
                torch.nn.Linear(256, 128),
                torch.nn.ReLU(),
                torch.nn.Dropout(0.35),
                torch.nn.Linear(128, 3),
            )
            self.decoder = torch.nn.Sequential(
                torch.nn.Linear(3, 128),
                torch.nn.ReLU(),
                torch.nn.Linear(128, 256),
                torch.nn.ReLU(),
                torch.nn.Linear(256, 400),
            )

        def forward(self, x):
            encoded = self.encoder(x)
            decoded = self.decoder(encoded)
            return decoded


    class Autoencoder2(torch.nn.Module):
        def __init__(self):
            super(Autoencoder2, self).__init__()
            self.encoder = torch.nn.Sequential(
                torch.nn.Linear(4304, 3000),
                torch.nn.ReLU(),
                torch.nn.Dropout(0.2),
                torch.nn.Linear(3000, 1000),
                torch.nn.ReLU(),
                torch.nn.Dropout(0.3),
                torch.nn.Linear(1000, 400),
            )
            self.decoder = torch.nn.Sequential(
                torch.nn.Linear(400, 1000),
                torch.nn.ReLU(),
                torch.nn.Linear(1000, 3000),
                torch.nn.ReLU(),
                torch.nn.Linear(3000, 4304),
            )

        def forward(self, x):
            encoded = self.encoder(x)
            decoded = self.decoder(encoded)
            return decoded


    model = Autoencoder().to(device)
    model.load_state_dict(torch.load('/data1/xpgeng/cross_pathogen/autoencoder/ae1_all_data_training.pth', map_location=device))
    model.eval()

    model2 = Autoencoder2().to(device)
    model2.load_state_dict(torch.load('/data1/xpgeng/cross_pathogen/autoencoder/ae2_all_data_training.pth', map_location=device))
    model2.eval()


@torch.no_grad()
def ae1_2(three_genes):
    # 输入：set({g1,g2,g3})
    # 输出：(3,400) tensor
    rest_genes = list(all_genes - three_genes)
    inputs = np.vstack([two_mer_dict[gene] for gene in rest_genes]).astype(np.float32)

    zeros_400 = np.zeros((2, 400), dtype=np.float32)
    inputs = np.vstack([inputs, zeros_400])

    inputs = torch.tensor(inputs).to(device)
    inputs = model.encoder(inputs)
    inputs = inputs.cpu().detach().numpy().T
    inputs = torch.tensor(inputs).to(device)
    outputs = model2.encoder(inputs)

    return outputs

log_print('two_mer_dict & ae1_2 ready.')


Total genes in FASTA: 4305
Example: [('b0001', 'MKRISTTITTTITITTGNGAG')]


Building two_mer_dict: 100%|█████████████████████████████████████████████| 4305/4305 [00:00<00:00, 7354.44it/s]


b0001 [0.   0.   0.   0.   0.   0.05 0.   0.   0.   0.  ]
b0002 [0.01587302 0.001221   0.00854701 0.01098901 0.002442   0.00854701
 0.         0.003663   0.00610501 0.00732601]
b0003 [0.01294498 0.00647249 0.00647249 0.01294498 0.         0.00647249
 0.00323625 0.00323625 0.00323625 0.01294498]
[2026-02-20 21:21:32] two_mer_dict & ae1_2 ready.


# Part A — Random 5-Fold CV + Baseline Compare / 随机五折 + baseline对比

**Goal / 目标**

- 合并两批数据（不含必需基因 + 含必需基因），**随机**分成 5 份（**允许单基因在不同折中重复出现**）。
- 每一折同时输出：
  1) **Tripleknock**（你的 MLP pipeline）在 TEST 上的性能；
  2) **Baseline**（基于单基因必需性：只要三基因中包含任意必需基因就判定致死）在同一 TEST 上的性能。

> 注意：之前严格的 *gene-disjoint* 划分逻辑已被替换/注释（reviewer 这次要求的是 baseline 对比 + 随机五折）。


In [5]:
# =========================
# Part A1 (FAST): Read CSV(s) and build RANDOM 5 folds (ALLOW gene repetition)
# ✅ This replaces the previous strict gene-disjoint split.
#  - We DO NOT enforce gene-level exclusivity.
#  - We ONLY do label-stratified random splitting into 5 folds.
# =========================

import time
from tqdm import tqdm
from sklearn.model_selection import StratifiedKFold

log_print("Reading 2 baseline CSV files and concatenating...")

t0 = time.time()
frames = []

for fp in tqdm(FILE_PARTS, desc="Reading CSV files"):
    t1 = time.time()

    # no header, 4 columns: g1,g2,g3,y
    dfp = pd.read_csv(
        fp,
        header=None,
        names=["g1", "g2", "g3", "y"],
        dtype={"g1": "string", "g2": "string", "g3": "string", "y": "int8"},
        engine="c",
        low_memory=False
    )
    frames.append(dfp)
    log_print(f"Loaded {os.path.basename(fp)} rows={len(dfp):,} time={time.time()-t1:.1f}s")

df = pd.concat(frames, axis=0, ignore_index=True)

# strip (fast)
for c in ["g1", "g2", "g3"]:
    df[c] = df[c].str.strip()
df["y"] = df["y"].astype(int)

log_print(f"Concat done. Total rows={len(df):,}, time={time.time()-t0:.1f}s")
log_print(f"Dataset y=1 ratio={df.y.mean():.4f} | y counts={df.y.value_counts().to_dict()}")

# -----------------------------------------------------
# Build RANDOM 5 folds (label-stratified)
# fold_pools[i] is the i-th fold dataframe (used as TEST/VAL/Train parts later)
# -----------------------------------------------------
skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)

fold_pools = []
fold_indices = []  # keep indices for optional debugging / saving

y = df["y"].values
for fold_id, (_, test_idx) in enumerate(skf.split(np.zeros(len(df)), y)):
    dff = df.iloc[test_idx].reset_index(drop=True)
    fold_pools.append(dff)
    fold_indices.append(test_idx)
    log_print(f"Fold {fold_id} pool ready: {len(dff):,} (y_mean={dff['y'].mean():.4f})")

log_print("✅ Random 5-fold split ready (gene repetition allowed).")


[2026-02-20 21:22:06] Reading 2 baseline CSV files and concatenating...


Reading CSV files: 100%|█████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 12.33it/s]

[2026-02-20 21:22:06] Loaded triples_no_essential_123630-0_123630-1.csv rows=247,260 time=0.1s





[2026-02-20 21:22:06] Concat done. Total rows=247,260, time=0.2s
[2026-02-20 21:22:06] Dataset y=1 ratio=0.5000 | y counts={0: 123630, 1: 123630}
[2026-02-20 21:22:06] Fold 0 pool ready: 49,452 (y_mean=0.5000)
[2026-02-20 21:22:06] Fold 1 pool ready: 49,452 (y_mean=0.5000)
[2026-02-20 21:22:06] Fold 2 pool ready: 49,452 (y_mean=0.5000)
[2026-02-20 21:22:06] Fold 3 pool ready: 49,452 (y_mean=0.5000)
[2026-02-20 21:22:06] Fold 4 pool ready: 49,452 (y_mean=0.5000)
[2026-02-20 21:22:06] ✅ Random 5-fold split ready (gene repetition allowed).


## Part A2 — Training MLP (streaming + AUC + confusion matrix)

这里我们只**整理代码**，不改变你的 MLP 架构。

### Why streaming chunk training? / 为什么要分块训练？
如果你把 60万样本一次性做成 `X_train` 再搬到 GPU，很容易 OOM。

这里采取两层策略：
1. **Feature building on CPU (numpy)**
2. **Mini-batch training on GPU (DataLoader)**
3. 如果训练样本数 > 20000：按 `TRAIN_CHUNK_SIZE` 分批构建特征并训练（继续更新同一个模型参数）

> 你可以先用 `TRAIN_SIZE_PER_FOLD=20000` 快速跑通，再逐步放大。


In [None]:
# =========================
# Part A2 (FULL): Feature builder + MLP + CV training
# - Add validation loss logging (reviewer requirement)
# - Fix shuffle per epoch (avoid same order every epoch)
# - Keep MLP architecture unchanged
# - Streaming training by chunks to avoid GPU OOM
# =========================

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import time

from sklearn.model_selection import train_test_split

from sklearn.metrics import (
    roc_auc_score,
    confusion_matrix,
    classification_report,
    f1_score,
    precision_score,
    recall_score,
)

# -------------------------
# Helper: zscore normalize
# -------------------------
def zscore(vec, eps=1e-8):
    return (vec - vec.mean()) / (vec.std() + eps)


# -------------------------
# Feature builder
# -------------------------
@torch.no_grad()
def build_feature_vector(g1, g2, g3, rest_scale=0.05, norm_mode='block'):
    """
    Return 2400-d float32 feature.

    norm_mode:
      - 'none': no normalization
      - 'per_sample': zscore on full 2400-d
      - 'block': zscore 1200(three) and 1200(rest) separately (recommended)
    """
    # 3 knocked-out genes 2-mer feature (3*400=1200)
    three = np.array([two_mer_dict[g] for g in [g1, g2, g3]], dtype=np.float32).flatten()

    # rest-of-genome embedding from AE (3*400=1200)
    rest = ae1_2({g1, g2, g3}).detach().cpu().numpy().astype(np.float32).flatten()

    if norm_mode == 'block':
        three = zscore(three)
        rest  = zscore(rest)

    feat = np.concatenate([three, rest * rest_scale], axis=0).astype(np.float32)

    if norm_mode == 'per_sample':
        feat = zscore(feat)

    return feat


def build_XY_from_df(dfx: pd.DataFrame, rest_scale=0.05, norm_mode='block', desc='Build XY'):
    """
    Build X (N,2400) and y (N,) from df with cols [g1,g2,g3,y]
    """
    triples = dfx[['g1','g2','g3']].values.tolist()
    y = dfx['y'].values.astype(np.int64)

    X = np.zeros((len(triples), 2400), dtype=np.float32)
    for i, (g1, g2, g3) in enumerate(tqdm(triples, desc=desc, leave=False)):
        X[i] = build_feature_vector(g1, g2, g3, rest_scale=rest_scale, norm_mode=norm_mode)

    return X, y


# -------------------------
# MLP model (UNCHANGED)
# -------------------------

class MLP(nn.Module):
    # Keep your architecture unchanged / 保持你的架构不变
    def __init__(self, input_size=2400, hidden_size1=512, hidden_size2=256, output_size=1, dropout_rate=0.5):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.dropout2 = nn.Dropout(dropout_rate)
        self.fc3 = nn.Linear(hidden_size2, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout1(x)
        x = torch.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return self.sigmoid(x)
'''
class MLP(nn.Module):
    def __init__(self, input_size=2400, hidden_dims=(512, 256), output_size=1, dropout_rate=0.5):
        super().__init__()
        layers = []
        prev = input_size

        for h in hidden_dims:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_rate))
            prev = h

        layers.append(nn.Linear(prev, output_size))
        self.net = nn.Sequential(*layers)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.net(x)
        return self.sigmoid(x)
'''

# -------------------------
# Training (streaming by chunks to avoid OOM)
# -------------------------
def train_one_epoch_streaming(
    model, optimizer, criterion,
    df_train_pool: pd.DataFrame,
    batch_size=512,
    train_chunk_size=20000,
    rest_scale=0.05,
    norm_mode='block',
    fold=0,
    epoch=1,
):
    """
    One epoch training over df_train_pool, streaming by chunks to avoid OOM.

    ✅ Important fix:
    shuffle seed changes with fold/epoch, NOT constant every epoch
    """
    model.train()

    idx = np.arange(len(df_train_pool))
    rng = np.random.default_rng(SEED + fold * 100000 + epoch)  # ✅ vary per fold & epoch
    rng.shuffle(idx)

    total_loss = 0.0
    seen = 0

    n_chunks = int(np.ceil(len(idx) / train_chunk_size))
    chunk_id = 0

    for start in range(0, len(idx), train_chunk_size):
        chunk_id += 1
        chunk_idx = idx[start:start+train_chunk_size]
        dfx = df_train_pool.iloc[chunk_idx].reset_index(drop=True)

        # build features on CPU
        Xc, yc = build_XY_from_df(
            dfx,
            rest_scale=rest_scale,
            norm_mode=norm_mode,
            desc=f"Train chunk {chunk_id}/{n_chunks}"
        )

        # DataLoader (CPU -> GPU in minibatches)
        ds = torch.utils.data.TensorDataset(
            torch.from_numpy(Xc),
            torch.from_numpy(yc.astype(np.float32))
        )
        dl = torch.utils.data.DataLoader(
            ds,
            batch_size=batch_size,
            shuffle=True,
            drop_last=False,
            num_workers=0,
            pin_memory=torch.cuda.is_available()
        )

        for xb, yb in dl:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)
            pred = model(xb).view(-1)
            loss = criterion(pred, yb)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * len(yb)
            seen += len(yb)

        # free memory
        del Xc, yc, ds, dl, dfx
        torch.cuda.empty_cache()

    return total_loss / max(seen, 1)


# -------------------------
# Prediction helper
# -------------------------
@torch.no_grad()
def predict_proba(model, X_np, batch_size=2048):
    model.eval()
    probs = []
    for i in range(0, len(X_np), batch_size):
        xb = torch.tensor(X_np[i:i+batch_size], dtype=torch.float32).to(device)
        pb = model(xb).view(-1).detach().cpu().numpy()
        probs.append(pb)
    return np.concatenate(probs, axis=0)


# ✅ NEW: validation loss + auc (reviewer wants val loss curve)
@torch.no_grad()
def eval_loss_and_auc(model, X_np, y_np, criterion, batch_size=2048):
    """
    中文：计算某个数据集（val/test）的平均 loss + AUC
    English: compute mean loss and AUC on a dataset
    """
    model.eval()

    y_np = np.asarray(y_np).astype(np.int64)
    total_loss = 0.0
    seen = 0
    probs_all = []

    for i in range(0, len(X_np), batch_size):
        xb = torch.tensor(X_np[i:i+batch_size], dtype=torch.float32).to(device)
        yb = torch.tensor(y_np[i:i+batch_size], dtype=torch.float32).to(device)

        pb = model(xb).view(-1)   # sigmoid prob
        loss = criterion(pb, yb)

        total_loss += loss.item() * len(yb)
        seen += len(yb)

        probs_all.append(pb.detach().cpu().numpy())

    probs_all = np.concatenate(probs_all, axis=0)
    auc = roc_auc_score(y_np, probs_all)
    return total_loss / max(seen, 1), auc, probs_all


# -------------------------
# Threshold search helper
# -------------------------
def find_best_threshold(
    y_true,
    y_prob,
    metric="youden",
    min_precision_pos=None,
    t_min=0.01,
    t_max=0.99,
    t_step=0.01,
):
    """
    metric:
      - "youden": maximize TPR - FPR (stable under prevalence shift)
      - "balanced_acc": maximize (TPR + TNR)/2
      - "f1_pos": maximize F1 for positive class (tends to smaller threshold -> high recall, many FP)
    """
    y_true = np.asarray(y_true).astype(int)
    y_prob = np.asarray(y_prob).astype(float)

    best = None

    for thr in np.arange(t_min, t_max + 1e-12, t_step):
        y_hat = (y_prob >= thr).astype(int)

        tn, fp, fn, tp = confusion_matrix(y_true, y_hat, labels=[0, 1]).ravel()

        tpr = tp / (tp + fn + 1e-12)
        fpr = fp / (fp + tn + 1e-12)
        tnr = tn / (tn + fp + 1e-12)

        precision_pos = tp / (tp + fp + 1e-12)
        recall_pos = tpr
        f1_pos = 2 * precision_pos * recall_pos / (precision_pos + recall_pos + 1e-12)

        if min_precision_pos is not None and precision_pos < min_precision_pos:
            continue

        if metric == "youden":
            score = tpr - fpr
        elif metric == "balanced_acc":
            score = 0.5 * (tpr + tnr)
        elif metric == "f1_pos":
            score = f1_pos
        else:
            score = tpr - fpr

        if (best is None) or (score > best["score"]):
            best = {
                "threshold": float(thr),
                "score": float(score),
                "f1_pos": float(f1_pos),
                "precision_pos": float(precision_pos),
                "recall_pos": float(recall_pos),
                "tn": int(tn),
                "fp": int(fp),
                "fn": int(fn),
                "tp": int(tp),
            }

    return best


# -------------------------
# Main CV runner (RANDOM 5-fold, gene repetition allowed)
# -------------------------
def run_random_5fold_cv(df: pd.DataFrame, fold_pools: list):
    """Run 5-fold CV on the merged dataset (no gene-disjoint constraint).

    Split strategy (keeps your previous Train/Val/Test pattern, but now folds are RANDOM):
      - Test = fold k
      - Val  = fold (k+1) mod 5
      - Train = remaining 3 folds

    In each fold we report:
      1) Tripleknock (MLP) metrics on TEST
      2) Baseline metrics (essential-gene presence rule) on the SAME TEST
    """

    # collect fold-wise results
    rows = []
    aucs_triple = []
    aucs_base   = []

    log_print("Fold pool sizes: " + str([len(x) for x in fold_pools]))

    for fold in range(N_FOLDS):
        log_print("\n" + "="*18 + f" Fold {fold} " + "="*18)

        # -----------------------------------------------------
        # RANDOM folds (no gene exclusivity):
        #  - 3 folds train, 1 fold val, 1 fold test
        # -----------------------------------------------------
        test_fold = fold
        val_fold  = (fold + 1) % N_FOLDS

        df_test = fold_pools[test_fold].reset_index(drop=True)
        df_val  = fold_pools[val_fold].reset_index(drop=True)

        df_train_pool = pd.concat(
            [fold_pools[i] for i in range(N_FOLDS) if i not in (test_fold, val_fold)],
            ignore_index=True
        )

        # Optional train subsample (only train is subsampled; val/test keep full)
        if TRAIN_SIZE_PER_FOLD is None:
            df_train = df_train_pool.reset_index(drop=True)
        else:
            df_train = stratified_subsample(df_train_pool, TRAIN_SIZE_PER_FOLD, seed=SEED+fold).reset_index(drop=True)

        log_print(f"Train folds: {[i for i in range(N_FOLDS) if i not in (test_fold, val_fold)]} | "
                  f"Val fold: {val_fold} | Test fold: {test_fold}")
        log_print(f"Train={len(df_train):,}, Val={len(df_val):,}, Test={len(df_test):,}")
        log_print(f"y_train mean={df_train.y.mean():.4f}, y_val mean={df_val.y.mean():.4f}, y_test mean={df_test.y.mean():.4f}")

        # (Optional) just log gene overlaps; do NOT enforce anything.
        def genes_in_df(dfx):
            return set(pd.unique(dfx[["g1","g2","g3"]].to_numpy().ravel()))

        g_tr  = genes_in_df(df_train)
        g_val = genes_in_df(df_val)
        g_te  = genes_in_df(df_test)
        log_print(f"[INFO] Gene overlap train∩val={len(g_tr & g_val)}, train∩test={len(g_tr & g_te)}, val∩test={len(g_val & g_te)} (allowed)")

        # -----------------------------------------------------
        # Build VAL/TEST features once (same as your original pipeline)
        # -----------------------------------------------------
        X_val, y_val = build_XY_from_df(df_val, rest_scale=REST_SCALE, norm_mode=NORM_MODE, desc=f"Fold {fold} VAL")
        X_test, y_test = build_XY_from_df(df_test, rest_scale=REST_SCALE, norm_mode=NORM_MODE, desc=f"Fold {fold} TEST")

        # model (UNCHANGED)
        model_mlp = MLP(dropout_rate=DROPOUT_RATE).to(device)
        optimizer = optim.AdamW(model_mlp.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
        criterion = nn.BCELoss()

        best_auc = -1
        best_state = None
        best_epoch = 0
        patience_left = PATIENCE

        # store curves (reviewer request)
        history = {"train_loss": [], "val_loss": [], "val_auc": []}

        for epoch in range(1, MAX_EPOCHS+1):
            train_loss = train_one_epoch_streaming(
                model_mlp, optimizer, criterion,
                df_train,
                batch_size=BATCH_SIZE,
                train_chunk_size=TRAIN_CHUNK_SIZE,
                rest_scale=REST_SCALE,
                norm_mode=NORM_MODE,
                fold=fold,
                epoch=epoch,
            )

            # val loss + val auc
            val_loss, val_auc, y_val_prob = eval_loss_and_auc(model_mlp, X_val, y_val, criterion)

            history["train_loss"].append(train_loss)
            history["val_loss"].append(val_loss)
            history["val_auc"].append(val_auc)

            log_print(f"[Fold {fold}] Epoch {epoch}: train_loss={train_loss:.6f}, val_loss={val_loss:.6f}, val_auc={val_auc:.6f}")

            # early stopping on AUC (keep your logic)
            if val_auc > best_auc + MIN_DELTA:
                best_auc = val_auc
                best_epoch = epoch
                best_state = {k: v.detach().cpu().clone() for k, v in model_mlp.state_dict().items()}
                patience_left = PATIENCE
            else:
                patience_left -= 1
                if patience_left <= 0:
                    log_print(f"[Fold {fold}] Early stop at epoch {epoch} (best_val_auc={best_auc:.6f} @epoch {best_epoch})")
                    break

        # restore best
        if best_state is not None:
            model_mlp.load_state_dict(best_state)

        # -----------------------------------------------------
        # Threshold search on VAL
        # -----------------------------------------------------
        val_loss_best, val_auc_best, y_val_prob = eval_loss_and_auc(model_mlp, X_val, y_val, criterion)

        best_t = find_best_threshold(
            y_true=y_val,
            y_prob=y_val_prob,
            metric=THRESH_METRIC,
            min_precision_pos=MIN_PRECISION_POS,
            t_min=THRESH_MIN,
            t_max=THRESH_MAX,
            t_step=THRESH_STEP,
        )

        if best_t is None:
            best_thr = 0.5
            log_print(f"[Fold {fold}] No valid threshold found, fallback thr=0.5")
        else:
            best_thr = best_t["threshold"]
            log_print(f"[Fold {fold}] Best threshold from VAL = {best_thr:.3f} | metric={THRESH_METRIC}")
            log_print(f"[Fold {fold}] VAL best stats = {best_t}")

        # -----------------------------------------------------
        # TEST (Tripleknock)
        # -----------------------------------------------------
        test_loss, test_auc, y_test_prob = eval_loss_and_auc(model_mlp, X_test, y_test, criterion)
        aucs_triple.append(test_auc)

        y_pred = (y_test_prob >= best_thr).astype(int)
        cm = confusion_matrix(y_test, y_pred, labels=[0, 1])
        rep = classification_report(y_test, y_pred, digits=4)

        # additional metrics
        f1_pos = f1_score(y_test, y_pred, pos_label=1)
        p_pos  = precision_score(y_test, y_pred, pos_label=1, zero_division=0)
        r_pos  = recall_score(y_test, y_pred, pos_label=1, zero_division=0)

        log_print(f"[Fold {fold}] ✅ Tripleknock TEST AUC = {test_auc:.6f}")
        log_print(f"[Fold {fold}] Tripleknock TEST loss = {test_loss:.6f}")
        log_print(f"[Fold {fold}] Tripleknock Confusion Matrix [[TN,FP],[FN,TP]]:\n{cm}")
        log_print(f"[Fold {fold}] Tripleknock Report:\n{rep}")

        # -----------------------------------------------------
        # TEST (Baseline: essential-gene presence rule)
        # -----------------------------------------------------
        y_base_pred = df_test[["g1","g2","g3"]].isin(essential_set).any(axis=1).astype(int).to_numpy()
        y_base_score = y_base_pred.astype(float)

        base_cm = confusion_matrix(y_test, y_base_pred, labels=[0, 1])
        base_rep = classification_report(y_test, y_base_pred, digits=4)

        try:
            base_auc = roc_auc_score(y_test, y_base_score)
        except Exception:
            base_auc = float('nan')

        aucs_base.append(base_auc)

        base_f1 = f1_score(y_test, y_base_pred, pos_label=1)
        base_p  = precision_score(y_test, y_base_pred, pos_label=1, zero_division=0)
        base_r  = recall_score(y_test, y_base_pred, pos_label=1, zero_division=0)

        log_print(f"[Fold {fold}] ✅ Baseline (essential-rule) TEST AUC = {base_auc}")
        log_print(f"[Fold {fold}] Baseline Confusion Matrix [[TN,FP],[FN,TP]]:\n{base_cm}")
        log_print(f"[Fold {fold}] Baseline Report:\n{base_rep}")

        # -----------------------------------------------------
        # Save per-fold predictions (same TEST set, both methods)
        # -----------------------------------------------------
        pred_df = df_test.copy()
        pred_df["y_true"] = y_test
        pred_df["tripleknock_prob"] = y_test_prob.astype(float)
        pred_df["tripleknock_pred"] = y_pred.astype(int)
        pred_df["baseline_pred"] = y_base_pred.astype(int)

        pred_out = os.path.join(OUTPUT_DIR, f"fold_{fold}_test_predictions.csv")
        pred_df.to_csv(pred_out, index=False)
        log_print(f"[Fold {fold}] Saved TEST predictions -> {pred_out}")

        # Save curves (optional but useful for reviewer)
        import matplotlib.pyplot as plt

        plt.figure()
        plt.plot(history["train_loss"], label="train_loss")
        plt.plot(history["val_loss"], label="val_loss")
        plt.legend()
        plt.title(f"Fold {fold} Loss Curves")
        plt.tight_layout()
        plt.savefig(os.path.join(OUTPUT_DIR, f"fold_{fold}_loss_curve.png"), dpi=300, bbox_inches="tight")
        plt.show()

        plt.figure()
        plt.plot(history["val_auc"], label="val_auc")
        plt.legend()
        plt.title(f"Fold {fold} Validation AUC")
        plt.tight_layout()
        plt.savefig(os.path.join(OUTPUT_DIR, f"fold_{fold}_val_auc_curve.png"), dpi=300, bbox_inches="tight")
        plt.show()

        # -----------------------------------------------------
        # Collect fold-wise metrics
        # -----------------------------------------------------
        rows.append({
            "fold": fold,
            "n_train": int(len(df_train)),
            "n_val": int(len(df_val)),
            "n_test": int(len(df_test)),
            "threshold": float(best_thr),

            "triple_auc": float(test_auc),
            "triple_f1_pos": float(f1_pos),
            "triple_precision_pos": float(p_pos),
            "triple_recall_pos": float(r_pos),

            "baseline_auc": float(base_auc) if base_auc == base_auc else float('nan'),
            "baseline_f1_pos": float(base_f1),
            "baseline_precision_pos": float(base_p),
            "baseline_recall_pos": float(base_r),
        })

    # -----------------------------------------------------
    # Summary + save
    # -----------------------------------------------------
    res_df = pd.DataFrame(rows)
    out_csv = os.path.join(OUTPUT_DIR, "cv_metrics_summary.csv")
    res_df.to_csv(out_csv, index=False)

    log_print("\n" + "="*20 + " CV Summary (Tripleknock vs Baseline) " + "="*20)
    log_print(f"Saved summary -> {out_csv}")
    log_print(f"Tripleknock AUCs per fold: {aucs_triple}")
    log_print(f"Baseline   AUCs per fold: {aucs_base}")
    log_print(f"Tripleknock Mean AUC = {np.nanmean(aucs_triple):.6f}, Std = {np.nanstd(aucs_triple):.6f}")
    log_print(f"Baseline   Mean AUC = {np.nanmean(aucs_base):.6f}, Std = {np.nanstd(aucs_base):.6f}")

    return res_df


# ---- Run RANDOM 5-fold CV ----
results_df = run_random_5fold_cv(df, fold_pools)
print(results_df)


[2026-02-20 21:22:47] Fold pool sizes: [49452, 49452, 49452, 49452, 49452]
[2026-02-20 21:22:47] 
[2026-02-20 21:22:47] Train folds: [2, 3, 4] | Val fold: 1 | Test fold: 0
[2026-02-20 21:22:47] Train=148,356, Val=49,452, Test=49,452
[2026-02-20 21:22:47] y_train mean=0.5000, y_val mean=0.5000, y_test mean=0.5000
[2026-02-20 21:22:47] [INFO] Gene overlap train∩val=1318, train∩test=1318, val∩test=1318 (allowed)


Fold 0 VAL:  50%|███████████████████████████▉                            | 24668/49452 [04:58<04:36, 89.60it/s]