In [2]:
import pandas as pd
import os
import numpy as np
import scanpy as sc
from scipy.spatial.distance import cdist
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import wasserstein_distance, pearsonr
from warnings import catch_warnings, simplefilter
np.random.seed(42)

def compute_mmd(pred_data, true_data, kernel='rbf', gamma=1.0):
    if kernel == 'rbf':
        dist_pred = cdist(pred_data, pred_data, metric='sqeuclidean')
        dist_truth = cdist(true_data, true_data, metric='sqeuclidean')
        dist_cross = cdist(pred_data, true_data, metric='sqeuclidean')

        Kxx = np.exp(-gamma * dist_pred)
        Kyy = np.exp(-gamma * dist_truth)
        Kxy = np.exp(-gamma * dist_cross)

        return np.mean(Kxx) + np.mean(Kyy) - 2 * np.mean(Kxy)
    else:
        raise ValueError("Unsupported kernel type. Use 'rbf'.")

def compute_wasserstein(pred_data, true_data):
    return wasserstein_distance(pred_data.flatten(), true_data.flatten())

def calculate_metrics_all(pred_data, true_data, ctrl_data):

    if isinstance(true_data, pd.DataFrame):
        true_data = true_data.values
    if isinstance(pred_data, pd.DataFrame):
        pred_data = pred_data.values
    if isinstance(ctrl_data, pd.DataFrame):
        ctrl_data = ctrl_data.values

    metrics = {}

    mean_true = np.mean(true_data, axis=0)
    mean_pred = np.mean(pred_data, axis=0)
    mean_ctrl = np.mean(ctrl_data, axis=0)

    with catch_warnings():
        simplefilter("ignore")
        try:
            metrics['R_squared'] = r2_score(mean_true, mean_pred)
            metrics['R_squared_delta'] = r2_score(mean_true - mean_ctrl, mean_pred - mean_ctrl)
        except Exception:
            metrics['R_squared'] = np.nan
            metrics['R_squared_delta'] = np.nan

        try:
            corr, _ = pearsonr(mean_true, mean_pred)
            metrics['Pearson_Correlation'] = corr
            corr_delta, _ = pearsonr(mean_true - mean_ctrl, mean_pred - mean_ctrl)
            metrics['Pearson_Correlation_delta'] = corr_delta
        except Exception:
            metrics['Pearson_Correlation'] = np.nan
            metrics['Pearson_Correlation_delta'] = np.nan

    mse = mean_squared_error(mean_true, mean_pred)
    mse_delta = mean_squared_error(mean_true - mean_ctrl, mean_pred - mean_ctrl)
    metrics.update({
        'MSE': mse,
        'RMSE': np.sqrt(mse),
        'MAE': mean_absolute_error(mean_true, mean_pred),
        'Cosine_Similarity': cosine_similarity([mean_true], [mean_pred])[0, 0],
        'L2': np.linalg.norm(mean_true - mean_pred),
        'MSE_delta': mse_delta,
        'RMSE_delta': np.sqrt(mse_delta),
        'MAE_delta': mean_absolute_error(mean_true - mean_ctrl, mean_pred - mean_ctrl),
        'Cosine_Similarity_delta': cosine_similarity(
            [(mean_true - mean_ctrl)], [(mean_pred - mean_ctrl)]
        )[0, 0],
        'L2_delta': np.linalg.norm((mean_true - mean_ctrl) - (mean_pred - mean_ctrl)),
    })

    return metrics

results_df = pd.DataFrame(columns=[
    'dataset', 'condition', 'model',
    'R_squared', 'R_squared_delta',
    'Pearson_Correlation', 'Pearson_Correlation_delta',
    'MSE', 'MSE_delta',
    'RMSE', 'RMSE_delta',
    'MAE', 'MAE_delta',
    'Cosine_Similarity', 'Cosine_Similarity_delta',
    'L2', 'L2_delta',
    'MMD','Wasserstein'
])

datasets = [
    'dixitregev2016',
    "normanweissman2019_filtered",
    'sunshine2023_crispri_sarscov2',
]


scFoundation_result = {}

for dataset in datasets:
    scFoundation_pred_paths = [
        f'/data2/lanxiang/perturb_benchmarking/Task_pred_out/scfoundation/{dataset}',
        f'/data2/lanxiang/perturb_benchmarking/tidy_data/task3_data/seen_1/scfoundation/{dataset}',
        f'/data2/lanxiang/perturb_benchmarking/tidy_data/task3_data/seen_2/scfoundation/{dataset}',
    ]
    true_path = f'/data2/lanxiang/perturb_benchmarking/Task_pred_out/True_csv/{dataset}'
    ctrl_path = f'/data2/lanxiang/perturb_benchmarking/Task_pred_out/Ctrl_csv/{dataset}/ctrl_data.csv'
    index_path = '/data/yy_data/sc_model/scFoundation-main/OS_scRNA_gene_index.19264.tsv'
    index_data = pd.read_csv(index_path, sep='\t')
    scfoundation_index = index_data['gene_name'].tolist()

    ctrl_data = pd.read_csv(ctrl_path, index_col=0)

    dataset_results = []

    for scFoundation_pred_path in scFoundation_pred_paths:
        if not os.path.exists(scFoundation_pred_path):
            print(f"路径 {scFoundation_pred_path} 不存在，跳过。")
            continue

        for file_name in os.listdir(scFoundation_pred_path):
            if file_name.endswith('.npz') and '+' in file_name:
           
                npz_file_path = os.path.join(scFoundation_pred_path, file_name)
                with np.load(npz_file_path) as npz:
                    if 'pred' in npz:
                        pred_data = npz['pred']
                        pred_data = pd.DataFrame(pred_data, columns=scfoundation_index)

                        true_file_path = os.path.join(true_path, file_name.replace('.npz', '.csv'))
                        true_data = pd.read_csv(true_file_path, index_col=0)
                        common_cols = list(set(scfoundation_index).intersection(true_data.columns))
                        true_data = true_data[common_cols]
                        pred_data = pred_data[common_cols].reindex(columns=true_data.columns)
                        ctrl_data = ctrl_data[common_cols].reindex(columns=true_data.columns)

                        if pred_data.shape[0] > 300:
                            pred_data = pred_data.iloc[np.random.choice(pred_data.shape[0], 300, replace=False), :]
                        if true_data.shape[0] > 300:
                            true_data = true_data.iloc[np.random.choice(true_data.shape[0], 300, replace=False), :]
                        if ctrl_data.shape[0] > 300:
                            ctrl_data = ctrl_data.iloc[np.random.choice(ctrl_data.shape[0], 300, replace=False), :]

                        metrics = calculate_metrics_all(pred_data, true_data, ctrl_data)
                        mmd_per_gene = []
                        for gene_idx in range(pred_data.shape[1]):
                            gene_pred = pred_data.iloc[:, gene_idx].values.reshape(-1, 1)
                            gene_truth = true_data.iloc[:, gene_idx].values.reshape(-1, 1)
                            mmd = compute_mmd(gene_pred, gene_truth, kernel='rbf', gamma=1.0)
                            mmd_per_gene.append(mmd)

                        average_mmd = np.mean(mmd_per_gene)

                        ws_per_gene = []
                        for gene_idx in range(pred_data.shape[1]):
                            gene_pred = pred_data.iloc[:, gene_idx].values.reshape(-1, 1)
                            gene_truth = true_data.iloc[:, gene_idx].values.reshape(-1, 1)
                            ws = compute_wasserstein(gene_pred, gene_truth)
                            ws_per_gene.append(ws)

                        average_ws = np.mean(ws_per_gene)
                        metrics['MMD'] = average_mmd
                        metrics['Wasserstein'] = average_ws
                        metrics['condition'] = file_name.replace('.npz', '')
                        metrics['dataset'] = dataset
                        metrics["model"] = "scFoundation"

                
                        results_df = pd.concat([results_df, pd.DataFrame([metrics])], ignore_index=True)

print(results_df)
results_df.to_csv('/task2/scFoundation_results.csv', index=False)


                          dataset        condition         model  R_squared  \
0                  dixitregev2016        IRF1+ETS1  scFoundation   0.989105   
1                  dixitregev2016       NR2C2+ETS1  scFoundation   0.989920   
2                  dixitregev2016       NR2C2+IRF1  scFoundation   0.987741   
3                  dixitregev2016        EGR1+ETS1  scFoundation   0.991350   
4                  dixitregev2016        ELK1+ETS1  scFoundation   0.992096   
..                            ...              ...           ...        ...   
76  sunshine2023_crispri_sarscov2        COG6+COG5  scFoundation   0.990703   
77  sunshine2023_crispri_sarscov2       CCZ1+CCZ1B  scFoundation   0.989263   
78  sunshine2023_crispri_sarscov2  SLC35B2+B3GALT6  scFoundation   0.985408   
79  sunshine2023_crispri_sarscov2    GPR89B+GPR89A  scFoundation   0.980497   
80  sunshine2023_crispri_sarscov2        RELB+RELA  scFoundation   0.989796   

    R_squared_delta  Pearson_Correlation  Pearson_C