In [None]:
import os
from pathlib import Path
import itertools 

import numpy as np
import pandas as pd
import anndata as ad
import scanpy as sc
from statsmodels.stats.multitest import multipletests
from scipy.stats import norm
import geopandas as gpd

import multiprocessing as mp
from functools import partial

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns

import spatialdata as sd

from spida._constants import ren_to_exp_map
from spida.pl import plot_categorical, plot_continuous
plt.rcParams['figure.dpi'] = 200
plt.rcParams['axes.facecolor'] = 'white'

## Functions

In [None]:
# Distances based on the cell geometries 
def get_geom_distances(
    shapes_df,
    cell_type_col = "Subclass",
    target_cell_type = "Astrocyte",
    query_cell_type = "STR D2 MSN",
):
    shapes_qry = shapes_sub.loc[shapes_sub[cell_type_col] == query_cell_type]
    shapes_tgt = shapes_sub.loc[shapes_sub[cell_type_col] == target_cell_type]

    df_dist = gpd.sjoin_nearest(
        shapes_qry, 
        shapes_tgt, 
        how='inner',
        distance_col='distance'
    )
    dists = df_dist['distance'].values
    
    return dists, df_dist['uid_left'].values, df_dist['uid_right'].values

# Distances based on the cell centroids
def get_closest_cell_of_type(
    adata,
    cell_type_col = "Subclass",
    spatial_keys = ['CENTER_X', 'CENTER_Y'],
    target_cell_type = "Astrocyte",
    query_cell_type = "STR D2 MSN",
):
    adata_target = adata[adata.obs[cell_type_col] == target_cell_type]
    adata_query = adata[adata.obs[cell_type_col] == query_cell_type]
    target_coords = adata_target.obs[spatial_keys].values
    query_coords = adata_query.obs[spatial_keys].values

    closest_distances = []
    closest_indices = []
    for qc in query_coords:
        dists = np.linalg.norm(target_coords - qc, axis=1)
        closest_idx = np.argmin(dists)
        closest_distances.append(dists[closest_idx])
        closest_indices.append(closest_idx)
    
    return np.array(closest_distances), np.array(closest_indices)

In [None]:
def agg_df_results(distances, cell_ids, cell_type): 
    df_ids = pd.DataFrame.from_dict(cell_ids[cell_type], orient='index').reset_index().melt(id_vars='index')
    df_ids.dropna(inplace=True)
    df_ids.drop(columns=['variable'], inplace=True)
    df_ids.rename(columns={'index':'dataset_id', 'value':'CELL_ID'}, inplace=True)

    df = pd.DataFrame.from_dict(distances[cell_type], orient='index').reset_index().melt(id_vars='index')
    df.dropna(inplace=True)
    df.drop(columns=['variable'], inplace=True)
    df.rename(columns={'index':'dataset_id', 'value':'distance'}, inplace=True)
    df[['brain_region', 'donor', 'lab']] = df.dataset_id.str.split("_", expand=True)
    df['brain_region'] = df['brain_region'].map(ren_to_exp_map).fillna(df['brain_region'])
    df['CELL_ID'] = df_ids['CELL_ID']

    cutoffs = np.percentile(df['distance'], [1, 15, 85, 99])
    df.loc[df['distance'] <= cutoffs[0], 'distance_category'] = 'Very Close'
    df.loc[(df['distance'] > cutoffs[0]) & (df['distance'] <= cutoffs[1]), 'distance_category'] = 'Close'
    df.loc[(df['distance'] > cutoffs[1]) & (df['distance'] <= cutoffs[2]), 'distance_category'] = 'Intermediate'
    df.loc[(df['distance'] > cutoffs[2]) & (df['distance'] <= cutoffs[3]), 'distance_category'] = 'Far'
    df.loc[df['distance'] > cutoffs[3], 'distance_category'] = 'Very Far'
    df['distance_category'] = pd.Categorical(
        df['distance_category'],
        categories=[
            'Very Close',
            'Close',
            'Intermediate',
            'Far',
            'Very Far'
        ],
        ordered=True
    )
    return df

## Read

In [None]:
ad_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad"
adata = ad.read_h5ad(ad_path)
adata

In [None]:
# # Generate Master Shapes DataFrame
# shapes = []
# sdata_root = Path("/home/x-aklein2/projects/aklein/BICAN/data/zarr_store/")
# for sd_file in sdata_root.glob("BICAN_BG_*_CPSfilt"):
#     print(sd_file)
#     lab = sd_file.name.split("_")[-2]
#     sdata = sd.read_zarr(sd_file, selection=['shapes'])
#     for _elem in sdata.gen_elements(): 
#         dsid = _elem[1]
#         print(dsid)
#         dsl = dsid.split("_")[:2]
#         dsl.append(lab)
#         dsid = "_".join(dsl)
#         print(dsid)
        
#         adata_sub = adata[adata.obs['dataset_id'] == dsid].copy()
#         df_sub = adata_sub.obs[['CELL_ID', 'Subclass', 'Group', 'MSN_Groups', 'dataset_id']].reset_index(names="uid").set_index("CELL_ID")

#         shapes_sub = _elem[2].copy()
#         shapes_merged = shapes_sub.merge(df_sub, left_index=True, right_index=True)
#         shapes.append(shapes_merged)

# shapes_all = pd.concat(shapes)
# shapes_all.head()
# shapes_all.to_parquet("/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS_shapes.parquet")

In [None]:
shapes = gpd.read_parquet("/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS_shapes.parquet")

## Geom Distances For MSNs Stratified by Astrocyte

In [None]:
tgt_ct = "Astrocyte"
qry_cells = ['STR D1 MSN', 'STR D2 MSN']
distances = {}
cell_ids = {}
for qry_ct in qry_cells:
    distances[qry_ct] = {}
    cell_ids[qry_ct] = {}
    for _dsid in shapes['dataset_id'].unique():
        shapes_sub = shapes[shapes['dataset_id'] == _dsid]
        if qry_ct not in shapes_sub['Subclass'].unique():
            continue
        ret = get_geom_distances(shapes_sub, cell_type_col="Subclass", target_cell_type=tgt_ct, query_cell_type=qry_ct)
        distances[qry_ct][_dsid] = ret[0]
        cell_ids[qry_ct][_dsid] = ret[1]

In [None]:
df_d1 = agg_df_results(distances, cell_ids, 'STR D1 MSN')
df_d2 = agg_df_results(distances, cell_ids, 'STR D2 MSN')

In [None]:
fig, axes = plt.subplots(figsize=(15, 8), nrows=2, ncols=3)
axes = axes.flatten()
sns.kdeplot(data=df_d1, x='distance', fill=True, hue='brain_region', palette=adata.uns['brain_region_palette'], common_norm=False, ax=axes[0])
axes[0].set_title('Distance from STR D1 MSN to Astrocytes by Brain Region')
sns.kdeplot(data=df_d1, x='distance', fill=True, hue='donor', palette=adata.uns['donor_palette'], common_norm=False, ax=axes[1])
axes[1].set_title('Distance from STR D1 MSN to Astrocytes by Donor')
sns.kdeplot(data=df_d1, x='distance', fill=True, hue='lab', palette=adata.uns['replicate_palette'], common_norm=False, ax=axes[2])
axes[2].set_title('Distance from STR D1 MSN to Astrocytes by Lab')
sns.kdeplot(data=df_d2, x='distance', fill=True, hue='brain_region', palette=adata.uns['brain_region_palette'], common_norm=False, ax=axes[3])
axes[3].set_title('Distance from STR D2 MSN to Astrocytes by Brain Region')
sns.kdeplot(data=df_d2, x='distance', fill=True, hue='donor', palette=adata.uns['donor_palette'], common_norm=False, ax=axes[4])
axes[4].set_title('Distance from STR D2 MSN to Astrocytes by Donor')
sns.kdeplot(data=df_d2, x='distance', fill=True, hue='lab', palette=adata.uns['replicate_palette'], common_norm=False, ax=axes[5])
axes[5].set_title('Distance from STR D2 MSN to Astrocytes by Lab')
plt.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(12, 3), nrows=1, ncols=2, sharey=True)
df_d1['distance_category'].value_counts().sort_index().plot.bar(ax=ax[0])
ax[0].set_ylabel('Number of Cells')
ax[0].set_xlabel('Distance Category')
ax[0].set_xticklabels(ax[0].get_xticklabels(), rotation=0)
ax[0].set_title('STR D1 MSN to Astrocyte Distance Categories')

df_d2['distance_category'].value_counts().sort_index().plot.bar(ax=ax[1])
ax[1].set_ylabel('Number of Cells')
ax[1].set_xlabel('Distance Category')
ax[1].set_xticklabels(ax[1].get_xticklabels(), rotation=0)
ax[1].set_title('STR D2 MSN to Astrocyte Distance Categories')

plt.tight_layout()
plt.show()

In [None]:
for _col in ['brain_region', 'donor', 'lab']:
    fig, ax = plt.subplots(figsize=(16, 3), nrows=1, ncols=2, sharey=True)

    ct_table = pd.crosstab(df_d1[_col], df_d1['distance_category'], normalize='index') * 100
    ct_table = ct_table.reindex(columns=['Very Close','Close','Intermediate','Far','Very Far'])
    ct_table.plot.bar(stacked=True, ax=ax[0], colormap='tab20')
    ax[0].set_ylabel('Percentage of Cells')
    ax[0].set_xlabel(_col.replace('_', ' ').title())
    ax[0].set_title(f'STR D1 MSN to Astrocyte Distance Categories by {_col.replace("_", " ").title()}')
    ax[0].legend(title='Distance Category', bbox_to_anchor=(1.05, 1), loc='upper left')

    ct_table = pd.crosstab(df_d2[_col], df_d2['distance_category'], normalize='index') * 100
    ct_table = ct_table.reindex(columns=['Very Close','Close','Intermediate','Far','Very Far'])
    ct_table.plot.bar(stacked=True, ax=ax[1], colormap='tab20')
    ax[1].set_ylabel('Percentage of Cells')
    ax[1].set_xlabel(_col.replace('_', ' ').title())
    ax[1].set_title(f'STR D2 MSN to Astrocyte Distance Categories by {_col.replace("_", " ").title()}')
    ax[1].legend(title='Distance Category', bbox_to_anchor=(1.05, 1), loc='upper left')

    plt.tight_layout()
    plt.show()

In [None]:
# D1
close_cells = df_d1[df_d1['distance_category'].isin(['Close'])]
far_cells = df_d1[df_d1['distance_category'].isin(['Far'])]

adata_d1_test = adata[adata.obs['Subclass'] == 'STR D1 MSN'].copy()
adata_d1_test.obs['distance_group'] = np.nan
adata_d1_test.obs.loc[adata_d1_test.obs.index.isin(close_cells['CELL_ID']), 'distance_group'] = 'Close'
adata_d1_test.obs.loc[adata_d1_test.obs.index.isin(far_cells['CELL_ID']), 'distance_group'] = 'Far'
adata_d1_test = adata_d1_test[~adata_d1_test.obs['distance_group'].isna()]


# D2 
close_cells = df_d2[df_d2['distance_category'].isin(['Close'])]
far_cells = df_d2[df_d2['distance_category'].isin(['Far'])]

adata_d2_test = adata[adata.obs['Subclass'] == 'STR D2 MSN'].copy()
adata_d2_test.obs['distance_group'] = np.nan
adata_d2_test.obs.loc[adata_d2_test.obs.index.isin(close_cells['CELL_ID']), 'distance_group'] = 'Close'
adata_d2_test.obs.loc[adata_d2_test.obs.index.isin(far_cells['CELL_ID']), 'distance_group'] = 'Far'
adata_d2_test = adata_d2_test[~adata_d2_test.obs['distance_group'].isna()]

In [None]:
sc.tl.rank_genes_groups(adata_d1_test, groupby='distance_group', method='t-test_overestim_var', pts=True)
sc.pl.rank_genes_groups_heatmap(adata_d1_test)

In [None]:
sc.tl.rank_genes_groups(adata_d2_test, groupby='distance_group', method='t-test_overestim_var', pts=True)
sc.pl.rank_genes_groups_heatmap(adata_d2_test)

## Get the Mapped to snm3C cells

In [None]:
imputation_path = "/anvil/projects/x-mcb130189/qzeng/analysis/251105_merfish_methylation_2/Imputation.Subclass_Restricted.mC_MERFISH.meta.csv"
df_impute = pd.read_csv(imputation_path, index_col=0)
df_impute.head()

In [None]:
perc_d1_included = adata_d1_test.obs_names.isin(df_impute.merfish_cell).sum() / adata_d1_test.n_obs
perc_d2_included = adata_d2_test.obs_names.isin(df_impute.merfish_cell).sum() / adata_d2_test.n_obs
print(f"Percentage of STR D1 MSN cells included in imputation data: {perc_d1_included:.2%}")
print(f"Percentage of STR D2 MSN cells included in imputation data: {perc_d2_included:.2%}")

In [None]:
df_impute_d1 = df_impute[df_impute.merfish_cell.isin(adata_d1_test.obs_names)].reset_index(names="mc_cell_id").set_index('merfish_cell')
df_impute_d1['distance_category'] = adata_d1_test.obs.loc[df_impute_d1.index, 'distance_group']
# df_impute_d1.head()

df_impute_d2 = df_impute[df_impute.merfish_cell.isin(adata_d2_test.obs_names)].reset_index(names="mc_cell_id").set_index('merfish_cell')
df_impute_d2['distance_category'] = adata_d2_test.obs.loc[df_impute_d2.index, 'distance_group']
# df_impute_d2.head()

In [None]:
# for _col in ['Amit_Group', 'mC_Group']:
#     fig, ax = plt.subplots(figsize=(16, 4), nrows=1, ncols=2, sharey=True)

#     ct_table = pd.crosstab(df_impute_d1[_col], df_impute_d1['distance_category'], normalize='index') * 100
#     ct_table = ct_table.reindex(columns=['Very Close','Close','Intermediate','Far','Very Far'])
#     ct_table.plot.bar(stacked=True, ax=ax[0], colormap='tab20')
#     ax[0].set_ylabel('Percentage of Cells')
#     ax[0].set_xlabel(_col.replace('_', ' ').title())
#     ax[0].set_xticklabels(ax[0].get_xticklabels(), rotation=45, ha='right')
#     ax[0].set_title(f'STR D1 MSN to Astrocyte Distance Categories by {_col.replace("_", " ").title()}')
#     ax[0].legend(title='Distance Category', bbox_to_anchor=(1.05, 1), loc='upper left')

#     ct_table = pd.crosstab(df_impute_d2[_col], df_impute_d2['distance_category'], normalize='index') * 100
#     ct_table = ct_table.reindex(columns=['Very Close','Close','Intermediate','Far','Very Far'])
#     ct_table.plot.bar(stacked=True, ax=ax[1], colormap='tab20')
#     ax[1].set_ylabel('Percentage of Cells')
#     ax[1].set_xlabel(_col.replace('_', ' ').title())
#     ax[1].set_title(f'STR D2 MSN to Astrocyte Distance Categories by {_col.replace("_", " ").title()}')
#     ax[1].set_xticklabels(ax[1].get_xticklabels(), rotation=45, ha='right')
#     ax[1].legend(title='Distance Category', bbox_to_anchor=(1.05, 1), loc='upper left')

#     plt.tight_layout()
#     plt.show()

In [None]:
d1_far_mc_cells = df_impute_d1[df_impute_d1['distance_category'] == 'Far']['mc_cell_id']
d1_close_mc_cells = df_impute_d1[df_impute_d1['distance_category'] == 'Close']['mc_cell_id']

d2_far_mc_cells = df_impute_d2[df_impute_d2['distance_category'] == 'Far']['mc_cell_id']
d2_close_mc_cells = df_impute_d2[df_impute_d2['distance_category'] == 'Close']['mc_cell_id']

In [None]:
print("D1 Close vs. Far MC Cells: %i - %i" % (len(d1_close_mc_cells.tolist()), len(d1_far_mc_cells.tolist())))
print("D2 Close vs. Far MC Cells: %i - %i" % (len(d2_close_mc_cells.tolist()), len(d2_far_mc_cells.tolist())))

In [None]:
dm_out_path = Path("/home/x-aklein2/projects/aklein/BICAN/BG/data/methylation_2/dm_comp")
d1_far_mc_cells.to_csv(dm_out_path / "d1_far_mc_cells.txt", index=False, header=False)
d1_close_mc_cells.to_csv(dm_out_path / "d1_close_mc_cells.txt", index=False, header=False)
d2_far_mc_cells.to_csv(dm_out_path / "d2_far_mc_cells.txt", index=False, header=False)
d2_close_mc_cells.to_csv(dm_out_path / "d2_close_mc_cells.txt", index=False, header=False)

## Testing Geom Distances

In [None]:
dsid = "PU_UCI5224_salk"

In [None]:
shapes_sub = shapes_all[shapes_all['dataset_id'] == dsid]

In [None]:
tgt_ct = "Astrocyte"
distances = {}
for qry_ct in shapes_sub['Subclass'].unique():
    if qry_ct == tgt_ct:
        continue
    ret = get_geom_distances(shapes_sub, target_cell_type=tgt_ct, query_cell_type=qry_ct)
    distances[qry_ct] = ret[0]

In [None]:
qry_types = ["Oligodendrocyte", "STR D1 MSN", "STR D2 MSN", "CN ST18 GABA", "CN Cholinergic GABA", "Microglia"]
qry_types = ["STR D1 MSN", "STR D2 MSN"] # , "CN ST18 GABA"]
qry_types = ["CN ST18 GABA"]
# qry_types = ["Oligodendrocyte", "Microglia"]
color=[adata.uns['Subclass_palette'][c] for c in qry_types]

fig, ax = plt.subplots(figsize=(4, 4), dpi=200)
for _ct, col in zip(qry_types, color):
    sns.kdeplot(distances[_ct], label=_ct, ax=ax, fill=True, alpha=0.2, color=col, linewidth=1)
ax.set_xlabel(f'Distances from {qry_ct} (um)')
ax.set_ylabel('Density')
ax.set_title("Distribution of cell-types to Astrocyte Distances")
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
plt.show()


In [None]:
tgt_ct = "Astrocyte"
distances = {}
for qry_ct in shapes_sub['Group'].unique():
    if qry_ct == tgt_ct:
        continue
    ret = get_geom_distances(shapes_sub, cell_type_col = "Group", target_cell_type=tgt_ct, query_cell_type=qry_ct)
    distances[qry_ct] = ret[0]

In [None]:
for _k in distances.keys(): 
    print(_k)

In [None]:
qry_types = shapes_sub['MSN_Groups'].unique()
qry_types = qry_types[qry_types == qry_types]
palette_key = "MSN_Groups_palette"
color=[adata.uns[palette_key][c] if c in adata.uns[palette_key] else "grey" for c in qry_types]

fig, ax = plt.subplots(figsize=(4, 4), dpi=200)
for _ct, col in zip(qry_types, color):
    sns.kdeplot(distances[_ct], label=_ct, ax=ax, fill=True, alpha=0.2, color=col, linewidth=1)
ax.set_xlabel(f'Distances from {qry_ct} (um)')
ax.set_ylabel('Density')
ax.set_title("Distribution of cell-types to Astrocyte Distances")
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
plt.show()


In [None]:
qry_types = adata[adata.obs['Subclass'] == "CN ST18 GABA"].obs['Group'].unique()
palette_key = "Group_palette"
color=[adata.uns[palette_key][c] if c in adata.uns[palette_key] else "grey" for c in qry_types]

fig, ax = plt.subplots(figsize=(4, 4), dpi=200)
for _ct, col in zip(qry_types, color):
    sns.kdeplot(distances[_ct], label=_ct, ax=ax, fill=True, alpha=0.2, color=col, linewidth=1)
ax.set_xlabel(f'Distances from {qry_ct} (um)')
ax.set_ylabel('Density')
ax.set_title("Distribution of cell-types to Astrocyte Distances")
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
plt.show()


In [None]:
adata.obs['neuron_type']

In [None]:
qry_types

In [None]:
list(qry_types)

In [None]:
qry_types = adata[adata.obs['neuron_type'] == "Nonneuron"].obs['Group'].unique()
qry_types = qry_types[np.isin(qry_types, shapes_sub['Group'])]
qry_types = qry_types[qry_types != "Astrocyte"]
palette_key = "Group_palette"
color=[adata.uns[palette_key][c] if c in adata.uns[palette_key] else "grey" for c in qry_types]

fig, ax = plt.subplots(figsize=(4, 4), dpi=200)
for _ct, col in zip(qry_types, color):
    sns.kdeplot(distances[_ct], label=_ct, ax=ax, fill=True, alpha=0.2, color=col, linewidth=1)
ax.set_xlabel(f'Distances from {qry_ct} (um)')
ax.set_ylabel('Density')
ax.set_title("Distribution of cell-types to Astrocyte Distances")
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
plt.show()


In [None]:
cell_type_col = "Subclass"
query_ct = "STR D2 MSN"
target_ct = "Astrocyte"

shapes_qry = shapes_sub.loc[shapes_sub[cell_type_col] == query_ct]
shapes_tgt = shapes_sub.loc[shapes_sub[cell_type_col] == target_ct]

In [None]:
shapes_qry.shape, shapes_tgt.shape

In [None]:
len(shapes_qry['ID'].unique()), len(shapes_tgt['ID'].unique())

In [None]:
df_dist = gpd.sjoin_nearest(
    shapes_qry, 
    shapes_tgt, 
    how='inner',
    distance_col='distance'
)
dists = df_dist['distance'].values

In [None]:
df_dist['ID_right']

## Distance by centers

In [None]:
qry_ct = "Astrocyte"
distances = {}
for _ct in adata.obs['Subclass'].unique():
    if _ct == qry_ct:
        continue
    ret = get_closest_cell_of_type(adata, target_cell_type=_ct, query_cell_type=qry_ct)
    distances[_ct] = ret[0]

In [None]:
qry_types = ["Oligodendrocyte", "STR D1 MSN", "STR D2 MSN", "CN ST18 GABA", "CN Cholinergic GABA", "Microglia"]
color=sns.color_palette("husl", len(qry_types))

In [None]:
fig, ax = plt.subplots(figsize=(4, 4), dpi=200)
for _ct, col in zip(qry_types, color):
    sns.kdeplot(distances[_ct], label=_ct, ax=ax, fill=True, alpha=0.2, color=col, linewidth=1)
ax.set_xlabel(f'Distances from {qry_ct} (um)')
ax.set_ylabel('Density')
ax.set_title("Distribution of Astrocyte Distances to cell type")
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
plt.show()


In [None]:
qry_ct = "CN Cholinergic GABA"
distances = {}
for _ct in adata_sub.obs['AIT_Subclass'].unique():
    if _ct == qry_ct:
        continue
    ret = get_closest_cell_of_type(adata_sub, target_cell_type=_ct, query_cell_type=qry_ct)
    distances[_ct] = ret[0]

In [None]:
qry_types = ["Oligodendrocyte", "STR D1 MSN", "STR D2 MSN", "Astrocyte", "Microglia"]
color=sns.color_palette("tab10", len(qry_types))

In [None]:
fig, ax = plt.subplots(figsize=(4, 4), dpi=200)
for _ct, col in zip(qry_types, color):
    sns.kdeplot(distances[_ct], label=_ct, ax=ax, fill=True, alpha=0.2, color=col, linewidth=1)
ax.set_xlabel(f'Distances from {qry_ct} (um)')
ax.set_ylabel('Density')
ax.set_title("Distribution of CN Cholinergic GABA Distances to cell type")
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
plt.show()
