In [None]:
import os
from pathlib import Path
from rich import print as rprint, inspect

from tqdm import tqdm
import itertools

import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
import scipy.sparse as scp

from spida.P.setup_adata import _calc_embeddings, multi_round_clustering

import matplotlib.pyplot as plt
import seaborn as sns
from spida.pl import plot_categorical, plot_continuous, categorical_scatter, continuous_scatter
plt.rcParams['figure.dpi'] = 150
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['axes.facecolor'] = 'white'

## functions

In [None]:
def normalize_adata(
    adata: ad.AnnData,
    log1p: bool = True,
): 
    n_counts = np.ravel(adata.X.sum(axis=1))
    adata.X.data = adata.X.data/np.repeat(n_counts, adata.X.getnnz(axis=1)) * np.median(n_counts)
    if log1p:
        sc.pp.log1p(adata)

In [None]:
# For distance between centers, need to implement a version of this using the true geometries. 
def get_cells_in_radius(adata, center, radius, cols=['CENTER_X', 'CENTER_Y']):
    return adata[adata.obs[cols].apply(lambda x: (x[cols[0]] - center[0])**2 + (x[cols[1]] - center[1])**2 < radius**2, axis=1)].copy()


def get_cell_by_cell_contacts(
    data : ad.AnnData | pd.DataFrame, 
    cell_type_col = "Subclass", 
    spatial_keys = ['center_x', 'center_y'],
    cell_type_list = None,
    radius = 50,
): 
    if isinstance(data, ad.AnnData):
        data = data.obs.copy()

    if cell_type_list is None:
        cell_type_list = np.unique(data[cell_type_col])
    N_cell_types = len(cell_type_list)
    contact_counts = np.zeros((N_cell_types, N_cell_types), dtype=int)

    coords = data[spatial_keys].values
    cell_types = data[cell_type_col].values
    cell_type_to_idx = {ct: i for i, ct in enumerate(cell_type_list)}

    for i in range(data.shape[0]):
        ct_i = cell_types[i]
        idx_i = cell_type_to_idx[ct_i]
        coord_i = coords[i]
        
        dists = np.linalg.norm(coords - coord_i, axis=1)
        neighbors = np.where((dists > 0) & (dists <= radius))[0]
        
        for j in neighbors:
            ct_j = cell_types[j]
            idx_j = cell_type_to_idx[ct_j]
            contact_counts[idx_i, idx_j] += 1
            
    return contact_counts, cell_type_list
    
def get_cell_contacts(
    data : ad.AnnData | pd.DataFrame, 
    cell_type_col = "Subclass", 
    spatial_keys = ['center_x', 'center_y'],
    cell_type_list = None,
    radius = 50,
): 
    if isinstance(data, ad.AnnData):
        data = data.obs.copy()

    if cell_type_list is None:
        cell_type_list = np.unique(data[cell_type_col])
    N_cell_types = len(cell_type_list)
    contact_counts = np.zeros((data.shape[0], N_cell_types), dtype=int)

    coords = data[spatial_keys].values
    cell_types = data[cell_type_col].values
    cell_type_to_idx = {ct: i for i, ct in enumerate(cell_type_list)}

    for i in range(data.shape[0]):
        ct_i = cell_types[i]
        idx_i = cell_type_to_idx[ct_i]
        coord_i = coords[i]
        
        dists = np.linalg.norm(coords - coord_i, axis=1)
        neighbors = np.where((dists > 0) & (dists <= radius))[0]
        
        for j in neighbors:
            ct_j = cell_types[j]
            idx_j = cell_type_to_idx[ct_j]
            contact_counts[i, idx_j] += 1

    return contact_counts, cell_type_list

## Read

In [None]:
# ad_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_PFV8_annotated_v5.h5ad"
# data_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/PF/cell_contacts_15um"

ad_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad"
data_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/cell_contacts_15um"

In [None]:
adata = ad.read_h5ad(ad_path)
adata.obs['Group'] = adata.obs['Group'].fillna("unknown")
adata

### Group Sepcific Clusters

In [None]:
adata_neu = adata[adata.obs['neuron_type'] == 'Neuron'].copy()
adata_neu.obs['Group'].value_counts()

In [None]:
# for i, _group in enumerate(adata_neu.obs['Group'].unique().tolist()):
#     if i <= 3: 
#         continue
#     rprint(_group)
#     adata_group = adata_neu[adata_neu.obs['Group'] == _group].copy()
#     break

_group = "STR TAC3-PLPP4 GABA"
adata_group = adata_neu[adata_neu.obs['Group'] == _group].copy()

In [None]:
adata_group.X = adata_group.layers['volume_norm'].copy()
normalize_adata(adata_group)
multi_round_clustering(
    adata_group,
    layer=None,
    key_added="group_",
    num_rounds=1,
    leiden_res=1,
    min_dist=0.25,
    knn=50,
    min_group_size=50,
    run_harmony=True, 
    batch_key=["replicate", "donor"],
    harmony_nclust=3,
    max_iter_harmony=10,
)
adata_group

In [None]:
fig, axes = plt.subplots(1,4, figsize=(20,4), dpi=300)

plot_categorical(adata_group, cluster_col="group_round1_leiden", coord_base="X_group_round1_umap", show=False, ax=axes[0])
plot_categorical(adata_group, cluster_col="donor", coord_base="X_group_round1_umap", show=False, ax=axes[1])
plot_categorical(adata_group, cluster_col="replicate", coord_base="X_group_round1_umap", show=False, ax=axes[2])
plot_categorical(adata_group, cluster_col="brain_region", coord_base="X_group_round1_umap", show=False, ax=axes[3])
plt.suptitle("Group: " + _group)

plt.show()
plt.close()

### Read in Contacts

In [None]:
donors = adata.obs['donor'].unique().tolist()
replicates = adata.obs['replicate'].unique().tolist()
brain_regions = adata.obs['brain_region'].unique().tolist()
skip = [("UWA7648", "CAT", "ucsd"), ("UWA7648", "CAT", "salk")]

In [None]:
contact_list = []
for _i in tqdm(itertools.product(donors, brain_regions, replicates)):
    if _i in skip:
        # print(f"Skipping {_i}")
        continue
    _donor, _brain_region, _replicate, = _i
    # rprint(f"Processing {_donor} | {_brain_region} | {_replicate}")
    adata_sub = adata[(adata.obs["donor"] == _donor) & (adata.obs["brain_region"] == _brain_region) & (adata.obs["replicate"] == _replicate)].copy()
    cell_types = np.unique(adata_sub.obs['Group'])
    cell_contacts = get_cell_contacts(adata_sub, cell_type_col='Group', spatial_keys=['CENTER_X', 'CENTER_Y'], cell_type_list=cell_types, radius=15)
    df_contacts = pd.DataFrame(cell_contacts[0], columns=cell_contacts[1], index=adata_sub.obs_names)
    contact_list.append(df_contacts)
    
    
    # real_contacts = np.load(Path(data_path) / f"contact_counts_real_{_donor}_{_brain_region}_{_replicate}_15um.npy")
    # null_contacts = np.load(Path(data_path) / f"contact_counts_permuted_{_donor}_{_brain_region}_{_replicate}_15um.npy")
    # null_contacts_mean = np.load(Path(data_path) / f"contact_counts_permuted_mean_{_donor}_{_brain_region}_{_replicate}_15um.npy")
    # null_contacts_std = np.load(Path(data_path) / f"contact_counts_permuted_std_{_donor}_{_brain_region}_{_replicate}_15um.npy")
    # break

In [None]:
df_contacts = pd.concat(contact_list).fillna(0).astype(np.uint8)

### Thoughts

I want to use this to investigate whether a group has differential contacts across a cluster. So technically using the group level clusters defined above I want to run something like tl.rank_gene_groups() but for the contacts??

Let's try this with the STR D1 Striosome MSN example

In [None]:
group_contacts = df_contacts.loc[adata_group.obs_names]
group_contacts.head()

In [None]:
gc_cols = group_contacts.columns[group_contacts.sum(axis=0) != 0]
ncols = 6
nrows = int(np.ceil(len(gc_cols) / ncols))
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols*4, nrows*4), dpi=300)
for i, col in enumerate(gc_cols):
    ax = axes.flatten()[i]
    adata_group.obs[f'{col}_contacts'] = group_contacts[col]
    plot_continuous(adata_group, color_by=f"{col}_contacts", coord_base="X_group_round1_umap", cmap="YlOrRd", show=False, ax=ax, hue_portion=1)
    adata_group.obs.drop(columns=[f'{col}_contacts'], inplace=True)
    ax.set_title(f"{col} contacts")
plt.suptitle("Cell type contacts in group: " + _group)
plt.show()
plt.close()

In [None]:
adata_group_contacts = ad.AnnData(
    obs=adata_group.obs.copy(),
    X = scp.csc_matrix(group_contacts.values),
    var=pd.DataFrame(index=group_contacts.columns)
)
adata_group_contacts.layers['raw'] = adata_group_contacts.X.copy()

In [None]:
# normalize_adata(adata_group_contacts, log1p=True)
# sc.pp.log1p(adata_group_contacts)
multi_round_clustering(
    adata_group_contacts,
    layer=None,
    key_added="contacts_",
    num_rounds=1,
    leiden_res=0.1,
    min_dist=0.5,
    knn=35,
    p_cutoff=0.01,
    min_group_size=50,
    run_harmony=False, 
    batch_key=["replicate", "donor"],
    harmony_nclust=3,
    max_iter_harmony=10,
)
adata_group_contacts

In [None]:
fig, axes = plt.subplots(1,4, figsize=(20,4), dpi=300)

plot_categorical(adata_group_contacts, cluster_col="contacts_round1_leiden", coord_base="X_contacts_round1_umap", show=False, ax=axes[0])
plot_categorical(adata_group_contacts, cluster_col="donor", coord_base="X_contacts_round1_umap", show=False, ax=axes[1])
plot_categorical(adata_group_contacts, cluster_col="replicate", coord_base="X_contacts_round1_umap", show=False, ax=axes[2])
plot_categorical(adata_group_contacts, cluster_col="brain_region", coord_base="X_contacts_round1_umap", show=False, ax=axes[3])
plt.suptitle("Group: " + _group)

plt.show()
plt.close()

In [None]:
sc.pl.heatmap(adata_group_contacts, var_names=gc_cols, groupby="contacts_round1_leiden", show=True,
              show_gene_labels=True, dendrogram=True, vmax=2)

In [None]:
astro_associated = adata_group_contacts.obs.index[adata_group_contacts.obs['contacts_round1_leiden'].isin(['5'])]

In [None]:
adata_group.obs['Astro_associated'] = adata_group.obs_names.isin(astro_associated)

In [None]:
adata_group.obs['Astro_associated'] = adata_group.obs['Astro_associated'].astype("category")

In [None]:
sc.tl.rank_genes_groups(
    adata_group,
    groupby="Astro_associated",
    method="wilcoxon",
    n_genes=adata_group.shape[1],
)
# df = sc.get.rank_genes_groups_df(adata_group, "group_round1_leiden")
sc.pl.rank_genes_groups_heatmap(
    adata_group,
    n_genes=20,
    groupby="Astro_associated",
    show=True,
    dendrogram=False
)

In [None]:
sc.pl.heatmap(
    adata_group,
    groupby="Astro_associated",
    var_names=['WFS1', "DNMT3A", "DNMT3B", "TET1", "TET2", "PENK", "SYNPR", "PDE10A", "MAG", "DRD1", "DRD2", "KCNIP4"],
    show=True,
)

In [None]:
sc.tl.rank_genes_groups(
    adata_group_contacts,
    groupby="contacts_round1_leiden",
    method="t-test_overestim_var",
    n_genes=adata_group_contacts.shape[1],
)

In [None]:
sc.pl.rank_genes_groups_heatmap(
    adata_group_contacts,
    n_genes=3,
    groupby="contacts_round1_leiden",
    show=True,
)

In [None]:
# adata_sub = adata[(adata.obs["donor"] == _donor) & (adata.obs["brain_region"] == _brain_region) & (adata.obs["replicate"] == _replicate)]
# adata_sub.obs['Group'] = adata_sub.obs['Group'].fillna("unknown")
# cell_types = np.unique(adata_sub.obs['Group'])
# cell_contacts = get_cell_contacts(adata_sub, cell_type_col='Group', spatial_keys=['CENTER_X', 'CENTER_Y'], cell_type_list=cell_types, radius=30)
# df_contacts = pd.DataFrame(cell_contacts[0], columns=cell_contacts[1], index=adata_sub.obs_names)
# # df_contacts = pd.concat(contact_list).fillna(0).astype(np.uint8)

In [None]:
# pd.read_csv("/home/x-aklein2/projects/aklein/BICAN/BG/data/PF/cell_contacts_30um/cell_contacts_UWA7648_CAB_ucsd_30um.csv")