In [None]:
import os
import scanpy as sc
import scib
import math
import pandas as pd
from tqdm import TqdmWarning
import warnings

warnings.filterwarnings("ignore", category=TqdmWarning)

# ================================================
# Helper functions
# ================================================
def safe_mean(values):
    """
    Compute the mean of a list, ignoring NaN values.
    If all values are NaN, return NaN.
    """
    valid = [v for v in values if not math.isnan(v)]
    return sum(valid) / len(valid) if valid else float('nan')


def safe_round(x, d=4):
    """
    Round a value to `d` decimal places.
    If the value is NaN, return NaN.
    """
    return round(x, d) if not math.isnan(x) else float('nan')


# ================================================
# Calculate scIB metrics for a single h5ad file
# ================================================
def calculate_metrics_for_h5ad(file_path):
    """
    Read an h5ad file, compute single-cell integration metrics using scIB,
    and return a list of metrics including a weighted score.
    """
    # Load h5ad file
    adata = sc.read(file_path)
    
    # Build neighbor graph using the latent representation
    sc.pp.neighbors(adata, use_rep='latent')

    # Get dataset/model name from file name
    model_name = os.path.basename(file_path).replace('.h5ad', '')

    # Compute scIB metrics
    result = scib.metrics.metrics(
        adata=adata,
        adata_int=adata,
        batch_key="batch",      # column for batch information
        label_key="cell_type",  # column for cell type labels
        embed='latent',         # use latent representation
        silhouette_=True,
        nmi_=True,
        ari_=True,
        kBET_=True,
        ilisi_=True
    )

    # Select the first row of metrics (as a Series)
    row = result.iloc[:, 0]

    # Define the order of metrics
    metrics_order = [
        "ASW_label/batch",  # batch correction metric
        "iLISI",            # batch correction metric
        "kBET",             # batch correction metric
        "NMI_cluster/label",# biological conservation metric
        "ARI_cluster/label",# biological conservation metric
        "ASW_label",        # biological conservation metric
    ]

    # Extract metric values
    values = []
    for metric in metrics_order:
        val = row.get(metric, float('nan'))
        val = float(val) if pd.notnull(val) else float('nan')
        values.append(val)

    # Compute mean batch correction and biological conservation scores
    mean1 = safe_mean(values[0:3])  # batch correction average
    mean2 = safe_mean(values[3:6])  # biological conservation average

    # Weighted average: overall score = 0.4 * batch + 0.6 * biology
    weighted_avg = (mean1 * 0.4 + mean2 * 0.6) if (not math.isnan(mean1) and not math.isnan(mean2)) else float('nan')

    # Construct the final row: model name, weighted score, batch mean, biology mean, individual metrics
    final_row = [model_name, safe_round(weighted_avg), safe_round(mean1), safe_round(mean2)]
    final_row += [safe_round(x) for x in values]

    return final_row


# ================================================
# Evaluate a single h5ad file and save results
# ================================================
def evaluate_single_h5ad(input_file: str, output_file: str):
    """
    Compute scIB metrics for a single h5ad file and save the results to a tab-separated text file.
    """
    results = []

    print(f"Processing: {input_file}")
    
    # Calculate metrics for the input file
    row = calculate_metrics_for_h5ad(input_file)
    results.append(row)

    # Define output column names
    columns = [
        "Model", "Score", "Batch Correction", "Biological Conservation",
        "ASW_label/batch", "iLISI", "kBET",
        "NMI_cluster/label", "ARI_cluster/label", "ASW_label"
    ]

    # Save results as a tab-separated CSV file
    df = pd.DataFrame(results, columns=columns)
    df.to_csv(output_file, sep="\t", index=False)

    print(f"Done! File saved to: {output_file}")

In [None]:
input_file = "./results/neurips-multiome-multigai.h5ad"
output_file = "./data/metric/multiome-multigai.txt"
evaluate_single_h5ad(input_file, output_file)

In [None]:
input_file = "./results/neurips-cite-multigai.h5ad"
output_file = "./data/metric/cite-multigai.txt"
evaluate_single_h5ad(input_file, output_file)

In [None]:
input_file = "./results/neurips-multiome-multivi.h5ad"
output_file = "./data/metric/multiome-multivi.txt"
evaluate_single_h5ad(input_file, output_file)

In [None]:
input_file = "./results/neurips-cite-totalvi.h5ad"
output_file = "./data/metric/cite-totalvi.txt"
evaluate_single_h5ad(input_file, output_file)

In [None]:
input_file = "./results/neurips-multiome-mofa.h5ad"
output_file = "./data/metric/multiome-mofa.txt"
evaluate_single_h5ad(input_file, output_file)

In [None]:
input_file = "./results/neurips-cite-mofa.h5ad"
output_file = "./data/metric/cite-mofa.txt"
evaluate_single_h5ad(input_file, output_file)

In [None]:
input_file = "./results/neurips-multiome-mima.h5ad"
output_file = "./data/metric/multiome-mima.txt"
evaluate_single_h5ad(input_file, output_file)

In [None]:
input_file = "./results/neurips-cite-mima.h5ad"
output_file = "./data/metric/cite-mima.txt"
evaluate_single_h5ad(input_file, output_file)

In [None]:
input_file = "./results/neurips-multiome-seurat.h5ad"
output_file = "./data/metric/multiome-seurat.txt"
evaluate_single_h5ad(input_file, output_file)

In [None]:
input_file = "./results/neurips-cite-seurat.h5ad"
output_file = "./data/metric/cite-seurat.txt"
evaluate_single_h5ad(input_file, output_file)

In [None]:
input_file = "./results/neurips-multiome-gluer.h5ad"
output_file = "./data/metric/multiome-gluer.txt"
evaluate_single_h5ad(input_file, output_file)

input_file = "./results/neurips-multiome-gluea.h5ad"
output_file = "./data/metric/multiome-gluea.txt"
evaluate_single_h5ad(input_file, output_file)

In [None]:
input_file = "./results/neurips-multiome-multigain.h5ad"
output_file = "./data/metric/multiome-multigain.txt"
evaluate_single_h5ad(input_file, output_file)

In [None]:
input_file = "./results/neurips-cite-multigain.h5ad"
output_file = "./data/metric/cite-multigain.txt"
evaluate_single_h5ad(input_file, output_file)