In [1]:
import os

base_data = r"C:\E\JSU\BIO\file\STransfer\STrafer\datanew\OV"
base_pred = r"C:\E\JSU\BIO\file\STransfer\STrafer\params\OV\STransfer"

label_file = {
    "CosMx": "CosMx_labels.csv",
    "Xenium": "Xenium_labels.csv",
    "Stereo-seq": "Ste_labels.csv",
    "Visium HD": "Visium HD_labels.csv",
}


pairs_edges = [
    ("Visium HD", "CosMx"),
    ("Visium HD", "Xenium"),
    ("Visium HD", "Stereo-seq"),
    ("Xenium", "Visium HD"),
    ("Xenium", "CosMx"),
    ("Xenium", "Stereo-seq"),
    ("CosMx", "Visium HD"),
    ("CosMx", "Xenium"),
    ("CosMx", "Stereo-seq"),
    ("Stereo-seq", "Visium HD"),
    ("Stereo-seq", "Xenium"),
    ("Stereo-seq", "CosMx"),
]

pairs = [
    (
        src, tgt,
        os.path.join(base_data, label_file[tgt]),  
        os.path.join(base_pred, f"{src}-{tgt}", f"{tgt}_target_pred_labels.csv")  
    )
    for src, tgt in pairs_edges
]


In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import adjusted_rand_score, accuracy_score 
mpl.rcParams['font.family'] = 'Times New Roman'
mpl.rcParams['mathtext.fontset'] = 'stix'
mpl.rcParams['axes.unicode_minus'] = False
sns.set_theme(style="white", font="Times New Roman")

all_slicename =['CosMx','Xenium', 'Visium HD','Stereo-seq']

def read_pred(csv_path: str) -> pd.Series:
    df = pd.read_csv(csv_path)
    if "Unnamed: 0" in df.columns:
        df = df.set_index("Unnamed: 0")
    if "pred" not in df.columns:
        raise ValueError(f"{csv_path}  'pred'，：{list(df.columns)[:20]}")
    return df["pred"]

def align_series(s: pd.Series, t: pd.Series):
    if s.index is not None and t.index is not None:
        common = s.index.intersection(t.index)
        if len(common) > 0:
            s2 = s.loc[common]
            t2 = t.loc[common]
            if len(common) >= min(len(s), len(t)) * 0.2:
                return s2, t2
    n = min(len(s), len(t))
    return s.iloc[:n], t.iloc[:n]

def add_outer_border(ax, n_rows, n_cols, lw=2):
    ax.add_patch(plt.Rectangle(
        (0, 0), n_cols, n_rows,
        fill=False, edgecolor="black", linewidth=lw, clip_on=False
    ))

# =======================
# =======================
acc_mat = pd.DataFrame(np.nan, index=all_slicename, columns=all_slicename)
ari_mat = pd.DataFrame(np.nan, index=all_slicename, columns=all_slicename)

for src_name, tgt_name, src_csv, tgt_csv in pairs:
    s = read_pred(src_csv)
    t = read_pred(tgt_csv)
    s, t = align_series(s, t)

    s_lab = s.astype(str).to_numpy()
    t_lab = t.astype(str).to_numpy()

    acc = accuracy_score(s_lab, t_lab)
    ari = adjusted_rand_score(s_lab, t_lab)


    acc_mat.loc[tgt_name, src_name] = acc
    ari_mat.loc[tgt_name, src_name] = ari

    print(f"{src_name} -> {tgt_name}:  ACC={acc:.4f}  ARI={ari:.4f}  n={len(s_lab)}")

# =======================
# =======================
out_dir = r"C:\E\JSU\BIO\file\STransfer\STrafer\params\OV\STransfer"
os.makedirs(out_dir, exist_ok=True)

grid_lw = 0.1
border_lw = 0.8

# Accuracy heatmap (0~1)
plt.figure(figsize=(7, 5))
ax = sns.heatmap(
    acc_mat, annot=True, fmt=".3f",
    cmap="Blues",
    vmin=0, vmax=1,                      
    linewidths=grid_lw, linecolor="black",
    mask=acc_mat.isna(),
    cbar_kws={"ticks":[0,0.2,0.4,0.6,0.8,1.0]}
)
add_outer_border(ax, acc_mat.shape[0], acc_mat.shape[1], lw=border_lw)
plt.title("CA Heatmap")
plt.xlabel("Source")
plt.ylabel("Target")
plt.xticks(rotation=0)
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()


plt.figure(figsize=(7, 5))
ax = sns.heatmap(
    ari_mat, annot=True, fmt=".3f",
    cmap="Blues",
    vmin=0, vmax=1,                       
    linewidths=grid_lw, linecolor="black",
    mask=ari_mat.isna(),
    cbar_kws={"ticks":[0,0.2,0.4,0.6,0.8,1.0]}
)
add_outer_border(ax, ari_mat.shape[0], ari_mat.shape[1], lw=border_lw)
plt.title("ARI Heatmap")
plt.xlabel("Source")
plt.ylabel("Target")
plt.xticks(rotation=0)
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()