In [None]:
import pandas as pd
import os
import numpy as np
import scanpy as sc
from scipy.spatial.distance import cdist
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import wasserstein_distance, pearsonr
from warnings import catch_warnings, simplefilter

np.random.seed(42)

def compute_mmd(pred_data, true_data, kernel='rbf', gamma=1.0):
    if kernel == 'rbf':
        dist_pred = cdist(pred_data, pred_data, metric='sqeuclidean')
        dist_truth = cdist(true_data, true_data, metric='sqeuclidean')
        dist_cross = cdist(pred_data, true_data, metric='sqeuclidean')

        Kxx = np.exp(-gamma * dist_pred)
        Kyy = np.exp(-gamma * dist_truth)
        Kxy = np.exp(-gamma * dist_cross)

        return np.mean(Kxx) + np.mean(Kyy) - 2 * np.mean(Kxy)
    else:
        raise ValueError("Unsupported kernel type. Use 'rbf'.")

def compute_wasserstein(pred_data, true_data):
    return wasserstein_distance(pred_data.flatten(), true_data.flatten())

def calculate_metrics_all(pred_data, true_data, ctrl_data):
   
    if isinstance(true_data, pd.DataFrame):
        true_data = true_data.values
    if isinstance(pred_data, pd.DataFrame):
        pred_data = pred_data.values
    if isinstance(ctrl_data, pd.DataFrame):
        ctrl_data = ctrl_data.values

    metrics = {}

    mean_true = np.mean(true_data, axis=0)
    mean_pred = np.mean(pred_data, axis=0)
    mean_ctrl = np.mean(ctrl_data, axis=0)

    with catch_warnings():
        simplefilter("ignore")
        try:
            metrics['R_squared'] = r2_score(mean_true, mean_pred)
            metrics['R_squared_delta'] = r2_score(mean_true - mean_ctrl, mean_pred - mean_ctrl)
        except Exception:
            metrics['R_squared'] = np.nan
            metrics['R_squared_delta'] = np.nan

        try:
            corr, _ = pearsonr(mean_true, mean_pred)
            metrics['Pearson_Correlation'] = corr
            corr_delta, _ = pearsonr(mean_true - mean_ctrl, mean_pred - mean_ctrl)
            metrics['Pearson_Correlation_delta'] = corr_delta
        except Exception:
            metrics['Pearson_Correlation'] = np.nan
            metrics['Pearson_Correlation_delta'] = np.nan

    mse = mean_squared_error(mean_true, mean_pred)
    mse_delta = mean_squared_error(mean_true - mean_ctrl, mean_pred - mean_ctrl)
    metrics.update({
        'MSE': mse,
        'RMSE': np.sqrt(mse),
        'MAE': mean_absolute_error(mean_true, mean_pred),
        'Cosine_Similarity': cosine_similarity([mean_true], [mean_pred])[0, 0],
        'L2': np.linalg.norm(mean_true - mean_pred),
        'MSE_delta': mse_delta,
        'RMSE_delta': np.sqrt(mse_delta),
        'MAE_delta': mean_absolute_error(mean_true - mean_ctrl, mean_pred - mean_ctrl),
        'Cosine_Similarity_delta': cosine_similarity(
            [(mean_true - mean_ctrl)], [(mean_pred - mean_ctrl)]
        )[0, 0],
        'L2_delta': np.linalg.norm((mean_true - mean_ctrl) - (mean_pred - mean_ctrl)),
    })

    return metrics

results_df = pd.DataFrame(columns=[
    'dataset', 'condition', 'model',
    'R_squared', 'R_squared_delta',
    'Pearson_Correlation', 'Pearson_Correlation_delta',
    'MSE', 'MSE_delta',
    'RMSE', 'RMSE_delta',
    'MAE', 'MAE_delta',
    'Cosine_Similarity', 'Cosine_Similarity_delta',
    'L2', 'L2_delta',
    'MMD','Wasserstein'
])

datasets = [
    'datlingerbock2017',
    'datlingerbock2021',
    'dixitregev2016',
    'frangiehizar2021_rna',
    "normanweissman2019_filtered",
    'papalexisatija2021_eccite_rna',
    'replogleweissman2022_rpe1',
    'sunshine2023_crispri_sarscov2',
    "tiankampmann2021_crispra",
    "tiankampmann2021_crispri"
]


for dataset in datasets:
    true_path = f'/data2/lanxiang/perturb_benchmarking/Task_pred_out/True_csv/{dataset}'
    ctrl_path = f'/data2/lanxiang/perturb_benchmarking/Task_pred_out/Ctrl_csv/{dataset}/ctrl_data.csv'

    ctrl_data = pd.read_csv(ctrl_path, index_col=0)

    for file_name in os.listdir(true_path):
        if file_name.endswith('.csv') and '+' not in file_name:
            true_file_path = os.path.join(true_path, file_name)
            true_data = pd.read_csv(true_file_path, index_col=0)

            if true_data.shape[0] > 300:
                true_data = true_data.iloc[np.random.choice(true_data.shape[0], 300, replace=False), :]
            if ctrl_data.shape[0] > 300:
                ctrl_data = ctrl_data.iloc[np.random.choice(ctrl_data.shape[0], 300, replace=False), :]

            # ctrl_data was used as the pred data
            metrics = calculate_metrics_all(ctrl_data, true_data, ctrl_data)

            mmd_per_gene = []
            ws_per_gene = []
            for gene_idx in range(true_data.shape[1]):
                gene_pred = ctrl_data.iloc[:, gene_idx].values.reshape(-1, 1)
                gene_truth = true_data.iloc[:, gene_idx].values.reshape(-1, 1)
                mmd_per_gene.append(compute_mmd(gene_pred, gene_truth, kernel='rbf', gamma=1.0))
                ws_per_gene.append(compute_wasserstein(gene_pred, gene_truth))

            metrics['MMD'] = np.mean(mmd_per_gene)
            metrics['Wasserstein'] = np.mean(ws_per_gene)
            metrics['condition'] = file_name.replace('.csv', '')
            metrics['dataset'] = dataset
            metrics['model'] = "NoPerturb"

            results_df = pd.concat([results_df, pd.DataFrame([metrics])], ignore_index=True)

output_path = '/task1/noperturb_results.csv'
results_df.to_csv(output_path, index=False)


In [2]:
results_df

Unnamed: 0,dataset,condition,model,R_squared,R_squared_delta,Pearson_Correlation,Pearson_Correlation_delta,MSE,MSE_delta,RMSE,RMSE_delta,MAE,MAE_delta,Cosine_Similarity,Cosine_Similarity_delta,L2,L2_delta,MMD,Wasserstein
0,datlingerbock2017,NFAT5,NoPerturb,0.990850,-0.031681,0.995646,,0.003209,0.003209,0.056647,0.056647,0.031607,0.031607,0.996415,0.0,4.012776,4.012776,0.002890,0.043083
1,datlingerbock2017,JUND,NoPerturb,0.992952,-0.002380,0.996506,,0.002438,0.002438,0.049372,0.049372,0.027612,0.027612,0.997132,0.0,3.497394,3.497394,0.001925,0.036667
2,datlingerbock2017,RUNX2,NoPerturb,0.993095,-0.020594,0.996695,,0.002413,0.002413,0.049127,0.049127,0.027882,0.027882,0.997286,0.0,3.480043,3.480043,0.002209,0.038431
3,datlingerbock2017,JUNB,NoPerturb,0.993494,-0.030954,0.996949,,0.002281,0.002281,0.047765,0.047765,0.026730,0.026730,0.997491,0.0,3.383539,3.383539,0.002375,0.038636
4,datlingerbock2017,NFATC1,NoPerturb,0.994176,-0.022860,0.997257,,0.002041,0.002041,0.045174,0.045174,0.025021,0.025021,0.997749,0.0,3.200054,3.200054,0.001891,0.034785
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
517,tiankampmann2021_crispri,PVR,NoPerturb,0.998058,-0.007909,0.999061,,0.000363,0.000363,0.019046,0.019046,0.008279,0.008279,0.999115,0.0,1.369483,1.369483,0.000389,0.011133
518,tiankampmann2021_crispri,DNAJC6,NoPerturb,0.998721,-0.001785,0.999374,,0.000240,0.000240,0.015484,0.015484,0.006924,0.006924,0.999412,0.0,1.113366,1.113366,0.000247,0.009205
519,tiankampmann2021_crispri,SNCB,NoPerturb,0.998669,-0.000036,0.999339,,0.000253,0.000253,0.015920,0.015920,0.006865,0.006865,0.999378,0.0,1.144689,1.144689,0.000267,0.008978
520,tiankampmann2021_crispri,EIF4G1,NoPerturb,0.997903,-0.000010,0.998951,,0.000397,0.000397,0.019916,0.019916,0.008414,0.008414,0.999015,0.0,1.431988,1.431988,0.000409,0.011137
