In [1]:
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import time
from itertools import combinations, chain

from scipy.sparse import csr_matrix
from scipy.spatial.distance import cdist
from sklearn.metrics import pairwise_distances


In [None]:
fpath = "/scratch/indikar_root/indikar1/cstansbu/HSC/geneformer_adata/merged.anndata.h5ad"

adata = sc.read_h5ad(fpath)
sc.logging.print_memory_usage()
adata.X = csr_matrix(adata.X) 
sc.logging.print_memory_usage()
adata

In [None]:
adata.obs['cell_type'].value_counts().sort_index()

# Clean up cell types

In [None]:
cell_type_map = {
    'B': 'B_cell',
    'CD4': 'T_cell',
    'CD8': 'T_cell',
    'CLP': 'CLP',
    'CMP': 'CMP',
    'EryP': 'EryP',
    'FB': 'Fib',
    'GMP': 'GMP',
    'HSC': 'HSC',
    'LMPP': 'LMPP',
    'LinNegCD34NegCD164high': 'LinNeg',
    'LinNegCD34NegCD164low': 'LinNeg',
    'LinNegCD34PosCD164Pos': 'LinNeg',
    'LinNegCD34lowCD164high': 'LinNeg',
    'MDP': 'MDP',
    'MEP': 'MEP',
    'MKP': 'MKP',
    'MLP': 'MLP',
    'MPP': 'MPP',
    'Mono': 'Mono',
    'NK': 'NK',
    'Plasma': 'B_cell',
    'PreBNK': 'PreBNK',
    'ProB': 'B_cell',
    'Refined.HSC': 'HSC',
    'cDC': 'Dendritic_cell',
    'iHSC': 'iHSC',
    'pDC': 'Dendritic_cell',
}

adata.obs['cell_type_standard'] = adata.obs['cell_type'].map(cell_type_map)
adata.obs['cell_type_standard'].value_counts().sort_index()

# Gene set selection

In [None]:
gene_mask, counts = sc.pp.filter_genes(
    adata, 
    min_cells=10, 
    inplace=False,
)

sc.pp.highly_variable_genes(
    adata,
    n_top_genes=1000,
    batch_key='dataset',
    flavor='seurat_v3',
)

selected_genes = adata.var[gene_mask]
print(f"{selected_genes.shape=}")

print(f"Number highly variable genes: {adata.var['highly_variable'].sum()}")

selected_genes.head()

# Reprocessing and PCA

In [None]:
break

# Preprocessing

In [None]:
def process_anndata(adata, gene_list=None):
    """
    Processes an AnnData object by normalizing and converting its data to a sparse DataFrame.

    Args:
        adata: The AnnData object to process.

    Returns:
        A sparse DataFrame containing the processed data.
    """
    if not gene_list is None:
        adata = adata[:, gene_list].copy()
    else:
        adata = adata.copy()

    # Normalize the AnnData object (creates a copy internally)
    sc.pp.normalize_total(adata, target_sum=1e3)

    # Convert the data to a sparse CSR matrix
    sparse_matrix = csr_matrix(adata.X)

    # Create a sparse DataFrame from the sparse matrix
    sparse_df = pd.DataFrame.sparse.from_spmatrix(
        sparse_matrix,
        index=adata.obs_names,
        columns=adata.var_names
    )

    return sparse_df


gene_list = adata.var[gene_mask & adata.var['highly_variable']]['gene_name'].to_list()
total_mem = 0
matrix = {} # for the results
for cell_type, group in adata.obs.groupby('cell_type_standard', observed=True):
    
    group_idx = group.index
    X = process_anndata(adata[group_idx, :], gene_list=gene_list)
    
    memory_usage = X.memory_usage(deep=True).sum() / 1024**3
    total_mem += memory_usage
    
    print(f"{cell_type=} {X.shape=} ({memory_usage:.2f} Gb)")
    matrix[cell_type] = X
    
print('done.')
print(f'Total Mem: {total_mem:.2f} Gb')

In [None]:
# break

# Distances

In [None]:
# metric = 'cosine'
metric = 'euclidean'
results = []

start_time = time.time()  # Start timing the entire process

def get_pairs(items):
    pairs = combinations(items, 2)
    self_loops = [(x, x) for x in items]
    return list(chain(pairs, self_loops))

cell_types = [
    'iHSC',
    'HSC',
    'MEP',
    'MKP',
    'MPP',
    'MLP',
    'Fib',
    'CMP',
    'CLP',
    'GMP',
]

# cell_types = list(matrix.keys())
key_pairs = get_pairs(cell_types)
print(f"Considered comparisons: {len(key_pairs)}")


distances = {}

for i, j in key_pairs:
    Xi = matrix[i]
    Xj = matrix[j]
    
    iter_start_time = time.time()  # Start timing this iteration

    D = pairwise_distances(Xi, Xj, metric=metric)
    # D = pd.DataFrame(D, index=Xi.index, columns=Xj.index)

    iter_end_time = time.time()  # End timing this iteration
    print(f"Iteration ({i}, {j}) took {iter_end_time - iter_start_time:.2f} seconds")

    row = {
        'cell_i': i,
        'cell_j': j,
        'N_i': Xi.shape[0],
        'N_j': Xj.shape[0],
        'mean_distance': np.mean(D),
        'std_distance': np.std(D),
        'median_distance': np.median(D),
        'min_distance': np.min(D),
        'max_distance': np.max(D),
        'seconds' : iter_end_time - iter_start_time,
    }
    results.append(row)
    # break  # Remove this break if you want to process all pairs

results = pd.DataFrame(results)

end_time = time.time()  # End timing the entire process
print(f"Total processing time: {end_time - start_time:.2f} seconds")

results.head()

In [None]:
# break

# visualization

In [None]:
A = pd.pivot_table(
    results, 
    index='cell_i',
    columns='cell_j',
    values='median_distance',
    fill_value=0,
)

A = (A.T + A) / 2

print(f"{A.shape=}")

plt.rcParams['figure.dpi'] = 300
plt.rcParams['figure.figsize'] = 10, 10

ax = sns.heatmap(
    A,
    square=True,
    linecolor='k',
    lw=1,
    fmt=".3f",
    annot=True,
    # center=True,
    cmap='plasma',
    cbar_kws={'shrink': 0.25, 'label' : f'{metric.title()} Distance'}
)

plt.ylabel("")
plt.xlabel("")

ax.tick_params(axis='y', rotation=0)
    
plt.show()

In [None]:
5 ** 5

In [None]:
break

In [None]:
break