In [None]:
import pandas as pd
import os
import numpy as np
import scanpy as sc
from scipy.spatial.distance import cdist
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import wasserstein_distance, pearsonr
from warnings import catch_warnings, simplefilter


np.random.seed(42)

def compute_mmd(pred_data, true_data, kernel='rbf', gamma=1.0):
    if kernel == 'rbf':
        dist_pred = cdist(pred_data, pred_data, metric='sqeuclidean')
        dist_truth = cdist(true_data, true_data, metric='sqeuclidean')
        dist_cross = cdist(pred_data, true_data, metric='sqeuclidean')

        Kxx = np.exp(-gamma * dist_pred)
        Kyy = np.exp(-gamma * dist_truth)
        Kxy = np.exp(-gamma * dist_cross)

        return np.mean(Kxx) + np.mean(Kyy) - 2 * np.mean(Kxy)
    else:
        raise ValueError("Unsupported kernel type. Use 'rbf'.")


def compute_wasserstein(pred_data, true_data):
    return wasserstein_distance(pred_data.flatten(), true_data.flatten())

def calculate_metrics_all(pred_data, true_data, ctrl_data):

    if isinstance(true_data, pd.DataFrame):
        true_data = true_data.values
    if isinstance(pred_data, pd.DataFrame):
        pred_data = pred_data.values
    if isinstance(ctrl_data, pd.DataFrame):
        ctrl_data = ctrl_data.values

    metrics = {}

    mean_true = np.mean(true_data, axis=0)
    mean_pred = np.mean(pred_data, axis=0)
    mean_ctrl = np.mean(ctrl_data, axis=0)

    with catch_warnings():
        simplefilter("ignore")
        try:
            metrics['R_squared'] = r2_score(mean_true, mean_pred)
            metrics['R_squared_delta'] = r2_score(mean_true - mean_ctrl, mean_pred - mean_ctrl)
        except Exception:
            metrics['R_squared'] = np.nan
            metrics['R_squared_delta'] = np.nan

        try:
            corr, _ = pearsonr(mean_true, mean_pred)
            metrics['Pearson_Correlation'] = corr
            corr_delta, _ = pearsonr(mean_true - mean_ctrl, mean_pred - mean_ctrl)
            metrics['Pearson_Correlation_delta'] = corr_delta
        except Exception:
            metrics['Pearson_Correlation'] = np.nan
            metrics['Pearson_Correlation_delta'] = np.nan

    mse = mean_squared_error(mean_true, mean_pred)
    mse_delta = mean_squared_error(mean_true - mean_ctrl, mean_pred - mean_ctrl)
    metrics.update({
        'MSE': mse,
        'RMSE': np.sqrt(mse),
        'MAE': mean_absolute_error(mean_true, mean_pred),
        'Cosine_Similarity': cosine_similarity([mean_true], [mean_pred])[0, 0],
        'L2': np.linalg.norm(mean_true - mean_pred),
        'MSE_delta': mse_delta,
        'RMSE_delta': np.sqrt(mse_delta),
        'MAE_delta': mean_absolute_error(mean_true - mean_ctrl, mean_pred - mean_ctrl),
        'Cosine_Similarity_delta': cosine_similarity(
            [(mean_true - mean_ctrl)], [(mean_pred - mean_ctrl)]
        )[0, 0],
        'L2_delta': np.linalg.norm((mean_true - mean_ctrl) - (mean_pred - mean_ctrl)),
    })

    return metrics


results_df = pd.DataFrame(columns=[
    'dataset', 'condition', 'model',
    'R_squared', 'R_squared_delta',
    'Pearson_Correlation', 'Pearson_Correlation_delta',
    'MSE', 'MSE_delta',
    'RMSE', 'RMSE_delta',
    'MAE', 'MAE_delta',
    'Cosine_Similarity', 'Cosine_Similarity_delta',
    'L2', 'L2_delta',
    'MMD','Wasserstein'
])


datasets = [
    'dixitregev2016',
    "normanweissman2019_filtered",
    'sunshine2023_crispri_sarscov2'
]


for dataset in datasets:
    true_path = f'/data2/lanxiang/perturb_benchmarking/Task_pred_out/True_csv/{dataset}'
    ctrl_path = f'/data2/lanxiang/perturb_benchmarking/Task_pred_out/Ctrl_csv/{dataset}/ctrl_data.csv'

    ctrl_data = pd.read_csv(ctrl_path, index_col=0)

    for file_name in os.listdir(true_path):
        if file_name.endswith('.csv') and '+' in file_name:
           
            true_file_path = os.path.join(true_path, file_name)
            true_data = pd.read_csv(true_file_path, index_col=0)

            if true_data.shape[0] > 300:
                true_data = true_data.iloc[np.random.choice(true_data.shape[0], 300, replace=False), :]
            if ctrl_data.shape[0] > 300:
                ctrl_data = ctrl_data.iloc[np.random.choice(ctrl_data.shape[0], 300, replace=False), :]

            # use ctrl_data as pred
            metrics = calculate_metrics_all(ctrl_data, true_data, ctrl_data)

            
            mmd_per_gene = []
            ws_per_gene = []
            for gene_idx in range(true_data.shape[1]):
                gene_pred = ctrl_data.iloc[:, gene_idx].values.reshape(-1, 1)
                gene_truth = true_data.iloc[:, gene_idx].values.reshape(-1, 1)
                mmd_per_gene.append(compute_mmd(gene_pred, gene_truth, kernel='rbf', gamma=1.0))
                ws_per_gene.append(compute_wasserstein(gene_pred, gene_truth))

            metrics['MMD'] = np.mean(mmd_per_gene)
            metrics['Wasserstein'] = np.mean(ws_per_gene)
            metrics['condition'] = file_name.replace('.csv', '')
            metrics['dataset'] = dataset
            metrics['model'] = "NoPerturb"

            results_df = pd.concat([results_df, pd.DataFrame([metrics])], ignore_index=True)

output_path = '/task2/noperturb_results.csv'
results_df.to_csv(output_path, index=False)


In [2]:
results_df

Unnamed: 0,dataset,condition,model,R_squared,R_squared_delta,Pearson_Correlation,Pearson_Correlation_delta,MSE,MSE_delta,RMSE,RMSE_delta,MAE,MAE_delta,Cosine_Similarity,Cosine_Similarity_delta,L2,L2_delta,MMD,Wasserstein
0,dixitregev2016,NR2C2+IRF1,NoPerturb,0.990990,-0.014436,0.996223,,0.004120,0.004120,0.064184,0.064184,0.031553,0.031553,0.996959,0.0,4.542103,4.542103,0.004146,0.044374
1,dixitregev2016,YY1+GABPA,NoPerturb,0.989763,-0.061917,0.996363,,0.004800,0.004800,0.069281,0.069281,0.034151,0.034151,0.997104,0.0,4.902854,4.902854,0.006782,0.055287
2,dixitregev2016,NR2C2+YY1,NoPerturb,0.988996,-0.051348,0.995663,,0.005100,0.005100,0.071414,0.071414,0.035575,0.035575,0.996551,0.0,5.053781,5.053781,0.006514,0.055143
3,dixitregev2016,YY1+ELK1,NoPerturb,0.990022,-0.021585,0.995879,,0.004591,0.004591,0.067754,0.067754,0.033601,0.033601,0.996695,0.0,4.794794,4.794794,0.006370,0.054268
4,dixitregev2016,NR2C2+ELF1,NoPerturb,0.992798,-0.054267,0.997445,,0.003319,0.003319,0.057607,0.057607,0.028100,0.028100,0.997963,0.0,4.076682,4.076682,0.004894,0.046618
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
76,sunshine2023_crispri_sarscov2,CCZ1+CCZ1B,NoPerturb,0.961647,-0.195213,0.991259,,0.012697,0.012697,0.112682,0.112682,0.050046,0.050046,0.992348,0.0,8.060486,8.060486,0.014259,0.071948
77,sunshine2023_crispri_sarscov2,GPR89B+GPR89A,NoPerturb,0.966620,-0.082235,0.986760,,0.009897,0.009897,0.099484,0.099484,0.041816,0.041816,0.988630,0.0,7.116388,7.116388,0.011339,0.064811
78,sunshine2023_crispri_sarscov2,IFNAR2+IFNAR1,NoPerturb,0.953953,-0.083757,0.982344,,0.014381,0.014381,0.119922,0.119922,0.051271,0.051271,0.984877,0.0,8.578402,8.578402,0.015368,0.075729
79,sunshine2023_crispri_sarscov2,COG6+COG5,NoPerturb,0.970656,-0.138435,0.988488,,0.008416,0.008416,0.091741,0.091741,0.039992,0.039992,0.989876,0.0,6.562543,6.562543,0.010191,0.062397
