In [None]:
#parameters
output_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/cell_contacts_15um"
path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_PFV8_annotated_v5.h5ad"
Np = 1000
r_permute = 100
r_test = 15
alpha = 0.05
min_contacts = 50
cell_type_col_subclass="Subclass"
cell_type_col_group="Group"

In [None]:
from pathlib import Path
import itertools 

import numpy as np
import pandas as pd
import anndata as ad
from statsmodels.stats.multitest import multipletests
from scipy.stats import norm
from tqdm import tqdm

import multiprocessing as mp
from functools import partial

In [None]:
# path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_PFV8_annotated_v5.h5ad"
adata = ad.read_h5ad(path)
adata

In [None]:
output_path = Path(output_path)
output_path.mkdir(parents=True, exist_ok=True)

In [None]:
donors = adata.obs['donor'].unique().tolist()
brain_regions = adata.obs['brain_region'].unique().tolist()
replicates = adata.obs['replicate'].unique().tolist()
slices = list(itertools.product(donors, brain_regions, replicates))
slices.remove(("UWA7648", "CAT", "salk"))
slices.remove(("UWA7648", "CAT", "ucsd"))

In [None]:
def get_cells_in_radius(adata, center, radius, cols=['CENTER_X', 'CENTER_Y']):
    return adata[adata.obs[cols].apply(lambda x: (x[cols[0]] - center[0])**2 + (x[cols[1]] - center[1])**2 < radius**2, axis=1)].copy()


def get_cell_by_cell_contacts(
    data : ad.AnnData | pd.DataFrame, 
    cell_type_col = "Subclass", 
    spatial_keys = ['center_x', 'center_y'],
    cell_type_list = None,
    radius = 50,
): 
    if isinstance(data, ad.AnnData):
        data = data.obs.copy()

    if cell_type_list is None:
        cell_type_list = np.unique(data[cell_type_col])
    N_cell_types = len(cell_type_list)
    contact_counts = np.zeros((N_cell_types, N_cell_types), dtype=int)

    coords = data[spatial_keys].values
    cell_types = data[cell_type_col].values
    cell_type_to_idx = {ct: i for i, ct in enumerate(cell_type_list)}

    for i in range(data.shape[0]):
        ct_i = cell_types[i]
        idx_i = cell_type_to_idx[ct_i]
        coord_i = coords[i]
        
        dists = np.linalg.norm(coords - coord_i, axis=1)
        neighbors = np.where((dists > 1e-4) & (dists <= radius))[0]
        
        for j in neighbors:
            ct_j = cell_types[j]
            idx_j = cell_type_to_idx[ct_j]
            contact_counts[idx_i, idx_j] += 1
    
    # Normalize by number of cells per cell type to make cell type specific
    N_cells = data.groupby(cell_type_col, observed=True).size().to_dict()
    divide = np.asarray([N_cells[i] for i in cell_type_list])
    norm_counts = contact_counts / divide

    return norm_counts, cell_type_list
    
# functions from xingjiepan 2023 mouse atlas paper
def adjust_p_value_matrix_by_BH(p_val_mtx):
    '''Adjust the p-values in a matrix by the Benjamini/Hochberg method.
    The matrix should be symmetric.
    '''
    p_val_sequential = []
    N = p_val_mtx.shape[0]
    
    for i in range(N):
        for j in range(i, N):
            p_val_sequential.append(p_val_mtx[i, j])

    p_val_sequential_bh = multipletests(p_val_sequential, method='fdr_bh')[1]
    
    adjusted_p_val_mtx = np.zeros((N, N))
    
    counter = 0
    for i in range(N):
        for j in range(i, N):
            adjusted_p_val_mtx[i, j] = p_val_sequential_bh[counter]
            adjusted_p_val_mtx[j, i] = p_val_sequential_bh[counter]
            counter += 1
            
    return adjusted_p_val_mtx

def get_data_frame_from_metrices(cell_types, mtx_dict):
    N = len(cell_types)
    
    serials_dict = {'cell_type1':[], 'cell_type2':[]}
    for k in mtx_dict.keys():
        serials_dict[k] = []
        
    for i in range(N):
        for j in range(N):
            serials_dict['cell_type1'].append(cell_types[i])
            serials_dict['cell_type2'].append(cell_types[j])
            for k in mtx_dict.keys():
                serials_dict[k].append(mtx_dict[k][i, j])
                
    return pd.DataFrame(serials_dict)
    

def sort_cell_type_contact_p_values(p_val_mtx, cell_types):
    '''Return a list of (cell_type1, cell_type2, p_value) sorted by p_values.'''
    p_val_list = []
    N = p_val_mtx.shape[0]
    for i in range(N):
        for j in range(N):
            p_val_list.append((cell_types[i], cell_types[j], p_val_mtx[i, j]))
    return sorted(p_val_list, key=lambda x:x[2])

In [None]:
def _permute_and_get_contacts(i, adata_sub, r_permute, r_test, cell_types=None):
    np.random.seed()
    df_slide = adata_sub.obs.copy()
    r = r_permute * np.sqrt(np.random.uniform(size=df_slide.shape[0]))
    theta = np.random.uniform(size=df_slide.shape[0]) * 2 * np.pi
    df_slide['CENTER_X'] = df_slide['CENTER_X'] + r * np.cos(theta)
    df_slide['CENTER_Y'] = df_slide['CENTER_Y'] + r * np.sin(theta)
    contacts = get_cell_by_cell_contacts(df_slide, cell_type_col=cell_type_col_group, spatial_keys=['CENTER_X', 'CENTER_Y'], cell_type_list=cell_types, radius=r_test)
    return contacts[0]

def _get_contacts_for_slice(
    _slice,
    adata,
    Np = 1000,
    r_permute = 100,
    r_test = 15,
    alpha = 0.05,
    min_contacts = 50,
): 
    # parameters: 
    # adata: AnnData object containing all data
    # _slice: tuple of (donor, brain_region, replicate)
    # radius: radius for contact calculation
    # Np: number of permutations
    # r_permute: max radius for permutation
    #

    donor, brain_region, replicate = _slice
    adata_sub = adata[(adata.obs['donor'] == donor) &
                      (adata.obs['brain_region'] == brain_region) &
                      (adata.obs['replicate'] == replicate)].copy()
    adata_sub = adata_sub[(adata_sub.obs[cell_type_col_group] != "unknown")].copy()
    cell_types = np.unique(adata_sub.obs[cell_type_col_group])
    N_cell_types = len(cell_types)
    real_contacts, _ = get_cell_by_cell_contacts(adata_sub, cell_type_col=cell_type_col_group, spatial_keys=['CENTER_X', 'CENTER_Y'], cell_type_list=cell_types, radius=r_test)
    # merged_contact_counts = np.zeros((Np, N_cell_types, N_cell_types), dtype=int)
    
    with mp.Pool(mp.cpu_count()) as pool: 
        ret_list = pool.map(partial(_permute_and_get_contacts,
                                    adata_sub=adata_sub,
                                    r_permute=r_permute,
                                    r_test=r_test,
                                    cell_types=cell_types),
                             range(Np))
    merged_contact_counts = np.array(ret_list)
    # for i in range(Np): 
    #     contact = _permute_and_get_contacts(i, adata_sub, r_permute, r_test)
    #     merged_contact_counts[i, :, :] = contact

    null_contacts_mean = np.mean(merged_contact_counts, axis=0)
    null_contacts_std = np.std(merged_contact_counts, axis=0)

    np.save(Path(output_path) / f"contact_counts_real_{donor}_{brain_region}_{replicate}_{r_test}um.npy", real_contacts)
    np.save(Path(output_path) / f"contact_counts_permuted_{donor}_{brain_region}_{replicate}_{r_test}um.npy", merged_contact_counts)
    np.save(Path(output_path) / f"contact_counts_permuted_mean_{donor}_{brain_region}_{replicate}_{r_test}um.npy", null_contacts_mean)
    np.save(Path(output_path) / f"contact_counts_permuted_std_{donor}_{brain_region}_{replicate}_{r_test}um.npy", null_contacts_std)

    null_contacts_std = np.maximum(null_contacts_std, np.sqrt(1/1000))
    permuted_z_score = (real_contacts - null_contacts_mean) / null_contacts_std
    local_p_values = norm.sf(np.abs(permuted_z_score))
    adjusted_local_p_value = adjust_p_value_matrix_by_BH(local_p_values)
    fold_changes = real_contacts / (null_contacts_mean + 1e-6)
    # Gather all results into a data frame
    contact_result_df = get_data_frame_from_metrices(cell_types, 
                                                    {'pval-adjusted': adjusted_local_p_value,
                                                    'pval': local_p_values,
                                                    'z_score': permuted_z_score,
                                                    'contact_count': real_contacts,
                                                    'permutation_mean': null_contacts_mean,
                                                    'permutation_std': null_contacts_std,
                                                    'fold-change' : fold_changes,
                                            }).sort_values('z_score', ascending=False)
    contact_result_df['id'] = f"{donor}_{brain_region}_{replicate}"
    # contact_result_df = contact_result_df[contact_result_df['pval-adjusted'] < alpha]
    # contact_result_df = contact_result_df[contact_result_df['contact_count'] > min_contacts]
    contact_result_df.to_csv(Path(output_path) / f"cell_contacts_{donor}_{brain_region}_{replicate}_{r_test}um.csv", index=False)
    return contact_result_df

In [None]:
function_args = {}
function_args['Np'] = Np
function_args['r_permute'] = r_permute
function_args['r_test'] = r_test
function_args['alpha'] = alpha
function_args['min_contacts'] = min_contacts
function_args['adata'] = adata

# parallel_func = partial(_get_contacts_for_slice, **function_args)
contacts_list = []
for _slice in tqdm(slices): 
    contacts_list.append(_get_contacts_for_slice(_slice, **function_args))

# with mp.Pool(4) as pool:
#     contacts_list = list(  # noqa: F841
#         tqdm(
#             pool.imap_unordered(parallel_func, slices),
#             total=len(slices),
#         )
#     )

In [None]:
pd.concat(contacts_list).to_csv(Path(output_path) / f"cell_type_contacts_{r_test}um_all_slices.csv", index=False)