In [None]:
# imports
import os
from pathlib import Path
import itertools 
from tqdm import tqdm

import math
import numpy as np
import pandas as pd
import anndata as ad
import spatialdata as sd

from rich import inspect
from spida.utilities._ad_utils import normalize_adata

import geopandas as gpd
from shapely import Polygon, Point, box
from sklearn.mixture import GaussianMixture
import libpysal as lps
import networkx as nx

import matplotlib.pyplot as plt
import seaborn as sns
import spatialdata_plot as sdp # type: ignore
plt.rcParams['figure.dpi'] = 150
from matplotlib.colors import TwoSlopeNorm


from spida.utilities.tiling import create_hexagonal_grid

In [None]:
# parameters
# ad_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_PFV8_annotated_v5.h5ad"
# sd_store = "/home/x-aklein2/projects/aklein/BICAN/data/zarr_store"
# wm_dir_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/regions/wm_v2"
# matstr_dir_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/regions/mat_str"
# geom_store_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/regions/region_geometries.parquet"

ad_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad"
sd_store = "/home/x-aklein2/projects/aklein/BICAN/data/zarr_store"
wm_dir_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/regions/wm_v4"
matstr_dir_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/regions/mat_str_CPS"
geom_store_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/regions/region_geometries_cps.parquet"

### functions (I need to just create a collection of plotting functions I use!)

In [None]:
def create_stacked_bar_chart(df, group_column, cell_type_column='cell_type', 
                           figsize=(12, 8), title=None, colors=None, 
                           show_percentages=True, rotation=45, rasterized=False):
    """
    Create a stacked bar chart showing cell type percentages across groups.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        The input dataframe containing the data
    group_column : str
        Column name to group by (x-axis categories)
    cell_type_column : str, default 'cell_type'
        Column name containing cell type information
    figsize : tuple, default (12, 8)
        Figure size (width, height)
    title : str, optional
        Chart title
    colors : list or dict, optional
        Colors for cell types. If None, uses seaborn default palette
    show_percentages : bool, default True
        Whether to show percentage labels on bars
    rotation : int, default 45
        Rotation angle for x-axis labels
    
    Returns:
    --------
    fig, ax : matplotlib figure and axis objects
    """
    
    # Calculate cell type counts and percentages
    counts = df.groupby([group_column, cell_type_column]).size().unstack(fill_value=0)
    percentages = counts.div(counts.sum(axis=1), axis=0) * 100
    
    # Set up colors
    n_cell_types = len(counts.columns)
    if colors is None:
        colors = sns.color_palette("Set3", n_cell_types)
    elif isinstance(colors, dict):
        colors = [colors.get(ct, 'gray') for ct in counts.columns]
    
    # Create the plot
    fig, ax = plt.subplots(figsize=figsize)
    
    # Create stacked bar chart
    bottom = np.zeros(len(percentages))
    bars = []
    
    for i, cell_type in enumerate(percentages.columns):
        bar = ax.bar(percentages.index, percentages[cell_type], 
                    bottom=bottom, label=cell_type, color=colors[i],
                    rasterized=rasterized)
        bars.append(bar)
        
        # Add percentage labels if requested
        if show_percentages:
            for j, (idx, value) in enumerate(percentages[cell_type].items()):
                if value > 2:  # Only show label if percentage > 2%
                    ax.text(j, bottom[j] + value/2, f'{value:.1f}%', 
                           ha='center', va='center', fontsize=8, fontweight='bold',
                           rasterized=rasterized)
        
        bottom += percentages[cell_type]
    
    # Customize the plot
    ax.set_xlabel(group_column.replace('_', ' ').title(), fontsize=12, rasterized=rasterized)
    ax.set_ylabel('Percentage (%)', fontsize=12, rasterized=rasterized)
    ax.set_ylim(0, 100)
    
    if title:
        ax.set_title(title, fontsize=14, fontweight='bold', rasterized=rasterized)
    else:
        ax.set_title(f'Cell Type Distribution by {group_column.replace("_", " ").title()}', 
                    fontsize=14, fontweight='bold', rasterized=rasterized)
    
    # Rotate x-axis labels
    plt.xticks(rotation=rotation, ha='right')
    
    # Add legend
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=10)
    
    # Add grid for better readability
    ax.grid(axis='y', alpha=0.3, linestyle='--')
    
    # plt.tight_layout()
    
    return fig, ax

# Example usage:
# Assuming you have a dataframe 'df' with columns 'region' and 'cell_type'
# fig, ax = create_stacked_bar_chart(df, group_column='region', cell_type_column='cell_type')
# plt.show()

# Alternative simpler version for quick use:
def quick_stacked_bar(df, group_col, cell_type_col='cell_type', rasterized=False):
    """Quick version with minimal customization"""
    counts = df.groupby([group_col, cell_type_col]).size().unstack(fill_value=0)
    percentages = counts.div(counts.sum(axis=1), axis=0) * 100
    
    ax = percentages.plot(kind='bar', stacked=True, figsize=(10, 6), 
                         colormap='Set3', rot=45, rasterized=rasterized)
    ax.set_ylabel('Percentage (%)')
    ax.set_title(f'Cell Type Distribution by {group_col}')
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    return ax

In [None]:
### Function for combining overlapping matrix - striosome - and white matter geometries
str_buffer = 0.25
def _combine_ms_wm(geoms): 

    new_geoms = geoms.copy()
    str_geoms = geoms[geoms['type'] == "Striosome"].copy()
    new_geoms.geometry = new_geoms.geometry.difference(str_geoms.union_all().buffer(str_buffer))
    new_geoms.loc[str_geoms.index, str_geoms.columns] = str_geoms
    mat_geoms = new_geoms[new_geoms['type'] == "Matrix"].copy()
    wm_geoms = new_geoms[new_geoms['type'] == "White_Matter"].copy()

    wm_geoms['wm_id'] = range(len(wm_geoms))
    mat_geoms['mat_id'] = range(len(mat_geoms))
    str_geoms['str_id'] = range(len(str_geoms))

    ints = gpd.overlay(mat_geoms, wm_geoms, how='intersection')
    ints['mat_area'] = ints['mat_id'].apply(lambda x: mat_geoms.loc[mat_geoms['mat_id'] == x, 'geometry'].values[0].area)
    ints['wm_area'] = ints['wm_id'].apply(lambda x: wm_geoms.loc[wm_geoms['wm_id'] == x, 'geometry'].values[0].area)

    # Remove overlapping regions
    new_geoms.geometry = new_geoms.geometry.difference(ints.geometry.union_all().buffer(str_buffer))

    new_ints = []
    for index, row in ints.iterrows(): 
        if row['mat_area'] < row['wm_area']:
            keep = "1"
            keep_idx = row['mat_id']
        else: 
            keep = "2"
            keep_idx = row['wm_id']
        keep_row = [keep_idx]
        for _col in geoms.columns: 
            if _col == "geometry": 
                keep_row.append(row['geometry'])
            else: 
                keep_row.append(row[f"{_col}_{keep}"])
        new_ints.append(keep_row)
    add_ints = gpd.GeoDataFrame(new_ints, columns=["keep_idx"] + list(geoms.columns))

    new_geoms = pd.concat([new_geoms, add_ints.drop(columns=["keep_idx"])])
    new_geoms = new_geoms.loc[~new_geoms.is_empty]
    return new_geoms



### Continue

In [None]:
adata = ad.read_h5ad(ad_path)
adata

In [None]:
gdf_cells = gpd.GeoDataFrame(adata.obs, geometry=gpd.points_from_xy(adata.obs['CENTER_X'], adata.obs['CENTER_Y']), crs=None)
gdf_cells.head()

In [None]:
donors = adata.obs['donor'].unique().tolist()
replicates = adata.obs['replicate'].unique().tolist()
brain_regions = adata.obs['brain_region'].unique().tolist()
experiments = adata.obs['experiment'].unique().tolist()
skip = [("UWA7648", "CAT", "ucsd"), ("UWA7648", "CAT", "salk")]

In [None]:
gdf_cells.shape, adata.shape

In [None]:
remove_cols = ["MS_NORM", "MS_SCORE", "wm_compartment", "MS_compartment"]
for col in remove_cols: 
    if col in adata.obs.columns: 
        adata.obs.drop(columns=[col], inplace=True)

In [None]:
geoms = []
pbar = tqdm(itertools.product(donors, brain_regions, replicates))
for counter, _i in enumerate(pbar):
    if _i in skip:
        # print(f"Skipping {_i}")
        continue
    _donor, _brain_region, _replicate, = _i
    _experiment, _region = adata.obs.loc[(adata.obs['donor'] == _donor) & 
                           (adata.obs['brain_region'] == _brain_region) & 
                           (adata.obs['replicate'] == _replicate), ['experiment', 'region']].values[0]

    gdf_sub = gdf_cells[(gdf_cells['donor'] == _donor) &
                    (gdf_cells['brain_region'] == _brain_region) &
                    (gdf_cells['replicate'] == _replicate)]
    gdf_sub.head()

    pbar.set_description(f"Processing {_i} ({_experiment}, {_region})")
    path_wm = f"{wm_dir_path}/{_donor}_{_brain_region}_{_replicate}_wm_regions.gpkg"
    path_str = f"{matstr_dir_path}/{_donor}_{_brain_region}_{_replicate}_str_regions.gpkg"
    path_mat = f"{matstr_dir_path}/{_donor}_{_brain_region}_{_replicate}_mat_regions.gpkg"
    sub_geoms = []
    if (Path(path_wm).exists()):
        wm_geoms = gpd.read_file(path_wm)
        wm_geoms = wm_geoms.explode()
        wm_geoms['type'] = 'White_Matter'
        wm_geoms['donor'] = _donor
        wm_geoms['brain_region'] = _brain_region
        wm_geoms['lab'] = _replicate
        sub_geoms.append(wm_geoms)
        wm_cells = gpd.sjoin(gdf_sub, wm_geoms, how="inner", predicate='within')
        adata.obs.loc[wm_cells.index, 'wm_compartment'] = "WM"
    if (Path(path_str).exists() & Path(path_mat).exists()):
        str_geoms = gpd.read_file(path_str)
        str_geoms['type'] = 'Striosome'
        str_geoms['donor'] = _donor
        str_geoms['brain_region'] = _brain_region
        str_geoms['lab'] = _replicate
        sub_geoms.append(str_geoms)
        mat_geoms = gpd.read_file(path_mat)
        mat_geoms['type'] = 'Matrix'
        mat_geoms['donor'] = _donor
        mat_geoms['brain_region'] = _brain_region
        mat_geoms['lab'] = _replicate
        sub_geoms.append(mat_geoms)
        
        sub_geoms = pd.concat(sub_geoms, ignore_index=True)
        sub_geoms = _combine_ms_wm(sub_geoms)
        mat_geoms = sub_geoms[sub_geoms['type'] == "Matrix"].copy()
        str_geoms = sub_geoms[sub_geoms['type'] == "Striosome"].copy()

        mat_cells = gpd.sjoin(gdf_sub, mat_geoms, how="inner", predicate='within')
        mat_cells = mat_cells.loc[~mat_cells.index.duplicated(keep="first")]
        str_cells = gpd.sjoin(gdf_sub, str_geoms, how="inner", predicate='within')
        str_cells = str_cells.loc[~str_cells.index.duplicated(keep="first")]

        adata.obs.loc[mat_cells.index, 'MS_compartment'] = "Matrix"
        adata.obs.loc[str_cells.index, 'MS_compartment'] = "Striosome"

        distances = []
        for i, _str in enumerate(str_geoms.geometry):
            distances.append(mat_cells.distance(_str))
        mat_cells['MS_SCORE'] = pd.DataFrame(distances).T.min(axis=1)
        mat_cells['MS_NORM'] = (
                mat_cells['MS_SCORE']
                .subtract(mat_cells['MS_SCORE'].min(), axis=0)
                .div(mat_cells['MS_SCORE'].quantile(0.99) - mat_cells['MS_SCORE'].min(), axis=0)
                .clip(0, 1)
            )
        for i, _mat in enumerate(mat_geoms.geometry):
            distances.append(str_cells.distance(_mat))
        str_cells['MS_SCORE'] = pd.DataFrame(distances).T.min(axis=1)
        str_cells['MS_NORM'] = (
                str_cells['MS_SCORE']
                .subtract(str_cells['MS_SCORE'].min(), axis=0)
                .div(str_cells['MS_SCORE'].quantile(0.99) - str_cells['MS_SCORE'].min(), axis=0)
                .clip(0, 1)
            ) * -1
        str_cells['MS_SCORE'] = str_cells['MS_SCORE'] * -1

        cells = pd.concat([mat_cells, str_cells])
        adata.obs.loc[cells.index, 'MS_SCORE'] = cells['MS_SCORE']
        adata.obs.loc[cells.index, 'MS_NORM'] = cells['MS_NORM']

    if isinstance(sub_geoms, list):
        sub_geoms = pd.concat(sub_geoms, ignore_index=True)
    geoms.append(sub_geoms)
    # break
    # if counter == 11: 
    #     break

df_geoms = pd.concat(geoms)

In [None]:
geoms_colors = {
    "White_Matter": "lightgray",
    "Striosome": "red",
    "Matrix": "blue"
}
df_geoms['type_color'] = df_geoms['type'].map(geoms_colors)

In [None]:
adata.write_h5ad(ad_path)
df_geoms.to_parquet(geom_store_path)

## Make Scoring Distribution

In [None]:
fig, ax = plt.subplots()
sub_geoms[sub_geoms['type'] == "White_Matter"].plot(ax=ax, color='lightgray', edgecolor='black')
sub_geoms[sub_geoms['type'] == "Striosome"].plot(ax=ax, color='red', edgecolor='black', alpha=0.5)
sub_geoms[sub_geoms['type'] == "Matrix"].plot(ax=ax, color='blue', edgecolor='black', alpha=0.5)
# int_geoms.plot(ax=ax, color='purple', edgecolor='black', alpha=0.7)
ax.set_title(f"{_donor} {_brain_region} {_replicate}")
plt.show()

In [None]:
# distances = []
# for i, _str in enumerate(str_geoms.geometry):
#     distances.append(mat_cells.distance(_str))
# mat_cells['MS_SCORE'] = pd.DataFrame(distances).T.min(axis=1)

# for i, _mat in enumerate(mat_geoms.geometry):
#     distances.append(str_cells.distance(_mat))
# str_cells['MS_SCORE'] = pd.DataFrame(distances).T.min(axis=1)*-1

In [None]:
# cells = pd.concat([mat_cells, str_cells])

In [None]:
adata

In [None]:
fig, ax = plt.subplots()

# norm = TwoSlopeNorm(vmin=cells['MS_SCORE'].min(), vcenter=0, vmax=cells['MS_SCORE'].max())
# cells.plot(ax=ax, column="MS_SCORE", cmap='coolwarm_r', norm=norm, edgecolor='none', markersize=10, alpha=0.5, legend_kwds={"label" : "MS_SCORE"}, legend=True).axis("off");
cells.plot(ax=ax, column="MS_NORM", cmap='coolwarm_r', edgecolor='none', markersize=10, alpha=0.5, legend_kwds={"label" : "MS_NORM"}, legend=True).axis("off");
cbar_ax = fig.axes[1]
# cbar_ax.set_yticks([-400, -200, 0, 500, 1000, 1500])
# cbar_ax.set_yticklabels(['-400', '-200', '0', '500', '1000', '1500'])
cbar_ax.tick_params(labelsize=12)

str_geoms.plot(ax=ax, color='red', edgecolor='none', alpha=0.05)
mat_geoms.plot(ax=ax, color='blue', edgecolor='none', alpha=0.05)
str_geoms.plot(ax=ax, color='none', edgecolor='black', linewidth=1, alpha=1)
mat_geoms.plot(ax=ax, color='none', edgecolor='black', linewidth=1, alpha=1)

# ax.legend()
ax.set_title(f"{_donor} {_brain_region} {_replicate}")
plt.show()

In [None]:
ad_cell_ss = gpd.GeoDataFrame(adata.obs, geometry=gpd.points_from_xy(adata.obs['CENTER_X'], adata.obs['CENTER_Y']), crs=None)
ad_cell_ss.head()

In [None]:
_donor, _brain_region, _replicate

In [None]:
cells_ss = ad_cell_ss[(ad_cell_ss['donor'] == _donor) & 
                  (ad_cell_ss['brain_region'] == _brain_region) & 
                  (ad_cell_ss['replicate'] == _replicate)].copy()
cells_ss.head()

In [None]:
fig, ax = plt.subplots()

norm = TwoSlopeNorm(vmin=cells_ss['MS_NORM'].min(), vcenter=0, vmax=cells_ss['MS_NORM'].max())
cells_ss.plot(ax=ax, color="gray", edgecolor='none', markersize=2, alpha=0.5);
cells_ss.plot(ax=ax, column="MS_NORM", cmap='coolwarm_r', edgecolor='none', markersize=2, alpha=0.75, legend_kwds={"label" : "MS_NORM"}, norm=norm, legend=True).axis("off");
cbar_ax = fig.axes[1]
# cbar_ax.set_yticks([-400, -200, 0, 500, 1000, 1500])
# cbar_ax.set_yticklabels(['-400', '-200', '0', '500', '1000', '1500'])
# cbar_ax.tick_params(labelsize=12)

# sub_geoms[sub_geoms['type'] == "White_Matter"].plot(ax=ax, color=sub_geoms['type_color'], edgecolor='none', alpha=0.1)
# sub_geoms[sub_geoms['type'] == "White_Matter"].plot(ax=ax, color='none', edgecolor='black', alpha=1, linewidth=0.5)

# sub_geoms[sub_geoms['type'] == "Striosome"].plot(ax=ax, color=sub_geoms['type_color'], edgecolor='none', alpha=0.5)
# sub_geoms[sub_geoms['type'] == "Striosome"].plot(ax=ax, color='none', edgecolor='black', alpha=1, linewidth=0.5)

# sub_geoms[sub_geoms['type'] == "Matrix"].plot(ax=ax, color=sub_geoms['type_color'], edgecolor='none', alpha=0.5)
# sub_geoms[sub_geoms['type'] == "Matrix"].plot(ax=ax, color='none', edgecolor='black', alpha=1, linewidth=0.5)

# sub_geoms[sub_geoms['type'] == "Striosome"].plot(ax=ax, color='red', edgecolor='black', alpha=0.1)
# sub_geoms[sub_geoms['type'] == "Matrix"].plot(ax=ax, color='blue', edgecolor='black', alpha=0.1)

# str_geoms.plot(ax=ax, color='red', edgecolor='none', alpha=0.05)
# mat_geoms.plot(ax=ax, color='blue', edgecolor='none', alpha=0.05)
# str_geoms.plot(ax=ax, color='none', edgecolor='black', linewidth=1, alpha=1)
# mat_geoms.plot(ax=ax, color='none', edgecolor='black', linewidth=1, alpha=1)

# ax.legend()
ax.set_title(f"{_donor} {_region} {_replicate}")
plt.show()

## Correlation with MS SCORE

In [None]:
adata.obs['MSN_Groups'].value_counts()

In [None]:
adata_msn = adata[~adata.obs['MSN_Groups'].isna()].copy()
adata_msn

In [None]:
geoms_colors = {
    "White_Matter": "lightgray",
    "Striosome": "red",
    "Matrix": "blue"
}
df_geoms['type_color'] = df_geoms['type'].map(geoms_colors)

### Investigating NANs in the data
Cases where there are Nans in the MS_SCORE column: 
1. When there are no striosomes detected (Like in the NAC UCI4723 region)
2. When there are cells that are neither within the Matrix nor within the Striosome (look at CAH UWA7648 salk / ucsd for example)

In [None]:
adata_msn[adata_msn.obs['MS_SCORE'].isna()].obs[['brain_region', 'donor', 'replicate']].value_counts()

In [None]:
_donor = "UWA7648"
_brain_region = "CAH"
_replicate = "salk"
sub_geoms = df_geoms.loc[(df_geoms['brain_region'] == _brain_region) & (df_geoms['donor'] == _donor) & (df_geoms['lab'] == _replicate)]
sub_cells = adata_msn[(adata_msn.obs['brain_region'] == _brain_region) & (adata_msn.obs['donor'] == _donor) & (adata_msn.obs['replicate'] == _replicate)].obs.copy()
sub_cells = gpd.GeoDataFrame(sub_cells, geometry=gpd.points_from_xy(sub_cells['CENTER_X'], sub_cells['CENTER_Y']), crs=None)

In [None]:
sub_cells.shape

In [None]:
sub_cells['MS_SCORE'].isna().sum(), sub_cells.shape[0]

In [None]:
fig, ax = plt.subplots()

norm = TwoSlopeNorm(vmin=sub_cells['MS_SCORE'].min(), vcenter=0, vmax=sub_cells['MS_SCORE'].max())
# sub_cells.plot(ax=ax, color="orange", edgecolor='none', markersize=10, alpha=0.5,).axis("off");
sub_cells.plot(ax=ax, column="MS_SCORE", cmap='coolwarm_r', norm=norm, edgecolor='none', markersize=10, alpha=0.5, legend_kwds={"label" : "MS_SCORE"}, legend=True).axis("off");
cbar_ax = fig.axes[1]
# cbar_ax.set_yticks([-400, -200, 0, 500, 1000, 1500])
# cbar_ax.set_yticklabels(['-400', '-200', '0', '500', '1000', '1500'])
cbar_ax.tick_params(labelsize=12)

sub_geoms.plot(ax=ax, column="type", color=sub_geoms['type_color'], edgecolor='none', legend=True, alpha=0.1).axis("off");
sub_geoms.plot(ax=ax, column="type", color="none", edgecolor='black', legend=True).axis("off");
ax.legend()
ax.set_title(f"{_donor} {_brain_region} {_replicate}")
plt.show()

### Continue corr. analysis

In [None]:
adata.X = adata.layers['counts'].copy()
normalize_adata(
    adata, 
    log1p=True,
)

In [None]:
from spida.utilities.sd_utils import _get_obs_or_gene
from scipy.stats import pearsonr
from statsmodels.stats.multitest import multipletests 

In [None]:
adata_msn = adata_msn[~adata_msn.obs['MS_SCORE'].isna()].copy()
adata_msn

In [None]:
sns.histplot(adata_msn.obs['MS_SCORE'])

In [None]:
layer=None
returns = {}
for _gene in adata_msn.var_names: #['PDE10A', 'PDE8B', "CADM1", "DRD1", "DRD2"]:
    adata_msn, _drop_col = _get_obs_or_gene(adata_msn, _gene, layer) # get the column from obs or var
    df_obs = adata_msn.obs[["donor", "replicate", "brain_region", "Subclass", "Group", "MS_SCORE", "MS_compartment", _gene]].copy()
    if _drop_col:  # drop the column if it was added for plotting
        adata_msn.obs.drop(columns=[_gene], inplace=True)
    
    # df_obs = df_obs[(~df_obs[_gene].isna()) & (df_obs[_gene] != 0)]
    n_obs = (df_obs['MS_compartment'] == "Striosome").sum()
    keep_idx = df_obs.loc[df_obs["MS_compartment"] == "Matrix"].sample(n_obs, random_state=42).index
    df_obs = df_obs.loc[keep_idx.tolist() + df_obs.loc[df_obs["MS_compartment"] == "Striosome"].index.tolist()]
    
    # fig, ax = plt.subplots(figsize=(3, 3))
    # ax.scatter(df_obs['MS_SCORE'], df_obs[_gene], alpha=0.1, s=1)
    # ax.set_title(f"Correlation between MS_SCORE and {_gene}")
    # plt.show()

    _stat, _p = pearsonr(df_obs['MS_SCORE'], df_obs[_gene])
    # print((_stat, _p))
    returns[_gene] = (_stat, _p)

In [None]:
df_returns = pd.DataFrame.from_dict(returns, orient='index', columns=['pearsonr_stat', 'pearsonr_pval'])
df_returns['pearsonr_pval_adj'] = multipletests(df_returns['pearsonr_pval'], method='fdr_bh', alpha=0.05, maxiter=1)[1]
df_returns = df_returns.sort_values(by='pearsonr_pval_adj', ascending=True)
df_returns.head()

In [None]:
toplot = df_returns.index[:10]
toplot

In [None]:
layer=None
for _gene in toplot:
    adata_msn, _drop_col = _get_obs_or_gene(adata_msn, _gene, layer) # get the column from obs or var
    df_obs = adata_msn.obs[["donor", "replicate", "brain_region", "Subclass", "Group", "MS_SCORE", "MS_compartment", _gene]].copy()
    if _drop_col:  # drop the column if it was added for plotting
        adata_msn.obs.drop(columns=[_gene], inplace=True)
    # df_obs = df_obs[(~df_obs[_gene].isna()) & (df_obs[_gene] != 0)]
    n_obs = (df_obs['MS_compartment'] == "Striosome").sum()
    keep_idx = df_obs.loc[df_obs["MS_compartment"] == "Matrix"].sample(n_obs, random_state=42).index
    df_obs = df_obs.loc[keep_idx.tolist() + df_obs.loc[df_obs["MS_compartment"] == "Striosome"].index.tolist()]
    
    fig, ax = plt.subplots(figsize=(3, 3))
    ax.scatter(df_obs['MS_SCORE'], df_obs[_gene], alpha=0.1, s=1)
    ax.set_title(f"Correlation between MS_SCORE and {_gene}")
    plt.show()

    _stat, _p = pearsonr(df_obs['MS_SCORE'], df_obs[_gene])
    print((_stat, _p))

In [None]:
_gene = "FKBP5"
pbar = tqdm(itertools.product(donors, brain_regions, replicates))
for counter, _i in enumerate(pbar):
    if _i in skip:
        # print(f"Skipping {_i}")
        continue
    _donor, _brain_region, _replicate = _i
    sub_geoms = df_geoms.loc[(df_geoms['brain_region'] == _brain_region) & (df_geoms['donor'] == _donor) & (df_geoms['lab'] == _replicate)]
    sub_cells = adata_msn[(adata_msn.obs['brain_region'] == _brain_region) & (adata_msn.obs['donor'] == _donor) & (adata_msn.obs['replicate'] == _replicate)].obs.copy()
    expr = adata_msn[sub_cells.index, _gene].X.toarray()
    sub_cells = gpd.GeoDataFrame(sub_cells, geometry=gpd.points_from_xy(sub_cells['CENTER_X'], sub_cells['CENTER_Y']), crs=None)
    sub_cells[_gene] = expr

    fig, axes = plt.subplots(1, 2, figsize=(10, 5), constrained_layout=True)
    ax = axes[0]

    # sub_cells.plot(ax=ax, color="orange", edgecolor='none', markersize=10, alpha=0.5, legend_kwds={"label" : "MS_SCORE"}, legend=True).axis("off");
    norm = TwoSlopeNorm(vmin=sub_cells['MS_SCORE'].min(), vcenter=0, vmax=sub_cells['MS_SCORE'].max())
    sub_cells.plot(ax=ax, column="MS_SCORE", cmap='coolwarm_r', norm=norm, edgecolor='none', markersize=10, alpha=0.5, legend_kwds={"label" : "MS_SCORE"}, legend=True).axis("off");
    cbar_ax = fig.axes[1]
    # cbar_ax.set_yticks([-400, -200, 0, 500, 1000, 1500])
    # cbar_ax.set_yticklabels(['-400', '-200', '0', '500', '1000', '1500'])
    cbar_ax.tick_params(labelsize=12)

    sub_geoms.plot(ax=ax, column="type", color=sub_geoms['type_color'], edgecolor='none', legend=True, alpha=0.1).axis("off");
    sub_geoms.plot(ax=ax, column="type", color="none", edgecolor='black', legend=True).axis("off");
    ax.legend()
    ax.set_title(f"{_donor} {_brain_region} {_replicate}")

    ax = axes[1]
    try: 
        norm = TwoSlopeNorm(vmin=sub_cells[_gene].min(), vcenter=sub_cells[_gene].median()*0.8, vmax=sub_cells[_gene].max()*0.8)
    except ValueError as e: 
        norm = None
    sub_cells.plot(ax=ax, column=_gene, cmap='RdYlGn_r', norm=norm, edgecolor='none', markersize=10, alpha=0.9, legend=True).axis("off");
    cbar_ax.tick_params(labelsize=12)

    sub_geoms = sub_geoms.loc[sub_geoms['type'] != "White_Matter"].copy()
    # sub_geoms.plot(ax=ax, column="type", color=sub_geoms['type_color'], edgecolor='none', legend=True, alpha=0.05).axis("off");
    sub_geoms.plot(ax=ax, column="type", color="none", edgecolor='black', legend=True).axis("off");
    ax.legend()
    ax.set_title(f"{_donor} {_brain_region} {_replicate}")

    plt.show()

    if counter == 10: 
        break

In [None]:
import PyComplexHeatmap as pch

In [None]:
df_col = adata_msn.obs[['Group', 'MS_SCORE', 'MS_compartment']].copy()
df_col = df_col.sample(frac=0.25, random_state=42).sort_values('MS_SCORE')
print(df_col.shape)
df_col.head()

In [None]:
df_row = df_returns_w0.index[:100].to_frame()
print(df_row.shape)
df_row.head()

In [None]:
genes_to_plot = ["PDYN", "PCDH11X", "FKBP5", "KIRREL3", "PCDH11X", "CADM1", "DRD1", "DRD2", "TMEM132D", "PDE10A", "PDE8B", "SLC35D3", "RXRG"]

In [None]:
df_row['annot'] = [c if c in genes_to_plot else np.nan for c in df_row.index]
df_row.head()

In [None]:
df_expr = adata_msn.X.toarray()
df_expr = pd.DataFrame(df_expr, index=adata_msn.obs_names, columns=adata_msn.var_names).T
df_expr = df_expr.loc[df_row.index, df_col.index]
df_expr_norm = df_expr.subtract(df_expr.min(axis=1), axis=0).div(df_expr.max(axis=1) - df_expr.min(axis=1), axis=0)
print(df_expr_norm.shape)
df_expr_norm.head()

In [None]:
ms_score_norm = TwoSlopeNorm(vmin=df_col['MS_SCORE'].values.min() * 0.1, vcenter=0, vmax=df_col['MS_SCORE'].values.max() * 0.1)

col_ha = pch.HeatmapAnnotation(
    label=pch.anno_label(
        df_col['MS_compartment'], merge=True, rotation=90, extend=True,
        colors={"Matrix": "blue", "Striosome": "red"}, 
    ),
    MS_COMPARTMENT=pch.anno_simple(df_col['MS_compartment'], colors={"Matrix": "blue", "Striosome": "red"}),
    MS_SCORE=pch.anno_simple(df_col['MS_SCORE'], cmap="coolwarm_r", norm=ms_score_norm),
    verbose=1, axis=1, plot_legend=False
)

left_ha = pch.HeatmapAnnotation(
    label=pch.anno_label(
        df_row['annot'], merge=True, rotation=0, extend=True,
        colors="black", relpos=(1, 0.5), 
    ),
    # Genes=pch.anno_simple(df_row[0]),
    verbose=1, axis=0
)

plt.figure(figsize=(6,8))
cm = pch.ClusterMapPlotter(data=df_expr_norm,
                           top_annotation=col_ha,
                           left_annotation=left_ha,
                           row_cluster=False,
                           col_cluster=False,
                           row_dendrogram=False,
                           label="Expression",
                           cmap='LaJolla_r',
                           rasterized=True, 
                           ylabel="Genes",
                           xlabel="Cells",
                           # vmin=0, vmax=df_expr.values.max()*0.5
                           )

# Plot

In [None]:
df_wm = adata[adata.obs['wm_compartment'] == "WM"].obs.copy()
df_mat = adata[adata.obs['MS_compartment'] == "Matrix"].obs.copy()
df_str = adata[adata.obs['MS_compartment'] == "Striosome"].obs.copy()

In [None]:
fig, ax = create_stacked_bar_chart(df_wm, group_column='brain_region', cell_type_column='Subclass', title='Cell Type Distribution in White Matter Regions', colors=adata.uns['Subclass_palette'])

In [None]:
fig, ax = create_stacked_bar_chart(df_mat, group_column='brain_region', cell_type_column='MSN_Groups', title='Cell Type Distribution in Matrix Regions', colors=adata.uns['MSN_Groups_palette'])

In [None]:
fig, ax = create_stacked_bar_chart(df_str, group_column='brain_region', cell_type_column='MSN_Groups', title='Cell Type Distribution in Striosome Regions', colors=adata.uns['MSN_Groups_palette'])

In [None]:
# counter = 0
# ## Iterating for all elements
# pbar = tqdm(itertools.product(donors, brain_regions, replicates))
# for _i in pbar:
#     if _i in skip:
#         # print(f"Skipping {_i}")
#         continue
#     _donor, _brain_region, _replicate, = _i
#     _experiment, _region = adata.obs.loc[(adata.obs['donor'] == _donor) & 
#                            (adata.obs['brain_region'] == _brain_region) & 
#                            (adata.obs['replicate'] == _replicate), ['experiment', 'region']].values[0]
#     pbar.set_description(f"Processing {_i} ({_experiment}, {_region})")
#     out_path_wm = f"{output_path}/{_donor}_{_brain_region}_{_replicate}_wm_regions.gpkg"
#     if (Path(out_path_wm).exists()):
#         gdf_geoms = gpd.read_file(out_path_wm)

#     zarr_path = f"{sd_store}/{_experiment}/{_region}"
#     sdata = sd.read_zarr(zarr_path)
#     print(sdata)

#     cs = "pixel"
#     ch = "DAPI"
#     image_key = f"default_{_experiment}_{_region}_z3"
#     points_key = f"proseg_fv38_{_experiment}_{_region}_transcripts"
#     shapes_key = f"proseg_fv38_{_experiment}_{_region}_polygons"
#     tab_key1 = "proseg_fv38_table_filt"
#     tab_key2 = "proseg_fv38_annot"

#     geoms_key = "wm_regions"

#     sdata[geoms_key] = sd.models.ShapesModel().parse(gdf_geoms)
#     sd.transformations.set_transformation(
#         sdata[geoms_key],
#         sd.transformations.get_transformation(sdata[points_key], to_coordinate_system="pixel"),
#         to_coordinate_system="pixel"
#     )

#     sd.transformations.set_transformation(
#         sdata[geoms_key],
#         sd.transformations.get_transformation(sdata[points_key], to_coordinate_system="global"),
#         to_coordinate_system="global"
#     )

#     fig, ax = plt.subplots(figsize=(5,5))
#     (
#         sdata.pl.render_images(image_key, channel=ch, cmap="gray")
#         .pl.render_shapes(geoms_key, color="none", outline_color="red", outline_width=2, outline_alpha=1, fill_alpha=0.5)
#         .pl.show(ax=ax, coordinate_systems=cs)
#     )
#     ax.set_title(f"{_donor} {_brain_region} {_replicate} - WM Regions")
#     plt.show()
#     plt.close()

#     if counter == 10: 
#         break
    
#     counter += 1
    


#     # sdata[geoms_key] = sd.models.ShapesModel().parse(gdf_geoms)

