In [2]:
%load_ext autoreload
%autoreload 2


In [57]:
import glob
import pandas as pd
import torch
import numpy as np
import scanpy as sc
import anndata
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score


# Load the embeddings and ground truth data

In [81]:
DATASET = "Lymph Node"

if DATASET == "Breast Cancer":
    ADATA_PATH = 'data/V1_Breast_Cancer_Block_A_Section_1/breast_cancer.h5ad'
    GT_PATH = 'data/V1_Breast_Cancer_Block_A_Section_1/cpdb_scores/thresholded_interaction_matrix.tsv'
    dataset_name = "breast_cancer"
elif DATASET == "Lymph Node":
    ADATA_PATH = 'data/V1_Human_Lymph_Node/lymph_node.h5ad'
    GT_PATH = 'data/V1_Human_Lymph_Node/cpdb_scores/thresholded_interaction_matrix.tsv'
    dataset_name = "lymph_node"


In [82]:
# Load the cell types
adata = anndata.read_h5ad(ADATA_PATH)
cell_types = adata.obs['leiden'].values.astype(int)

# Load the ground truth data
ground_truth = pd.read_csv(GT_PATH, sep='\t', index_col=0).values.astype(int)

In [99]:
accuracy_scores = []
f1_scores = []

for embedding_path in glob.glob(f'training_output/{dataset_name}_seed_*'):
    seed = embedding_path.split('_')[-1]
    embeddings = torch.load(embedding_path + '/final_embeddings.pt').numpy()
    embedding_pairs = []
    interaction_labels = []

    num_samples = 11000
    indices = torch.randint(0, embeddings.shape[0], (num_samples, 2))

    embedding_pairs = []
    interaction_labels = []

    for idx in indices:
        i, j = idx
        cell_type_i = cell_types[i]
        cell_type_j = cell_types[j]
        interaction = ground_truth[cell_type_i, cell_type_j]
        
        concatenated_embedding = np.concatenate((embeddings[i], embeddings[j]))
        
        embedding_pairs.append(concatenated_embedding)
        interaction_labels.append(interaction)

    embedding_pairs = np.stack(embedding_pairs)
    interaction_labels = np.array(interaction_labels)

    # Create train/test split
    train_pairs, test_pairs, train_labels, test_labels = train_test_split(
        embedding_pairs, interaction_labels, test_size=(1./11), random_state=42
    )

    # Flatten the embeddings for logistic regression
    train_pairs_flat = train_pairs.reshape(train_pairs.shape[0], -1)
    test_pairs_flat = test_pairs.reshape(test_pairs.shape[0], -1)

    # Initialize the logistic regression model
    log_reg = LogisticRegression(max_iter=1000)

    # Train the model
    log_reg.fit(train_pairs_flat, train_labels)

    # Predict on the test set
    test_predictions = log_reg.predict(test_pairs_flat)

    # Calculate accuracy
    accuracy = accuracy_score(test_labels, test_predictions)
    f1 = f1_score(test_labels, test_predictions)

    print(f"Seed {seed} | Test Accuracy: {accuracy:.4f} | Test F1 Score: {f1:.4f}")

    accuracy_scores.append(accuracy)
    f1_scores.append(f1)

print(f"Average Test Accuracy: {np.mean(accuracy_scores):.4f} ± {np.std(accuracy_scores):.4f}")
print(f"Average Test F1 Score: {np.mean(f1_scores):.4f} ± {np.std(f1_scores):.4f}")


Seed 42 | Test Accuracy: 0.8560 | Test F1 Score: 0.8583
Seed 21 | Test Accuracy: 0.8580 | Test F1 Score: 0.8605
Seed 13 | Test Accuracy: 0.8300 | Test F1 Score: 0.8365
Average Test Accuracy: 0.8480 ± 0.0128
Average Test F1 Score: 0.8518 ± 0.0108


# Compare against logistic regression model on expression data

In [88]:
sc.pp.highly_variable_genes(adata, flavor='cell_ranger', n_top_genes=2000, subset=True)
adata

AnnData object with n_obs × n_vars = 4022 × 2000
    obs: 'in_tissue', 'array_row', 'array_col', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'leiden'
    var: 'gene_ids', 'feature_types', 'genome', 'n_cells', 'mt', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'mean', 'std'
    uns: 'hvg', 'leiden', 'leiden_colors', 'log1p', 'neighbors', 'pca', 'spatial', 'umap'
    obsm: 'X_pca', 'X_umap', 'spatial'
    varm: 'PCs'
    layers: 'log_norm'
    obsp: 'connectivities', 'distances'

In [89]:
accuracy_scores = []
f1_scores = []

for i in range(3):
    embeddings = adata.X.toarray()
    embedding_pairs = []
    interaction_labels = []

    num_samples = 11000
    indices = torch.randint(0, embeddings.shape[0], (num_samples, 2))

    embedding_pairs = []
    interaction_labels = []

    for idx in indices:
        i, j = idx
        cell_type_i = cell_types[i]
        cell_type_j = cell_types[j]
        interaction = ground_truth[cell_type_i, cell_type_j]
        
        concatenated_embedding = np.concatenate((embeddings[i], embeddings[j]))
        
        embedding_pairs.append(concatenated_embedding)
        interaction_labels.append(interaction)

    embedding_pairs = np.stack(embedding_pairs)
    interaction_labels = np.array(interaction_labels)

    # Create train/test split
    train_pairs, test_pairs, train_labels, test_labels = train_test_split(
        embedding_pairs, interaction_labels, test_size=(1./11), random_state=42
    )

    # Flatten the embeddings for logistic regression
    train_pairs_flat = train_pairs.reshape(train_pairs.shape[0], -1)
    test_pairs_flat = test_pairs.reshape(test_pairs.shape[0], -1)

    # Initialize the logistic regression model
    log_reg = LogisticRegression(max_iter=1000)

    # Train the model
    log_reg.fit(train_pairs_flat, train_labels)

    # Predict on the test set
    test_predictions = log_reg.predict(test_pairs_flat)

    # Calculate accuracy
    accuracy = accuracy_score(test_labels, test_predictions)
    f1 = f1_score(test_labels, test_predictions)

    print(f"Run {i} | Test Accuracy: {accuracy:.4f} | Test F1 Score: {f1:.4f}")

    accuracy_scores.append(accuracy)
    f1_scores.append(f1)

print(f"Average Test Accuracy: {np.mean(accuracy_scores):.4f} ± {np.std(accuracy_scores):.4f}")
print(f"Average Test F1 Score: {np.mean(f1_scores):.4f} ± {np.std(f1_scores):.4f}")


Run 46 | Test Accuracy: 0.8220 | Test F1 Score: 0.8282
Run 3145 | Test Accuracy: 0.8260 | Test F1 Score: 0.8214
Run 3540 | Test Accuracy: 0.8240 | Test F1 Score: 0.8163
Average Test Accuracy: 0.8240 ± 0.0016
Average Test F1 Score: 0.8219 ± 0.0049
