In this tutorial, we will use Wilcoxon signed-rank test to test out results.

For the preparation, you need PCC results on models you want to test, which should be `npy` format.

In [None]:
import numpy as np
from scipy.stats import wilcoxon
import matplotlib.pyplot as plt
import os

# --------------------------
# 核心函数
# --------------------------




In [None]:
def load_and_aggregate_data(reg2st_path, baseline_paths, agg_func=np.median):
    data = {}
    
    # Load PCC results of Reg2ST    
    if not os.path.exists(reg2st_path):
        raise FileNotFoundError(f"PCC result of Reg2ST doesn't exist: {reg2st_path}")
    reg2st_raw = np.load(reg2st_path)[:-2]
    if reg2st_raw.shape != (32, 785):
        raise ValueError("shape of Reg2ST PCC result must be 32×785")
    data["Reg2ST"] = agg_func(reg2st_raw, axis=1)  # 按切片聚合 -> (32,)
    
    # Load PCC results of baseline methods
    for method, path in baseline_paths.items():
        if not os.path.exists(path):
            raise FileNotFoundError(f"PCC result of {method} doesn't exist: {path}")
        baseline_raw = np.load(path)
        if baseline_raw.shape != (32, 785):
            raise ValueError(f"shape of {method} PCC result must be 32×785")
        
        # Fix nan in PCC results
        for i in range(32):
            col = baseline_raw[i, :]
            mean_raw = np.nanmean(col)
            col[np.isnan(col)] = mean_raw
            
        data[method] = agg_func(baseline_raw, axis=1)
    
    return data

def perform_statistical_tests(data_dict, alpha=0.05, alternative="greater"):
    reg2st_scores = data_dict["Reg2ST"]
    baseline_methods = [k for k in data_dict if k != "Reg2ST"]
    n_comparisons = len(baseline_methods)
    adjusted_alpha = alpha / n_comparisons  # Bonferroni correction
    
    results = []
    for method in baseline_methods:
        baseline_scores = data_dict[method]
        
        # shape checking
        if len(reg2st_scores) != len(baseline_scores):
            raise ValueError(f"{method}数据长度不匹配")
        
        # Wilcoxon signed-rank test
        try:
            stat, p_value = wilcoxon(
                reg2st_scores - baseline_scores,
                alternative=alternative,
                zero_method="pratt"
            )
        except ValueError as e:
            print(f"Warning: {method} Failed: ({str(e)})")
            continue
        
        median_diff = np.median(reg2st_scores - baseline_scores)
        
        results.append({
            "method": method,
            "median_diff": median_diff,
            "p_value": p_value,
            "significant": p_value < adjusted_alpha,
            "adjusted_alpha": adjusted_alpha
        })
    
    return results

def visualize_results(data_dict, results):
    methods = ["Reg2ST"] + list(results.keys())
    scores = [data_dict["Reg2ST"]] + [data_dict[method] for method in results.keys()]
    
    plt.figure(figsize=(10, 6))
    plt.boxplot(scores, labels=methods, showmeans=True, patch_artist=True)
    plt.ylabel("Median PCC per Slide", fontsize=12)
    plt.title("Model Performance Comparison (n=32 slides)", fontsize=14)
    plt.grid(True, linestyle="--", alpha=0.7)
    plt.show()


In [None]:
if __name__ == "__main__":
    # change dir to your own PCC result folders
    DATA_PATHS = {
        "Reg2ST": "her2st_pcc_final.npy",
        "HisToGene": "../model-abi/histogenepccherst.npy",
        "Hist2ST": "../model-abi/histtostherstpcc.npy", 
        "THItoGene": "../model-abi/thi2geneherstpcc.npy",
        "HGGEP": "../HGGEP/her2st_pcc.npy"    
    }

    try:
        data = load_and_aggregate_data(
            reg2st_path=DATA_PATHS["Reg2ST"],
            baseline_paths={k:v for k,v in DATA_PATHS.items() if k != "Reg2ST"},
            agg_func=np.median
        )

        test_results = perform_statistical_tests(data)
        
        print("\nStatistical Significance Tests:")
        print("="*65)
        print(f"{'Method':<15} | {'Median Diff':>12} | {'Wilcoxon p-value':>10} | {'Sig (α=0.05)'}")
        print("-"*65)
        for res in test_results:
            sig_flag = "Yes" if res["significant"] else "No"
            print(
                f"{res['method']:<15} | "
                f"{res['median_diff']:>12.4f} | "
                f"{res['p_value']:>10.4f} | "
                f"{sig_flag:>12}"
            )
        
        visualize_results(data, {res["method"]: res for res in test_results})
        
    except Exception as e:
        print(f"ERROR: {str(e)}")