In [43]:
 import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
from tqdm import tqdm
import math
import yaml
from torch.utils.data.dataloader import DataLoader
import pickle as pkl
import seaborn as sns
import matplotlib.pyplot as plt
import anndata as ad
from scipy.spatial.distance import euclidean, cosine
from sklearn.metrics import mean_absolute_error, mean_squared_error
from scipy.stats import wasserstein_distance
from itertools import combinations
from numpy import corrcoef
import random
from scipy.stats import spearmanr


In [26]:
adata = ad.read_h5ad("../../data/sciplex/sciplex_preprocessed.h5ad")
adata = adata[adata.obs['cell_type'] == "A549"]

In [27]:
print("Vehicle_0" in list(adata.obs['product_dose'].unique())) #eliminate Vehicle_0 to get all perturbed cell populations

True


Create 5 disjoint sets of n=300 control cell populations

In [28]:
adata_control = adata[adata.obs['product_name'] == "Vehicle"]

num_sets = 5
cells_per_set  = 300
n_total = adata_control.n_obs

all_indices = np.arange(n_total)
random.shuffle(all_indices)

# Split indices into disjoint sets
disjoint_sets = []
for i in range(num_sets):
    start_idx = i * cells_per_set
    end_idx = start_idx + cells_per_set
    disjoint_set_indices = all_indices[start_idx:end_idx]
    disjoint_sets.append(disjoint_set_indices)

# Optionally, create new AnnData objects for each set
disjoint_adata_sets = [adata_control[indices].copy() for indices in disjoint_sets]

# Output the sets
for i, indices in enumerate(disjoint_sets):
    print(f"Set {i + 1}: {len(indices)} cells")

Set 1: 300 cells
Set 2: 300 cells
Set 3: 300 cells
Set 4: 300 cells
Set 5: 300 cells


Compute distance between pairwise control subsets

In [29]:
def get_energy_distance(dist1, dist2):
    """
    Calculate the energy distance between two 2D distributions.

    Parameters:
        dist1 (numpy.ndarray): First distribution, a 2D array of shape (n1, d).
        dist2 (numpy.ndarray): Second distribution, a 2D array of shape (n2, d).

    Returns:
        float: The energy distance between the two distributions.
    """
    # Ensure inputs are numpy arrays
    dist1 = np.asarray(dist1)
    dist2 = np.asarray(dist2)

    # Pairwise distances within dist1
    a = np.linalg.norm(dist1[:, None, :] - dist1[None, :, :], axis=-1)
    mean_within_dist1 = np.mean(a)
    
    # Pairwise distances within dist2
    b = np.linalg.norm(dist2[:, None, :] - dist2[None, :, :], axis=-1)
    mean_within_dist2 = np.mean(b)
    
    # Pairwise distances between dist1 and dist2
    c = np.linalg.norm(dist1[:, None, :] - dist2[None, :, :], axis=-1)
    mean_between_dist = np.mean(c)
    
    # Energy distance formula
    energy_dist = 2 * mean_between_dist - mean_within_dist1 - mean_within_dist2
    return energy_dist

In [30]:
# Linear Maximum Mean Discrepancy function
def linear_mmd(X, Y):
    """
    Compute the linear Maximum Mean Discrepancy (MMD) between two distributions X and Y.
    """
    return np.abs(np.mean(X, axis=0) - np.mean(Y, axis=0)).mean()

In [31]:
# Initialize lists to store metrics
euclidean_dist_control = []
cosine_sim_control = []
mae_control = []
mse_control = []
wasserstein_dist_control = []
mmd_control = []
edist_control = []



# Compute metrics between disjoint sets
for i in range(len(disjoint_adata_sets)):
    # Initialize accumulators for each metric
    euclidean_distance = 0
    cosine_similarity = 0
    mae_distance = 0
    mse_distance = 0
    wasserstein_distance_accum = 0
    mmd_distance = 0
    energy_distance = 0
    
    for j in range(len(disjoint_adata_sets)):
        if i != j:
            ad1 = disjoint_adata_sets[i]
            ad2 = disjoint_adata_sets[j]
            
            centroid_1 = np.mean(ad1.X, axis=0)
            centroid_2 = np.mean(ad2.X, axis=0)
            
            # Calculate metrics
            euclidean_distance += euclidean(centroid_1, centroid_2)
            cosine_similarity += cosine(centroid_1, centroid_2)  # Cosine similarity
            mae_distance += mean_absolute_error(centroid_1, centroid_2)
            mse_distance += mean_squared_error(centroid_1, centroid_2)
            wasserstein_distance_accum += wasserstein_distance(centroid_1, centroid_2)
            mmd_distance += linear_mmd(ad1.X, ad2.X)
            energy_distance += get_energy_distance(ad1.X, ad2.X)
    
    # Average metrics over all other sets
    euclidean_dist_control.append(euclidean_distance / (len(disjoint_adata_sets) - 1))
    cosine_sim_control.append(cosine_similarity / (len(disjoint_adata_sets) - 1))
    mae_control.append(mae_distance / (len(disjoint_adata_sets) - 1))
    mse_control.append(mse_distance / (len(disjoint_adata_sets) - 1))
    wasserstein_dist_control.append(wasserstein_distance_accum / (len(disjoint_adata_sets) - 1))
    mmd_control.append(mmd_distance / (len(disjoint_adata_sets) - 1))
    edist_control.append(energy_distance / (len(disjoint_adata_sets)) - 1)

For every combination of control - perturbed populations, calculate distance

In [32]:
# Initialize lists to store all metrics
all_euclidean_dist = []
all_cosine_sim = []
all_mae_dist = []
all_mse_dist = []
all_wasserstein_dist = []
all_mmd_dist = []
all_energy_dist = []

# Iterate over each product-dose combination
for product_dose in tqdm(list(adata.obs['product_dose'].unique())):
    if product_dose == "Vehicle_0":
        continue

    # Initialize accumulators for the distances
    dist_euclidean_mean = 0
    dist_cosine_mean = 0
    dist_mae_mean = 0
    dist_mse_mean = 0
    dist_wasserstein_mean = 0
    dist_mmd_mean = 0
    dist_energy_mean = 0

    # Perturbed data for the current product-dose
    pert_adata = adata[adata.obs['product_dose'] == product_dose]
    centroid_pert = np.mean(pert_adata.X, axis=0)

    # Compare with control sets
    for ctrl_adata in disjoint_adata_sets:
        centroid_ctrl = np.mean(ctrl_adata.X, axis=0)

        # Calculate metrics
        dist_euclidean_mean += euclidean(centroid_ctrl, centroid_pert)
        dist_cosine_mean += cosine(centroid_ctrl, centroid_pert)
        dist_mae_mean += mean_absolute_error(centroid_ctrl, centroid_pert)
        dist_mse_mean += mean_squared_error(centroid_ctrl, centroid_pert)
        dist_wasserstein_mean += wasserstein_distance(centroid_ctrl, centroid_pert)
        dist_mmd_mean += linear_mmd(ctrl_adata.X, pert_adata.X)
        dist_energy_mean += get_energy_distance(ctrl_adata.X, pert_adata.X)

    # Average the distances across control sets
    num_ctrl_sets = len(disjoint_adata_sets)
    
    all_euclidean_dist.append(dist_euclidean_mean / num_ctrl_sets)
    all_cosine_sim.append(dist_cosine_mean / num_ctrl_sets)
    all_mae_dist.append(dist_mae_mean / num_ctrl_sets)
    all_mse_dist.append(dist_mse_mean / num_ctrl_sets)
    all_wasserstein_dist.append(dist_wasserstein_mean / num_ctrl_sets)
    all_mmd_dist.append(dist_mmd_mean / num_ctrl_sets)
    all_energy_dist.append(dist_energy_mean / num_ctrl_sets)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 741/741 [27:09<00:00,  2.20s/it]


In [33]:
def get_distance_stats(within_ctrl_distances, ctr_pert_distances):
    CRP_list = list()

    for ctrl_ctrl_dist in within_ctrl_distances:
    
        no_smaller = 0
        for ctrl_pert_dist in ctr_pert_distances:
            if ctrl_ctrl_dist > ctrl_pert_dist:
                no_smaller += 1
                
        crp_fold = no_smaller / len(all_euclidean_dist)
        CRP_list.append(crp_fold)

    ROBUSTNESS = 1 / (1 + np.var(CRP_list))
    CRP = 1 - float(np.mean(CRP_list))

    return CRP, ROBUSTNESS
    

In [34]:
crp, robustness = get_distance_stats(euclidean_dist_control, all_euclidean_dist)
print("Euclidean Distance - CRP:", crp, "Robustness:", robustness)

Euclidean Distance - CRP: 0.9886486486486487 Robustness: 0.999919947972025


In [35]:
crp, robustness = get_distance_stats(cosine_sim_control, all_cosine_sim)
print("Cosine Distance - CRP:", crp, "Robustness:", robustness)

Cosine Distance - CRP: 0.9902702702702703 Robustness: 0.9999653773857647


In [36]:
crp, robustness = get_distance_stats(mae_control, all_mae_dist)
print("Mean Absolute Error - CRP:", crp, "Robustness:", robustness)

Mean Absolute Error - CRP: 0.9905405405405405 Robustness: 0.9999649391401909


In [37]:
crp, robustness = get_distance_stats(mse_control, all_mse_dist)
print("Mean Squared Error - CRP:", crp, "Robustness:", robustness)

Mean Squared Error - CRP: 0.9881081081081081 Robustness: 0.9999018357876335


In [38]:
crp, robustness = get_distance_stats(wasserstein_dist_control, all_wasserstein_dist)
print("Wasserstein Distance - CRP:", crp, "Robustness:", robustness)

Wasserstein Distance - CRP: 0.9127027027027027 Robustness: 0.9922464168959922


In [39]:
crp, robustness = get_distance_stats(mmd_control, all_mmd_dist)
print("Mean Maximum Discrepancy - CRP:", crp, "Robustness:", robustness)

Mean Maximum Discrepancy - CRP: 0.9905405405405405 Robustness: 0.9999649391401909


In [40]:
crp, robustness = get_distance_stats(edist_control, all_energy_dist)
print("Mean Maximum Discrepancy - CRP:", crp, "Robustness:", robustness)

Mean Maximum Discrepancy - CRP: 1.0 Robustness: 1.0


In [44]:
ad_control = adata[adata.obs['product_name'] == "Vehicle"]
centroid_control = np.mean(ad_control.X, axis=0)

correlations_euclidean = list()
correlations_cosine = list()
correlations_mae = list()
correlations_mse = list()
correlations_wasserstein = list()
correlations_mmd = list()
#correlations_energy = list()

for drug in tqdm(list(adata.obs['product_name'].unique())):
    if drug == "Vehicle":
        continue

    ad_perturb = adata[adata.obs['product_name'] == drug]

    doses = list()
    
    distances_euclidean = list()
    distances_cosine = list()
    distances_mae = list()
    distances_mse = list()
    distances_wasserstein = list()
    distances_mmd = list()
    #distances_energy = list()
    
    for dose in list(ad_perturb.obs['dose'].unique()):
        
        ad_perturb_dose = ad_perturb[ad_perturb.obs['dose'] == dose]
    
        centroid_perturbation = np.mean(ad_perturb_dose.X, axis=0)

        doses.append(dose)

        #calculate distances
        dist_euclidean = euclidean(centroid_control, centroid_perturbation)
        dist_cosine_mean = cosine(centroid_control, centroid_perturbation)
        dist_mae_mean = mean_absolute_error(centroid_control, centroid_perturbation)
        dist_mse_mean = mean_squared_error(centroid_control, centroid_perturbation)
        dist_wasserstein_mean = wasserstein_distance(centroid_control, centroid_perturbation)
        dist_mmd_mean = linear_mmd(ad_control.X, ad_perturb_dose.X)
        #dist_energy_mean = get_energy_distance(ad_control.X, ad_perturb_dose.X)

        #append to corresponding list
        distances_euclidean.append(dist_euclidean)
        distances_cosine.append(dist_cosine_mean)
        distances_mae.append(dist_mae_mean)
        distances_mse.append(dist_mse_mean)
        distances_wasserstein.append(dist_wasserstein_mean)
        distances_mmd.append(dist_wasserstein_mean)
        #distances_energy.append(dist_energy_mean)


    
    corr_euclidean, _ = spearmanr(doses, distances_euclidean)
    corr_cosine, _ = spearmanr(doses, distances_cosine)
    corr_mae, _ = spearmanr(doses, distances_mae)
    corr_mse, _ = spearmanr(doses, distances_mse)
    corr_mmd, _ = spearmanr(doses, distances_mmd)
    #corr_energy, _ = spearmanr(doses, distances_energy)

    correlations_euclidean.append(corr_euclidean)
    correlations_cosine.append(corr_cosine)
    correlations_mae.append(corr_mae)
    correlations_mse.append(corr_mse)
    correlations_mmd.append(corr_mmd)
    #correlations_energy.append(corr_energy)

print("Bio Rep Scores:")
print("Euclidean", np.mean(correlations_euclidean))
print("Cosine", np.mean(correlations_cosine))
print("MAE", np.mean(correlations_mae))
print("MSE", np.mean(correlations_mse))
print("MMD", np.mean(correlations_mmd))
#print("Energy", np.mean(correlations_energy))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 186/186 [00:06<00:00, 27.50it/s]

Bio Rep Scores:
Euclidean 0.6345945945945946
Cosine 0.6345945945945947
MAE 0.6389189189189189
MSE 0.6345945945945946
MMD 0.5254054054054054



