In [None]:
import os
from pathlib import Path
import itertools 

import numpy as np
import pandas as pd
import anndata as ad
from statsmodels.stats.multitest import multipletests
from scipy.stats import norm

import multiprocessing as mp
from functools import partial

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
from spida.pl import plot_categorical, plot_continuous
plt.rcParams['figure.dpi'] = 200
plt.rcParams['axes.facecolor'] = 'white'

In [None]:
path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPSAM_annotated_v2.h5ad"
# new_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CP_annotated_v1.h5ad"
adata = ad.read_h5ad(path)
# adata.write_h5ad(new_path)

In [None]:
adata.obs['brain_region'].unique()

In [None]:
donors = adata.obs['donor'].unique().tolist()
brain_regions = adata.obs['brain_region'].unique().tolist()
replicates = adata.obs['replicate'].unique().tolist()
slices = list(itertools.product(donors, brain_regions, replicates))
slices.remove(("UWA7648", "CAT", "salk"))
slices.remove(("UWA7648", "CAT", "ucsd"))

In [None]:
_donor, _brain_region, _replicate = "UCI5224", "PU", "salk"

In [None]:
adata_sub = adata[(adata.obs['donor'] == _donor) &
                  (adata.obs['brain_region'] == _brain_region) &
                  (adata.obs['replicate'] == _replicate)].copy()

In [None]:
# adata_sub = adata_sub[(adata_sub.obs['AIT_Subclass'] != "unknown") & (adata_sub.obs['AIT_Group'] != "unknown")].copy()

In [None]:
adata_sub

# Permutation Testing

### functions

In [None]:
# For distance between centers, need to implement a version of this using the true geometries. 
def get_cells_in_radius(adata, center, radius, cols=['CENTER_X', 'CENTER_Y']):
    return adata[adata.obs[cols].apply(lambda x: (x[cols[0]] - center[0])**2 + (x[cols[1]] - center[1])**2 < radius**2, axis=1)].copy()


def get_cell_by_cell_contacts(
    data : ad.AnnData | pd.DataFrame, 
    cell_type_col = "Subclass", 
    spatial_keys = ['center_x', 'center_y'],
    cell_type_list = None,
    radius = 50,
): 
    if isinstance(data, ad.AnnData):
        data = data.obs.copy()

    if cell_type_list is None:
        cell_type_list = np.unique(data[cell_type_col])
    N_cell_types = len(cell_type_list)
    contact_counts = np.zeros((N_cell_types, N_cell_types), dtype=int)

    coords = data[spatial_keys].values
    cell_types = data[cell_type_col].values
    cell_type_to_idx = {ct: i for i, ct in enumerate(cell_type_list)}

    for i in range(data.shape[0]):
        ct_i = cell_types[i]
        idx_i = cell_type_to_idx[ct_i]
        coord_i = coords[i]
        
        dists = np.linalg.norm(coords - coord_i, axis=1)
        neighbors = np.where((dists > 1e-4) & (dists <= radius))[0]
        
        for j in neighbors:
            ct_j = cell_types[j]
            idx_j = cell_type_to_idx[ct_j]
            contact_counts[idx_i, idx_j] += 1
    
    # Normalize by number of cells per cell type to make cell type specific
    N_cells = adata_sub.obs.groupby(cell_type_col, observed=True).size().to_dict()
    divide = np.asarray([N_cells[i] for i in cell_type_list])
    norm_counts = contact_counts / divide

    return norm_counts, cell_type_list
    
def get_cell_contacts(
    data : ad.AnnData | pd.DataFrame, 
    cell_type_col = "Subclass", 
    spatial_keys = ['center_x', 'center_y'],
    cell_type_list = None,
    radius = 50,
): 
    if isinstance(data, ad.AnnData):
        data = data.obs.copy()

    if cell_type_list is None:
        cell_type_list = np.unique(data[cell_type_col])
    N_cell_types = len(cell_type_list)
    contact_counts = np.zeros((data.shape[0], N_cell_types), dtype=int)

    coords = data[spatial_keys].values
    cell_types = data[cell_type_col].values
    cell_type_to_idx = {ct: i for i, ct in enumerate(cell_type_list)}

    for i in range(data.shape[0]):
        ct_i = cell_types[i]
        idx_i = cell_type_to_idx[ct_i]
        coord_i = coords[i]
        
        dists = np.linalg.norm(coords - coord_i, axis=1)
        neighbors = np.where((dists > 0) & (dists <= radius))[0]
        
        for j in neighbors:
            ct_j = cell_types[j]
            idx_j = cell_type_to_idx[ct_j]
            contact_counts[i, idx_j] += 1

    return contact_counts, cell_type_list

# functions from xingjiepan 2023 mouse atlas paper
def adjust_p_value_matrix_by_BH(p_val_mtx):
    '''Adjust the p-values in a matrix by the Benjamini/Hochberg method.
    The matrix should be symmetric.
    '''
    p_val_sequential = []
    N = p_val_mtx.shape[0]
    
    for i in range(N):
        for j in range(i, N):
            p_val_sequential.append(p_val_mtx[i, j])

    p_val_sequential_bh = multipletests(p_val_sequential, method='fdr_bh')[1]
    
    adjusted_p_val_mtx = np.zeros((N, N))
    
    counter = 0
    for i in range(N):
        for j in range(i, N):
            adjusted_p_val_mtx[i, j] = p_val_sequential_bh[counter]
            adjusted_p_val_mtx[j, i] = p_val_sequential_bh[counter]
            counter += 1
            
    return adjusted_p_val_mtx

def get_data_frame_from_metrices(cell_types, mtx_dict):
    N = len(cell_types)
    
    serials_dict = {'cell_type1':[], 'cell_type2':[]}
    for k in mtx_dict.keys():
        serials_dict[k] = []
        
    for i in range(N):
        for j in range(i, N):
            serials_dict['cell_type1'].append(cell_types[i])
            serials_dict['cell_type2'].append(cell_types[j])
            for k in mtx_dict.keys():
                serials_dict[k].append(mtx_dict[k][i, j])
                
    return pd.DataFrame(serials_dict)
    

def sort_cell_type_contact_p_values(p_val_mtx, cell_types):
    '''Return a list of (cell_type1, cell_type2, p_value) sorted by p_values.'''
    p_val_list = []
    N = p_val_mtx.shape[0]
    for i in range(N):
        for j in range(i, N):
            p_val_list.append((cell_types[i], cell_types[j], p_val_mtx[i, j]))
    return sorted(p_val_list, key=lambda x:x[2])

### image ex

In [None]:
radius = 50
fig, axes = plt.subplots(1, 2, figsize=(8,4))
ax=axes[0]
plot_categorical(adata_sub, cluster_col='Group', coord_base="spatial", ax=ax, show=False)
center = adata_sub.obs.sample()[['CENTER_X', 'CENTER_Y']].values[0]
ax.add_patch(mpatches.Circle(center, radius, color='red', fill=False))
ax=axes[1]
adata_sub_sub = get_cells_in_radius(adata_sub, center, radius)
plot_categorical(adata_sub_sub, cluster_col='Group', coord_base="spatial", ax=ax, show=False)
plt.show()

### Get cell contacts

In [None]:
level = "Group"
adata_sub.obs[level] = adata_sub.obs[level].fillna("unknown")
cell_types = np.unique(adata_sub.obs[level])
cell_contacts = get_cell_by_cell_contacts(adata_sub, cell_type_col=level, spatial_keys=['CENTER_X', 'CENTER_Y'], cell_type_list=cell_types, radius=30)

In [None]:
cts = cell_contacts[1]
conts = cell_contacts[0]

In [None]:
conts.shape

In [None]:
N_cells = adata_sub.obs.groupby(level, observed=True).size().to_dict()
N_cells

In [None]:
divide = np.asarray([N_cells[i] for i in cts])
divide

In [None]:
norm_counts = conts / divide
norm_counts.shape

In [None]:
conts.shape

In [None]:
plt.imshow(conts)

In [None]:
plt.imshow(conts / divide)

In [None]:
# for i in cts: 
#     print(i, N_cells[i], np.divide(conts[:, i], N_cells[i]))

In [None]:
cell_contacts[0]

In [None]:
df_contacts = pd.DataFrame(cell_contacts[0], columns=cell_contacts[1], index=adata_sub.obs_names)

In [None]:
df_contacts.sum(axis=0).sort_values(ascending=False)

In [None]:
sns.histplot(df_contacts.sum(axis=1))

### get real contacts

In [None]:
adata_sub.obs['Group'] = adata_sub.obs['Group'].fillna("unknown")

In [None]:
cell_types = np.unique(adata_sub.obs['Group'])

In [None]:
contact_counts = get_cell_by_cell_contacts(adata_sub, cell_type_col='Group', spatial_keys=['CENTER_X', 'CENTER_Y'], cell_type_list=cell_types, radius=15)

In [None]:
df_contacts = pd.DataFrame(contact_counts[0], index=contact_counts[1], columns=contact_counts[1])
to_plot = df_contacts / np.sum(df_contacts, axis=0)

In [None]:
fig, ax = plt.subplots()
img = ax.imshow(to_plot, cmap='cividis')
ax.set_xticks(np.arange(len(to_plot.columns)))
ax.set_yticks(np.arange(len(to_plot.index)))
ax.set_xticklabels(to_plot.columns, rotation=45, ha='right', fontsize=6)
ax.set_yticklabels(to_plot.index, fontsize=6)
plt.colorbar(img, ax=ax, fraction=0.046, pad=0.04)
plt.show()

### get perturbed contacts

In [None]:
Np = 1000
r_permute = 100
merged_contact_counts = np.zeros((Np, N_cell_types, N_cell_types), dtype=int)
for i in range(Np): 
    df_slide = adata_sub.obs.copy()
    r = r_permute * np.sqrt(np.random.uniform(size=df_slide.shape[0]))
    theta = np.random.uniform(size=df_slide.shape[0]) * 2 * np.pi
    df_slide['CENTER_X'] = df_slide['CENTER_X'] + r * np.cos(theta)
    df_slide['CENTER_Y'] = df_slide['CENTER_Y'] + r * np.sin(theta)
    contacts = get_cell_by_cell_contacts(df_slide, cell_type_col='AIT_Subclass', spatial_keys=['CENTER_X', 'CENTER_Y'], cell_type_list=cell_types, radius=15)
    merged_contact_counts[i] = contacts[0]


In [None]:
merged_contact_counts_mean = np.mean(merged_contact_counts, axis=0)

In [None]:
df_contacts_perm = pd.DataFrame(merged_contact_counts_mean, index=contacts[1], columns=contacts[1])
to_plot = df_contacts_perm / np.sum(df_contacts_perm, axis=0)
fig, ax = plt.subplots()
img = ax.imshow(to_plot, cmap='cividis')
ax.set_xticks(np.arange(len(to_plot.columns)))
ax.set_yticks(np.arange(len(to_plot.index)))
ax.set_xticklabels(to_plot.columns, rotation=45, ha='right', fontsize=6)
ax.set_yticklabels(to_plot.index, fontsize=6)
plt.colorbar(img, ax=ax, fraction=0.046, pad=0.04)
plt.show()

### save

In [None]:
output_path = Path("/home/x-aklein2/projects/aklein/BICAN/BG/data/cell_contacts_15um")
output_path.mkdir(parents=True, exist_ok=True)

In [None]:
np.save(Path(output_path) / f"contact_counts_real_{_donor}_{_brain_region}_{_replicate}_15um.npy", contact_counts[0])
np.save(Path(output_path) / f"contact_counts_permuted_{_donor}_{_brain_region}_{_replicate}_15um.npy", merged_contact_counts)
permuted_means = np.mean(merged_contact_counts, axis=0)
permuted_std = np.std(merged_contact_counts, axis=0)
np.save(Path(output_path) / f"contact_counts_permuted_mean_{_donor}_{_brain_region}_{_replicate}_15um.npy", permuted_means)
np.save(Path(output_path) / f"contact_counts_permuted_std_{_donor}_{_brain_region}_{_replicate}_15um.npy", permuted_std)

In [None]:
real_contacts = np.load(Path(output_path) / f"contact_counts_real_{_donor}_{_brain_region}_{_replicate}_15um.npy")
null_contacts = np.load(Path(output_path) / f"contact_counts_permuted_{_donor}_{_brain_region}_{_replicate}_15um.npy")
null_contacts_mean = np.load(Path(output_path) / f"contact_counts_permuted_mean_{_donor}_{_brain_region}_{_replicate}_15um.npy")
null_contacts_std = np.load(Path(output_path) / f"contact_counts_permuted_std_{_donor}_{_brain_region}_{_replicate}_15um.npy")

### p-testing

In [None]:
null_contacts_std = np.maximum(null_contacts_std, np.sqrt(1/1000))
permuted_z_score = (real_contacts - null_contacts_mean) / null_contacts_std
local_p_values = norm.sf(np.abs(permuted_z_score))
adjusted_local_p_value = adjust_p_value_matrix_by_BH(local_p_values)
fold_changes = real_contacts / (null_contacts_mean + 1e-6)
# Gather all results into a data frame
contact_result_df = get_data_frame_from_metrices(cell_types, 
                                                 {'pval-adjusted': adjusted_local_p_value,
                                                  'pval': local_p_values,
                                                  'z_score': permuted_z_score,
                                                  'contact_count': real_contacts,
                                                  'permutation_mean': null_contacts_mean,
                                                  'permutation_std': null_contacts_std,
                                                  'fold-change' : fold_changes,
                                        }).sort_values('z_score', ascending=False)

In [None]:
contact_result_df = contact_result_df[contact_result_df['pval-adjusted'] < 0.05]
contact_result_df = contact_result_df[contact_result_df['contact_count'] > 50]
contact_result_df

### image contacts

In [None]:
_ct1 = "Astrocyte"
_ct2 = "STR D2 MSN"
i = np.where(cell_types == _ct1)[0][0]
j = np.where(cell_types == _ct2)[0][0]
null_dist = null_contacts[:, i, j]
real_count = real_contacts[i, j]

In [None]:
fig, ax = plt.subplots(figsize=(4, 4))
ax.hist(null_dist, bins=30, color='lightgrey', density=True)
ax.axvline(real_count, color='red', linestyle='--', label='Real Count')
ax.set_xlabel(f'Contact Counts between {_ct1} and {_ct2}')
ax.set_ylabel('Density')
ax.set_title(f'Contact Counts Distribution\n{_donor}, {_brain_region}, {_replicate}')
ax.legend(loc='upper right', bbox_to_anchor=(1.6, 1), fontsize=12)
plt.show()

## Chain together for all slices

In [None]:
### TODO: 
# - for each slice (multiprocessing here?)
# - calculate the real contacts
# - calculate the null contacts
# - Save the outputs (can do the merging afterwards)
#  #

In [None]:
def _get_contacts_for_slice(
    _slice,
    adata,
    Np = 1000,
    r_permute = 100,
    r_test = 15,
    alpha = 0.05,
    min_contacts = 50,
): 
    # parameters: 
    # adata: AnnData object containing all data
    # _slice: tuple of (donor, brain_region, replicate)
    # radius: radius for contact calculation
    # Np: number of permutations
    # r_permute: max radius for permutation
    #

    donor, brain_region, replicate = _slice
    adata_sub = adata[(adata.obs['donor'] == donor) &
                      (adata.obs['brain_region'] == brain_region) &
                      (adata.obs['replicate'] == replicate)].copy()
    adata_sub = adata_sub[(adata_sub.obs['AIT_Subclass'] != "unknown") & (adata_sub.obs['AIT_Group'] != "unknown")].copy()
    cell_types = np.unique(adata_sub.obs['AIT_Subclass'])
    N_cell_types = len(cell_types)
    real_contacts, _ = get_cell_by_cell_contacts(adata_sub, cell_type_col='AIT_Subclass', spatial_keys=['CENTER_X', 'CENTER_Y'], cell_type_list=cell_types, radius=r_test)
    merged_contact_counts = np.zeros((Np, N_cell_types, N_cell_types), dtype=int)
    for i in range(Np): 
        df_slide = adata_sub.obs.copy()
        r = r_permute * np.sqrt(np.random.uniform(size=df_slide.shape[0]))
        theta = np.random.uniform(size=df_slide.shape[0]) * 2 * np.pi
        df_slide['CENTER_X'] = df_slide['CENTER_X'] + r * np.cos(theta)
        df_slide['CENTER_Y'] = df_slide['CENTER_Y'] + r * np.sin(theta)
        contacts = get_cell_by_cell_contacts(df_slide, cell_type_col='AIT_Subclass', spatial_keys=['CENTER_X', 'CENTER_Y'], cell_type_list=cell_types, radius=r_test)
        merged_contact_counts[i] = contacts[0]
    null_contacts_mean = np.mean(merged_contact_counts, axis=0)
    null_contacts_std = np.std(merged_contact_counts, axis=0)

    np.save(Path(output_path) / f"contact_counts_real_{_donor}_{_brain_region}_{_replicate}_15um.npy", real_contacts)
    np.save(Path(output_path) / f"contact_counts_permuted_{_donor}_{_brain_region}_{_replicate}_15um.npy", merged_contact_counts)
    np.save(Path(output_path) / f"contact_counts_permuted_mean_{_donor}_{_brain_region}_{_replicate}_15um.npy", null_contacts_mean)
    np.save(Path(output_path) / f"contact_counts_permuted_std_{_donor}_{_brain_region}_{_replicate}_15um.npy", null_contacts_std)

    null_contacts_std = np.maximum(null_contacts_std, np.sqrt(1/1000))
    permuted_z_score = (real_contacts - null_contacts_mean) / null_contacts_std
    local_p_values = norm.sf(np.abs(permuted_z_score))
    adjusted_local_p_value = adjust_p_value_matrix_by_BH(local_p_values)
    fold_changes = real_contacts / (null_contacts_mean + 1e-6)
    # Gather all results into a data frame
    contact_result_df = get_data_frame_from_metrices(cell_types, 
                                                    {'pval-adjusted': adjusted_local_p_value,
                                                    'pval': local_p_values,
                                                    'z_score': permuted_z_score,
                                                    'contact_count': real_contacts,
                                                    'permutation_mean': null_contacts_mean,
                                                    'permutation_std': null_contacts_std,
                                                    'fold-change' : fold_changes,
                                            }).sort_values('z_score', ascending=False)
    contact_result_df['id'] = f"{donor}_{brain_region}_{replicate}"
    contact_result_df = contact_result_df[contact_result_df['pval-adjusted'] < alpha]
    contact_result_df = contact_result_df[contact_result_df['contact_count'] > min_contacts]
    return contact_result_df

In [None]:
slices[0]

In [None]:
from tqdm import tqdm

In [None]:
function_args = {}
function_args['Np'] = 1000
function_args['r_permute'] = 100
function_args['r_test'] = 15
function_args['alpha'] = 0.05
function_args['min_contacts'] = 50
function_args['adata'] = adata

parallel_func = partial(_get_contacts_for_slice, **function_args)
with mp.Pool(8) as pool:
    contacts_list = list(  # noqa: F841
        tqdm(
            pool.imap_unordered(parallel_func, slices),
            total=len(slices),
        )
    )

In [None]:
contact_df = _get_contacts_for_slice(adata, slices[0])

In [None]:
contact_df

# Radii to closest nuclei

### Functions

In [None]:
def get_closest_cell_of_type(
    adata,
    cell_type_col = "AIT_Subclass",
    spatial_keys = ['CENTER_X', 'CENTER_Y'],
    target_cell_type = "Astrocyte",
    query_cell_type = "STR D2 MSN",
):
    adata_target = adata[adata.obs[cell_type_col] == target_cell_type]
    adata_query = adata[adata.obs[cell_type_col] == query_cell_type]
    target_coords = adata_target.obs[spatial_keys].values
    query_coords = adata_query.obs[spatial_keys].values

    closest_distances = []
    closest_indices = []
    for qc in query_coords:
        dists = np.linalg.norm(target_coords - qc, axis=1)
        closest_idx = np.argmin(dists)
        closest_distances.append(dists[closest_idx])
        closest_indices.append(closest_idx)
    
    return np.array(closest_distances), np.array(closest_indices)

In [None]:
qry_ct = "Astrocyte"
distances = {}
for _ct in adata_sub.obs['AIT_Subclass'].unique():
    if _ct == qry_ct:
        continue
    ret = get_closest_cell_of_type(adata_sub, target_cell_type=_ct, query_cell_type=qry_ct)
    distances[_ct] = ret[0]

In [None]:
distances.keys()

In [None]:
qry_types = ["Oligodendrocyte", "STR D1 MSN", "STR D2 MSN", "CN ST18 GABA", "CN Cholinergic GABA", "Microglia"]
color=sns.color_palette("husl", len(qry_types))

In [None]:
fig, ax = plt.subplots(figsize=(4, 4), dpi=200)
for _ct, col in zip(qry_types, color):
    sns.kdeplot(distances[_ct], label=_ct, ax=ax, fill=True, alpha=0.2, color=col, linewidth=1)
ax.set_xlabel(f'Distances from {qry_ct} (um)')
ax.set_ylabel('Density')
ax.set_title("Distribution of Astrocyte Distances to cell type")
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
plt.show()


In [None]:
qry_ct = "CN Cholinergic GABA"
distances = {}
for _ct in adata_sub.obs['AIT_Subclass'].unique():
    if _ct == qry_ct:
        continue
    ret = get_closest_cell_of_type(adata_sub, target_cell_type=_ct, query_cell_type=qry_ct)
    distances[_ct] = ret[0]

In [None]:
qry_types = ["Oligodendrocyte", "STR D1 MSN", "STR D2 MSN", "Astrocyte", "Microglia"]
color=sns.color_palette("tab10", len(qry_types))

In [None]:
fig, ax = plt.subplots(figsize=(4, 4), dpi=200)
for _ct, col in zip(qry_types, color):
    sns.kdeplot(distances[_ct], label=_ct, ax=ax, fill=True, alpha=0.2, color=col, linewidth=1)
ax.set_xlabel(f'Distances from {qry_ct} (um)')
ax.set_ylabel('Density')
ax.set_title("Distribution of CN Cholinergic GABA Distances to cell type")
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
plt.show()
