In [9]:
import anndata as ad
import pertpy as pt
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from sklearn.metrics import pairwise_distances
from tqdm import tqdm
from itertools import combinations

In [3]:
adata = ad.read_h5ad("../../data/sciplex/sciplex3_uce_adata.h5ad")

with open("../../data/sciplex/drugs_validation_list.txt") as f:
 test_compounds = [line.strip() for line in f]

#exclude control cells, training compounds
adata = adata[adata.obs['product_name'] != "Vehicle"]
adata = adata[adata.obs['product_name'].isin(test_compounds)]

In [4]:
adata

View of AnnData object with n_obs × n_vars = 137979 × 17376
    obs: 'cell_type', 'dose', 'dose_character', 'dose_pattern', 'g1s_score', 'g2m_score', 'pathway', 'pathway_level_1', 'pathway_level_2', 'product_dose', 'product_name', 'proliferation_index', 'replicate', 'size_factor', 'target', 'vehicle', 'n_genes'
    var: 'id', 'num_cells_expressed-0-0', 'num_cells_expressed-1-0', 'num_cells_expressed-1', 'n_cells'
    obsm: 'X_uce'

In [10]:
def calculate_edistance(X, Y):
    """
    Calculate edistances between two matrices
    """
    sigma_X = pairwise_distances(X, X, metric="sqeuclidean").mean()
    sigma_Y = pairwise_distances(Y, Y, metric="sqeuclidean").mean()
    delta = pairwise_distances(X, Y, metric="sqeuclidean").mean()
    return 2 * delta - sigma_X - sigma_Y

In [12]:
#A549
def get_pariwise_edistances(adata):
    e_distances = list()

    compounds = list(adata.obs['product_name'].unique())

    for pair in tqdm(list(combinations(compounds, 2))):
        ad1 = adata[adata.obs['product_name'] == pair[0]]
        ad2 = adata[adata.obs['product_name'] == pair[1]]

        edist = calculate_edistance(ad1.obsm['X_uce'], ad2.obsm['X_uce'])
        e_distances.append(edist)


    return e_distances

In [24]:
def plot_edist_distribution(edist_list, title):
    mean_value = np.mean(pw_dist_A549)
    # Plot KDE
    plt.figure(figsize=(8, 5))
    sns.histplot(pw_dist_A549, bins=30, color="blue", alpha=0.5)
    
    # Plot mean as a vertical line
    plt.axvline(mean_value, color="red", linestyle="--", label=f"Mean: {mean_value:.4f}")
    
    # Labels and title
    plt.xlabel("Value")
    plt.ylabel("Density")
    plt.title(title)
    plt.legend()
    
    # Show plot
    plt.show()

In [None]:
results = list()
for cell_type in list(adata.obs['cell_type'].unique()):
    for dose in list(adata.obs['dose'].unique()):
    
        adata_subset = adata[adata.obs['cell_type'] == cell_type]
        adata_subset = adata_subset[adata_subset.obs['dose'] == dose]

        title = cell_type + "_" + str(dose)

        pw_dist_A549 = get_pariwise_edistances(adata_subset)

        results.append({title: np.mean(pw_dist_A549)})
        #plot_edist_distribution(adata_subset, title)

In [29]:
for l in results:
    print(l)

{'A549_1000.0': 0.04998170317530318}
{'A549_100.0': 0.03311512477407399}
{'A549_10.0': 0.016399602604891154}
{'A549_10000.0': 0.09285112448817988}
{'MCF7_1000.0': 0.04512665252011626}
{'MCF7_100.0': 0.029308987419350757}
{'MCF7_10.0': 0.01338384093676058}
{'MCF7_10000.0': 0.09148223034269391}
{'K562_1000.0': 0.050168518140024945}
{'K562_100.0': 0.039993097427767146}
{'K562_10.0': 0.017432546783475774}
{'K562_10000.0': 0.08976637647921036}
