# 生成long-format格式文件

In [1]:
import re
import numpy as np
import pandas as pd

TNB_CSV = "/work/longyh/BY/processed/TNB/TNB_summary_by_unit.csv"
CLINICAL_XLSX = "/work/longyh/BY/raw/1-s2.0-S0092867417311224-mmc2.xlsx"


In [7]:

# -------------------------
# Helpers (same as before)
# -------------------------
def make_response_label(df, response_col="Response"):
    df = df.copy()
    df = df[df[response_col].notna()]
    df[response_col] = df[response_col].astype(str).str.strip().str.upper()
    df = df[df[response_col] != "NE"].copy()
    response_map = {"CR": 1, "PR": 1, "SD": 0, "PD": 0}
    df["response_label"] = df[response_col].map(response_map)

    unknown = df.loc[df["response_label"].isna(), response_col].value_counts()
    if len(unknown) > 0:
        print("[WARN] Unmapped Response categories:\n", unknown)
        df = df.dropna(subset=["response_label"]).copy()

    df["response_label"] = df["response_label"].astype(int)
    return df


def wide_to_long_tnb(df, id_cols=("sample",), pattern=r"^(mutation|peptide|hla_peptide)_(s[1-4])_(raw|unique)$"):
    """
    Convert wide TNB columns (e.g., mutation_s1_unique) to long format:
      columns: sample, unit, strategy, which, tnb_value
    """
    df = df.copy()
    rx = re.compile(pattern)

    tnb_cols = [c for c in df.columns if rx.match(c)]
    if not tnb_cols:
        raise ValueError("No TNB columns matched the expected pattern. Check your column names.")

    # melt
    long = df.melt(
        id_vars=list(id_cols),
        value_vars=tnb_cols,
        var_name="tnb_key",
        value_name="tnb_value",
    )

    # parse key
    m = long["tnb_key"].str.extract(pattern)
    m.columns = ["unit", "strategy", "which"]
    long = pd.concat([long.drop(columns=["tnb_key"]), m], axis=1)

    # numeric
    long["tnb_value"] = pd.to_numeric(long["tnb_value"], errors="coerce")

    # drop missing
    long = long.dropna(subset=["tnb_value"]).copy()

    # optional: add human-readable strategy name
    strategy_map = {
        "s1": "Binding-only (IC50<500)",
        "s2": "IC50<500 & TPM>1",
        "s3": "IC50<50 & TPM>5",
        "s4": "High-quality (IC50<50,TPM>5,VAF>0.1,WT>1000)",
    }
    long["strategy_name"] = long["strategy"].map(strategy_map).fillna(long["strategy"])

    return long


# -------------------------
# 1) Load data
# -------------------------
summary = pd.read_csv(TNB_CSV, dtype=str)
clinical = pd.read_excel(CLINICAL_XLSX,skiprows=2)

# ensure join key
summary["sample"] = summary["sample"].astype(str).str.strip()
clinical["Patient"] = clinical["Patient"].astype(str).str.strip()

# numeric conversion in summary (except sample/file)
for c in summary.columns:
    if c in ["sample", "file"]:
        continue
    summary[c] = pd.to_numeric(summary[c], errors="coerce")

# clinical: response label + drop NE
clinical2 = make_response_label(clinical, response_col="Response")

# merge
merged = pd.merge(summary, clinical2, left_on="sample", right_on="Patient", how="inner")
print("Merged n =", len(merged), "| responders =", merged["response_label"].sum(), "/", len(merged))

# -------------------------
# 2) Wide -> Long (TNB)
# -------------------------
# keep whichever clinical cols you care about
keep_cols = [
    "sample", "Patient", "Response", "response_label",
    "Mutation Load", "Neo-antigen Load", "Neo-peptide Load", "Cytolytic Score",
    "Dead/Alive\n(Dead = True)", "Time to Death\n(weeks)",
    "Cohort"
]
keep_cols = [c for c in keep_cols if c in merged.columns]

merged_small = merged[keep_cols + [c for c in merged.columns if re.match(r"^(mutation|peptide|hla_peptide)_s[1-4]_(raw|unique)$", c)]].copy()

tnb_long = wide_to_long_tnb(merged_small, id_cols=("sample",))
# attach clinical columns (wide_to_long_tnb keeps only id_cols by default)
tnb_long = tnb_long.merge(merged_small[keep_cols].drop_duplicates("sample"), on="sample", how="left")

print("Long-format rows =", len(tnb_long))
print(tnb_long.head(10))




Merged n = 22 | responders = 6 / 22
Long-format rows = 528
  sample  tnb_value      unit strategy which            strategy_name Patient  \
0   Pt10        124  mutation       s1   raw  Binding-only (IC50<500)    Pt10   
1   Pt11         82  mutation       s1   raw  Binding-only (IC50<500)    Pt11   
2   Pt18        241  mutation       s1   raw  Binding-only (IC50<500)    Pt18   
3   Pt24          1  mutation       s1   raw  Binding-only (IC50<500)    Pt24   
4   Pt27        190  mutation       s1   raw  Binding-only (IC50<500)    Pt27   
5   Pt28         17  mutation       s1   raw  Binding-only (IC50<500)    Pt28   
6   Pt29        521  mutation       s1   raw  Binding-only (IC50<500)    Pt29   
7   Pt30         36  mutation       s1   raw  Binding-only (IC50<500)    Pt30   
8   Pt31        343  mutation       s1   raw  Binding-only (IC50<500)    Pt31   
9    Pt3        128  mutation       s1   raw  Binding-only (IC50<500)     Pt3   

  Response  response_label  Mutation Load  Neo-an

In [None]:
# -------------------------
# 3) Save
# -------------------------
# tnb_long.to_csv("/work/longyh/BY/processed/TNB/tnb_long_format.csv", index=False)
# print("Saved: tnb_long_format.csv")

# 使用long-format 文件生成统计结果

In [25]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import roc_curve, auc

# lifelines for KM/Cox
from lifelines import KaplanMeierFitter, CoxPHFitter
from lifelines.statistics import logrank_test


In [26]:

# =========================
# 0) EDIT HERE
# =========================
LONG_CSV = "/work/longyh/BY/processed/TNB/tnb_long_format.csv"  # 你的long-format文件
OUTDIR = "/work/longyh/BY/plots"
COHORT_LABEL = "Ipi-N"  # 纯标签：不依赖临床Cohort列
FILTER_WHICH = ["unique"]  # 可改成 ["unique"] 或 ["raw","unique"]
FILTER_STRATEGY = ["s1", "s2", "s3", "s4"]  # 可选：只跑 ["s1"]
FILTER_UNIT = ["mutation", "peptide", "hla_peptide"]
N_BOOT = 2000
SEED = 1

# 是否保存图片
SAVE_PLOTS = True



In [29]:

# =========================
# 1) Helpers
# =========================
def ensure_response_label(df):
    """
    Ensure response_label exists: CR/PR=1, SD/PD=0, NE removed.
    """
    df = df.copy()
    if "response_label" in df.columns and df["response_label"].notna().any():
        # make sure int
        df = df.dropna(subset=["response_label"])
        df["response_label"] = df["response_label"].astype(int)
        return df

    if "Response" not in df.columns:
        raise ValueError("Neither response_label nor Response column found.")

    d = df.copy()
    d["Response"] = d["Response"].astype(str).str.strip().str.upper()
    d = d[d["Response"] != "NE"].copy()
    mp = {"CR": 1, "PR": 1, "SD": 0, "PD": 0}
    d["response_label"] = d["Response"].map(mp)
    unknown = d.loc[d["response_label"].isna(), "Response"].value_counts()
    if len(unknown) > 0:
        print("[WARN] Unmapped Response categories:\n", unknown)
        d = d.dropna(subset=["response_label"]).copy()
    d["response_label"] = d["response_label"].astype(int)
    return d


def bootstrap_auc_ci(y_true, y_score, n_boot=2000, seed=1):
    """
    Bootstrap percentile CI for AUC.
    Skips resamples with single class.
    """
    rng = np.random.default_rng(seed)
    y_true = np.asarray(y_true, dtype=int)
    y_score = np.asarray(y_score, dtype=float)

    fpr, tpr, _ = roc_curve(y_true, y_score)
    auc_point = auc(fpr, tpr)

    boot = []
    n = len(y_true)
    for _ in range(n_boot):
        idx = rng.integers(0, n, size=n)
        yt = y_true[idx]
        ys = y_score[idx]
        if len(np.unique(yt)) < 2:
            continue
        fpr_b, tpr_b, _ = roc_curve(yt, ys)
        boot.append(auc(fpr_b, tpr_b))

    if len(boot) < 50:
        return float(auc_point), np.nan, np.nan, len(boot)

    lo, hi = np.quantile(boot, [0.025, 0.975])
    return float(auc_point), float(lo), float(hi), len(boot)


def safe_filename(s: str) -> str:
    s = str(s)
    s = re.sub(r"[^\w\-.]+", "_", s)
    return s[:200]


def find_time_event_cols(df):
    """
    Prefer PFS if columns exist; else OS using:
      Time to Death (weeks) + Dead=True
    """
    cols = list(df.columns)

    # Look for any plausible PFS
    pfs_time_candidates = [c for c in cols if re.search(r"\bPFS\b|progression", str(c), flags=re.I)]
    pfs_event_candidates = [c for c in cols if re.search(r"event|status|progress", str(c), flags=re.I)]

    # heuristics: time column should contain 'week' or 'time'
    time_col = None
    event_col = None
    for c in pfs_time_candidates:
        if re.search(r"week|time", str(c), flags=re.I):
            time_col = c
            break
    for c in pfs_event_candidates:
        if re.search(r"event|status", str(c), flags=re.I):
            event_col = c
            break

    if time_col and event_col:
        return time_col, event_col, "PFS"

    # fallback to your known OS columns
    os_time = "Time to Death\n(weeks)" if "Time to Death\n(weeks)" in cols else None
    os_event = "Dead/Alive\n(Dead = True)" if "Dead/Alive\n(Dead = True)" in cols else None
    if os_time and os_event:
        return os_time, os_event, "OS"

    return None, None, None


def to_event01(x):
    if pd.isna(x):
        return np.nan
    if isinstance(x, (bool, np.bool_)):
        return int(bool(x))
    s = str(x).strip().lower()
    if s in {"true", "1", "dead", "yes", "event"}:
        return 1
    if s in {"false", "0", "alive", "no", "censored", "none"}:
        return 0
    try:
        v = int(float(s))
        return 1 if v != 0 else 0
    except Exception:
        return np.nan


In [30]:
# =========================
# 2) Load long-format
# =========================
df = pd.read_csv(LONG_CSV)

required = {"sample", "unit", "strategy", "which", "tnb_value"}
missing = required - set(df.columns)
if missing:
    raise ValueError(f"Missing columns in long CSV: {missing}")

df["sample"] = df["sample"].astype(str).str.strip()
df["unit"] = df["unit"].astype(str).str.strip()
df["strategy"] = df["strategy"].astype(str).str.strip()
df["which"] = df["which"].astype(str).str.strip()
df["tnb_value"] = pd.to_numeric(df["tnb_value"], errors="coerce")

# add/ensure response_label and drop NE
df = ensure_response_label(df)
df = df.dropna(subset=["tnb_value"]).copy()

# optional filters
df = df[df["which"].isin(FILTER_WHICH)].copy()
df = df[df["strategy"].isin(FILTER_STRATEGY)].copy()
df = df[df["unit"].isin(FILTER_UNIT)].copy()

print("Long-format usable rows:", len(df))
print("Unique samples:", df["sample"].nunique())
print("Responders:", int(df.drop_duplicates("sample")["response_label"].sum()), "/", df["sample"].nunique())

# output dir
os.makedirs(OUTDIR, exist_ok=True)



Long-format usable rows: 264
Unique samples: 22
Responders: 6 / 22


In [31]:

# =========================
# 3) AUC loop + ROC plots
# =========================
auc_rows = []

group_cols = ["unit", "strategy", "which"]
for (unit, strategy, which), g in df.groupby(group_cols):
    # g is multiple rows across samples; each sample should appear once per (unit,strategy,which)
    # If duplicates exist, keep first and warn
    gg = g.sort_values("sample").drop_duplicates(subset=["sample"])
    y = gg["response_label"].astype(int).values
    x = gg["tnb_value"].astype(float).values

    # skip invalid
    if len(gg) < 8 or len(np.unique(y)) < 2:
        auc_rows.append({
            "unit": unit, "strategy": strategy, "which": which,
            "n": int(len(gg)),
            "responders": int(y.sum()),
            "auc": np.nan, "ci_low": np.nan, "ci_high": np.nan,
            "boot_kept": 0,
            "suggest_flip": np.nan,
        })
        continue

    auc_point, ci_low, ci_high, boot_kept = bootstrap_auc_ci(y, x, n_boot=N_BOOT, seed=SEED)

    # quick direction hint (not auto flipping)
    auc_rev, _, _, _ = bootstrap_auc_ci(y, -x, n_boot=500, seed=SEED)
    suggest_flip = bool(auc_point < 0.5 and auc_rev > auc_point)

    auc_rows.append({
        "unit": unit, "strategy": strategy, "which": which,
        "n": int(len(gg)),
        "responders": int(y.sum()),
        "auc": float(auc_point),
        "ci_low": float(ci_low) if not np.isnan(ci_low) else np.nan,
        "ci_high": float(ci_high) if not np.isnan(ci_high) else np.nan,
        "boot_kept": int(boot_kept),
        "suggest_flip": suggest_flip,
    })

    # ROC plot
    fpr, tpr, _ = roc_curve(y, x)
    plt.figure(figsize=(6, 6))
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], linestyle="--")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    ci_text = "CI unavailable" if np.isnan(ci_low) else f"95%CI [{ci_low:.3f}, {ci_high:.3f}]"
    flip_text = " (maybe flip sign)" if suggest_flip else ""
    plt.title(f"[{COHORT_LABEL}] {unit}-{strategy}-{which}\nAUC={auc_point:.3f}, {ci_text}{flip_text}")
    plt.tight_layout()

    if SAVE_PLOTS:
        fn = safe_filename(f"ROC_{COHORT_LABEL}_{unit}_{strategy}_{which}.png")
        plt.savefig(os.path.join(OUTDIR, fn), dpi=200)
        plt.close()
    else:
        plt.show()

auc_summary = pd.DataFrame(auc_rows).sort_values(["auc"], ascending=False)
auc_summary.to_csv("/work/longyh/BY/processed/TNB/auc_summary_long.csv", index=False)
print("Saved: auc_summary_long.csv")
print(auc_summary.head(20))


Saved: auc_summary_long.csv
           unit strategy   which   n  responders       auc    ci_low  \
4      mutation       s1  unique  22           6  0.583333  0.223498   
9       peptide       s2  unique  22           6  0.578125  0.242279   
1   hla_peptide       s2  unique  22           6  0.567708  0.224881   
0   hla_peptide       s1  unique  22           6  0.562500  0.227616   
8       peptide       s1  unique  22           6  0.562500  0.223529   
5      mutation       s2  unique  22           6  0.557292  0.221875   
10      peptide       s3  unique  22           6  0.552083  0.213445   
6      mutation       s3  unique  22           6  0.552083  0.200000   
2   hla_peptide       s3  unique  22           6  0.546875  0.200850   
3   hla_peptide       s4  unique  22           6  0.500000  0.225000   
11      peptide       s4  unique  22           6  0.500000  0.225000   
7      mutation       s4  unique  22           6  0.494792  0.224722   

     ci_high  boot_kept  suggest_fl

In [32]:
# =========================
# 3.5) AUC heatmap tables + plots
# =========================
def plot_heatmap(matrix_df, title, xlabel, ylabel, out_png):
    """
    Simple matplotlib heatmap for a pivot table (values numeric).
    """
    # ensure order
    mat = matrix_df.copy()

    plt.figure(figsize=(1.2 * max(4, mat.shape[1] + 1), 1.0 * max(3, mat.shape[0] + 1)))
    im = plt.imshow(mat.values, aspect="auto")

    plt.xticks(range(mat.shape[1]), mat.columns, rotation=30, ha="right")
    plt.yticks(range(mat.shape[0]), mat.index)

    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)

    cbar = plt.colorbar(im)
    cbar.ax.set_ylabel("Value", rotation=270, labelpad=12)

    plt.tight_layout()
    plt.savefig(out_png, dpi=220)
    plt.close()


# 只用“有auc”的行
auc_ok = auc_summary.dropna(subset=["auc"]).copy()

# 生成每个 which 一张热图（如果你只跑 unique，这里也只会出一张）
for which in sorted(auc_ok["which"].unique()):
    sub = auc_ok[auc_ok["which"] == which].copy()

    # AUC pivot
    auc_pivot = sub.pivot_table(
        index="strategy",
        columns="unit",
        values="auc",
        aggfunc="mean"
    ).reindex(index=FILTER_STRATEGY, columns=FILTER_UNIT)

    auc_pivot.to_csv(f"auc_heatmap_table_{which}.csv")
    print(f"Saved: auc_heatmap_table_{which}.csv")

    # AUC with CI string pivot (nice for tables)
    sub["auc_ci_str"] = sub.apply(
        lambda r: (
            f"{r['auc']:.3f} [{r['ci_low']:.3f}, {r['ci_high']:.3f}]"
            if pd.notna(r["ci_low"]) and pd.notna(r["ci_high"])
            else f"{r['auc']:.3f} [NA]"
        ),
        axis=1
    )

    auc_ci_pivot = sub.pivot_table(
        index="strategy",
        columns="unit",
        values="auc_ci_str",
        aggfunc="first"
    ).reindex(index=FILTER_STRATEGY, columns=FILTER_UNIT)

    auc_ci_pivot.to_csv(f"/work/longyh/BY/processed/TNB/auc_heatmap_table_ci_{which}.csv")
    print(f"Saved: auc_heatmap_table_ci_{which}.csv")

    # CI width heatmap (stability proxy)
    sub["ci_width"] = sub["ci_high"] - sub["ci_low"]
    ciw_pivot = sub.pivot_table(
        index="strategy",
        columns="unit",
        values="ci_width",
        aggfunc="mean"
    ).reindex(index=FILTER_STRATEGY, columns=FILTER_UNIT)

    ciw_pivot.to_csv(f"/work/longyh/BY/processed/TNB/auc_heatmap_table_ciwidth_{which}.csv")
    print(f"Saved: auc_heatmap_table_ciwidth_{which}.csv")

    # Plot heatmaps
    if SAVE_PLOTS:
        out_auc = os.path.join(
            OUTDIR,
            safe_filename(f"AUC_heatmap_{COHORT_LABEL}_{which}.png")
        )
        plot_heatmap(
            auc_pivot,
            title=f"[{COHORT_LABEL}] AUC heatmap ({which})",
            xlabel="Counting unit",
            ylabel="Strategy",
            out_png=out_auc
        )
        print("Saved:", out_auc)

        out_ciw = os.path.join(
            OUTDIR,
            safe_filename(f"AUC_heatmap_CIwidth_{COHORT_LABEL}_{which}.png")
        )
        plot_heatmap(
            ciw_pivot,
            title=f"[{COHORT_LABEL}] AUC CI width heatmap ({which})",
            xlabel="Counting unit",
            ylabel="Strategy",
            out_png=out_ciw
        )
        print("Saved:", out_ciw)


Saved: auc_heatmap_table_unique.csv
Saved: auc_heatmap_table_ci_unique.csv
Saved: auc_heatmap_table_ciwidth_unique.csv
Saved: /work/longyh/BY/plots/AUC_heatmap_Ipi-N_unique.png
Saved: /work/longyh/BY/plots/AUC_heatmap_CIwidth_Ipi-N_unique.png


In [19]:


# =========================
# 4) KM + Cox loop (optional)
# =========================
time_col, event_col, endpoint = find_time_event_cols(df)
km_rows = []

if time_col is None:
    print("[INFO] No time-to-event columns detected; skip KM/Cox.")
else:
    km_df = df.copy()
    km_df[event_col] = km_df[event_col].apply(to_event01)
    km_df[time_col] = pd.to_numeric(km_df[time_col], errors="coerce")
    km_df = km_df.dropna(subset=[time_col, event_col]).copy()
    km_df[event_col] = km_df[event_col].astype(int)

    print(f"[INFO] Using endpoint={endpoint} | time_col={time_col} | event_col={event_col}")
    print("KM usable samples:", km_df.drop_duplicates(["sample"]).shape[0])

    for (unit, strategy, which), g in km_df.groupby(group_cols):
        gg = g.sort_values("sample").drop_duplicates(subset=["sample"])
        if len(gg) < 10:
            km_rows.append({
                "unit": unit, "strategy": strategy, "which": which,
                "endpoint": endpoint, "n": int(len(gg)),
                "events": int(gg[event_col].sum()),
                "median_cut": np.nan,
                "logrank_p": np.nan,
                "cox_hr_log1p": np.nan, "cox_ci_low": np.nan, "cox_ci_high": np.nan, "cox_p": np.nan,
            })
            continue

        # median split
        median_val = gg["tnb_value"].median()
        gg = gg.copy()
        gg["group"] = np.where(gg["tnb_value"] > median_val, "High", "Low")

        gH = gg[gg["group"] == "High"]
        gL = gg[gg["group"] == "Low"]
        if len(gH) < 3 or len(gL) < 3:
            km_rows.append({
                "unit": unit, "strategy": strategy, "which": which,
                "endpoint": endpoint, "n": int(len(gg)),
                "events": int(gg[event_col].sum()),
                "median_cut": float(median_val),
                "logrank_p": np.nan,
                "cox_hr_log1p": np.nan, "cox_ci_low": np.nan, "cox_ci_high": np.nan, "cox_p": np.nan,
            })
            continue

        # log-rank
        lr = logrank_test(
            gH[time_col], gL[time_col],
            event_observed_A=gH[event_col],
            event_observed_B=gL[event_col],
        )

        # KM plot
        kmf = KaplanMeierFitter()
        plt.figure(figsize=(6.8, 5.6))

        kmf.fit(gL[time_col], event_observed=gL[event_col], label=f"Low (n={len(gL)})")
        ax = kmf.plot(ci_show=True)
        kmf.fit(gH[time_col], event_observed=gH[event_col], label=f"High (n={len(gH)})")
        kmf.plot(ax=ax, ci_show=True)

        plt.xlabel(f"Time (weeks) [{endpoint}]")
        plt.ylabel("Survival probability")
        plt.title(f"[{COHORT_LABEL}] {unit}-{strategy}-{which}\nlog-rank p={lr.p_value:.3g} (median={median_val:.3g})")
        plt.tight_layout()

        if SAVE_PLOTS:
            fn = safe_filename(f"KM_{COHORT_LABEL}_{endpoint}_{unit}_{strategy}_{which}.png")
            plt.savefig(os.path.join(OUTDIR, fn), dpi=200)
            plt.close()
        else:
            plt.show()

        # Cox continuous: log1p(tnb)
        cox_df = gg[[time_col, event_col, "tnb_value"]].copy()
        cox_df["x"] = np.log1p(cox_df["tnb_value"].astype(float))

        cph = CoxPHFitter()
        cph.fit(cox_df[[time_col, event_col, "x"]], duration_col=time_col, event_col=event_col)
        summ = cph.summary.loc["x", ["exp(coef)", "exp(coef) lower 95%", "exp(coef) upper 95%", "p"]]

        hr = float(summ["exp(coef)"])
        lo = float(summ["exp(coef) lower 95%"])
        hi = float(summ["exp(coef) upper 95%"])
        p = float(summ["p"])

        km_rows.append({
            "unit": unit, "strategy": strategy, "which": which,
            "endpoint": endpoint, "n": int(len(gg)),
            "events": int(gg[event_col].sum()),
            "median_cut": float(median_val),
            "logrank_p": float(lr.p_value),
            "cox_hr_log1p": hr, "cox_ci_low": lo, "cox_ci_high": hi, "cox_p": p,
        })

    km_summary = pd.DataFrame(km_rows).sort_values(["cox_p"])
    km_summary.to_csv("/work/longyh/BY/processed/TNB/km_cox_summary_long.csv", index=False)
    print("Saved: km_cox_summary_long.csv")
    print(km_summary.head(20))


print("Done. Plots saved to:", OUTDIR)


[INFO] Using endpoint=OS | time_col=Time to Death
(weeks) | event_col=Dead/Alive
(Dead = True)
KM usable samples: 22
Saved: km_cox_summary_long.csv
           unit strategy   which endpoint   n  events  median_cut  logrank_p  \
4      mutation       s1  unique       OS  22      15        30.5   0.541620   
3   hla_peptide       s4  unique       OS  22      15         0.0   0.724310   
11      peptide       s4  unique       OS  22      15         0.0   0.724310   
5      mutation       s2  unique       OS  22      15        23.5   0.613692   
1   hla_peptide       s2  unique       OS  22      15        85.0   0.613692   
0   hla_peptide       s1  unique       OS  22      15       127.5   0.461506   
6      mutation       s3  unique       OS  22      15         8.0   0.851040   
8       peptide       s1  unique       OS  22      15       118.0   0.293848   
9       peptide       s2  unique       OS  22      15        74.5   0.757195   
7      mutation       s4  unique       OS  22      1

# 一致性对照：3种计数层级的相关性分析

In [17]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


In [38]:
# df_long: 你的 long-format DataFrame（比如读 tnb_long_format.csv）
# 必须包含列：sample, unit, strategy, which, tnb_value
df=pd.read_csv("/work/longyh/BY/processed/TNB/tnb_long_format.csv")
# =========================
# 1) 选择做一致性对照的层级
# =========================
WHICH = "unique"      # 强烈建议用 unique
COHORT_LABEL = "Ipi-N "  # 纯标签：不依赖临床Cohort列
STRATEGIES = ["s1", "s2", "s3", "s4"]  # 你也可以只做 ["s1"]
UNITS = ["mutation", "peptide", "hla_peptide"]
PLOT_DIR = "/work/longyh/BY/plots"

df_check = df.copy()
df_check = df_check[df_check["which"] == WHICH].copy()
df_check = df_check[df_check["unit"].isin(UNITS)].copy()
df_check = df_check[df_check["strategy"].isin(STRATEGIES)].copy()

print("Consistency check rows:", len(df_check))
print("Samples:", df_check["sample"].nunique())


Consistency check rows: 264
Samples: 22


In [39]:
# =========================
# 2) 对每个 strategy：做 wide pivot + Spearman 相关
# =========================

def spearman_corr_matrix(wide_df):
    """
    wide_df: columns are units, rows are samples
    returns Spearman correlation matrix (DataFrame)
    """
    return wide_df.corr(method="spearman")

def plot_corr_heatmap(corr_df, title):
    plt.figure(figsize=(5.5, 4.8))
    im = plt.imshow(corr_df.values, aspect="auto")
    plt.xticks(range(corr_df.shape[1]), corr_df.columns, rotation=30, ha="right")
    plt.yticks(range(corr_df.shape[0]), corr_df.index)
    plt.title(title)
    cbar = plt.colorbar(im)
    cbar.ax.set_ylabel("Spearman ρ", rotation=270, labelpad=12)
    plt.tight_layout()
    out = f"{PLOT_DIR}/spearman_heatmap_{COHORT_LABEL}_{WHICH}_{title.replace(' ', '_')}.png"
    plt.savefig(out, dpi=200)
    plt.close()


def plot_scatter_pairs(wide_df, strategy, title_prefix=""):

    """
    画三种 unit 两两散点（建议 log1p 让尺度更稳）
    """
    wide2 = wide_df.copy()
    for c in wide2.columns:
        wide2[c] = np.log1p(wide2[c].astype(float))

    pairs = [("mutation", "peptide"), ("mutation", "hla_peptide"), ("peptide", "hla_peptide")]
    for a, b in pairs:
        if a not in wide2.columns or b not in wide2.columns:
            continue
        x = wide2[a].values
        y = wide2[b].values
        rho = pd.Series(x).corr(pd.Series(y), method="spearman")

        plt.figure(figsize=(5.5, 5.0))
        plt.scatter(x, y, alpha=0.75)
        plt.xlabel(f"log1p({a})")
        plt.ylabel(f"log1p({b})")
        plt.title(f"{title_prefix}{a} vs {b}\nSpearman ρ={rho:.3f}")
        plt.tight_layout()
        out = f"{PLOT_DIR}/spearman_scatter_{COHORT_LABEL}_{WHICH}_{strategy}_{a}_vs_{b}.png"
        plt.savefig(out, dpi=200)
        plt.close()



In [40]:



# 保存结果表（论文很有用）
corr_rows = []

for s in STRATEGIES:
    g = df_check[df_check["strategy"] == s].copy()

    # pivot to samples x unit
    wide = g.pivot_table(index="sample", columns="unit", values="tnb_value", aggfunc="first")
    # 有些样本可能缺失某个 unit（比如h1a_peptide缺失），这里只用共同样本做相关
    wide_common = wide.dropna(subset=UNITS, how="any").copy()

    print("\n==============================")
    print("Strategy:", s)
    print("Samples with all 3 units:", len(wide_common))

    if len(wide_common) < 8:
        print("[SKIP] too few complete samples for correlation.")
        continue

    corr = spearman_corr_matrix(wide_common[UNITS])
    print(corr)

    # 记录到表
    corr_rows.append({
        "strategy": s,
        "which": WHICH,
        "n_complete": int(len(wide_common)),
        "rho_mut_pep": float(corr.loc["mutation", "peptide"]),
        "rho_mut_hla": float(corr.loc["mutation", "hla_peptide"]),
        "rho_pep_hla": float(corr.loc["peptide", "hla_peptide"]),
    })

    # 画相关热图 + 两两散点（都很直观）
    plot_corr_heatmap(corr, title=f"[{COHORT_LABEL}] Spearman correlation (which={WHICH}, {s})")
    
    plot_scatter_pairs(
    wide_common[UNITS],
    strategy=s,
    title_prefix=f"[{COHORT_LABEL}] which={WHICH}, {s}: "
)

# 汇总表输出
corr_summary = pd.DataFrame(corr_rows)
corr_summary.to_csv("/work/longyh/BY/processed/TNB/tnb_unit_spearman_summary.csv", index=False)
print("\nSaved: tnb_unit_spearman_summary.csv")
print(corr_summary)



Strategy: s1
Samples with all 3 units: 22
unit         mutation   peptide  hla_peptide
unit                                        
mutation     1.000000  0.931299     0.939215
peptide      0.931299  1.000000     0.980802
hla_peptide  0.939215  0.980802     1.000000

Strategy: s2
Samples with all 3 units: 22
unit         mutation   peptide  hla_peptide
unit                                        
mutation     1.000000  0.933824     0.944869
peptide      0.933824  1.000000     0.975982
hla_peptide  0.944869  0.975982     1.000000

Strategy: s3
Samples with all 3 units: 22
unit         mutation   peptide  hla_peptide
unit                                        
mutation     1.000000  0.937646     0.945049
peptide      0.937646  1.000000     0.998296
hla_peptide  0.945049  0.998296     1.000000

Strategy: s4
Samples with all 3 units: 22
unit         mutation   peptide  hla_peptide
unit                                        
mutation     1.000000  0.994284     0.994284
peptide      0.994