In [None]:
!pip install python-igraph

In [None]:
!wget https://ftp.ncbi.nlm.nih.gov/geo/series/GSE194nnn/GSE194122/suppl/GSE194122%5Fopenproblems%5Fneurips2021%5Fcite%5FBMMC%5Fprocessed.h5ad.gz

In [None]:
import gzip, shutil
with gzip.open("GSE194122_openproblems_neurips2021_cite_BMMC_processed.h5ad.gz", "rb") as f_in:
    with open("GSE194122_openproblems_neurips2021_cite_BMMC_processed.h5ad", "wb") as f_out:
        shutil.copyfileobj(f_in, f_out)

In [None]:
!pip -q install igraph plotly scikit-learn

In [None]:
!pip -q install -r requirements_analysis_tools.txt


In [None]:
import scanpy as sc
adata = sc.read("GSE194122_openproblems_neurips2021_cite_BMMC_processed.h5ad")
adata.obs.rename(columns={'batch':'BATCH','cell_type':'celltype'}, inplace=True)

In [None]:
methods = ['Harmony', 'mrVI', 'scDML', 'scDisInFact', 'scVI', 'seurat', 'scGen']

In [None]:
import numpy as np

# -------------------------------------------------------
# Load Harmony embeddings
# -------------------------------------------------------
adata.obsm["X_harmony"] = np.load("/content/drive/MyDrive/Datasets/GSE194122/Harmony/X_Harmony.npy")
adata.obsm["X_pca_harmony"] = np.load("/content/drive/MyDrive/Datasets/GSE194122/Harmony/X_pca_harmony.npy")

# -------------------------------------------------------
# Load mrVI embeddings
# -------------------------------------------------------
adata.obsm["X_mrvi_u"] = np.load("/content/drive/MyDrive/Datasets/GSE194122/mrVI/X_mrvi_u.npy")
adata.obsm["X_umap_mrvi"] = np.load("/content/drive/MyDrive/Datasets/GSE194122/mrVI/X_umap_mrvi.npy")

# -------------------------------------------------------
# Load scDML embeddings
# -------------------------------------------------------
adata.obsm["X_scDML"] = np.load("/content/drive/MyDrive/Datasets/GSE194122/scDML/X_scDML.npy")
adata.obsm["X_umap_scDML"] = np.load("/content/drive/MyDrive/Datasets/GSE194122/scDML/X_umap_scDML.npy")

# -------------------------------------------------------
# Load scDisInFact embeddings
# -------------------------------------------------------
adata.obsm["X_scDisInFact"] = np.load("/content/drive/MyDrive/Datasets/GSE194122/scDisInFact/X_scDisInFact.npy")
adata.obsm["X_umap_scDisInFact"] = np.load("/content/drive/MyDrive/Datasets/GSE194122/scDisInFact/X_umap_scDisInFact.npy")

# -------------------------------------------------------
# Load scVI embeddings
# -------------------------------------------------------
adata.obsm["X_scVI"] = np.load("/content/drive/MyDrive/Datasets/GSE194122/scVI/X_scVI.npy")
adata.obsm["X_umap_scVI"] = np.load("/content/drive/MyDrive/Datasets/GSE194122/scVI/X_umap_scVI.npy")


# -------------------------------------------------------
# Load scGen embeddings
# -------------------------------------------------------
adata.obsm["X_scGen"] = np.load("/content/drive/MyDrive/Datasets/GSE194122/scGen/X_scGen.npy")
adata.obsm["X_umap_scGen"] = np.load("/content/drive/MyDrive/Datasets/GSE194122/scGen/X_umap_scGen.npy")



print("All embeddings successfully added to adata.obsm!")

In [None]:
import pandas as pd
import numpy as np

# ---- paths (yours) ----
LATENT_CSV = "/content/drive/MyDrive/Datasets/GSE194122/seurat/latent_embeddings.csv"
UMAP_CSV   = "/content/drive/MyDrive/Datasets/GSE194122/seurat/umap_embeddings.csv"

# ---- load csvs ----
latent_df = pd.read_csv(LATENT_CSV, index_col=0)
umap_df   = pd.read_csv(UMAP_CSV,   index_col=0)

# ---- make sure indices are strings (common Seurat/AnnData gotcha) ----
latent_df.index = latent_df.index.astype(str)
umap_df.index   = umap_df.index.astype(str)
adata.obs_names = adata.obs_names.astype(str)

# ---- align to adata cells (obs_names) ----
missing_latent = adata.obs_names.difference(latent_df.index)
missing_umap   = adata.obs_names.difference(umap_df.index)

if len(missing_latent) > 0 or len(missing_umap) > 0:
    print(f"[WARN] Missing latent rows for {len(missing_latent)} adata cells.")
    print(f"[WARN] Missing UMAP rows   for {len(missing_umap)} adata cells.")
    # show a few examples (optional)
    print("Example missing (latent):", list(missing_latent[:5]))
    print("Example missing (umap):  ", list(missing_umap[:5]))

# subset + order exactly like adata
latent_aligned = latent_df.reindex(adata.obs_names)
umap_aligned   = umap_df.reindex(adata.obs_names)

# ---- if any missing, fail loudly (recommended) ----
if latent_aligned.isna().any().any():
    bad = latent_aligned.index[latent_aligned.isna().any(axis=1)]
    raise ValueError(f"Latent has NaNs after alignment (likely missing cells). Example: {list(bad[:5])}")

if umap_aligned.isna().any().any():
    bad = umap_aligned.index[umap_aligned.isna().any(axis=1)]
    raise ValueError(f"UMAP has NaNs after alignment (likely missing cells). Example: {list(bad[:5])}")

# ---- store in obsm ----
adata.obsm["X_seurat"] = latent_aligned.to_numpy(dtype=np.float32)
adata.obsm["X_umap_seurat"] = umap_aligned.to_numpy(dtype=np.float32)

# ---- (optional) keep column names for interpretability ----
adata.uns["X_seurat_cols"] = latent_aligned.columns.to_list()
adata.uns["X_umap_seurat_cols"] = umap_aligned.columns.to_list()

print("Saved:")
print("  adata.obsm['X_seurat']      =", adata.obsm["X_seurat"].shape)
print("  adata.obsm['X_umap_seurat'] =", adata.obsm["X_umap_seurat"].shape)

In [None]:
sc.pp.neighbors(adata, use_rep="X_harmony", n_neighbors=15)
sc.tl.umap(adata)
adata.obsm["X_umap_harmony"] = adata.obsm["X_umap"].copy()

# step2: **Clustering & Labeling**

In [None]:
import analysis_tools as at  # your uploaded file  :contentReference[oaicite:0]{index=0}

# All embeddings that require clustering
methods_to_cluster = [
    "X_harmony",
    "X_mrvi_u",
    "X_scDML",
    "X_scDisInFact",
    "X_scVI",
    "X_scGen",
    "X_seurat"
]

# Choose Leiden parameters
n_neighbors = 15
resolution = 1.0

for rep in methods_to_cluster:
    print(f"\n===== Clustering for {rep} =====")

    # create a unique cluster key name per method
    cluster_key = f"leiden_{rep.replace('X_', '')}"

    # Run clustering
    at.cluster_adata(
        adata=adata,
        use_rep=rep,
        method="leiden",
        n_neighbors=n_neighbors,
        resolution=resolution
    )

    # Rename the clustering column so it is method-specific
    adata.obs.rename(columns={"leiden": cluster_key}, inplace=True)

print("\nAll methods have been clustered successfully!")
print("Available cluster keys:", [c for c in adata.obs.columns if c.startswith("leiden_")])


In [None]:
cluster = ['leiden_harmony', 'leiden_mrvi_u', 'leiden_scDML', 'leiden_scDisInFact', 'leiden_scVI' ,'leiden_seurat', 'leiden_scGen']

In [None]:
import analysis_tools as at  # :contentReference[oaicite:0]{index=0}

# All cluster methods
cluster_keys = [
    'leiden_harmony',
    'leiden_mrvi_u',
    'leiden_scDML',
    'leiden_scDisInFact',
    'leiden_scVI',
    'leiden_seurat',
    'leiden_scGen'
]

true_label_key = "celltype"

for ck in cluster_keys:
    print(f"\n===== Majority vote for {ck} =====")
    new_label_key = f"{ck}_majority_label"

    at.assign_majority_vote_labels(
        adata=adata,
        cluster_key=ck,
        true_label_key=true_label_key,
        new_label_key=new_label_key
    )

print("\nAll majority voting labels added to adata.obs!")


#step3: **Calculating the metrics**

In [None]:
import os
import pandas as pd
import analysis_tools as at  # your tools file  :contentReference[oaicite:1]{index=1}

def run_all_evaluations(
    adata,
    out_dir="/content/drive/MyDrive/Datasets/GSE194122",
    true_label_key="celltype",
    rare_threshold=0.01,
):
    # Map each cluster key to its latent representation in adata.obsm
    rep_map = {
        "leiden_harmony": "X_harmony",
        "leiden_mrvi_u": "X_mrvi_u",
        "leiden_scDML": "X_scDML",
        "leiden_scDisInFact": "X_scDisInFact",
        "leiden_scVI": "X_scVI",
        "leiden_seurat": "X_seurat",
        "leiden_scGen": "X_scGen",
    }

    cluster_keys = list(rep_map.keys())

    all_acc_dfs = []
    all_val_dfs = []

    os.makedirs(out_dir, exist_ok=True)

    for ck in cluster_keys:
        method_name = ck.replace("leiden_", "")  # e.g. 'harmony', 'scVI', 'seurat'
        pred_key = f"{ck}_majority_label"
        latent_rep_key = rep_map[ck]

        print(f"\n===== Evaluating method: {method_name} =====")
        print(f"Cluster key: {ck}")
        print(f"Predicted label key: {pred_key}")
        print(f"Latent rep key: {latent_rep_key}")

        # 1) Per–cell-type accuracy table
        acc_df = at.calculate_prediction_accuracy(
            adata=adata,
            true_label_key=true_label_key,
            predicted_label_key=pred_key,
            cluster_key=ck,
        )

        # add method + cell type as columns so we can concatenate
        acc_df = acc_df.copy()
        acc_df.insert(0, "method", method_name)
        acc_df.insert(1, "cell_type", acc_df.index)
        all_acc_dfs.append(acc_df.reset_index(drop=True))

        # 2) Overall + rare metrics
        val_df = at.calculate_validation_metrics(
            adata=adata,
            true_label_key=true_label_key,
            cluster_key=ck,
            predicted_label_key=pred_key,
            latent_rep_key=latent_rep_key,
            rare_threshold=rare_threshold,
        )

        val_df = val_df.copy()
        val_df.insert(0, "method", method_name)
        val_df.insert(1, "metric", val_df.index)
        all_val_dfs.append(val_df.reset_index(drop=True))

    # ---- Combine & save ----
    acc_all = pd.concat(all_acc_dfs, ignore_index=True)
    val_all = pd.concat(all_val_dfs, ignore_index=True)

    acc_path = os.path.join(out_dir, "all_methods_prediction_accuracy.csv")
    val_path = os.path.join(out_dir, "all_methods_validation_metrics.csv")

    acc_all.to_csv(acc_path, index=False)
    val_all.to_csv(val_path, index=False)

    print("\n=== Saved CSV files ===")
    print(f"Prediction accuracy (per cell type): {acc_path}")
    print(f"Validation metrics (overall + rare): {val_path}")

    return acc_all, val_all

In [None]:
acc_all, val_all = run_all_evaluations(adata)

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

# ---- Load accuracy CSV ----
acc_all = pd.read_csv("/content/drive/MyDrive/Datasets/GSE194122/all_methods_prediction_accuracy.csv")
acc_all.columns = acc_all.columns.str.strip()

# ---- Compute rarity from adata ----
counts = adata.obs["celltype"].value_counts()
fractions = counts / counts.sum()

rare_types = fractions[fractions < 0.01].index.tolist()
print("Detected rare cell types (<1%):")
print(rare_types)

# ---- Filter ----
rare_acc = acc_all[acc_all["cell_type"].isin(rare_types)].copy()

# ---- Use correct accuracy column ----
# accuracy is fraction in [0,1] -> convert to %
rare_acc["accuracy_pct"] = 100.0 * rare_acc["accuracy"]

# ---- Pivot: methods × rare cell types ----
heatmap_df = rare_acc.pivot_table(
    values="accuracy_pct",
    index="method",
    columns="cell_type",
    aggfunc="mean"
)

# Sort columns by rarity (optional)
heatmap_df = heatmap_df.reindex(columns=sorted(rare_types, key=lambda ct: fractions[ct]))

# ---------------------------------------------------------
# 1) Heatmap
# ---------------------------------------------------------
plt.figure(figsize=(14, 8))
sns.heatmap(
    heatmap_df,
    annot=True,
    fmt=".1f",
    cmap="viridis",
    linewidths=0.5,
    cbar_kws={"label": "Prediction Accuracy (%)"}
)
plt.title("Rare Cell Type Prediction Accuracy per Method")
plt.ylabel("Method")
plt.xlabel("Rare Cell Type")
plt.tight_layout()
plt.show()

# ---------------------------------------------------------
# 2) Barplot: mean rare accuracy per method
# ---------------------------------------------------------
rare_method_mean = heatmap_df.mean(axis=1).sort_values(ascending=False)

plt.figure(figsize=(10, 6))
sns.barplot(
    x=rare_method_mean.index,
    y=rare_method_mean.values
)
plt.xticks(rotation=45, ha="right")
plt.ylabel("Mean Accuracy (%)")
plt.title("Overall Rare-Cell Prediction Accuracy per Method")
plt.tight_layout()
plt.show()

# ---------------------------------------------------------
# 3) Clustermap with variance guards
# ---------------------------------------------------------
mat = heatmap_df.fillna(0)

col_var = mat.var(axis=0)
mat = mat.loc[:, col_var > 0]

row_var = mat.var(axis=1)
mat = mat.loc[row_var > 0, :]

print("Matrix shape used for clustermap:", mat.shape)

if mat.shape[0] >= 2 and mat.shape[1] >= 2:
    sns.clustermap(mat, cmap="viridis", figsize=(14, 10))
    plt.suptitle("Clustered Heatmap (Methods × Rare Cell Types)", y=1.02)
    plt.show()
else:
    print("Not enough variation (or too few methods/rare types) to plot a clustermap.")

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA

val_all = pd.read_csv("/content/drive/MyDrive/Datasets/GSE194122/all_methods_validation_metrics.csv")
val_all.columns = val_all.columns.str.strip()

print("Columns in val_all:", val_all.columns.tolist())
print(val_all.head(10))

# ensure numeric value
val_all["value"] = pd.to_numeric(val_all["value"], errors="coerce")

# -----------------------------
# Parse metric into scope + name
# -----------------------------
# We try multiple patterns. If nothing matches, we fallback to "Overall".
metric_str = val_all["metric"].astype(str).str.strip()

# Scope is first token that starts with Overall or Rare (case-insensitive)
scope = np.where(metric_str.str.match(r"(?i)^rare"), "Rare",
         np.where(metric_str.str.match(r"(?i)^overall"), "Overall", "Overall"))

# Remove leading scope marker and separators to get metric name
metric_name = metric_str.copy()
metric_name = metric_name.str.replace(r"(?i)^rare", "", regex=True)
metric_name = metric_name.str.replace(r"(?i)^overall", "", regex=True)
metric_name = metric_name.str.replace(r"^[\s:_\-()<>%\.]+", "", regex=True)  # cleanup leading separators
metric_name = metric_name.str.replace(r"[\s:_\-]+", "_", regex=True)        # normalize separators
metric_name = metric_name.replace("", np.nan)

val_all["scope"] = scope
val_all["metric_name"] = metric_name

# If metric_name is still NaN for some rows, keep original metric
val_all["metric_name"] = val_all["metric_name"].fillna(metric_str)

print("\nParsed scopes:", val_all["scope"].value_counts().to_dict())
print("Example parsed metrics:\n", val_all[["metric", "scope", "metric_name"]].head(10))

# -----------------------------
# Pivot to method × metric_name
# -----------------------------
overall_long = val_all[val_all["scope"] == "Overall"].copy()
rare_long    = val_all[val_all["scope"] == "Rare"].copy()

overall_df = overall_long.pivot_table(index="method", columns="metric_name", values="value", aggfunc="mean")
rare_df    = rare_long.pivot_table(index="method", columns="metric_name", values="value", aggfunc="mean")

# Make columns consistent order
overall_df = overall_df.sort_index(axis=1)
rare_df = rare_df.sort_index(axis=1)

print("\noverall_df shape:", overall_df.shape)
print("rare_df shape:", rare_df.shape)

# -----------------------------
# A) OVERALL METRICS HEATMAP (exclude rare/total cols)
# -----------------------------
if overall_df.size > 0:
    # Drop columns that are rare-related or total-cell/count-related
    drop_pattern = r"(?i)(^rare|rare|total|n_true|n_correct|count|cells|support)"
    overall_df_plot = overall_df.loc[:, ~overall_df.columns.str.contains(drop_pattern, regex=True)].copy()

    print("Overall heatmap columns kept:", overall_df_plot.columns.tolist())

    if overall_df_plot.shape[1] == 0:
        print("No overall metric columns left after filtering rare/total/count columns.")
    else:
        plt.figure(figsize=(14, 8))
        sns.heatmap(overall_df_plot, annot=True, fmt=".3f", cmap="viridis")
        plt.title("Overall Metrics Heatmap — Methods × Metrics")
        plt.tight_layout()
        plt.show()
else:
    print("No Overall metrics detected in CSV.")


# -----------------------------
# B) RARE HEATMAP
# -----------------------------
if rare_df.size > 0:
    plt.figure(figsize=(14, 8))
    sns.heatmap(rare_df, annot=True, fmt=".3f", cmap="viridis")
    plt.title("Rare Metrics Heatmap — Methods × Metrics")
    plt.tight_layout()
    plt.show()
else:
    print("No Rare metrics detected in CSV. (Maybe your metric names don't start with 'Rare'.)")

# -----------------------------
# C) Rare - Overall (only on common metrics)
# -----------------------------
common_metrics = sorted(set(overall_df.columns).intersection(set(rare_df.columns)))
if common_metrics:
    diff_df = rare_df[common_metrics] - overall_df[common_metrics]
    plt.figure(figsize=(14, 8))
    sns.heatmap(diff_df, annot=True, fmt=".3f", cmap="coolwarm", center=0)
    plt.title("Rare − Overall Metric Differences (Negative = Drop on Rare Cells)")
    plt.tight_layout()
    plt.show()
else:
    print("No common metrics between Overall and Rare to compute differences.")

# -----------------------------
# D) Combined clustering + PCA (if enough features)
# -----------------------------
combined_df = pd.concat(
    [overall_df.add_prefix("Overall_"),
     rare_df.add_prefix("Rare_")],
    axis=1
).fillna(0)

# Drop zero-variance columns
col_var = combined_df.var(axis=0)
combined_filtered = combined_df.loc[:, col_var > 0]

if combined_filtered.shape[0] >= 2 and combined_filtered.shape[1] >= 2:
    sns.clustermap(combined_filtered, cmap="viridis", figsize=(14, 10))
    plt.suptitle("Clustered Heatmap of Methods (Overall + Rare Metrics)", y=1.02)
    plt.show()

    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(combined_filtered)

    pca = PCA(n_components=2)
    pc = pca.fit_transform(X_scaled)
    pca_df = pd.DataFrame(pc, columns=["PC1", "PC2"], index=combined_filtered.index)

    plt.figure(figsize=(10, 7))
    sns.scatterplot(x="PC1", y="PC2", data=pca_df, s=200)
    for m in pca_df.index:
        plt.text(pca_df.loc[m, "PC1"] + 0.02, pca_df.loc[m, "PC2"] + 0.02, m, fontsize=12)
    plt.title("PCA of Methods Based on All Metrics")
    plt.xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}% Var)")
    plt.ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}% Var)")
    plt.tight_layout()
    plt.show()
else:
    print("Not enough variation / features for clustering + PCA.")

# -----------------------------
# E) Radar plots (optional, only if a few metrics)
# -----------------------------
def radar_plot(df, title):
    if df.shape[1] < 3:
        print(f"Radar plot skipped: need >=3 metrics, got {df.shape[1]}.")
        return

    labels = df.columns.tolist()
    num_vars = len(labels)
    angles = np.linspace(0, 2 * np.pi, num_vars, endpoint=False).tolist()
    angles += angles[:1]

    fig, ax = plt.subplots(figsize=(7, 7), subplot_kw=dict(polar=True))

    for idx, row in df.iterrows():
        vals = row.values.tolist()
        vals += vals[:1]
        ax.plot(angles, vals, label=idx)
        ax.fill(angles, vals, alpha=0.1)

    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(labels, fontsize=9)
    plt.title(title)
    plt.legend(bbox_to_anchor=(1.35, 1.1))
    plt.tight_layout()
    plt.show()

# Normalize each metric to [0,1] for radar comparison
if overall_df.shape[1] >= 3:
    overall_norm = overall_df / overall_df.max(numeric_only=True)
    radar_plot(overall_norm.fillna(0), "Radar Plot — Overall Metrics (Normalized per Metric)")

if rare_df.shape[1] >= 3:
    rare_norm = rare_df / rare_df.max(numeric_only=True)
    radar_plot(rare_norm.fillna(0), "Radar Plot — Rare Metrics (Normalized per Metric)")


In [None]:
import scanpy as sc
import analysis_tools as at
import pandas as pd
import os

# -------------------------------------------------------------------
# Configuration
# -------------------------------------------------------------------
TRUE_LABEL_KEY = "celltype"
BATCH_KEY = "BATCH"   # you can change this to your real batch column if needed

BASELINE_UMAP_KEY = "GEX_X_umap"

methods = [
    {"name": "harmony",     "cluster_key": "leiden_harmony",     "pred_key": "leiden_harmony_majority_label",     "umap_post": "X_umap_harmony"},
    {"name": "mrvi",        "cluster_key": "leiden_mrvi_u",      "pred_key": "leiden_mrvi_u_majority_label",      "umap_post": "X_umap_mrvi"},
    {"name": "scDML",       "cluster_key": "leiden_scDML",       "pred_key": "leiden_scDML_majority_label",       "umap_post": "X_umap_scDML"},
    {"name": "scDisInFact", "cluster_key": "leiden_scDisInFact", "pred_key": "leiden_scDisInFact_majority_label", "umap_post": "X_umap_scDisInFact"},
    {"name": "scVI",        "cluster_key": "leiden_scVI",        "pred_key": "leiden_scVI_majority_label",        "umap_post": "X_umap_scVI"},
    {"name": "seurat",      "cluster_key": "leiden_seurat",      "pred_key": "leiden_seurat_majority_label",      "umap_post": "X_umap_seurat"},
    {"name": "scGen",       "cluster_key": "leiden_scGen",       "pred_key": "leiden_scGen_majority_label",       "umap_post": "X_umap_scGen"},
]

# -------------------------------------------------------------------
# Rare cell type detection (global)
# -------------------------------------------------------------------
freq = adata.obs[TRUE_LABEL_KEY].value_counts(normalize=True)
rare_types = freq[freq < 0.01].index.tolist()

print("=== Rare Cell Types (<1%) ===")
if rare_types:
    for rt in rare_types:
        print(" -", rt)
else:
    print("No rare types found under the 1% threshold.")

# Base directory for Sankey outputs
BASE_SANK_DIR = "./sankey_plots_all_methods"
os.makedirs(BASE_SANK_DIR, exist_ok=True)

# -------------------------------------------------------------------
# MAIN LOOP OVER METHODS
# -------------------------------------------------------------------
for cfg in methods:
    name = cfg["name"]
    cluster_key = cfg["cluster_key"]
    pred_key = cfg["pred_key"]
    umap_post = cfg["umap_post"]

    print("\n" + "=" * 60)
    print(f"METHOD: {name}")
    print("=" * 60)

    # Required obs columns
    if cluster_key not in adata.obs.columns:
        print(f"[WARNING] cluster_key '{cluster_key}' not in adata.obs. Skipping '{name}'.")
        continue
    if pred_key not in adata.obs.columns:
        print(f"[WARNING] pred_key '{pred_key}' not in adata.obs. Skipping '{name}'.")
        continue
    if TRUE_LABEL_KEY not in adata.obs.columns:
        raise KeyError(f"TRUE_LABEL_KEY '{TRUE_LABEL_KEY}' not found in adata.obs.")
    if BATCH_KEY not in adata.obs.columns:
        print(f"[WARNING] BATCH_KEY '{BATCH_KEY}' not found in adata.obs. Skipping batch-colored UMAPs for '{name}'.")

    # UMAP keys
    if BASELINE_UMAP_KEY not in adata.obsm:
        raise KeyError(f"Baseline UMAP key '{BASELINE_UMAP_KEY}' not found in adata.obsm.")
    has_post_umap = umap_post in adata.obsm

    # -------------------------
    # 1) Global UMAPs (all cells)
    # -------------------------
    if has_post_umap:
        # (a) Batch-colored UMAPs
        if BATCH_KEY in adata.obs:
            print(f"Plotting Before vs After UMAPs by BATCH for '{name}'...")
            at.plot_side_by_side_umaps(
                adata=adata,
                umap_key1=BASELINE_UMAP_KEY,
                umap_key2=umap_post,
                color_key=BATCH_KEY,
                title1=f"{name}: Before (Batch)",
                title2=f"{name}: After (Batch)",
                show=True,
            )
        else:
            print(f"[INFO] Skipping batch UMAPs for '{name}' (no '{BATCH_KEY}').")

        # (b) Cell-type-colored UMAPs
        print(f"Plotting Before vs After UMAPs by CELLTYPE for '{name}'...")
        at.plot_side_by_side_umaps(
            adata=adata,
            umap_key1=BASELINE_UMAP_KEY,
            umap_key2=umap_post,
            color_key=TRUE_LABEL_KEY,
            title1=f"{name}: Before (Cell Types)",
            title2=f"{name}: After (Cell Types)",
            show=True,
        )
    else:
        print(f"[WARNING] UMAP key '{umap_post}' not found for '{name}'. Skipping UMAP plots.")

    # -------------------------
    # 2) Global Sankey (all cell types)
    # -------------------------
    print(f"Plotting Sankey diagram (all cell types) for '{name}'...")
    sank_dir_method = os.path.join(BASE_SANK_DIR, name)
    os.makedirs(sank_dir_method, exist_ok=True)

    at.plot_sankey_diagram(
        adata=adata,
        true_label_key=TRUE_LABEL_KEY,
        predicted_label_key=pred_key,
        highlight_label=None,
        save_path=os.path.join(sank_dir_method, f"{name}_sankey_all.html"),
        show=True,
    )

    # -------------------------
    # 3) Rare cell types: UMAP + Sankey
    # -------------------------
    if not rare_types:
        print(f"No rare types found; skipping rare-type plots for '{name}'.")
        continue

    print(f"\n--- RARE CELL TYPE PLOTS for '{name}' ---")
    rare_sank_dir = os.path.join(sank_dir_method, "rare_types")
    os.makedirs(rare_sank_dir, exist_ok=True)

    for rt in rare_types:
        print(f"\n[Method: {name}] Rare Type: {rt}")

        # (a) UMAPs highlighting this rare type
        if has_post_umap:
            print(f"  Plotting UMAPs highlighting '{rt}'...")
            at.plot_side_by_side_umaps(
                adata=adata,
                umap_key1=BASELINE_UMAP_KEY,
                umap_key2=umap_post,
                color_key=TRUE_LABEL_KEY,
                title1=f"{name}: Before (Highlight {rt})",
                title2=f"{name}: After (Highlight {rt})",
                highlight_cell_type=rt,
                true_label_key=TRUE_LABEL_KEY,  # REQUIRED by your analysis_tools.py
                show=True,
            )
        else:
            print(f"  [INFO] Skipping UMAP highlight for '{name}' (no '{umap_post}').")

        # (b) Sankey highlighting this rare type
        print(f"  Plotting Sankey highlighting '{rt}'...")
        safe_rt = rt.replace("/", "_")  # avoid path issues like 'MK/E prog'
        at.plot_sankey_diagram(
            adata=adata,
            true_label_key=TRUE_LABEL_KEY,
            predicted_label_key=pred_key,
            highlight_label=rt,
            save_path=os.path.join(rare_sank_dir, f"{name}_rare_{safe_rt}.html"),
            show=True,
        )

print("\nAll multi-method UMAP and Sankey plots completed!")


In [None]:
def classify_rare_cells(adata, batch_key='BATCH', type_key='celltype', abundance_threshold=0.01, prevalence_threshold=0.25):
    """
    Classifies cell types in an AnnData object into 3 Rare types + Common.

    Args:
        adata: Your AnnData object.
        batch_key: Column name for batches/samples (e.g., 'BATCH', 'Samplename').
        type_key: Column name for cell labels (e.g., 'celltype').
        abundance_threshold: Fraction below which a cell is "Low Abundance" (0.01 = 1%).
        prevalence_threshold: Fraction of batches below which a cell is "Low Prevalence" (0.25 = 25%).

    Returns:
        A DataFrame summarizing the metrics and classification for each cell type.
    """

    # 1. Create a contingency table (Counts of each cell type in each batch)
    # Rows = Batches, Cols = Cell Types
    df_counts = pd.crosstab(adata.obs[batch_key], adata.obs[type_key])

    # 2. Convert to Fractions (Abundance)
    # We divide each row by its total cell count to get percentages per batch
    df_fractions = df_counts.div(df_counts.sum(axis=1), axis=0)

    results = []

    # 3. Iterate through each cell type to calculate metrics
    for cell_type in df_fractions.columns:
        # Extract the column for this specific cell type
        data_col = df_fractions[cell_type]

        # Metric A: Prevalence (In what fraction of batches does it appear?)
        # We assume "appears" means count > 0.
        batches_with_cell = (data_col > 0).sum()
        total_batches = len(data_col)
        prevalence = batches_with_cell / total_batches

        # Metric B: Abundance (When it appears, how big is it?)
        # We calculate mean ONLY on batches where it is present (non-zero)
        if batches_with_cell > 0:
            avg_abundance = data_col[data_col > 0].mean()
        else:
            avg_abundance = 0.0

        # 4. Apply the Classification Logic
        # ---------------------------------------------------------
        is_low_abundance = avg_abundance < abundance_threshold
        is_low_prevalence = prevalence < prevalence_threshold

        if not is_low_abundance and not is_low_prevalence:
            # High Abundance, High Prevalence
            category = "Common"
            desc = "Standard population"

        elif is_low_abundance and not is_low_prevalence:
            # Low Abundance, High Prevalence
            category = "Type 1: Ubiquitous Rare"
            desc = "Low %, but in most batches"

        elif not is_low_abundance and is_low_prevalence:
            # High Abundance, Low Prevalence
            category = "Type 2: Context Specific"
            desc = "High % (spikes), but few batches"

        elif is_low_abundance and is_low_prevalence:
            # Low Abundance, Low Prevalence
            category = "Type 3: Extremely Rare"
            desc = "Low %, few batches"

        # ---------------------------------------------------------

        results.append({
            "Cell Type": cell_type,
            "Classification": category,
            "Avg Abundance (When Present)": round(avg_abundance * 100, 3), # as %
            "Prevalence": round(prevalence * 100, 1), # as %
            "N_Batches": f"{batches_with_cell}/{total_batches}",
            "Description": desc
        })

    # Convert to DataFrame and sort for readability
    results_df = pd.DataFrame(results).sort_values(by=['Classification', 'Cell Type'])
    return results_df

In [None]:
acc_all = pd.read_csv("/content/drive/MyDrive/Datasets/GSE194122/all_methods_prediction_accuracy.csv")

# Normalize column names
acc_all.columns = acc_all.columns.str.strip()

print(acc_all.columns.tolist())
print(acc_all.head())


In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

# ==============================================================================
# 1. DEFINE THE CLASSIFICATION FUNCTION
# ==============================================================================
def classify_rare_cells(
    adata,
    batch_key="BATCH",
    type_key="celltype",
    abundance_threshold=0.01,
    prevalence_threshold=0.75
):
    # 1) Contingency table
    df_counts = pd.crosstab(adata.obs[batch_key], adata.obs[type_key])

    # 2) Fractions per batch
    df_fractions = df_counts.div(df_counts.sum(axis=1), axis=0)

    results = []

    # 3) Calculate metrics per cell type
    for cell_type in df_fractions.columns:
        data_col = df_fractions[cell_type]

        batches_with_cell = (data_col > 0).sum()
        total_batches = len(data_col)
        prevalence = batches_with_cell / total_batches

        avg_abundance = data_col[data_col > 0].mean() if batches_with_cell > 0 else 0.0

        # 4) Classification logic
        is_low_abundance = avg_abundance < abundance_threshold
        is_low_prevalence = prevalence < prevalence_threshold

        if (not is_low_abundance) and (not is_low_prevalence):
            category = "Common"
        elif is_low_abundance and (not is_low_prevalence):
            category = "Type 1: Ubiquitous Rare"      # Low %, High Prev
        elif (not is_low_abundance) and is_low_prevalence:
            category = "Type 2: Context Specific"     # High %, Low Prev
        else:
            category = "Type 3: Extremely Rare"       # Low %, Low Prev

        results.append({
            "Cell Type": cell_type,
            "Classification": category,
            "Avg Abundance": avg_abundance,
            "Prevalence": prevalence
        })

    return pd.DataFrame(results)


# ==============================================================================
# 2. LOAD DATA & RUN CLASSIFICATION
# ==============================================================================
acc_all = pd.read_csv("/content/drive/MyDrive/Datasets/GSE194122/all_methods_prediction_accuracy.csv")
acc_all.columns = acc_all.columns.str.strip()

# ---- Ensure accuracy is numeric and convert to percent for plotting ----
acc_all["accuracy"] = pd.to_numeric(acc_all["accuracy"], errors="coerce")
acc_all["Accuracy (%)"] = acc_all["accuracy"] * 100.0

# ---- Run classification on AnnData ----
rare_info_df = classify_rare_cells(
    adata,
    batch_key="BATCH",
    type_key="celltype",
    abundance_threshold=0.01,
    prevalence_threshold=0.75
)

# Keep only rare types (Types 1–3)
rare_info_df = rare_info_df[rare_info_df["Classification"] != "Common"].copy()
valid_rare_types = rare_info_df["Cell Type"].tolist()

# Map: Cell Type -> Classification
type_to_class = dict(zip(rare_info_df["Cell Type"], rare_info_df["Classification"]))

print(f"Identified {len(valid_rare_types)} rare cell types using the 3-Type classification.")
print(rare_info_df["Classification"].value_counts())


# ==============================================================================
# 3. PREPARE HEATMAP DATA (USING Accuracy (%))
# ==============================================================================
rare_acc = acc_all[acc_all["cell_type"].isin(valid_rare_types)].copy()

heatmap_df = rare_acc.pivot_table(
    values="Accuracy (%)",
    index="method",
    columns="cell_type",
    aggfunc="mean"
)

# ---- Sort columns by rare class order (Type 1 -> Type 2 -> Type 3) then name ----
rare_info_idx = rare_info_df.set_index("Cell Type")
class_order = {
    "Type 1: Ubiquitous Rare": 1,
    "Type 2: Context Specific": 2,
    "Type 3: Extremely Rare": 3
}

# Sort by (class rank, cell type)
sorted_cols = (
    rare_info_idx.assign(_rank=rare_info_idx["Classification"].map(class_order))
               .sort_values(["_rank", "Cell Type"])
               .index
)
sorted_cols = [c for c in sorted_cols if c in heatmap_df.columns]
heatmap_df = heatmap_df[sorted_cols]


# ==============================================================================
# 4. PLOT 1: SORTED HEATMAP WITH CLASS LABELS
# ==============================================================================
plt.figure(figsize=(16, 8))

x_labels = [f"[{type_to_class[c].split(':')[0]}] {c}" for c in heatmap_df.columns]

sns.heatmap(
    heatmap_df,
    annot=True,
    fmt=".0f",
    cmap="viridis",
    xticklabels=x_labels,
    cbar_kws={"label": "Prediction Accuracy (%)"}
)

plt.title("Prediction Accuracy by Rare Cell Class\n(T1=Ubiquitous, T2=Context, T3=Extremely Rare)")
plt.xlabel("Rare Cell Type (Sorted by Class)")
plt.ylabel("Method")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()


# ==============================================================================
# 5. PLOT 2: CLUSTERMAP WITH CLASS COLORS (FIXED COLOR FALLBACK)
# ==============================================================================
class_pal = {
    "Type 1: Ubiquitous Rare": "#FFA500",   # Orange
    "Type 2: Context Specific": "#00CED1",  # Dark Turquoise
    "Type 3: Extremely Rare": "#9370DB"     # Medium Purple
}

# Map columns to class colors with a real fallback
col_colors = heatmap_df.columns.map(
    lambda ct: class_pal.get(type_to_class.get(ct), "#B0B0B0")
)

# Clean data (remove zero variance)
mat = heatmap_df.fillna(0)
mat = mat.loc[:, mat.var(axis=0) > 0]
mat = mat.loc[mat.var(axis=1) > 0, :]

col_colors_filtered = mat.columns.map(
    lambda ct: class_pal.get(type_to_class.get(ct), "#B0B0B0")
)

if mat.shape[0] >= 2 and mat.shape[1] >= 2:
    g = sns.clustermap(
        mat,
        cmap="viridis",
        col_colors=col_colors_filtered,
        figsize=(14, 10),
        dendrogram_ratio=(0.1, 0.2),
        cbar_pos=(0.02, 0.8, 0.03, 0.15)
    )

    patches = [mpatches.Patch(color=color, label=label) for label, color in class_pal.items()]
    plt.legend(handles=patches, title="Rare Class", bbox_to_anchor=(1.5, 1), loc="upper left")

    plt.suptitle("Clustered Heatmap with Rare Class Annotation", y=1.02, fontsize=16)
    plt.show()
else:
    print("Not enough variation to cluster.")


# ==============================================================================
# 6. PLOT 3: BARPLOT (MEAN Accuracy (%) PER METHOD x RARE CLASS)
# ==============================================================================
rare_acc_classed = rare_acc.copy()
rare_acc_classed["Rare_Class"] = rare_acc_classed["cell_type"].map(type_to_class)

grouped_acc = (
    rare_acc_classed
    .groupby(["method", "Rare_Class"])["Accuracy (%)"]
    .mean()
    .reset_index()
)

plt.figure(figsize=(12, 6))
sns.barplot(
    data=grouped_acc,
    x="method",
    y="Accuracy (%)",
    hue="Rare_Class",
    palette=class_pal
)

plt.title("Method Performance across Different Rare Types")
plt.xticks(rotation=45, ha="right")
plt.legend(title="Rare Class", bbox_to_anchor=(1.05, 1), loc="upper left")
plt.tight_layout()
plt.show()


In [None]:
# --- FIX: Ensure 'Cell Type' is a column, not the index ---
if "Cell Type" not in rare_info_df.columns:
    rare_info_df = rare_info_df.reset_index()

# Now re-run the setup
rare_types_list = rare_info_df[rare_info_df["Classification"] != "Common"]["Cell Type"].unique().tolist()
print(f"Ready to plot {len(rare_types_list)} rare cell types.")

In [None]:
import scanpy as sc
import matplotlib.pyplot as plt
import math

# ==============================================================================
# 1. SETUP & FIX INDEX
# ==============================================================================
# (Just in case 'Cell Type' is still the index from previous steps)
if "Cell Type" not in rare_info_df.columns:
    rare_info_df = rare_info_df.reset_index()

# Get the list of rare types (excluding "Common")
rare_types_list = rare_info_df[rare_info_df["Classification"] != "Common"]["Cell Type"].unique().tolist()

# Define your Method -> OBSM Key mapping
method_embeddings = {
    'scVI': 'X_umap_scVI',
    'scDML': 'X_umap_scDML',
    'scDisInFact': 'X_umap_scDisInFact',
    'MRVI': 'X_umap_mrvi',
    'Seurat': 'X_umap_seurat',
    'scGen': 'X_umap_scGen',
    'harmony': 'X_umap_harmony'
}

print(f"Generating plots for {len(rare_types_list)} rare cell types across {len(method_embeddings)} methods...")

# ==============================================================================
# 2. PLOTTING LOOP
# ==============================================================================

for cell_type in rare_types_list:

    # Get Classification Label
    c_class = rare_info_df.loc[rare_info_df["Cell Type"] == cell_type, "Classification"].values[0]

    # Calculate Subplot Grid
    n_methods = len(method_embeddings)
    cols = 4
    rows = math.ceil(n_methods / cols)

    fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 4 * rows))
    axes = axes.flatten()

    # Figure Title
    fig.suptitle(f"Target: {cell_type}\n({c_class})", fontsize=16, y=1.02, fontweight='bold')

    for i, (method_name, obsm_key) in enumerate(method_embeddings.items()):
        ax = axes[i]

        if obsm_key in adata.obsm.keys():

            # --- THE FIX IS HERE ---
            # Use sc.pl.embedding instead of sc.pl.umap
            sc.pl.embedding(
                adata,
                basis=obsm_key,           # Explicitly pass the full key (e.g. 'X_umap_scVI')
                color='celltype',
                groups=[cell_type],       # Highlights target, grays out others
                ax=ax,
                show=False,
                title=method_name,
                frameon=False,
                legend_loc=None,
                s=20,
                na_in_legend=False
            )
        else:
            ax.text(0.5, 0.5, f"{obsm_key}\nnot found", ha='center', va='center')
            ax.axis('off')

    # Turn off unused subplots
    for j in range(i + 1, len(axes)):
        axes[j].axis('off')

    plt.tight_layout()
    plt.show()

In [43]:
import pandas as pd
import plotly.graph_objects as go

def plot_method_sankey_highlight(
    adata,
    method_name: str,
    focus_cell_type: str,
    method_to_pred_key: dict,
    true_label_key: str = "celltype",
    min_count: int = 1,
    title: str | None = None,
    save_path: str | None = None,
    show: bool = True,
):
    """
    Sankey: True labels (left) -> Predicted labels (right)
    Link colors:
      - Green: focus true -> focus pred (correct)
      - Red:   focus true -> other pred (wrong)
      - Blue:  other true -> focus pred (incoming)
      - Gray:  everything else
    """
    if true_label_key not in adata.obs.columns:
        raise KeyError(f"true_label_key='{true_label_key}' not found in adata.obs")

    if method_name not in method_to_pred_key:
        raise KeyError(f"method_name='{method_name}' not in method_to_pred_key mapping")

    pred_key = method_to_pred_key[method_name]
    if pred_key not in adata.obs.columns:
        raise KeyError(f"pred_key='{pred_key}' not found in adata.obs for method '{method_name}'")

    df = adata.obs[[true_label_key, pred_key]].copy()
    df[true_label_key] = df[true_label_key].astype(str)
    df[pred_key] = df[pred_key].astype(str)

    # Build flow table
    flows = (
        df.groupby([true_label_key, pred_key])
          .size()
          .reset_index(name="count")
    )
    flows = flows[flows["count"] >= min_count].copy()

    # Keep all node labels (even if focus doesn't appear, you'll get a plot but no highlighted links)
    true_labels = sorted(df[true_label_key].unique().tolist())
    pred_labels = sorted(df[pred_key].unique().tolist())

    # Prefix to keep left/right nodes distinct even if same text
    left_nodes = [f"T: {x}" for x in true_labels]
    right_nodes = [f"P: {x}" for x in pred_labels]
    nodes = left_nodes + right_nodes

    node_index = {n: i for i, n in enumerate(nodes)}

    # Create sankey links
    sources = []
    targets = []
    values = []
    link_colors = []

    # Color constants (RGBA looks nice in Sankey)
    GRAY  = "rgba(180,180,180,0.45)"
    GREEN = "rgba(0,160,0,0.75)"
    RED   = "rgba(220,0,0,0.75)"
    BLUE  = "rgba(0,90,220,0.75)"

    for _, row in flows.iterrows():
        t = row[true_label_key]
        p = row[pred_key]
        c = int(row["count"])

        s_node = f"T: {t}"
        t_node = f"P: {p}"

        sources.append(node_index[s_node])
        targets.append(node_index[t_node])
        values.append(c)

        # Apply your highlighting rules
        if t == focus_cell_type and p == focus_cell_type:
            link_colors.append(GREEN)  # correct match out of focus type
        elif t == focus_cell_type and p != focus_cell_type:
            link_colors.append(RED)    # wrong match out of focus type
        elif t != focus_cell_type and p == focus_cell_type:
            link_colors.append(BLUE)   # incoming to focus type
        else:
            link_colors.append(GRAY)   # everything else

    if title is None:
        title = f"Sankey ({method_name}) — focus: {focus_cell_type}"

    fig = go.Figure(
        data=[
            go.Sankey(
                arrangement="snap",
                node=dict(
                    pad=15,
                    thickness=15,
                    line=dict(color="rgba(100,100,100,0.6)", width=0.5),
                    label=nodes
                ),
                link=dict(
                    source=sources,
                    target=targets,
                    value=values,
                    color=link_colors
                )
            )
        ]
    )
    fig.update_layout(title_text=title, font_size=10)

    if save_path:
        fig.write_html(save_path)

    if show:
        fig.show()

    return fig


In [None]:
method_to_pred_key = {
    "harmony": "leiden_harmony_majority_label",
    "mrvi": "leiden_mrvi_u_majority_label",
    "scDML": "leiden_scDML_majority_label",
    "scDisInFact": "leiden_scDisInFact_majority_label",
    "scVI": "leiden_scVI_majority_label",
    "seurat": "leiden_seurat_majority_label",
    "scGen": "leiden_scGen_majority_label",
}


In [45]:
plot_method_sankey_highlight(
    adata=adata,
    method_name="harmony",
    focus_cell_type="ILC",
    method_to_pred_key=method_to_pred_key,
    true_label_key="celltype",
    save_path="./harmony_sankey_focus_ILC.html",
    show=True
)


In [46]:
plot_method_sankey_highlight(
    adata=adata,
    method_name="harmony",
    focus_cell_type="ILC1",
    method_to_pred_key=method_to_pred_key,
    true_label_key="celltype",
    save_path="./harmony_sankey_focus_ILC1.html",
    show=True
)


In [47]:
import pandas as pd
import plotly.graph_objects as go

def plot_sankey_for_method_and_celltype(
    adata,
    method_name: str,
    focus_cell_type: str,
    method_to_pred_key: dict,
    true_label_key: str = "celltype",
    min_count: int = 1,
    width: int = 1100,
    height: int = 650,
    title_font_size: int = 14,
    font_size: int = 11,
    save_html: str | None = None,
    save_png: str | None = None,
    png_scale: int = 2,
    show: bool = True,
):
    """
    Sankey: True labels (left) -> Predicted labels (right)
    Styling: gray nodes, mostly gray links; highlight links in green/red/blue.
    """

    if true_label_key not in adata.obs.columns:
        raise KeyError(f"true_label_key '{true_label_key}' not found in adata.obs")

    if method_name not in method_to_pred_key:
        raise KeyError(f"method '{method_name}' not in method_to_pred_key mapping")

    pred_key = method_to_pred_key[method_name]
    if pred_key not in adata.obs.columns:
        raise KeyError(f"pred_key '{pred_key}' not found in adata.obs for method '{method_name}'")

    df = adata.obs[[true_label_key, pred_key]].copy()
    df[true_label_key] = df[true_label_key].astype(str)
    df[pred_key] = df[pred_key].astype(str)

    # Flow counts
    flows = (
        df.groupby([true_label_key, pred_key])
          .size()
          .reset_index(name="count")
    )
    flows = flows[flows["count"] >= min_count].copy()

    # Build separate node lists for left and right (even if names overlap)
    true_labels = sorted(df[true_label_key].unique().tolist())
    pred_labels = sorted(df[pred_key].unique().tolist())

    # Internal unique node ids (avoid merging), but show plain labels like your screenshot
    left_ids  = [f"L::{x}" for x in true_labels]
    right_ids = [f"R::{x}" for x in pred_labels]
    node_ids = left_ids + right_ids

    node_index = {nid: i for i, nid in enumerate(node_ids)}

    sources, targets, values, link_colors = [], [], [], []

    # Colors (tuned to look like your screenshot)
    LINK_GRAY = "rgba(170,170,170,0.35)"
    GREEN = "rgba(0,170,0,0.85)"
    RED   = "rgba(220,0,0,0.85)"
    BLUE  = "rgba(0,90,220,0.85)"

    for _, r in flows.iterrows():
        t = r[true_label_key]
        p = r[pred_key]
        c = int(r["count"])

        s = node_index[f"L::{t}"]
        tg = node_index[f"R::{p}"]

        sources.append(s)
        targets.append(tg)
        values.append(c)

        # Highlight rules
        if t == focus_cell_type and p == focus_cell_type:
            link_colors.append(GREEN)          # correct from focus
        elif t == focus_cell_type and p != focus_cell_type:
            link_colors.append(RED)            # wrong from focus
        elif t != focus_cell_type and p == focus_cell_type:
            link_colors.append(BLUE)           # incoming to focus
        else:
            link_colors.append(LINK_GRAY)      # everything else gray

    # Node labels shown (plain, no prefixes) — like the screenshot
    node_labels = true_labels + pred_labels

    # Gray nodes (both sides)
    node_gray = "rgba(120,120,120,0.85)"

    title_text = f"Sankey: True '{true_label_key}' \u2192  Predicted '{pred_key}' (Highlighting: {focus_cell_type})"

    fig = go.Figure(
        data=[
            go.Sankey(
                arrangement="snap",
                node=dict(
                    pad=10,
                    thickness=12,
                    label=node_labels,
                    color=[node_gray] * len(node_labels),
                    line=dict(color="rgba(90,90,90,0.6)", width=0.5),
                ),
                link=dict(
                    source=sources,
                    target=targets,
                    value=values,
                    color=link_colors,
                ),
            )
        ]
    )

    # Layout to match the “old” look: centered title, white background, controlled size
    fig.update_layout(
        title=dict(
            text=title_text,
            x=0.5,
            xanchor="center",
            font=dict(size=title_font_size),
        ),
        width=width,
        height=height,
        font=dict(size=font_size),
        paper_bgcolor="white",
        plot_bgcolor="white",
        margin=dict(l=10, r=10, t=50, b=10),
    )

    if save_html:
        fig.write_html(save_html)

    # Optional high-res PNG (requires kaleido)
    if save_png:
        # pip install -U kaleido
        fig.write_image(save_png, scale=png_scale)

    if show:
        fig.show()

    return fig


In [None]:
method_to_pred_key = {
    "seurat": "leiden_seurat_majority_label",
    "harmony": "leiden_harmony_majority_label",
    "mrvi": "leiden_mrvi_u_majority_label",
    "scDML": "leiden_scDML_majority_label",
    "scDisInFact": "leiden_scDisInFact_majority_label",
    "scVI": "leiden_scVI_majority_label",
    "scGen": "leiden_scGen_majority_label",
}

plot_sankey_for_method_and_celltype(
    adata=adata,
    method_name="seurat",
    focus_cell_type="ILC",
    method_to_pred_key=method_to_pred_key,
    true_label_key="celltype",
    width=1100,
    height=650,
    save_html="./seurat_sankey_focus_ILC.html",
    # save_png="./seurat_sankey_focus_ILC.png",  # uncomment if you want a high-res PNG (needs kaleido)
    png_scale=3,
    show=True,
)


In [51]:
plot_sankey_for_method_and_celltype(
    adata=adata,
    method_name="harmony",
    focus_cell_type="ILC",
    method_to_pred_key=method_to_pred_key,
    true_label_key="celltype",
    width=1100,
    height=650,
    save_html="./harmony_sankey_focus_ILC.html",
    # save_png="./seurat_sankey_focus_ILC.png",  # uncomment if you want a high-res PNG (needs kaleido)
    png_scale=3,
    show=True,
)