# Tripleknock sub-test analysis (structured)

本Notebook用于：在 **gene-disjoint** 的 5-fold 测试集上做分组(sub-test)分析，寻找“哪些子集更容易预测正确”，并为 rebuttal/Discussion 提供可解释结论。  
This notebook performs **sub-test / stratified** analyses on **gene-disjoint** 5-fold test sets to identify subsets with higher performance and provide interpretable evidence for rebuttal/Discussion.

> 约定：本Notebook默认当前目录下存在 `fold-1-test.csv ... fold-5-test.csv` 以及对应的 `fold-1-train.csv ... fold-5-train.csv`。


## Part 1 — Data preprocessing (essential genes + FASTA + 2-mer)

中文：保留你原有的两部分预处理：  
1) 读取蛋白质 FASTA 得到 `gene_sequence_dict` 与 `two_mer_dict`（400维 2-mer频率向量）  
2) 定义必需基因列表 `essential_genes`  

EN: Keep your original preprocessing:  
1) parse FASTA -> `gene_sequence_dict`, build `two_mer_dict` (400-dim 2-mer frequency)  
2) define essential gene list `essential_genes`


In [1]:
# 中文：安装依赖（只需运行一次）
# EN: Install dependencies (run once)

! pip -q install tqdm biopython scikit-learn pandas numpy


In [1]:
# 中文：读取FASTA并构建 two_mer_dict（保留原代码）
# EN: Read FASTA and build two_mer_dict (keep your original code)

from tqdm import tqdm
import numpy as np

# 如果你已经有 two_mer_dict，可以把 BUILD_2MER=False，然后跳过。
BUILD_2MER = True

if BUILD_2MER:
    from Bio import SeqIO
    from collections import Counter

    fasta_path = '/data1/xpgeng/cross_pathogen/autoencoder/E.coli.tag_seq.fasta'

    def read_fasta(fp):
        gene_sequence_dict = {}
        for record in SeqIO.parse(fp, 'fasta'):
            gene_sequence_dict[record.id] = str(record.seq)
        return gene_sequence_dict

    gene_sequence_dict = read_fasta(fasta_path)
    all_genes = set(gene_sequence_dict.keys())

    print('Total genes in FASTA:', len(all_genes))
    print('Example:', list(gene_sequence_dict.items())[:1])

    standard_amino_acids = 'ACDEFGHIKLMNPQRSTVWY'
    all_2mers = [a + b for a in standard_amino_acids for b in standard_amino_acids]
    two_mer_index = {two_mer: idx for idx, two_mer in enumerate(all_2mers)}

    two_mer_dict = {}

    for gene, sequence in tqdm(gene_sequence_dict.items(), desc='Building two_mer_dict'):
        sequence = ''.join([aa for aa in sequence if aa in standard_amino_acids])

        if len(sequence) < 2:
            two_mer_dict[gene] = np.zeros(400, dtype=np.float32)
            continue

        two_mer_counts = Counter(sequence[i:i+2] for i in range(len(sequence)-1))
        total_two_mers = sum(two_mer_counts.values())

        feature_vector = np.zeros(400, dtype=np.float32)
        for two_mer, count in two_mer_counts.items():
            idx = two_mer_index.get(two_mer)
            if idx is not None:
                feature_vector[idx] = count / total_two_mers

        two_mer_dict[gene] = feature_vector

    for gene, vec in list(two_mer_dict.items())[:3]:
        print(gene, vec[:10])

Total genes in FASTA: 4305
Example: [('b0001', 'MKRISTTITTTITITTGNGAG')]


Building two_mer_dict: 100%|█████████████████████████████████████████████| 4305/4305 [00:00<00:00, 4581.66it/s]

b0001 [0.   0.   0.   0.   0.   0.05 0.   0.   0.   0.  ]
b0002 [0.01587302 0.001221   0.00854701 0.01098901 0.002442   0.00854701
 0.         0.003663   0.00610501 0.00732601]
b0003 [0.01294498 0.00647249 0.00647249 0.01294498 0.         0.00647249
 0.00323625 0.00323625 0.00323625 0.01294498]





In [2]:
# 中文：必需基因列表（保留原代码）
# EN: Essential gene list (keep your original code)

essential_genes = ['b0003', 'b0004', 'b0025', 'b0029', 'b0031', 'b0052', 'b0054', 'b0071', 'b0072', 'b0074', 'b0084', 'b0085', 'b0086', 'b0087', 'b0088', 'b0089', 'b0090', 'b0091', 'b0096', 'b0103', 'b0109', 'b0131', 'b0133', 'b0134', 'b0142', 'b0154', 'b0159', 'b0166', 'b0173', 'b0174', 'b0175', 'b0179', 'b0180', 'b0181', 'b0182', 'b0185', 'b0242', 'b0243', 'b0369', 'b0386', 'b0414', 'b0415', 'b0417', 'b0420', 'b0421', 'b0423', 'b0522', 'b0523', 'b0524', 'b0635', 'b0639', 'b0641', 'b0720', 'b0750', 'b0774', 'b0775', 'b0776', 'b0777', 'b0778', 'b0908', 'b0914', 'b0915', 'b0918', 'b1062', 'b1069', 'b1091', 'b1092', 'b1093', 'b1094', 'b1098', 'b1131', 'b1136', 'b1208', 'b1210', 'b1215', 'b1260', 'b1261', 'b1262', 'b1263', 'b1264', 'b1277', 'b1281', 'b1288', 'b1662', 'b1693', 'b1740', 'b1812', 'b2019', 'b2020', 'b2021', 'b2022', 'b2023', 'b2024', 'b2025', 'b2026', 'b2103', 'b2153', 'b2312', 'b2315', 'b2316', 'b2323', 'b2329', 'b2400', 'b2472', 'b2476', 'b2478', 'b2499', 'b2507', 'b2515', 'b2530', 'b2557', 'b2564', 'b2574', 'b2585', 'b2599', 'b2600', 'b2615', 'b2687', 'b2746', 'b2747', 'b2750', 'b2751', 'b2752', 'b2762', 'b2763', 'b2764', 'b2780', 'b2818', 'b2827', 'b2838', 'b2942', 'b3018', 'b3040', 'b3041', 'b3058', 'b3172', 'b3176', 'b3177', 'b3187', 'b3189', 'b3196', 'b3198', 'b3199', 'b3200', 'b3201', 'b3255', 'b3256', 'b3360', 'b3368', 'b3389', 'b3412', 'b3433', 'b3607', 'b3633', 'b3634', 'b3639', 'b3642', 'b3648', 'b3729', 'b3730', 'b3770', 'b3771', 'b3774', 'b3804', 'b3805', 'b3809', 'b3843', 'b3850', 'b3870', 'b3939', 'b3941', 'b3957', 'b3958', 'b3959', 'b3960', 'b3967', 'b3972', 'b3974', 'b3990', 'b3991', 'b3992', 'b3993', 'b3994', 'b3997', 'b4005', 'b4006', 'b4013', 'b4040', 'b4160', 'b4177', 'b4214', 'b4245', 'b4261', 'b4262', 'b4407', 's0001']

In [3]:
# 中文：转为 set 方便快速查询
# EN: Convert to set for fast membership queries

essential_set = set(essential_genes)
print("Essential genes:", len(essential_set))


Essential genes: 196


## Part 2 — Basic evaluation on 5 folds (overall)

中文：读取 5 个 fold 的 test 文件，输出每折的总体指标，并给出跨fold的 mean±sd。  
EN: Load all fold test files, compute overall metrics per fold, and report mean±sd across folds.

> 默认使用 `y_pred`（阈值0.5）做二分类预测标签；AUC 使用 `y_score`。


In [4]:
# 中文：读取所有 fold 的 test 文件，并计算总体指标
# EN: Load all fold test files and compute overall metrics per fold

import glob, os, re
import pandas as pd
import numpy as np
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support, accuracy_score

# 1) 发现可用的fold（要求test文件存在）
test_files = sorted(glob.glob("fold-*-test.csv"))
if len(test_files) == 0:
    raise FileNotFoundError("No fold-*-test.csv found in current directory.")

fold_ids = []
for fp in test_files:
    m = re.search(r"fold-(\d+)-test\.csv", os.path.basename(fp))
    if m:
        fold_ids.append(int(m.group(1)))
fold_ids = sorted(set(fold_ids))

print("Detected folds:", fold_ids)

# 2) 逐fold读入并算指标
rows = []
for k in fold_ids:
    fp = f"fold-{k}-test.csv"
    df = pd.read_csv(fp)

    # canonicalize triple (avoid permutation duplicates)
    genes_sorted = np.sort(df[["g1","g2","g3"]].values.astype(str), axis=1)
    df[["g1","g2","g3"]] = genes_sorted

    y_true = df["y_true"].astype(int).values
    y_pred = df["y_pred"].astype(int).values
    y_score = df["y_score"].astype(float).values

    # metrics
    acc = accuracy_score(y_true, y_pred)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, labels=[0,1], zero_division=0)
    # prec/rec/f1 arrays correspond to class 0 and class 1
    auc = roc_auc_score(y_true, y_score) if len(np.unique(y_true)) > 1 else np.nan

    rows.append({
        "fold": k,
        "N": len(df),
        "pos_rate(y=1)": float(np.mean(y_true)),
        "AUC": auc,
        "Accuracy": acc,
        "F1_0": float(f1[0]),
        "F1_1": float(f1[1]),
        "Precision_1": float(prec[1]),
        "Recall_1": float(rec[1]),
    })

df_overall = pd.DataFrame(rows).sort_values("fold")
display(df_overall)

# 3) mean±sd across folds
def mean_sd(x):
    return f"{np.mean(x):.3f}±{np.std(x, ddof=1):.3f}" if len(x) > 1 else f"{np.mean(x):.3f}±nan"

summary = {
    "AUC": mean_sd(df_overall["AUC"].values),
    "Accuracy": mean_sd(df_overall["Accuracy"].values),
    "F1_0": mean_sd(df_overall["F1_0"].values),
    "F1_1": mean_sd(df_overall["F1_1"].values),
    "Precision_1": mean_sd(df_overall["Precision_1"].values),
    "Recall_1": mean_sd(df_overall["Recall_1"].values),
}
print("\nMean±SD across folds:")
for k,v in summary.items():
    print(f"{k}: {v}")


Detected folds: [1, 2, 3, 4, 5]


Unnamed: 0,fold,N,pos_rate(y=1),AUC,Accuracy,F1_0,F1_1,Precision_1,Recall_1
0,1,600000,0.377402,0.580763,0.551405,0.602042,0.486003,0.428139,0.561952
1,2,600000,0.339545,0.634371,0.582117,0.637142,0.50742,0.423018,0.633897
2,3,600000,0.308592,0.54038,0.53186,0.610889,0.412549,0.336632,0.532678
3,4,600000,0.325213,0.568603,0.565618,0.651815,0.422704,0.372237,0.489002
4,5,600000,0.347988,0.589298,0.559472,0.621056,0.473987,0.405474,0.570359



Mean±SD across folds:
AUC: 0.583±0.034
Accuracy: 0.558±0.018
F1_0: 0.625±0.020
F1_1: 0.461±0.041
Precision_1: 0.393±0.038
Recall_1: 0.558±0.053


## Part 3 — Essential gene count (n_essential) vs accuracy

中文：对每条 triple 计算 `n_essential∈{0,1,2,3}`，并输出每折在不同 n_essential 下的准确率与样本量。  
EN: Compute `n_essential` per triple and report accuracy and sample counts by n_essential for each fold.


In [5]:
# 中文：n_essential 分组准确率（每个fold）
# EN: Accuracy grouped by n_essential (per fold)

import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score

rows = []
for k in fold_ids:
    df = pd.read_csv(f"fold-{k}-test.csv")
    genes_sorted = np.sort(df[["g1","g2","g3"]].values.astype(str), axis=1)
    df[["g1","g2","g3"]] = genes_sorted

    df["y_true"] = df["y_true"].astype(int)
    df["y_pred"] = df["y_pred"].astype(int)

    df["n_essential"] = (
        df["g1"].isin(essential_set).astype(int) +
        df["g2"].isin(essential_set).astype(int) +
        df["g3"].isin(essential_set).astype(int)
    )

    for ne in [0,1,2,3]:
        sub = df[df["n_essential"]==ne]
        if len(sub) == 0:
            continue
        acc = accuracy_score(sub["y_true"], sub["y_pred"])
        rows.append({
            "fold": k,
            "n_essential": ne,
            "N": len(sub),
            "pos_rate(y=1)": float(sub["y_true"].mean()),
            "Accuracy": acc,
        })

df_ne = pd.DataFrame(rows).sort_values(["fold","n_essential"])
display(df_ne)

# 中文：把每个 n_essential 的准确率在fold上做 mean±sd
# EN: mean±sd across folds for each n_essential
agg = (
    df_ne.groupby("n_essential")
    .agg(N_total=("N","sum"),
         acc_mean=("Accuracy","mean"),
         acc_sd=("Accuracy", lambda x: np.std(x, ddof=1) if len(x)>1 else np.nan))
    .reset_index()
)
agg["Accuracy(mean±sd)"] = agg.apply(lambda r: f"{r.acc_mean:.3f}±{r.acc_sd:.3f}", axis=1)
display(agg[["n_essential","N_total","Accuracy(mean±sd)"]])


Unnamed: 0,fold,n_essential,N,pos_rate(y=1),Accuracy
0,1,0,373620,0.000163,0.544998
1,1,1,192452,1.0,0.548163
2,1,2,32115,1.0,0.6364
3,1,3,1813,1.0,0.710425
4,2,0,396308,8.8e-05,0.555505
5,2,1,176965,1.0,0.615274
6,2,2,25565,1.0,0.752826
7,2,3,1162,1.0,0.85284
8,3,0,415026,0.000436,0.531494
9,3,1,163382,1.0,0.526031


Unnamed: 0,n_essential,N_total,Accuracy(mean±sd)
0,0,1981279,0.558±0.027
1,1,883777,0.545±0.050
2,2,128758,0.639±0.076
3,3,6186,0.720±0.091


## Part 4 — High-confidence subsets by score margin (|y_score−0.5| ≥ t)

中文：定义高置信子集：`abs(y_score-0.5) >= t`，对 t=0.1/0.2/0.3/0.4 分别输出每折指标，并汇总 mean±sd。  
EN: Define high-confidence subset as `abs(y_score-0.5) >= t` for t=0.1/0.2/0.3/0.4. Report per-fold results and mean±sd.

> 这里的目的：验证“高置信区域”是否明显优于整体（常见会提升 Accuracy / F1）。


In [20]:
# 中文：高置信度子集（t=0.1,0.2,0.3,0.4）的表现
# EN: Performance on high-confidence subsets (t=0.1,0.2,0.3,0.4)

import pandas as pd
import numpy as np
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support, accuracy_score

t_list = [0.1, 0.2, 0.3, 0.4]

rows = []
for k in fold_ids:
    df = pd.read_csv(f"fold-{k}-test.csv")
    df["y_true"] = df["y_true"].astype(int)
    df["y_pred"] = df["y_pred"].astype(int)
    df["y_score"] = df["y_score"].astype(float)

    df["margin"] = (df["y_score"] - 0.5).abs()
    N_all = len(df)

    for t in t_list:
        sub = df[df["margin"] >= t]
        if len(sub) == 0:
            continue

        y_true = sub["y_true"].values
        y_pred = sub["y_pred"].values
        y_score = sub["y_score"].values

        acc = accuracy_score(y_true, y_pred)
        prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, labels=[0,1], zero_division=0)
        auc = roc_auc_score(y_true, y_score) if len(np.unique(y_true)) > 1 else np.nan

        rows.append({
            "fold": k,
            "t": t,
            "N_subset": len(sub),
            "coverage": len(sub)/N_all,
            "AUC": auc,
            "Accuracy": acc,
            "F1_1": float(f1[1]),
            "Precision_1": float(prec[1]),
            "Recall_1": float(rec[1]),
        })

df_conf = pd.DataFrame(rows).sort_values(["fold","t"])
display(df_conf)

# 中文：按 t 汇总 mean±sd across folds
# EN: Aggregate mean±sd across folds by t
def mean_sd_col(s):
    return f"{np.mean(s):.3f}±{np.std(s, ddof=1):.3f}" if len(s)>1 else f"{np.mean(s):.3f}±nan"

grouped = []
for t in t_list:
    sub = df_conf[df_conf["t"]==t]
    if len(sub)==0:
        continue
    grouped.append({
        "t": t,
        "coverage(mean±sd)": mean_sd_col(sub["coverage"].values),
        "Accuracy(mean±sd)": mean_sd_col(sub["Accuracy"].values),
        "F1_1(mean±sd)": mean_sd_col(sub["F1_1"].values),
        "AUC(mean±sd)": mean_sd_col(sub["AUC"].values),
    })

df_conf_summary = pd.DataFrame(grouped)
display(df_conf_summary)


Unnamed: 0,fold,t,N_subset,coverage,AUC,Accuracy,F1_1,Precision_1,Recall_1
0,1,0.1,579817,0.966362,0.582642,0.553062,0.487722,0.429709,0.563844
1,1,0.2,557735,0.929558,0.584788,0.555018,0.489772,0.431497,0.566246
2,1,0.3,530737,0.884562,0.587517,0.557442,0.492647,0.434162,0.569342
3,1,0.4,489124,0.815207,0.592028,0.561778,0.497836,0.438645,0.575495
4,2,0.1,582151,0.970252,0.636633,0.584718,0.510917,0.426091,0.637914
5,2,0.2,562919,0.938198,0.639142,0.587397,0.514466,0.429177,0.642061
6,2,0.3,539221,0.898702,0.642251,0.590912,0.51922,0.433362,0.647505
7,2,0.4,503013,0.838355,0.647184,0.596173,0.526333,0.439565,0.655781
8,3,0.1,581589,0.969315,0.540896,0.53283,0.412915,0.336714,0.533695
9,3,0.2,561618,0.93603,0.541434,0.533713,0.413545,0.336898,0.535339


Unnamed: 0,t,coverage(mean±sd),Accuracy(mean±sd),F1_1(mean±sd),AUC(mean±sd)
0,0.1,0.969±0.002,0.560±0.019,0.462±0.042,0.584±0.035
1,0.2,0.936±0.004,0.562±0.020,0.464±0.044,0.586±0.036
2,0.3,0.895±0.006,0.564±0.021,0.466±0.045,0.588±0.037
3,0.4,0.833±0.010,0.568±0.022,0.469±0.048,0.591±0.038


## Part 5 — Accuracy by n_essential under different confidence thresholds (t)

中文：在每个 t 的高置信子集内，再按 `n_essential=0/1/2/3` 分组，计算准确率。  
EN: Within each high-confidence subset (by t), compute accuracy stratified by `n_essential`.

> 目的：回答“在更确定的预测区域里，essential genes 是否仍然导致更高错误率？”  


no merge-n_essentials

In [21]:
# 中文：Part 5（增强版）—— 在不同置信度阈值 t 下，按 n_essential 分组，输出每个fold的 coverage + 多指标
# EN: Part 5 (enhanced) — Under confidence threshold t, stratify by n_essential and report coverage + metrics per fold.

import os, re, glob
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, roc_auc_score, precision_recall_fscore_support

# ---- 参数 ----
t_list = [0.1, 0.2, 0.3, 0.4]
ne_list = [0, 1, 2, 3]

# ---- 找到 folds ----
test_files = sorted(glob.glob("fold-*-test.csv"))
if len(test_files) == 0:
    raise FileNotFoundError("No fold-*-test.csv found in current directory.")

fold_ids = []
for fp in test_files:
    m = re.search(r"fold-(\d+)-test\.csv", os.path.basename(fp))
    if m:
        fold_ids.append(int(m.group(1)))
fold_ids = sorted(set(fold_ids))
print("Detected folds:", fold_ids)

rows = []

for k in fold_ids:
    df = pd.read_csv(f"fold-{k}-test.csv")

    # canonicalize gene order (避免排列重复)
    genes_sorted = np.sort(df[["g1","g2","g3"]].values.astype(str), axis=1)
    df[["g1","g2","g3"]] = genes_sorted

    # 基础列
    df["y_true"]  = df["y_true"].astype(int)
    df["y_score"] = df["y_score"].astype(float)

    # 预测标签：默认用 y_pred (0.5阈值)
    df["y_pred_use"] = df["y_pred"].astype(int)

    # 如果你想用 best_thr 的预测：取消下一行注释，并注释掉上面那行
    # df["y_pred_use"] = df["y_pred_bestthr"].astype(int)

    # 置信度 margin
    df["margin"] = (df["y_score"] - 0.5).abs()

    # n_essential
    df["n_essential"] = (
        df["g1"].isin(essential_set).astype(int) +
        df["g2"].isin(essential_set).astype(int) +
        df["g3"].isin(essential_set).astype(int)
    )

    N_all = len(df)

    for t in t_list:
        df_t = df[df["margin"] >= t]
        if len(df_t) == 0:
            continue

        for ne in ne_list:
            sub = df_t[df_t["n_essential"] == ne]
            if len(sub) == 0:
                continue

            y_true = sub["y_true"].values
            y_pred = sub["y_pred_use"].values
            y_score = sub["y_score"].values

            acc = accuracy_score(y_true, y_pred)

            # class-1 metrics
            prec, rec, f1, _ = precision_recall_fscore_support(
                y_true, y_pred, labels=[0,1], zero_division=0
            )
            f1_1 = float(f1[1])
            prec_1 = float(prec[1])
            rec_1 = float(rec[1])

            # AUC（若子集只有单一类别，会报错，因此设为 NaN）
            auc = np.nan
            if len(np.unique(y_true)) > 1:
                auc = roc_auc_score(y_true, y_score)

            rows.append({
                "fold": k,
                "t": t,
                "n_essential": ne,
                "N_subset": int(len(sub)),
                "coverage": float(len(sub) / N_all),
                "pos_rate(y=1)": float(np.mean(y_true)),
                "Accuracy": float(acc),
                "F1_1": f1_1,
                "Precision_1": prec_1,
                "Recall_1": rec_1,
                "AUC": float(auc),
            })

df_part5_perfold = pd.DataFrame(rows).sort_values(["t","n_essential","fold"])
display(df_part5_perfold)

# 保存 per-fold 结果
df_part5_perfold.to_csv("Part5_perfold_confidence_by_essential.csv", index=False)
print("Saved: Part5_perfold_confidence_by_essential.csv")

Detected folds: [1, 2, 3, 4, 5]


Unnamed: 0,fold,t,n_essential,N_subset,coverage,pos_rate(y=1),Accuracy,F1_1,Precision_1,Recall_1,AUC
0,1,0.1,0,361090,0.601817,0.000158,0.546515,0.000317,0.000159,0.456140,0.510337
16,2,0.1,0,384234,0.640390,0.000086,0.557322,0.000270,0.000135,0.696970,0.618371
32,3,0.1,0,402735,0.671225,0.000439,0.532444,0.000997,0.000499,0.531073,0.536128
48,4,0.1,0,393811,0.656352,0.000523,0.605184,0.001297,0.000650,0.490291,0.600143
64,5,0.1,0,379989,0.633315,0.000087,0.555042,0.000378,0.000189,0.969697,0.881867
...,...,...,...,...,...,...,...,...,...,...,...
15,1,0.4,3,1476,0.002460,1.000000,0.756098,0.861111,1.000000,0.756098,
31,2,0.4,3,1056,0.001760,1.000000,0.884470,0.938693,1.000000,0.884470,
47,3,0.4,3,710,0.001183,1.000000,0.656338,0.792517,1.000000,0.656338,
63,4,0.4,3,790,0.001317,1.000000,0.683544,0.812030,1.000000,0.683544,


Saved: Part5_perfold_confidence_by_essential.csv


In [22]:
# 中文：把 Part5 的 per-fold 结果汇总成 mean±sd（可直接与“严格划分总体表”并排展示）
# EN: Summarize Part 5 results into mean±sd across folds (paper-ready table).

import numpy as np
import pandas as pd

if "df_part5_perfold" not in globals() or len(df_part5_perfold) == 0:
    raise RuntimeError("df_part5_perfold is empty. Please run the previous cell first.")

def mean_sd(series):
    x = np.asarray(series, dtype=float)
    x = x[~np.isnan(x)]
    if len(x) == 0:
        return "nan±nan"
    if len(x) == 1:
        return f"{np.mean(x):.3f}±nan"
    return f"{np.mean(x):.3f}±{np.std(x, ddof=1):.3f}"

group_rows = []
for t in sorted(df_part5_perfold["t"].unique()):
    for ne in sorted(df_part5_perfold["n_essential"].unique()):
        sub = df_part5_perfold[(df_part5_perfold["t"]==t) & (df_part5_perfold["n_essential"]==ne)]
        if len(sub) == 0:
            continue

        group_rows.append({
            "t": t,
            "n_essential": int(ne),
            "N_total": int(sub["N_subset"].sum()),
            "coverage(mean±sd)": mean_sd(sub["coverage"]),
            "pos_rate(mean±sd)": mean_sd(sub["pos_rate(y=1)"]),
            "Accuracy(mean±sd)": mean_sd(sub["Accuracy"]),
            "F1_1(mean±sd)": mean_sd(sub["F1_1"]),
            "Precision_1(mean±sd)": mean_sd(sub["Precision_1"]),
            "Recall_1(mean±sd)": mean_sd(sub["Recall_1"]),
            "AUC(mean±sd)": mean_sd(sub["AUC"]),  # 有些子集可能是 nan（单一类别），属正常
        })

df_part5_summary = pd.DataFrame(group_rows).sort_values(["t","n_essential"])
display(df_part5_summary)

# 保存 summary 结果（你可以直接粘到主文/补充材料表格）
df_part5_summary.to_csv("Part5_summary_confidence_by_essential.csv", index=False)
print("Saved: Part5_summary_confidence_by_essential.csv")

Unnamed: 0,t,n_essential,N_total,coverage(mean±sd),pos_rate(mean±sd),Accuracy(mean±sd),F1_1(mean±sd),Precision_1(mean±sd),Recall_1(mean±sd),AUC(mean±sd)
0,0.1,0,1921859,0.641±0.026,0.000±0.000,0.559±0.027,0.001±0.000,0.000±0.000,0.629±0.212,0.629±0.148
1,0.1,1,855784,0.285±0.017,1.000±0.000,0.546±0.051,0.705±0.043,1.000±0.000,0.546±0.051,nan±nan
2,0.1,2,124739,0.042±0.007,1.000±0.000,0.643±0.078,0.781±0.057,1.000±0.000,0.643±0.078,nan±nan
3,0.1,3,5980,0.002±0.001,1.000±0.000,0.724±0.094,0.837±0.062,1.000±0.000,0.724±0.094,nan±nan
4,0.2,0,1857364,0.619±0.027,0.000±0.000,0.561±0.028,0.001±0.000,0.000±0.000,0.632±0.215,0.629±0.146
5,0.2,1,825239,0.275±0.017,1.000±0.000,0.548±0.053,0.707±0.044,1.000±0.000,0.548±0.053,nan±nan
6,0.2,2,120309,0.040±0.007,1.000±0.000,0.648±0.079,0.784±0.057,1.000±0.000,0.648±0.079,nan±nan
7,0.2,3,5787,0.002±0.001,1.000±0.000,0.730±0.094,0.841±0.062,1.000±0.000,0.730±0.094,nan±nan
8,0.3,0,1777888,0.593±0.027,0.000±0.000,0.563±0.029,0.001±0.000,0.000±0.000,0.620±0.221,0.629±0.145
9,0.3,1,787967,0.263±0.016,1.000±0.000,0.550±0.055,0.708±0.046,1.000±0.000,0.550±0.055,nan±nan


Saved: Part5_summary_confidence_by_essential.csv


merge n_essentials 0 and >=1

In [25]:
# 中文：精简版 Part5 —— 每fold 计算 (t, n_essential=0 vs >=1) 的 coverage / pos_coverage / neg_coverage / accuracy
# EN: Minimal Part5 — per-fold metrics for (t, n_essential=0 vs >=1)

import os, re, glob
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score

# ---- 参数 ----
t_list = [0.1, 0.2, 0.3, 0.4]

# ---- folds ----
test_files = sorted(glob.glob("fold-*-test.csv"))
if len(test_files) == 0:
    raise FileNotFoundError("No fold-*-test.csv found in current directory.")

fold_ids = []
for fp in test_files:
    m = re.search(r"fold-(\d+)-test\.csv", os.path.basename(fp))
    if m:
        fold_ids.append(int(m.group(1)))
fold_ids = sorted(set(fold_ids))
print("Detected folds:", fold_ids)

# ---- 兜底：如果你还没定义 essential_set，这里从 essential_genes 建一个 ----
if "essential_set" not in globals():
    try:
        essential_set = set(essential_genes)
    except NameError:
        raise RuntimeError("Please define essential_set (or essential_genes) before running this cell.")

rows = []

for k in fold_ids:
    df = pd.read_csv(f"fold-{k}-test.csv")

    # canonicalize gene order
    genes_sorted = np.sort(df[["g1","g2","g3"]].values.astype(str), axis=1)
    df[["g1","g2","g3"]] = genes_sorted

    # types
    df["y_true"]  = df["y_true"].astype(int)
    df["y_pred"]  = df["y_pred"].astype(int)
    df["y_score"] = df["y_score"].astype(float)

    # margin
    df["margin"] = (df["y_score"] - 0.5).abs()

    # n_essential
    df["n_essential"] = (
        df["g1"].isin(essential_set).astype(int) +
        df["g2"].isin(essential_set).astype(int) +
        df["g3"].isin(essential_set).astype(int)
    )

    # collapse: 0 vs >=1
    df["ne_group"] = np.where(df["n_essential"] == 0, "0", ">=1")

    N_all = len(df)
    P_all = int((df["y_true"] == 1).sum())
    Nneg_all = int((df["y_true"] == 0).sum())

    for t in t_list:
        df_t = df[df["margin"] >= t]
        if len(df_t) == 0:
            continue

        for grp in ["0", ">=1"]:
            sub = df_t[df_t["ne_group"] == grp]
            if len(sub) == 0:
                continue

            N_sub = int(len(sub))
            P_sub = int((sub["y_true"] == 1).sum())
            Nneg_sub = int((sub["y_true"] == 0).sum())

            coverage = N_sub / N_all
            pos_coverage = (P_sub / P_all) if P_all > 0 else np.nan
            neg_coverage = (Nneg_sub / Nneg_all) if Nneg_all > 0 else np.nan

            acc = accuracy_score(sub["y_true"], sub["y_pred"])

            rows.append({
                "fold": k,
                "t": t,
                "n_essential_group": grp,
                "coverage": float(coverage),
                "pos_coverage": float(pos_coverage),
                "neg_coverage": float(neg_coverage),
                "Accuracy": float(acc),
            })

df_min_perfold = pd.DataFrame(rows).sort_values(["t","n_essential_group","fold"])
display(df_min_perfold)

#df_min_perfold.to_csv("Part5_minimal_perfold.csv", index=False)
print("Saved: Part5_minimal_perfold.csv")

Detected folds: [1, 2, 3, 4, 5]


Unnamed: 0,fold,t,n_essential_group,coverage,pos_coverage,neg_coverage,Accuracy
0,1,0.1,0,0.601817,0.000252,0.966468,0.546515
8,2,0.1,0,0.64039,0.000162,0.969536,0.557322
16,3,0.1,0,0.671225,0.000956,0.970382,0.532444
24,4,0.1,0,0.656352,0.001056,0.972171,0.605184
32,5,0.1,0,0.633315,0.000158,0.97124,0.555042
1,1,0.1,>=1,0.364545,0.965934,0.0,0.563872
9,2,0.1,>=1,0.329862,0.971481,0.0,0.637904
17,3,0.1,>=1,0.29809,0.965969,0.0,0.533698
25,4,0.1,>=1,0.314657,0.967539,0.0,0.488649
33,5,0.1,>=1,0.337018,0.968476,0.0,0.572684


Saved: Part5_minimal_perfold.csv


In [27]:
# 中文：把精简版 per-fold 结果汇总成 mean±sd（最终表），并保存 CSV
# EN: Summarize minimal per-fold results into mean±sd (final table) and save CSV

import numpy as np
import pandas as pd

if "df_min_perfold" not in globals() or len(df_min_perfold) == 0:
    raise RuntimeError("df_min_perfold is empty. Please run Cell 1 first.")

def mean_sd(x):
    x = np.asarray(x, dtype=float)
    x = x[~np.isnan(x)]
    if len(x) == 0:
        return "nan±nan"
    if len(x) == 1:
        return f"{np.mean(x):.3f}±nan"
    return f"{np.mean(x):.3f}±{np.std(x, ddof=1):.3f}"

out_rows = []
for t in sorted(df_min_perfold["t"].unique()):
    for grp in ["0", ">=1"]:
        sub = df_min_perfold[(df_min_perfold["t"]==t) & (df_min_perfold["n_essential_group"]==grp)]
        if len(sub) == 0:
            continue
        out_rows.append({
            "t": t,
            "n_essential_group": grp,
            "coverage(mean±sd)": mean_sd(sub["coverage"]),
            "pos_coverage(mean±sd)": mean_sd(sub["pos_coverage"]),
            "neg_coverage(mean±sd)": mean_sd(sub["neg_coverage"]),
            "Accuracy(mean±sd)": mean_sd(sub["Accuracy"]),
        })

df_min_summary = pd.DataFrame(out_rows).sort_values(["t","n_essential_group"])
display(df_min_summary)

#df_min_summary.to_csv("Part5_minimal_summary.csv", index=False)
print("Saved: Part5_minimal_summary.csv")

Unnamed: 0,t,n_essential_group,coverage(mean±sd),pos_coverage(mean±sd),neg_coverage(mean±sd),Accuracy(mean±sd)
0,0.1,0,0.641±0.026,0.001±0.000,0.970±0.002,0.559±0.027
1,0.1,>=1,0.329±0.025,0.968±0.002,0.000±0.000,0.559±0.055
2,0.2,0,0.619±0.027,0.001±0.000,0.937±0.005,0.561±0.028
3,0.2,>=1,0.317±0.024,0.933±0.005,0.000±0.000,0.561±0.057
4,0.3,0,0.593±0.027,0.000±0.000,0.897±0.008,0.563±0.029
5,0.3,>=1,0.303±0.023,0.891±0.007,0.000±0.000,0.564±0.059
6,0.4,0,0.552±0.028,0.000±0.000,0.836±0.013,0.566±0.031
7,0.4,>=1,0.281±0.021,0.827±0.011,0.000±0.000,0.569±0.062


Saved: Part5_minimal_summary.csv


## Part 6 — OOD similarity stratification (max cosine similarity to training genes)

中文：对每个 fold：  
1) 从 `fold-k-train.csv` 提取训练集中出现过的基因集合 \(G_{train}\)  
2) 对每个 test gene，计算其 2-mer 向量与 \(G_{train}\) 中所有基因的 **最大 cosine 相似度**：`max_sim(g)`  
3) 对每条 test triple，计算 `sim_score = max(max_sim(g1), max_sim(g2), max_sim(g3))`  
4) 按 `sim_score` 分成四分位组 Q1–Q4，并计算每组性能

EN: For each fold:  
1) get training gene set from `fold-k-train.csv`  
2) for each test gene compute max cosine similarity to any training gene (using 2-mer vectors)  
3) for each triple define `sim_score` as max of its three genes' max_sim  
4) stratify sim_score into quartiles Q1–Q4 and compute performance per quartile

> 目的：解释“为什么整体只有 ~0.6”：更接近训练分布（Q4）的子集通常更容易预测。


In [12]:
# 中文：计算每fold的 OOD 相似度分组（Q1–Q4）并输出每组性能
# EN: Compute OOD similarity strata (Q1–Q4) per fold and report performance per quartile

import pandas as pd
import numpy as np
from sklearn.metrics import roc_auc_score, precision_recall_fscore_support, accuracy_score

# 小提示：如果你的 two_mer_dict 构建很慢，可以提前保存/加载；这里假设 two_mer_dict 已在 Part1 生成
# Tip: assumes two_mer_dict is available from Part 1.

rows = []
for k in fold_ids:
    train_fp = f"fold-{k}-train.csv"
    test_fp  = f"fold-{k}-test.csv"

    if not os.path.exists(train_fp):
        print(f"[WARN] missing {train_fp}, skip fold {k}")
        continue

    df_train = pd.read_csv(train_fp)
    df_test  = pd.read_csv(test_fp)

    # canonicalize
    for df in [df_train, df_test]:
        genes_sorted = np.sort(df[["g1","g2","g3"]].values.astype(str), axis=1)
        df[["g1","g2","g3"]] = genes_sorted

    # 1) training gene set
    train_genes = pd.unique(df_train[["g1","g2","g3"]].values.ravel()).astype(str).tolist()
    test_genes  = pd.unique(df_test[["g1","g2","g3"]].values.ravel()).astype(str).tolist()

    # 2) build normalized matrices (only genes that exist in two_mer_dict)
    train_vecs = []
    train_keep = []
    for g in train_genes:
        v = two_mer_dict.get(g, None)
        if v is None:
            continue
        train_vecs.append(v.astype(np.float32))
        train_keep.append(g)

    test_vecs = []
    test_keep = []
    for g in test_genes:
        v = two_mer_dict.get(g, None)
        if v is None:
            continue
        test_vecs.append(v.astype(np.float32))
        test_keep.append(g)

    train_mat = np.vstack(train_vecs) if len(train_vecs)>0 else np.zeros((0,400), dtype=np.float32)
    test_mat  = np.vstack(test_vecs)  if len(test_vecs)>0  else np.zeros((0,400), dtype=np.float32)

    if train_mat.shape[0]==0 or test_mat.shape[0]==0:
        print(f"[WARN] fold {k}: empty vectors after filtering, skip similarity analysis")
        continue

    # normalize for cosine similarity
    eps = 1e-8
    train_norm = train_mat / (np.linalg.norm(train_mat, axis=1, keepdims=True) + eps)
    test_norm  = test_mat  / (np.linalg.norm(test_mat, axis=1, keepdims=True) + eps)

    # 3) max cosine similarity for each test gene
    # sim = test_norm @ train_norm.T  (n_test_genes x n_train_genes)
    sim = test_norm @ train_norm.T
    max_sim = sim.max(axis=1)  # per test gene

    gene2maxsim = dict(zip(test_keep, max_sim.tolist()))

    # 4) per triple sim_score
    df_test["maxsim_g1"] = df_test["g1"].map(gene2maxsim)
    df_test["maxsim_g2"] = df_test["g2"].map(gene2maxsim)
    df_test["maxsim_g3"] = df_test["g3"].map(gene2maxsim)
    df_test["sim_score"] = df_test[["maxsim_g1","maxsim_g2","maxsim_g3"]].max(axis=1)

    # drop NaN sim_score (genes missing vectors)
    df_s = df_test.dropna(subset=["sim_score"]).copy()
    if len(df_s) == 0:
        print(f"[WARN] fold {k}: no rows with sim_score, skip")
        continue

    # quartiles
    try:
        df_s["sim_bin"] = pd.qcut(df_s["sim_score"], 4, labels=["Q1","Q2","Q3","Q4"])
    except ValueError:
        # in case of too many duplicates -> fewer bins
        df_s["sim_bin"] = pd.qcut(df_s["sim_score"], 4, labels=["Q1","Q2","Q3","Q4"], duplicates="drop")

    # metrics per bin
    for q in ["Q1","Q2","Q3","Q4"]:
        sub = df_s[df_s["sim_bin"]==q]
        if len(sub)==0:
            continue
        y_true = sub["y_true"].astype(int).values
        y_pred = sub["y_pred"].astype(int).values
        y_score = sub["y_score"].astype(float).values

        acc = accuracy_score(y_true, y_pred)
        prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, labels=[0,1], zero_division=0)
        auc = roc_auc_score(y_true, y_score) if len(np.unique(y_true)) > 1 else np.nan

        rows.append({
            "fold": k,
            "quartile": q,
            "N": len(sub),
            "sim_score_mean": float(sub["sim_score"].mean()),
            "AUC": auc,
            "Accuracy": acc,
            "F1_1": float(f1[1]),
        })

df_sim = pd.DataFrame(rows).sort_values(["quartile","fold"])
display(df_sim)


Unnamed: 0,fold,quartile,N,sim_score_mean,AUC,Accuracy,F1_1
0,1,Q1,151614,0.708173,0.551621,0.518573,0.543712
4,2,Q1,151605,0.710142,0.631088,0.575502,0.571588
8,3,Q1,151473,0.708019,0.50937,0.488635,0.448038
12,4,Q1,150735,0.717202,0.547134,0.529525,0.462273
16,5,Q1,151185,0.715623,0.604011,0.549155,0.510151
1,1,Q2,149106,0.753419,0.569855,0.543566,0.463709
5,2,Q2,148803,0.756218,0.625764,0.578779,0.511103
9,3,Q2,149805,0.75759,0.550558,0.537639,0.48053
13,4,Q2,149647,0.755489,0.580942,0.57651,0.430653
17,5,Q2,149698,0.761285,0.576413,0.537108,0.508407


In [13]:
# 中文：把 Q1–Q4 的结果在 5 folds 上求 mean±sd，方便写进表格/补充材料
# EN: Aggregate Q1–Q4 results across folds (mean±sd)

import numpy as np
import pandas as pd

if "df_sim" not in globals() or len(df_sim)==0:
    print("df_sim is empty. Please run Part 6 cell above first.")
else:
    def mean_sd(x):
        return f"{np.mean(x):.3f}±{np.std(x, ddof=1):.3f}" if len(x)>1 else f"{np.mean(x):.3f}±nan"

    out_rows = []
    for q in ["Q1","Q2","Q3","Q4"]:
        sub = df_sim[df_sim["quartile"]==q]
        if len(sub)==0:
            continue
        out_rows.append({
            "quartile": q,
            "N_total": int(sub["N"].sum()),
            "AUC(mean±sd)": mean_sd(sub["AUC"].values),
            "Accuracy(mean±sd)": mean_sd(sub["Accuracy"].values),
            "F1_1(mean±sd)": mean_sd(sub["F1_1"].values),
            "sim_score_mean(mean±sd)": mean_sd(sub["sim_score_mean"].values),
        })

    df_sim_summary = pd.DataFrame(out_rows)
    display(df_sim_summary)


Unnamed: 0,quartile,N_total,AUC(mean±sd),Accuracy(mean±sd),F1_1(mean±sd),sim_score_mean(mean±sd)
0,Q1,756612,0.569±0.049,0.532±0.033,0.507±0.052,0.712±0.004
1,Q2,747059,0.581±0.028,0.555±0.021,0.479±0.033,0.757±0.003
2,Q3,759714,0.577±0.035,0.561±0.023,0.443±0.052,0.785±0.004
3,Q4,736615,0.587±0.037,0.585±0.015,0.387±0.049,0.837±0.013
