In [None]:
# parameters
ad_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad"
output_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/regions/mat_str_CPS"
contacts_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/cell_contacts_group_200um.csv"
w_thr = 150
str_alpha=0.006
str_buffer=25

In [None]:
from pathlib import Path

import pandas as pd
import numpy as np
import anndata as ad
import itertools
from tqdm import tqdm

from sklearn.decomposition import PCA, NMF
from sklearn.cluster import KMeans
import networkx as nx

import matplotlib.pyplot as plt
import seaborn as sns
from spida.pl import plot_categorical

plt.rcParams['figure.dpi'] = 150
plt.rcParams['savefig.dpi'] = 150
plt.rcParams['axes.facecolor'] = 'white'

import warnings
warnings.filterwarnings('ignore')

import libpysal as lps
import geopandas as gpd
import alphashape

In [None]:
Path(output_path).mkdir(parents=True, exist_ok=True)

In [None]:
# Get Adata
adata = ad.read_h5ad(ad_path)
adata.obs['Group'] = adata.obs['Group'].fillna("unknown")
adata

In [None]:
adata_neu = adata[adata.obs['neuron_type'] == 'Neuron'].copy()
# adata_neu.obs['Group'].value_counts()
donors = adata.obs['donor'].unique().tolist()
replicates = adata.obs['replicate'].unique().tolist()
brain_regions = adata.obs['brain_region'].unique().tolist()
skip = [("UWA7648", "CAT", "ucsd"), ("UWA7648", "CAT", "salk")]

In [None]:
# Get Contact Map
df_contacts = pd.read_csv(contacts_path, index_col=0)
df_contacts.head()

In [None]:
# MSN Specific adata
adata_msn = adata[~adata.obs['MSN_Groups'].isna()].copy()
adata_msn

In [None]:
# Get the msn cell types
msn_types = list(adata_msn.obs['MSN_Groups'].unique())
msn_types.remove("unknown")
print(msn_types)

In [None]:
# Get the MSN specific contact map
df_contacts_msn = df_contacts.loc[adata_msn.obs_names, df_contacts.columns.isin(msn_types)].copy()
print(df_contacts_msn.shape)
df_contacts_msn.head()

In [None]:
# Calculate the NMF and KMeans Clustering to call Mat. vs. Str. regions
W_msn = NMF(n_components='auto', init='nndsvd', random_state=444).fit_transform(df_contacts_msn)
kmeans = KMeans(n_clusters=3, random_state=444).fit_predict(W_msn)

adata_msn.obs['kmeans'] = kmeans
adata_msn.obs['kmeans'] = adata_msn.obs['kmeans'].astype('category')

adata_msn_sub = adata_msn[(adata_msn.obs['replicate'] == 'salk') & (adata_msn.obs['brain_region'] == "PU")].copy()
# plot_categorical(adata_msn_sub, coord_base="spatial", cluster_col='kmeans')
for _brain_region in adata_msn.obs['brain_region'].unique().tolist():
    adata_msn_br = adata_msn[adata_msn.obs['brain_region'] == _brain_region].copy()
    print(f"Brain Region: {_brain_region}")
    plot_categorical(adata_msn_br[adata_msn_br.obs['replicate'] == "salk"], coord_base="spatial", cluster_col='kmeans')

## Assign Splits and get regional defintions

In [None]:
# Assign Splits (based on prev. results)
adata_msn.obs['MS_split'] = adata_msn.obs['kmeans'].map({2: "Striosome", 0: "Matrix", 1: "Matrix"})
adata_msn.obs['MS_split'] = adata_msn.obs['MS_split'].astype('category')
adata_msn.write_h5ad(f"{output_path}/adata_msn_mat_str_split.h5ad")

In [None]:
# Get the spatial gdf
gdf_msn = gpd.GeoDataFrame(
    index=adata_msn.obs_names,
    geometry=gpd.points_from_xy(adata_msn.obs['CENTER_X'],adata_msn.obs['CENTER_Y'])
    ).set_crs(None, allow_override=True)
gdf_msn.head()

In [None]:
# Get the appropriate cols 
transfer_cols = ["donor", "replicate", "brain_region", "MSN_Groups", "MS_split"]
for _col in transfer_cols:
    gdf_msn[_col] = adata_msn.obs[_col]
gdf_msn.head()

In [None]:
brain_regions = ["PU", "CAH", "CAB", "CAT", "NAC"]

In [None]:
# From here on this needs to be iterable. 
contact_list = []
pbar = tqdm(itertools.product(donors, brain_regions, replicates))
for _i in pbar:
    if _i in skip:
        # print(f"Skipping {_i}")
        continue
    _donor, _brain_region, _replicate, = _i
    pbar.set_description(f"Processing {_donor} | {_brain_region} | {_replicate}")
    out_path_mat = f"{output_path}/{_donor}_{_brain_region}_{_replicate}_mat_regions.gpkg"
    out_path_str = f"{output_path}/{_donor}_{_brain_region}_{_replicate}_str_regions.gpkg"
    if (Path(out_path_mat).exists()) & (Path(out_path_str).exists()):
        print(f"Regions already exist for {_i}, skipping...")
        continue
        
    gdf_msn_sub = gdf_msn[(gdf_msn['replicate'] == _replicate) & (gdf_msn['brain_region'] == _brain_region) & (gdf_msn['donor'] == _donor)]
    gdf_msn_sub.head()

    gdf_msn_sub_mat = gdf_msn_sub[gdf_msn_sub['MS_split'] == "Matrix"]
    gdf_msn_sub_str = gdf_msn_sub[gdf_msn_sub['MS_split'] == "Striosome"]

    # Get the Weights KNN and remove cells that are isolated (probable noise)
    wb_mat = lps.weights.DistanceBand.from_dataframe(gdf_msn_sub_mat, threshold=w_thr)
    num_neighbors = [len(neigh) for neigh in wb_mat.neighbors.values()]
    gdf_msn_sub_mat['num_neighbors'] = num_neighbors
    gdf_msn_sub_mat['drop_cell'] = gdf_msn_sub_mat['num_neighbors'] < 2

    wb_str = lps.weights.DistanceBand.from_dataframe(gdf_msn_sub_str, threshold=w_thr)
    num_neighbors = [len(neigh) for neigh in wb_str.neighbors.values()]
    gdf_msn_sub_str['num_neighbors'] = num_neighbors
    gdf_msn_sub_str['drop_cell'] = gdf_msn_sub_str['num_neighbors'] < 2

    # Plotting for validataion 
    fig, ax = plt.subplots(1,1, figsize=(8,6))
    gdf_msn_sub_mat.plot(ax=ax, column="num_neighbors", cmap="Blues", markersize=1).axis("off");
    gdf_msn_sub_mat[gdf_msn_sub_mat['drop_cell']].plot(ax=ax, column="drop_cell", markersize=5, cmap="Greens_r").axis("off");
    gdf_msn_sub_str.plot(ax=ax, column="num_neighbors", cmap="Reds", markersize=1).axis("off");
    gdf_msn_sub_str[gdf_msn_sub_str['drop_cell']].plot(ax=ax, column="drop_cell", markersize=5, cmap="Purples_r").axis("off");
    plt.show()
    plt.close()


    # Recalculate the weights after dropping cells
    gdf_msn_sub_mat = gdf_msn_sub_mat[~gdf_msn_sub_mat['drop_cell']].copy()
    gdf_msn_sub_str = gdf_msn_sub_str[~gdf_msn_sub_str['drop_cell']].copy()

    wb_mat = lps.weights.DistanceBand.from_dataframe(gdf_msn_sub_mat, threshold=w_thr)
    wb_str = lps.weights.DistanceBand.from_dataframe(gdf_msn_sub_str, threshold=w_thr)

    # STR
    # Get connected components for individual alphashape calculations 
    G = nx.from_dict_of_lists(wb_str.neighbors)
    components = nx.connected_components(G)
    disconnected_comp = [comp for comp in components]
    # len(disconnected_comp)

    # If there are enough cells left in each components calculate the alphashape
    gdf_msn_sub_str['comp'] = -1
    geoms = []
    for i, disc in enumerate(disconnected_comp): 
        # print(len(disc))
        if len(disc) < 5:
            continue
        gdf_msn_sub_str.loc[list(disc), "comp"] = i
        temp = gdf_msn_sub_str[gdf_msn_sub_str.index.isin(disc)]
        geom = alphashape.alphashape(temp, alpha=str_alpha)
        geoms.append(geom)

    # Add the geometries to a gdf
    gdf_str_geoms = gpd.GeoDataFrame(geometry=geoms)

    # plot Geometries
    fig, ax = plt.subplots()
    gdf_msn_sub_str.plot(ax=ax, markersize=1, column="comp", cmap="tab20").axis("off")
    gdf_str_geoms.plot(ax=ax, edgecolor="k", facecolor="none").axis("off")
    plt.show()
    plt.close()

    # MAT
    G = nx.from_dict_of_lists(wb_mat.neighbors)
    components = nx.connected_components(G)
    disconnected_comp = [comp for comp in components]
    # len(disconnected_comp)

    gdf_msn_sub_mat['comp'] = -1
    geoms = []
    for i, disc in enumerate(disconnected_comp): 
        # print(len(disc))
        if len(disc) < 5:
            continue
        gdf_msn_sub_mat.loc[list(disc), "comp"] = i
        temp = gdf_msn_sub_mat[gdf_msn_sub_mat.index.isin(disc)]
        geom = alphashape.alphashape(temp)
        geoms.append(geom)
    gdf_mat_geoms = gpd.GeoDataFrame(geometry=geoms)

    # plot Geometries
    fig, ax = plt.subplots()
    gdf_msn_sub_mat.plot(ax=ax, markersize=1, column="comp", cmap="tab20").axis("off")
    gdf_mat_geoms.plot(ax=ax, edgecolor="k", facecolor="none").axis("off")
    plt.show()
    plt.close()

    # Remove overlapping regions
    gdf_mat_geoms.geometry = gdf_mat_geoms.geometry.difference(gdf_str_geoms.unary_union.buffer(str_buffer))
    
    # Plot Final Geoms: 
    fig, ax = plt.subplots()
    gdf_str_geoms.plot(ax=ax, edgecolor="Red", facecolor="none").axis("off")
    gdf_mat_geoms.plot(ax=ax, edgecolor="Blue", facecolor="none").axis("off")
    plt.show()
    plt.close()

    # Save Geometries
    # out_path_mat = f"{output_path}/{_donor}_{_brain_region}_{_replicate}_mat_regions.gpkg"
    # out_path_str = f"{output_path}/{_donor}_{_brain_region}_{_replicate}_str_regions.gpkg"
    gdf_mat_geoms.to_file(out_path_mat, driver="GPKG")
    gdf_str_geoms.to_file(out_path_str, driver="GPKG")