# 为long-format文件添加临床信息

In [5]:
import re
import numpy as np
import pandas as pd


In [13]:

# =========================
# 0) EDIT HERE
# =========================
TNB_RANK_LONG_CSV = "/work/longyh/BY/processed/TNB/rank/TNB_rank_long.csv"  # rank版 long
CLINICAL_XLSX     = "/work/longyh/BY/raw/1-s2.0-S0092867417311224-mmc2.xlsx"

OUT_LONG_CSV      = "/work/longyh/BY/processed/TNB/rank/tnb_rank_long_with_clinical.csv"

# 可选：只保留哪些策略 / 哪些metric
KEEP_STRATEGY = ["s1_rank", "s2_rank", "s3_rank", "s4_rank"]   # 若也要 total，可加 "total"
KEEP_WHICH    = ["unique"]                                    # or ["raw","unique"]
KEEP_UNIT     = ["mutation"]                                  # 主分析；附录再加 ["peptide","hla_peptide"]


In [10]:
# -------------------------
# Helpers
# -------------------------
def make_response_label(df, response_col="Response"):
    """
    Make response_label: CR/PR=1, SD/PD=0; remove NE.
    """
    df = df.copy()
    df = df[df[response_col].notna()]
    df[response_col] = df[response_col].astype(str).str.strip().str.upper()
    df = df[df[response_col] != "NE"].copy()

    response_map = {"CR": 1, "PR": 1, "SD": 0, "PD": 0}
    df["response_label"] = df[response_col].map(response_map)

    unknown = df.loc[df["response_label"].isna(), response_col].value_counts()
    if len(unknown) > 0:
        print("[WARN] Unmapped Response categories:\n", unknown)
        df = df.dropna(subset=["response_label"]).copy()

    df["response_label"] = df["response_label"].astype(int)
    return df


In [11]:
# -------------------------
# 1) Load
# -------------------------
tnb = pd.read_csv(TNB_RANK_LONG_CSV, dtype=str)
clinical = pd.read_excel(CLINICAL_XLSX, skiprows=2)

# sanity check
required = {"sample", "unit", "strategy", "metric", "value"}
missing = required - set(tnb.columns)
if missing:
    raise ValueError(f"Missing columns in rank long CSV: {missing}")

# normalize keys
tnb["sample"] = tnb["sample"].astype(str).str.strip()
clinical["Patient"] = clinical["Patient"].astype(str).str.strip()

# numeric
tnb["value"] = pd.to_numeric(tnb["value"], errors="coerce")

# clinical: response label + drop NE
clinical2 = make_response_label(clinical, response_col="Response")


In [12]:

# -------------------------
# 2) Merge (inner join = keep samples with clinical labels)
# -------------------------
merged = pd.merge(
    tnb,
    clinical2,
    left_on="sample",
    right_on="Patient",
    how="inner"
)

print("Merged rows =", len(merged))
print("Merged unique samples =", merged["sample"].nunique())
print("Responders =", int(merged.drop_duplicates("sample")["response_label"].sum()),
      "/", merged["sample"].nunique())

# -------------------------
# 3) Convert to legacy column names for plotting compatibility
#    (so you can keep your old plotting script almost unchanged)
# -------------------------
merged = merged.rename(columns={"metric": "which", "value": "tnb_value"})
merged["which"] = merged["which"].astype(str).str.strip()
merged["strategy"] = merged["strategy"].astype(str).str.strip()
merged["unit"] = merged["unit"].astype(str).str.strip()

# optional: filter
if KEEP_STRATEGY is not None:
    merged = merged[merged["strategy"].isin(KEEP_STRATEGY)].copy()
if KEEP_WHICH is not None:
    merged = merged[merged["which"].isin(KEEP_WHICH)].copy()
if KEEP_UNIT is not None:
    merged = merged[merged["unit"].isin(KEEP_UNIT)].copy()

# optional: add strategy_name for rank version
strategy_map = {
    "s1_rank": "Binding-only (%Rank<2.0)",
    "s2_rank": "%Rank<2.0 & TPM>1",
    "s3_rank": "%Rank<0.5 & TPM>5",
    "s4_rank": "High-quality (%Rank<0.5,TPM>5,VAF>0.1,WT%Rank>2.0)",
    "total": "Total candidates",
}
merged["strategy_name"] = merged["strategy"].map(strategy_map).fillna(merged["strategy"])

# -------------------------
# 4) Keep only clinical cols you care about (optional)
# -------------------------
keep_cols = [
    # core ids / outcome
    "sample", "Patient", "Response", "response_label",
    # covariates / endpoints (use your existing names if present)
    "Mutation Load", "Neo-antigen Load", "Neo-peptide Load", "Cytolytic Score",
    "Dead/Alive\n(Dead = True)", "Time to Death\n(weeks)",
    "Cohort",
]

# keep only those that actually exist
keep_cols = [c for c in keep_cols if c in merged.columns]

final_cols = (
    ["sample", "unit", "strategy", "strategy_name", "which", "tnb_value"]
    + keep_cols
    + ["file"] if "file" in merged.columns else []
)

# de-dup any repeats in final_cols while preserving order
seen = set()
final_cols2 = []
for c in final_cols:
    if c not in seen and c in merged.columns:
        final_cols2.append(c)
        seen.add(c)

out = merged[final_cols2].copy()

print("Final long-format rows =", len(out))
print(out.head(10))




Merged rows = 440
Merged unique samples = 22
Responders = 6 / 22
Final long-format rows = 88
   sample      unit strategy  \
3    Pt10  mutation  s1_rank   
5    Pt10  mutation  s2_rank   
7    Pt10  mutation  s3_rank   
9    Pt10  mutation  s4_rank   
23   Pt11  mutation  s1_rank   
25   Pt11  mutation  s2_rank   
27   Pt11  mutation  s3_rank   
29   Pt11  mutation  s4_rank   
43   Pt18  mutation  s1_rank   
45   Pt18  mutation  s2_rank   

                                        strategy_name   which  tnb_value  \
3                            Binding-only (%Rank<2.0)  unique         47   
5                                   %Rank<2.0 & TPM>1  unique         39   
7                                   %Rank<0.5 & TPM>5  unique         20   
9   High-quality (%Rank<0.5,TPM>5,VAF>0.1,WT%Rank>...  unique          1   
23                           Binding-only (%Rank<2.0)  unique         28   
25                                  %Rank<2.0 & TPM>1  unique         20   
27                    

In [14]:
# -------------------------
# 5) Save
# -------------------------
out.to_csv(OUT_LONG_CSV, index=False)
print("Saved:", OUT_LONG_CSV)

Saved: /work/longyh/BY/processed/TNB/rank/tnb_rank_long_with_clinical.csv


# 绘制ROC和KM图

In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import roc_curve, auc

# lifelines for KM/Cox
from lifelines import KaplanMeierFitter, CoxPHFitter
from lifelines.statistics import logrank_test




In [15]:
# =========================
# 0) EDIT HERE
# =========================
LONG_CSV = "/work/longyh/BY/processed/TNB/rank/tnb_rank_long_with_clinical.csv"  # 你的long-format文件
OUTDIR = "/work/longyh/BY/plots"
COHORT_LABEL = "Ipi-N"  # 纯标签：不依赖临床Cohort列
FILTER_WHICH = ["unique"]  # 可改成 ["unique"] 或 ["raw","unique"]
FILTER_STRATEGY = ["s1_rank", "s2_rank", "s3_rank", "s4_rank"]  # 可选：只跑 ["s1"]
FILTER_UNIT = ["mutation", "peptide"]
N_BOOT = 2000
SEED = 1

# 是否保存图片
SAVE_PLOTS = True



In [16]:
# =========================
# 1) Helpers
# =========================
def ensure_response_label(df):
    """
    Ensure response_label exists: CR/PR=1, SD/PD=0, NE removed.
    """
    df = df.copy()
    if "response_label" in df.columns and df["response_label"].notna().any():
        # make sure int
        df = df.dropna(subset=["response_label"])
        df["response_label"] = df["response_label"].astype(int)
        return df

    if "Response" not in df.columns:
        raise ValueError("Neither response_label nor Response column found.")

    d = df.copy()
    d["Response"] = d["Response"].astype(str).str.strip().str.upper()
    d = d[d["Response"] != "NE"].copy()
    mp = {"CR": 1, "PR": 1, "SD": 0, "PD": 0}
    d["response_label"] = d["Response"].map(mp)
    unknown = d.loc[d["response_label"].isna(), "Response"].value_counts()
    if len(unknown) > 0:
        print("[WARN] Unmapped Response categories:\n", unknown)
        d = d.dropna(subset=["response_label"]).copy()
    d["response_label"] = d["response_label"].astype(int)
    return d


def bootstrap_auc_ci(y_true, y_score, n_boot=2000, seed=1):
    """
    Bootstrap percentile CI for AUC.
    Skips resamples with single class.
    """
    rng = np.random.default_rng(seed)
    y_true = np.asarray(y_true, dtype=int)
    y_score = np.asarray(y_score, dtype=float)

    fpr, tpr, _ = roc_curve(y_true, y_score)
    auc_point = auc(fpr, tpr)

    boot = []
    n = len(y_true)
    for _ in range(n_boot):
        idx = rng.integers(0, n, size=n)
        yt = y_true[idx]
        ys = y_score[idx]
        if len(np.unique(yt)) < 2:
            continue
        fpr_b, tpr_b, _ = roc_curve(yt, ys)
        boot.append(auc(fpr_b, tpr_b))

    if len(boot) < 50:
        return float(auc_point), np.nan, np.nan, len(boot)

    lo, hi = np.quantile(boot, [0.025, 0.975])
    return float(auc_point), float(lo), float(hi), len(boot)


def safe_filename(s: str) -> str:
    s = str(s)
    s = re.sub(r"[^\w\-.]+", "_", s)
    return s[:200]


def find_time_event_cols(df):
    """
    Prefer PFS if columns exist; else OS using:
      Time to Death (weeks) + Dead=True
    """
    cols = list(df.columns)

    # Look for any plausible PFS
    pfs_time_candidates = [c for c in cols if re.search(r"\bPFS\b|progression", str(c), flags=re.I)]
    pfs_event_candidates = [c for c in cols if re.search(r"event|status|progress", str(c), flags=re.I)]

    # heuristics: time column should contain 'week' or 'time'
    time_col = None
    event_col = None
    for c in pfs_time_candidates:
        if re.search(r"week|time", str(c), flags=re.I):
            time_col = c
            break
    for c in pfs_event_candidates:
        if re.search(r"event|status", str(c), flags=re.I):
            event_col = c
            break

    if time_col and event_col:
        return time_col, event_col, "PFS"

    # fallback to your known OS columns
    os_time = "Time to Death\n(weeks)" if "Time to Death\n(weeks)" in cols else None
    os_event = "Dead/Alive\n(Dead = True)" if "Dead/Alive\n(Dead = True)" in cols else None
    if os_time and os_event:
        return os_time, os_event, "OS"

    return None, None, None


def to_event01(x):
    if pd.isna(x):
        return np.nan
    if isinstance(x, (bool, np.bool_)):
        return int(bool(x))
    s = str(x).strip().lower()
    if s in {"true", "1", "dead", "yes", "event"}:
        return 1
    if s in {"false", "0", "alive", "no", "censored", "none"}:
        return 0
    try:
        v = int(float(s))
        return 1 if v != 0 else 0
    except Exception:
        return np.nan

In [18]:
# =========================
# 2) Load long-format (compat: old vs rank)
# =========================
df = pd.read_csv(LONG_CSV)

print("[INFO] Columns in LONG_CSV:", list(df.columns))

# Case A: rank-long schema: sample,unit,strategy,metric,value
if {"sample", "unit", "strategy", "metric", "value"}.issubset(df.columns):
    df["sample"] = df["sample"].astype(str).str.strip()
    df["unit"] = df["unit"].astype(str).str.strip()
    df["strategy"] = df["strategy"].astype(str).str.strip()
    df["metric"] = df["metric"].astype(str).str.strip()
    df["value"] = pd.to_numeric(df["value"], errors="coerce")

    # map to legacy names used by your plotting code
    df = df.rename(columns={"metric": "which", "value": "tnb_value"})

# Case B: legacy schema already: sample,unit,strategy,which,tnb_value
elif {"sample", "unit", "strategy", "which", "tnb_value"}.issubset(df.columns):
    df["sample"] = df["sample"].astype(str).str.strip()
    df["unit"] = df["unit"].astype(str).str.strip()
    df["strategy"] = df["strategy"].astype(str).str.strip()
    df["which"] = df["which"].astype(str).str.strip()
    df["tnb_value"] = pd.to_numeric(df["tnb_value"], errors="coerce")

else:
    raise ValueError(
        "Unrecognized long-format schema. Need either "
        "{sample,unit,strategy,metric,value} or {sample,unit,strategy,which,tnb_value}."
    )

# ---- from here, keep your original logic ----
# drop strategy == "total" if you only want s1~s4 plots
# (only apply this if those strategy names exist)
keep_strats = [s for s in ["s1_rank", "s2_rank", "s3_rank", "s4_rank"] if s in set(df["strategy"])]
if keep_strats:
    df = df[df["strategy"].isin(keep_strats)].copy()

# add/ensure response_label and drop NE
df = ensure_response_label(df)
df = df.dropna(subset=["tnb_value"]).copy()

# optional filters (use your original variables)
df = df[df["which"].isin(FILTER_WHICH)].copy()
df = df[df["strategy"].isin(FILTER_STRATEGY)].copy()
df = df[df["unit"].isin(FILTER_UNIT)].copy()

print("Long-format usable rows:", len(df))
print("Unique samples:", df["sample"].nunique())
print("Responders:", int(df.drop_duplicates("sample")["response_label"].sum()), "/", df["sample"].nunique())

os.makedirs(OUTDIR, exist_ok=True)


[INFO] Columns in LONG_CSV: ['sample', 'unit', 'strategy', 'strategy_name', 'which', 'tnb_value', 'Patient', 'Response', 'response_label', 'Mutation Load', 'Neo-antigen Load', 'Neo-peptide Load', 'Cytolytic Score', 'Dead/Alive\n(Dead = True)', 'Time to Death\n(weeks)', 'Cohort', 'file']
Long-format usable rows: 88
Unique samples: 22
Responders: 6 / 22


In [19]:


# =========================
# 3) AUC loop + ROC plots
# =========================
auc_rows = []

group_cols = ["unit", "strategy", "which"]
for (unit, strategy, which), g in df.groupby(group_cols):
    # g is multiple rows across samples; each sample should appear once per (unit,strategy,which)
    # If duplicates exist, keep first and warn
    gg = g.sort_values("sample").drop_duplicates(subset=["sample"])
    y = gg["response_label"].astype(int).values
    x = gg["tnb_value"].astype(float).values

    # skip invalid
    if len(gg) < 8 or len(np.unique(y)) < 2:
        auc_rows.append({
            "unit": unit, "strategy": strategy, "which": which,
            "n": int(len(gg)),
            "responders": int(y.sum()),
            "auc": np.nan, "ci_low": np.nan, "ci_high": np.nan,
            "boot_kept": 0,
            "suggest_flip": np.nan,
        })
        continue

    auc_point, ci_low, ci_high, boot_kept = bootstrap_auc_ci(y, x, n_boot=N_BOOT, seed=SEED)

    # quick direction hint (not auto flipping)
    auc_rev, _, _, _ = bootstrap_auc_ci(y, -x, n_boot=500, seed=SEED)
    suggest_flip = bool(auc_point < 0.5 and auc_rev > auc_point)

    auc_rows.append({
        "unit": unit, "strategy": strategy, "which": which,
        "n": int(len(gg)),
        "responders": int(y.sum()),
        "auc": float(auc_point),
        "ci_low": float(ci_low) if not np.isnan(ci_low) else np.nan,
        "ci_high": float(ci_high) if not np.isnan(ci_high) else np.nan,
        "boot_kept": int(boot_kept),
        "suggest_flip": suggest_flip,
    })

    # ROC plot
    fpr, tpr, _ = roc_curve(y, x)
    plt.figure(figsize=(6, 6))
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    ci_text = "CI unavailable" if np.isnan(ci_low) else f"95%CI [{ci_low:.3f}, {ci_high:.3f}]"
    flip_text = " (maybe flip sign)" if suggest_flip else ""
    plt.title(f"[{COHORT_LABEL}] {unit}-{strategy}-{which}\nAUC={auc_point:.3f}, {ci_text}{flip_text}")
    plt.tight_layout()

    if SAVE_PLOTS:
        fn = safe_filename(f"ROC_{COHORT_LABEL}_{unit}_{strategy}_{which}.png")
        plt.savefig(os.path.join(OUTDIR, fn), dpi=200)
        plt.close()
    else:
        plt.show()

auc_summary = pd.DataFrame(auc_rows).sort_values(["auc"], ascending=False)
auc_summary.to_csv("/work/longyh/BY/processed/TNB/auc_summary_long.csv", index=False)
print("Saved: auc_summary_long.csv")
print(auc_summary.head(20))
# =========================
# 3.5) AUC heatmap tables + plots
# =========================
def plot_heatmap(matrix_df, title, xlabel, ylabel, out_png):
    """
    Simple matplotlib heatmap for a pivot table (values numeric).
    """
    # ensure order
    mat = matrix_df.copy()

    plt.figure(figsize=(1.2 * max(4, mat.shape[1] + 1), 1.0 * max(3, mat.shape[0] + 1)))
    im = plt.imshow(mat.values, aspect="auto")

    plt.xticks(range(mat.shape[1]), mat.columns, rotation=30, ha="right")
    plt.yticks(range(mat.shape[0]), mat.index)

    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)

    cbar = plt.colorbar(im)
    cbar.ax.set_ylabel("Value", rotation=270, labelpad=12)

    plt.tight_layout()
    plt.savefig(out_png, dpi=220)
    plt.close()


# 只用“有auc”的行
auc_ok = auc_summary.dropna(subset=["auc"]).copy()

# 生成每个 which 一张热图（如果你只跑 unique，这里也只会出一张）
for which in sorted(auc_ok["which"].unique()):
    sub = auc_ok[auc_ok["which"] == which].copy()

    # AUC pivot
    auc_pivot = sub.pivot_table(
        index="strategy",
        columns="unit",
        values="auc",
        aggfunc="mean"
    ).reindex(index=FILTER_STRATEGY, columns=FILTER_UNIT)

    auc_pivot.to_csv(f"auc_heatmap_table_{which}.csv")
    print(f"Saved: auc_heatmap_table_{which}.csv")

    # AUC with CI string pivot (nice for tables)
    sub["auc_ci_str"] = sub.apply(
        lambda r: (
            f"{r['auc']:.3f} [{r['ci_low']:.3f}, {r['ci_high']:.3f}]"
            if pd.notna(r["ci_low"]) and pd.notna(r["ci_high"])
            else f"{r['auc']:.3f} [NA]"
        ),
        axis=1
    )

    auc_ci_pivot = sub.pivot_table(
        index="strategy",
        columns="unit",
        values="auc_ci_str",
        aggfunc="first"
    ).reindex(index=FILTER_STRATEGY, columns=FILTER_UNIT)

    auc_ci_pivot.to_csv(f"/work/longyh/BY/processed/TNB/auc_heatmap_table_ci_{which}.csv")
    print(f"Saved: auc_heatmap_table_ci_{which}.csv")

    # CI width heatmap (stability proxy)
    sub["ci_width"] = sub["ci_high"] - sub["ci_low"]
    ciw_pivot = sub.pivot_table(
        index="strategy",
        columns="unit",
        values="ci_width",
        aggfunc="mean"
    ).reindex(index=FILTER_STRATEGY, columns=FILTER_UNIT)

    ciw_pivot.to_csv(f"/work/longyh/BY/processed/TNB/auc_heatmap_table_ciwidth_{which}.csv")
    print(f"Saved: auc_heatmap_table_ciwidth_{which}.csv")

    # Plot heatmaps
    if SAVE_PLOTS:
        out_auc = os.path.join(
            OUTDIR,
            safe_filename(f"AUC_heatmap_{COHORT_LABEL}_{which}.png")
        )
        plot_heatmap(
            auc_pivot,
            title=f"[{COHORT_LABEL}] AUC heatmap ({which})",
            xlabel="Counting unit",
            ylabel="Strategy",
            out_png=out_auc
        )
        print("Saved:", out_auc)

        out_ciw = os.path.join(
            OUTDIR,
            safe_filename(f"AUC_heatmap_CIwidth_{COHORT_LABEL}_{which}.png")
        )
        plot_heatmap(
            ciw_pivot,
            title=f"[{COHORT_LABEL}] AUC CI width heatmap ({which})",
            xlabel="Counting unit",
            ylabel="Strategy",
            out_png=out_ciw
        )
        print("Saved:", out_ciw)


# =========================
# 4) KM + Cox loop (optional)
# =========================
time_col, event_col, endpoint = find_time_event_cols(df)
km_rows = []

if time_col is None:
    print("[INFO] No time-to-event columns detected; skip KM/Cox.")
else:
    km_df = df.copy()
    km_df[event_col] = km_df[event_col].apply(to_event01)
    km_df[time_col] = pd.to_numeric(km_df[time_col], errors="coerce")
    km_df = km_df.dropna(subset=[time_col, event_col]).copy()
    km_df[event_col] = km_df[event_col].astype(int)

    print(f"[INFO] Using endpoint={endpoint} | time_col={time_col} | event_col={event_col}")
    print("KM usable samples:", km_df.drop_duplicates(["sample"]).shape[0])

    for (unit, strategy, which), g in km_df.groupby(group_cols):
        gg = g.sort_values("sample").drop_duplicates(subset=["sample"])
        if len(gg) < 10:
            km_rows.append({
                "unit": unit, "strategy": strategy, "which": which,
                "endpoint": endpoint, "n": int(len(gg)),
                "events": int(gg[event_col].sum()),
                "median_cut": np.nan,
                "logrank_p": np.nan,
                "cox_hr_log1p": np.nan, "cox_ci_low": np.nan, "cox_ci_high": np.nan, "cox_p": np.nan,
            })
            continue

        # median split
        median_val = gg["tnb_value"].median()
        gg = gg.copy()
        gg["group"] = np.where(gg["tnb_value"] > median_val, "High", "Low")

        gH = gg[gg["group"] == "High"]
        gL = gg[gg["group"] == "Low"]
        if len(gH) < 3 or len(gL) < 3:
            km_rows.append({
                "unit": unit, "strategy": strategy, "which": which,
                "endpoint": endpoint, "n": int(len(gg)),
                "events": int(gg[event_col].sum()),
                "median_cut": float(median_val),
                "logrank_p": np.nan,
                "cox_hr_log1p": np.nan, "cox_ci_low": np.nan, "cox_ci_high": np.nan, "cox_p": np.nan,
            })
            continue

        # log-rank
        lr = logrank_test(
            gH[time_col], gL[time_col],
            event_observed_A=gH[event_col],
            event_observed_B=gL[event_col],
        )

        # KM plot
        kmf = KaplanMeierFitter()
        plt.figure(figsize=(6.8, 5.6))

        kmf.fit(gL[time_col], event_observed=gL[event_col], label=f"Low (n={len(gL)})")
        ax = kmf.plot(ci_show=True)
        kmf.fit(gH[time_col], event_observed=gH[event_col], label=f"High (n={len(gH)})")
        kmf.plot(ax=ax, ci_show=True)

        plt.xlabel(f"Time (weeks) [{endpoint}]")
        plt.ylabel("Survival probability")
        plt.title(f"[{COHORT_LABEL}] {unit}-{strategy}-{which}\nlog-rank p={lr.p_value:.3g} (median={median_val:.3g})")
        plt.tight_layout()

        if SAVE_PLOTS:
            fn = safe_filename(f"KM_{COHORT_LABEL}_{endpoint}_{unit}_{strategy}_{which}.png")
            plt.savefig(os.path.join(OUTDIR, fn), dpi=200)
            plt.close()
        else:
            plt.show()

        # Cox continuous: log1p(tnb)
        cox_df = gg[[time_col, event_col, "tnb_value"]].copy()
        cox_df["x"] = np.log1p(cox_df["tnb_value"].astype(float))

        cph = CoxPHFitter()
        cph.fit(cox_df[[time_col, event_col, "x"]], duration_col=time_col, event_col=event_col)
        summ = cph.summary.loc["x", ["exp(coef)", "exp(coef) lower 95%", "exp(coef) upper 95%", "p"]]

        hr = float(summ["exp(coef)"])
        lo = float(summ["exp(coef) lower 95%"])
        hi = float(summ["exp(coef) upper 95%"])
        p = float(summ["p"])

        km_rows.append({
            "unit": unit, "strategy": strategy, "which": which,
            "endpoint": endpoint, "n": int(len(gg)),
            "events": int(gg[event_col].sum()),
            "median_cut": float(median_val),
            "logrank_p": float(lr.p_value),
            "cox_hr_log1p": hr, "cox_ci_low": lo, "cox_ci_high": hi, "cox_p": p,
        })

    km_summary = pd.DataFrame(km_rows).sort_values(["cox_p"])
    km_summary.to_csv("/work/longyh/BY/processed/TNB/km_cox_summary_long.csv", index=False)
    print("Saved: km_cox_summary_long.csv")
    print(km_summary.head(20))


print("Done. Plots saved to:", OUTDIR)




Saved: auc_summary_long.csv
       unit strategy   which   n  responders       auc    ci_low   ci_high  \
0  mutation  s1_rank  unique  22           6  0.562500  0.199444  0.894787   
1  mutation  s2_rank  unique  22           6  0.546875  0.221930  0.868638   
3  mutation  s4_rank  unique  22           6  0.541667  0.208333  0.846300   
2  mutation  s3_rank  unique  22           6  0.536458  0.210526  0.847115   

   boot_kept  suggest_flip  
0       1997         False  
1       1997         False  
3       1997         False  
2       1997         False  
Saved: auc_heatmap_table_unique.csv
Saved: auc_heatmap_table_ci_unique.csv
Saved: auc_heatmap_table_ciwidth_unique.csv
Saved: /work/longyh/BY/plots/AUC_heatmap_Ipi-N_unique.png
Saved: /work/longyh/BY/plots/AUC_heatmap_CIwidth_Ipi-N_unique.png
[INFO] Using endpoint=OS | time_col=Time to Death
(weeks) | event_col=Dead/Alive
(Dead = True)
KM usable samples: 22
Saved: km_cox_summary_long.csv
       unit strategy   which endpoint   n  ev