In [None]:
# imports
import os
from pathlib import Path
import itertools
from tqdm import tqdm

import numpy as np
import pandas as pd
import anndata as ad
from scipy.stats import norm
from statsmodels.stats.multitest import multipletests

import geopandas as gpd
from shapely import Polygon, Point, box

import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['figure.dpi'] = 150
from matplotlib.colors import TwoSlopeNorm
from adjustText import adjust_text

In [None]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['figure.figsize'] = (4, 4)
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rcParams['font.size'] = 14
plt.rcParams['axes.facecolor'] = 'white'

### functions

In [None]:
def permute_geometry(geometry_col):
    """
    Randomly permute a GeoSeries of point geometries.

    Parameters
    ----------
    geometry_col : geopandas.GeoSeries
        A column of shapely Points (e.g., gdf.geometry).

    Returns
    -------
    geopandas.GeoSeries
        Shuffled GeoSeries (same geometries, new order).
    """
    # Ensure input is a GeoSeries
    if not isinstance(geometry_col, gpd.GeoSeries):
        geometry_col = gpd.GeoSeries(geometry_col)

    # Shuffle indices
    shuffled = np.random.permutation(geometry_col)

    return shuffled

# functions from xingjiepan 2023 mouse atlas paper
def adjust_p_value_matrix_by_BH(p_val_mtx):
    '''Adjust the p-values in a matrix by the Benjamini/Hochberg method.
    The matrix should be symmetric.
    '''
    p_val_sequential = []
    N = p_val_mtx.shape[0]
    
    for i in range(N):
        for j in range(i, N):
            p_val_sequential.append(p_val_mtx[i, j])

    p_val_sequential_bh = multipletests(p_val_sequential, method='fdr_bh')[1]
    
    adjusted_p_val_mtx = np.zeros((N, N))
    
    counter = 0
    for i in range(N):
        for j in range(i, N):
            adjusted_p_val_mtx[i, j] = p_val_sequential_bh[counter]
            adjusted_p_val_mtx[j, i] = p_val_sequential_bh[counter]
            counter += 1
            
    return adjusted_p_val_mtx

def one_sided_pval(real, null_dist):
    """
    Calculate one-sided p-value for real value against null distribution.

    Parameters
    ----------
    real : float
        The observed real value.
    null_dist : array-like
        The null distribution values.

    Returns
    -------
    float
        One-sided p-value.
    """
    np.asarray(real)
    null_dist = np.array(null_dist)
    null_mean = np.mean(null_dist)
    null_std = np.maximum(np.std(null_dist), 1e-6)
    z_score = (real - null_mean) / null_std
    p_vals = norm.sf(np.abs(z_score))
    # adj_p_value = adjust_p_value_matrix_by_BH(p_vals)
    adj_p_value = multipletests(p_vals, method='fdr_bh')[1]
    # pval = (np.sum(null_dist >= real) + 1) / (len(null_dist) + 1)
    return p_vals, adj_p_value

### Functions for accumulating tests

In [None]:
def dl_tau2(yi, vi):
    """
    DerSimonian–Laird estimator of between-study variance Tau sq.
    """
    w = 1.0 / vi
    ybar = np.sum(w * yi) / np.sum(w)
    Q = np.sum(w * (yi - ybar) ** 2)
    k = len(yi)
    c = np.sum(w) - np.sum(w ** 2) / np.sum(w)
    tau2 = max(0.0, (Q - (k - 1)) / c) if c > 0 else 0.0
    return tau2, Q

def re_meta(yi, vi):
    """
    Run a random-effects meta-analysis given effect sizes yi and variances vi.
    Returns pooled mean, SE, z-statistic, p-value, CI, tau-sq, and I-sq.
    """
    tau2, Q = dl_tau2(yi, vi)
    w_star = 1.0 / (vi + tau2)
    mu = np.sum(w_star * yi) / np.sum(w_star)
    se = np.sqrt(1.0 / np.sum(w_star))
    z = mu / se if se > 0 else np.nan
    p = 2 * norm.sf(abs(z)) if np.isfinite(z) else np.nan
    ci_lb, ci_ub = mu - 1.96 * se, mu + 1.96 * se
    k = len(yi)
    I2 = max(0.0, (Q - (k - 1)) / Q) * 100 if (k > 1 and Q > 0) else 0.0
    return dict(mu=mu, se=se, z=z, p=p,
                ci_lb=ci_lb, ci_ub=ci_ub,
                tau2=tau2, Q=Q, k=k, I2=I2)

In [None]:
def create_stacked_bar_chart(df, group_column, cell_type_column='cell_type', 
                           figsize=(12, 8), title=None, colors=None, 
                           show_percentages=True, rotation=45, rasterized=False,
                           legend_threshold=5.0, text_threshold=2.0,
                           legend_fontsize=12, def_fontsize=12, title_fontsize=12,
                           xlabel=None, 
                        ):
    """
    Create a stacked bar chart showing cell type percentages across groups.
    
    Parameters:
    -----------
    df : pandas.DataFrame
        The input dataframe containing the data
    group_column : str
        Column name to group by (x-axis categories)
    cell_type_column : str, default 'cell_type'
        Column name containing cell type information
    figsize : tuple, default (12, 8)
        Figure size (width, height)
    title : str, optional
        Chart title
    colors : list or dict, optional
        Colors for cell types. If None, uses seaborn default palette
    show_percentages : bool, default True
        Whether to show percentage labels on bars
    rotation : int, default 45
        Rotation angle for x-axis labels
    legend_threshold : float, default 5.0
        Minimum percentage threshold for including cell types in legend
    
    Returns:
    --------
    fig, ax : matplotlib figure and axis objects
    """
    
    # Calculate cell type counts and percentages
    counts = df.groupby([group_column, cell_type_column]).size().unstack(fill_value=0)
    percentages = counts.div(counts.sum(axis=1), axis=0) * 100
    
    # Set up colors
    n_cell_types = len(counts.columns)
    if colors is None:
        colors = sns.color_palette("Set3", n_cell_types)
    elif isinstance(colors, dict):
        colors = [colors.get(ct, 'gray') for ct in counts.columns]
    
    # Create the plot
    fig, ax = plt.subplots(figsize=figsize)
    
    # Determine which cell types meet the legend threshold
    # Calculate max percentage for each cell type across all groups
    max_percentages = percentages.max(axis=0)
    legend_cell_types = max_percentages[max_percentages >= legend_threshold].index.tolist()
    
    # Create stacked bar chart
    bottom = np.zeros(len(percentages))
    bars = []
    
    for i, cell_type in enumerate(percentages.columns):
        # Only include in legend if it meets the threshold
        label = cell_type if cell_type in legend_cell_types else None
        
        bar = ax.bar(percentages.index, percentages[cell_type], 
                    bottom=bottom, label=label, color=colors[i], 
                    rasterized=rasterized)
        bars.append(bar)
        
        # Add percentage labels if requested
        if show_percentages:
            for j, (idx, value) in enumerate(percentages[cell_type].items()):
                if value > text_threshold:  # Only show label if percentage > 2%
                    ax.text(j, bottom[j] + value/2, f'{value:.1f}%', 
                           ha='center', va='center', fontsize=def_fontsize, fontweight='bold', 
                           rasterized=rasterized)
        
        bottom += percentages[cell_type]
    
    # Customize the plot
    if xlabel is None: 
        ax.set_xlabel(group_column.replace('_', ' ').title(), fontsize=def_fontsize)
    else: 
        ax.set_xlabel(xlabel, fontsize=def_fontsize)
    ax.set_ylabel('Percentage (%)', fontsize=def_fontsize)
    ax.set_ylim(0, 100)
    
    if title:
        ax.set_title(title, fontsize=title_fontsize, fontweight='bold', rasterized=rasterized)
    else:
        ax.set_title(f'Cell Type Distribution by {group_column.replace("_", " ").title()}', 
                    fontsize=title_fontsize, fontweight='bold', rasterized=rasterized)
    
    # Rotate x-axis labels
    if rotation != 0: 
        plt.xticks(rotation=rotation, ha='right', fontsize=def_fontsize)
    else: 
        plt.xticks(ha='center', fontsize=def_fontsize)
    
    # Add legend (only for cell types that meet the threshold)
    legend_handles = [bar for bar, ct in zip(bars, percentages.columns) if ct in legend_cell_types]
    if len(legend_handles) <= 20 and len(legend_handles) > 0:
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=legend_fontsize)
    elif len(legend_handles) > 20:
        # For many legend items, you might want to handle differently
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=legend_fontsize, ncol=2)
    


    # Add grid for better readability
    ax.grid(axis='y', alpha=0.3, linestyle='--')
    
    # plt.tight_layout()
    
    return fig, ax

# Example usage:
# Assuming you have a dataframe 'df' with columns 'region' and 'cell_type'
# fig, ax = create_stacked_bar_chart(df, group_column='region', cell_type_column='cell_type')
# plt.show()

# Alternative simpler version for quick use:
def quick_stacked_bar(df, group_col, cell_type_col='cell_type'):
    """Quick version with minimal customization"""
    counts = df.groupby([group_col, cell_type_col]).size().unstack(fill_value=0)
    percentages = counts.div(counts.sum(axis=1), axis=0) * 100
    
    ax = percentages.plot(kind='bar', stacked=True, figsize=(10, 6), 
                         colormap='Set3', rot=45)
    ax.set_ylabel('Percentage (%)')
    ax.set_title(f'Cell Type Distribution by {group_col}')
    plt.tight_layout()
    return ax

In [None]:
def plot_volcano(
    df,
    lfc_col="log_2FC",
    p_col="p_value",
    label_col=None,
    alpha=0.05,
    lfc_thresh=1.0,
    top_labels=10,
    title="Volcano Plot",
    figsize=(7,6),
    ax=None,
    save=None,
    label_fontsize=8,
    rasterized=False, 
):
    """
    Volcano plot (log2FC vs -log10 p-value) using pure Matplotlib,
    with adjustText-based non-overlapping labels.

    Parameters
    ----------
    df : pd.DataFrame
        DataFrame containing log2 fold change and p-values.
    lfc_col : str
        Column name for log2 fold change.
    p_col : str
        Column name for p-value.
    label_col : str, optional
        Column name for labeling points (e.g., cell type, interaction).
    alpha : float
        p-value threshold for significance.
    lfc_thresh : float
        Fold change threshold for significance.
    top_labels : int
        Number of most significant features to label.
    title : str
        Plot title.
    figsize : tuple
        Figure size.
    ax : matplotlib.axes.Axes or None
        Axis to plot on (creates new figure if None).
    save : str or None
        Path to save figure (if None, displays interactively).
    label_fontsize : int
        Font size for labels.
    """
    df = df.copy()
    df = df.replace([np.inf, -np.inf], np.nan).dropna(subset=[lfc_col, p_col])
    df["neglog10p"] = -np.log10(df[p_col].clip(lower=1e-300))  # avoid log(0)

    # Classify points
    sig_up = (df[p_col] < alpha) & (df[lfc_col] > lfc_thresh)
    sig_down = (df[p_col] < alpha) & (df[lfc_col] < -lfc_thresh)
    nonsig = ~(sig_up | sig_down)

    # Create figure / axis
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize)
    else:
        fig = ax.figure

    # Scatter points
    ax.scatter(
        df.loc[nonsig, lfc_col],
        df.loc[nonsig, "neglog10p"],
        color="lightgray", s=20, alpha=0.7, label="Not significant",
        rasterized=rasterized
    )
    ax.scatter(
        df.loc[sig_up, lfc_col],
        df.loc[sig_up, "neglog10p"],
        color="red", s=25, alpha=0.8, label="Up",
        rasterized=rasterized
    )
    ax.scatter(
        df.loc[sig_down, lfc_col],
        df.loc[sig_down, "neglog10p"],
        color="blue", s=25, alpha=0.8, label="Down",
        rasterized=rasterized
    )

    # Threshold lines
    ax.axhline(-np.log10(alpha), color="black", lw=1, ls="--", rasterized=rasterized)
    ax.axvline(lfc_thresh, color="black", lw=1, ls="--", rasterized=rasterized)
    ax.axvline(-lfc_thresh, color="black", lw=1, ls="--", rasterized=rasterized)

    # Labels and title
    ax.set_xlabel("log2(Fold Change)", fontsize=12)
    ax.set_ylabel("−log10(p-value)", fontsize=12)
    ax.set_title(title, fontsize=14)
    ax.legend(frameon=False)

    # --- Add non-overlapping labels ---
    if label_col is not None and label_col in df.columns:
        top = df.nsmallest(top_labels, p_col).copy()
        texts = []
        for _, row in top.iterrows():
            txt = ax.text(
                row[lfc_col],
                row["neglog10p"],
                str(row[label_col]),
                fontsize=label_fontsize,
                color="black",
                weight="bold",
                ha="center",
                va="bottom",
                rasterized=rasterized
            )
            texts.append(txt)

        # Adjust text to prevent overlap
        adjust_text(
            texts,
            ax=ax,
            arrowprops=dict(arrowstyle="-", color="gray", lw=0.5)
        )

    fig.tight_layout()

    if save:
        fig.savefig(save, dpi=300, bbox_inches="tight")
        print(f"Saved volcano plot to {save}")
    elif ax is None:  # Only show if not using external axes
        plt.show()

    return fig, ax


### Read

In [None]:
ad_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad"
geom_store_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/regions/region_geometries_cps.parquet"
str_buffer = 25
N_permute = 1000

In [None]:
adata = ad.read_h5ad(ad_path)
geoms = gpd.read_parquet(geom_store_path)

In [None]:
donors = adata.obs['donor'].unique().tolist()
replicates = adata.obs['replicate'].unique().tolist()
brain_regions = adata.obs['brain_region'].unique().tolist()
skip = [("UWA7648", "CAT", "ucsd"), ("UWA7648", "CAT", "salk")]

In [None]:
DIR = Path("/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/ms_enrichment")

In [None]:
_donor = donors[1]
_replicate = replicates[1]
_region = brain_regions[4]
print(f"Processing donor: {_donor}, replicate: {_replicate}, region: {_region}")

In [None]:
_level = "group"
_compartment = "white_matter"

In [None]:
input = DIR / f"ms_composition_{_level}_{_compartment}_{_donor}_{_region}_{_replicate}.csv"
ms_comp = pd.read_csv(input, index_col=0)

In [None]:
# Plottnig a single Experiment: 
# fig, axes = plt.subplots(2, 3, figsize=(15,10))
# for _i, _level in enumerate(["subclass", "group"]):
#     for _j, _compartment in enumerate(["white_matter", "matrix", "striosome"]):
#         ms_comp = pd.read_csv(
#             DIR / f"ms_composition_{_level}_{_compartment}_{_donor}_{_region}_{_replicate}.csv",
#             index_col=0
#         )
#         plot_volcano(ms_comp.reset_index(), p_col="p_value", lfc_thresh=.1,
#                      label_col="cell_type", ax=axes[_i, _j], title=f"{_level} - {_compartment}",
#                      top_labels=10)
# plt.show()

In [None]:
naming_map = {
    "subclass_white_matter": "Subclass - White Matter",
    "subclass_matrix": "Subclass - Matrix",
    "subclass_striosome": "Subclass - Striosome",
    "group_white_matter": "Group - White Matter",
    "group_matrix": "Group - Matrix",
    "group_striosome": "Group - Striosome",
}

In [None]:
agg_tables = {}
for _i, _level in enumerate(["subclass", "group"]):
    for _j, _compartment in enumerate(["white_matter", "matrix", "striosome"]):
        df_list = []
        for _file in DIR.glob(f"ms_composition_{_level}_{_compartment}*.csv"):
            _donor, _region, _lab = _file.stem.split("_")[-3:]
            df = pd.read_csv(_file)
            df['donor'] = _donor
            df['region'] = _region
            df['lab'] = _lab
            df['id'] = f"{_donor}|{_region}|{_lab}"
            df_list.append(df)
        df_ms = pd.concat(df_list, axis=0)

        rows = []
        for cat, df_cat in df_ms.groupby('cell_type'):
            # df_cat['var_null'] = df_cat['std_null_count'] ** 2
            res = re_meta(df_cat['log_2FC'].values, 1)
            res['cell_type'] = cat
            rows.append(res)
        df_rows = pd.DataFrame(rows)
        df_rows["p_fdr"] = multipletests(df_rows["p"], method="fdr_bh")[1]
        agg_tables[f"{_level}_{_compartment}"] = df_rows

In [None]:
agg_tables['group_striosome'].head()

## MS - Plots

In [None]:
image_path = "/home/x-aklein2/projects/aklein/BICAN/BG/images/figures/supp"

In [None]:
def _get_palette(key): 
    if key.split("_")[0] == "subclass": 
        return adata.uns['Subclass_palette']
    # elif key.split("_")[1] != "white": 
    #     return adata.uns['MSN_Groups_palette']
    else: 
        return adata.uns['Group_palette']


In [None]:
fig, axes = plt.subplots(2, 3, figsize=(15,10))
for i, (_key, df_ms) in enumerate(agg_tables.items()):
    # print(f"{_key}: {_value.shape}")
    ax = axes[i // 3, i % 3]
    palette = _get_palette(_key)

    rightmost = max(df_ms['ci_ub'])
    leftmost = min(df_ms['ci_lb'])
    star_x = rightmost + (rightmost - leftmost) / 10
    labels = []
    for idx, (_c, _row) in enumerate(df_ms.sort_values(by="ci_lb").iterrows()): 
        point = ax.scatter([_row['ci_lb'], _row['ci_ub']], [idx, idx], color=palette[_row['cell_type']])
        ax.plot([_row['ci_lb'], _row['ci_ub']], [idx, idx], color=palette[_row['cell_type']])
        labels.append(_row['cell_type'])
        stars = ""
        if _row['p_fdr'] < 0.01: 
            stars += "*"
        if _row['p_fdr'] < 0.001: 
            stars += "*"
        if _row['p_fdr'] < 0.0001: 
            stars += "*"
        ax.text(star_x, idx-1, stars)
        # print(stars)

    ax.axvline(0, linestyle='--', color='gray', alpha=0.5)
    ax.set_yticks(np.arange(0, len(labels)))
    ax.set_yticklabels(labels, fontsize=6)
    ax.grid(axis='y', linestyle='--', color='gray', alpha=0.25)
    ax.set_title(naming_map[_key])
plt.tight_layout()
plt.savefig(image_path + "/ms_enrichment_PI_all.pdf", dpi=300, bbox_inches="tight")
plt.savefig(image_path + "/ms_enrichment_PI_all.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path + "/ms_enrichment_PI_all.svg", dpi=300, bbox_inches="tight")
plt.show()

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(15,10))
for i, (_key, df_ms) in enumerate(agg_tables.items()):
    # print(f"{_key}: {_value.shape}")
    ax = axes[i // 3, i % 3]
    
    plot_volcano(df_ms, p_col="p_fdr", lfc_thresh=.1, lfc_col="mu",
                label_col="cell_type", ax=ax, title=naming_map[_key],
                top_labels=10, rasterized=True)
# plt.savefig(image_path + "/ms_enrichment_volcano_all.png", dpi=300, bbox_inches="tight")
# plt.savefig(image_path + "/ms_enrichment_volcano_all.pdf", dpi=300, bbox_inches="tight")
plt.savefig(image_path + "/ms_enrichment_volcano_all.svg", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

In [None]:
subclass_to_group_map = {}
for _sub, _gr in adata.obs[['Subclass', 'Group']].drop_duplicates().groupby("Subclass", observed=True):
    subclass_to_group_map[_sub] = _gr['Group'].cat.remove_unused_categories().values
    # break
for _key, _value in subclass_to_group_map.items():
    new_val = [v for v in _value if v != 'unknown']
    subclass_to_group_map[_key] = new_val

print(subclass_to_group_map)

In [None]:
### Specific Regions: 
subclasses_to_investigate = ['CN ST18 GABA', 'CN VIP GABA', 'CN LAMP5-CXCL14 GABA', "STR D1 MSN", "STR D2 MSN", "STR Hybrid MSN"]
fig, axes = plt.subplots(2, 3, figsize=(15,10))
for i, (_key, df_ms) in enumerate(agg_tables.items()):
    ax = axes[i // 3, i % 3]
    if _key.startswith("group"):
        df_sub = df_ms[df_ms['cell_type'].isin(
            list(itertools.chain.from_iterable(
                subclass_to_group_map[_sub] for _sub in subclasses_to_investigate
            )))]
    else:
        df_sub = df_ms[df_ms['cell_type'].isin(subclasses_to_investigate)]

    plot_volcano(df_sub, p_col="p_fdr", lfc_thresh=.1, lfc_col="mu",
                label_col="cell_type", ax=ax, title=naming_map[_key],
                top_labels=10)

# plt.savefig(image_path + "/ms_enrichment_volcano_sub.png", dpi=300, bbox_inches="tight")
# plt.savefig(image_path + "/ms_enrichment_volcano_sub.pdf", dpi=300, bbox_inches="tight")
plt.savefig(image_path + "/ms_enrichment_volcano_sub.svg", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

In [None]:
agg_tables.keys()

In [None]:
### Specific Regions: 
subclasses_to_investigate = ['CN ST18 GABA', 'CN VIP GABA', 'CN LAMP5-CXCL14 GABA']
ncols = len(subclasses_to_investigate)
nrows = 2
fig, axes = plt.subplots(nrows, ncols, figsize=(15,10))
for c, _sub in enumerate(subclasses_to_investigate):
    for i, (_key) in enumerate(['group_matrix', 'group_striosome']):
        df_ms = agg_tables[_key]
        ax = axes[i, c]
        df_sub = df_ms[df_ms['cell_type'].isin(subclass_to_group_map[_sub])]
        plot_volcano(df_sub, p_col="p_fdr", lfc_thresh=.1, lfc_col="mu",
                    label_col="cell_type", ax=ax, title=f"{naming_map[_key]} - {_sub}",
                    top_labels=10)
plt.show()

In [None]:
print(adata.shape)
adata = adata[adata.obs['neuron_type'] != "unknown"].copy()
print(adata.shape)

In [None]:
df_wm = adata[adata.obs['wm_compartment'] == "WM"].obs.copy()
df_mat = adata[adata.obs['MS_compartment'] == "Matrix"].obs.copy()
df_str = adata[adata.obs['MS_compartment'] == "Striosome"].obs.copy()

In [None]:
fig, ax = create_stacked_bar_chart(
    df_wm,
    group_column='brain_region_corr',
    rotation=0,
    cell_type_column='Group',
    title='Cell Type Distribution in White Matter Regions',
    colors=adata.uns['Group_palette'],
    rasterized=True,
    legend_threshold=2.0,
    text_threshold=5.0,
    legend_fontsize=14, 
    def_fontsize=14,
    title_fontsize=24,
    xlabel="")
# plt.savefig(image_path + "/ms_composition_wm_group.png", dpi=300, bbox_inches="tight")
# plt.savefig(image_path + "/ms_composition_wm_group.pdf", dpi=300, bbox_inches="tight")
plt.savefig(image_path + "/ms_composition_wm_group.svg", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

In [None]:
fig, ax = create_stacked_bar_chart(
    df_wm,
    group_column='brain_region_corr',
    rotation=0,
    cell_type_column='Subclass',
    title='Cell Type Distribution in White Matter Regions',
    colors=adata.uns['Subclass_palette'],
    rasterized=True,
    legend_threshold=2.0,
    text_threshold=5.0,
    legend_fontsize=14, 
    def_fontsize=14,
    title_fontsize=24, 
    xlabel="")
# plt.savefig(image_path + "/ms_composition_wm_subclass.png", dpi=300, bbox_inches="tight")
# plt.savefig(image_path + "/ms_composition_wm_subclass.pdf", dpi=300, bbox_inches="tight")
plt.savefig(image_path + "/ms_composition_wm_subclass.svg", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

In [None]:
fig, ax = create_stacked_bar_chart(
    df_mat,
    group_column='brain_region_corr',
    rotation=0,
    cell_type_column='MSN_Groups',
    title='Cell Type Distribution in Matrix Regions',
    colors=adata.uns['MSN_Groups_palette'],
    rasterized=True,
    legend_threshold=2.0,
    text_threshold=5.0,
    legend_fontsize=14, 
    def_fontsize=14,
    title_fontsize=24,
    xlabel="")
# plt.savefig(image_path + "/ms_composition_matrix_subclass.png", dpi=300, bbox_inches="tight")
# plt.savefig(image_path + "/ms_composition_matrix_subclass.pdf", dpi=300, bbox_inches="tight")
plt.savefig(image_path + "/ms_composition_matrix_subclass.svg", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

In [None]:
fig, ax = create_stacked_bar_chart(
    df_str,
    group_column='brain_region_corr',
    rotation=0,
    cell_type_column='MSN_Groups',
    title='Cell Type Distribution in Striosome Regions',
    colors=adata.uns['MSN_Groups_palette'],
    rasterized=True,
    legend_threshold=2.0,
    text_threshold=5.0,
    legend_fontsize=14, 
    def_fontsize=14,
    title_fontsize=24,
    xlabel="")
# plt.savefig(image_path + "/ms_composition_str_subclass.png", dpi=300, bbox_inches="tight")
# plt.savefig(image_path + "/ms_composition_str_subclass.pdf", dpi=300, bbox_inches="tight")
plt.savefig(image_path + "/ms_composition_str_subclass.svg", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

### Plotting the MS_NORM axis

In [None]:
gdf_cells = gpd.GeoDataFrame(adata.obs, geometry=gpd.points_from_xy(adata.obs['CENTER_X'], adata.obs['CENTER_Y']), crs=None)

In [None]:
for i, (_donor, _region, _replicate) in enumerate(geoms.groupby(['donor', 'brain_region', 'lab']).groups):
    sub_geoms = geoms[(geoms['donor'] == _donor) & 
                     (geoms['brain_region'] == _region) & 
                     (geoms['lab'] == _replicate)].copy()
    sub_geoms.head()
    if i == 34: 
        break

In [None]:
# sub_geoms = geoms[(geoms['donor'] == _donor) & 
#                  (geoms['brain_region'] == _region) & 
#                  (geoms['lab'] == _replicate)].copy()
# sub_geoms.head()

In [None]:
fig, ax = plt.subplots()
sub_geoms[sub_geoms['type'] == "White_Matter"].plot(ax=ax, color='lightgray', edgecolor='black')
sub_geoms[sub_geoms['type'] == "Striosome"].plot(ax=ax, color='red', edgecolor='black', alpha=0.5)
sub_geoms[sub_geoms['type'] == "Matrix"].plot(ax=ax, color='blue', edgecolor='black', alpha=0.5)
# int_geoms.plot(ax=ax, color='purple', edgecolor='black', alpha=0.7)
ax.set_title(f"{_donor} {_region} {_replicate}")
plt.show()

In [None]:
cells = gdf_cells[(gdf_cells['donor'] == _donor) & 
                  (gdf_cells['brain_region'] == _region) & 
                  (gdf_cells['replicate'] == _replicate)].copy()
_region_corr = cells['brain_region_corr'].unique()[0]
# cells.head()

In [None]:
# gdf.plot(ax=ax, column="Group", markersize=1, color=gdf['group_color'])

In [None]:
cells['group_color'] = cells['Group'].map(adata.uns['Group_palette'])
cells['subclass_color'] = cells['Subclass'].map(adata.uns['Subclass_palette'])

In [None]:
fig, ax = plt.subplots(dpi=300)
norm = TwoSlopeNorm(vmin=cells['MS_NORM'].min(), vcenter=0, vmax=cells['MS_NORM'].max())
cells.plot(ax=ax, color="gray", edgecolor='none', markersize=2, alpha=0.5, rasterized=True);
cells.plot(ax=ax, column="MS_NORM", cmap='coolwarm_r', edgecolor='none', markersize=2, alpha=0.75, legend_kwds={"label" : "Mat-Str Score", "shrink":.6}, norm=norm, legend=True, rasterized=True).axis("off");
cbar_ax = fig.axes[1]
ax.set_title(f"{_region_corr} - {_donor}")
plt.savefig(image_path + f"/ms_score_{_donor}_{_region}_{_replicate}.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path + f"/ms_score_{_donor}_{_region}_{_replicate}.pdf", dpi=300, bbox_inches="tight")
plt.savefig(image_path + f"/ms_score_{_donor}_{_region}_{_replicate}.svg", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

fig, ax = plt.subplots()
cells.plot(ax=ax, color=cells['group_color'], edgecolor='none', markersize=1, alpha=1, rasterized=True).axis("off");
ax.set_title(f"{_region_corr} - {_donor}")
plt.savefig(image_path + f"/ms_group_{_donor}_{_region}_{_replicate}.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path + f"/ms_group_{_donor}_{_region}_{_replicate}.pdf", dpi=300, bbox_inches="tight")
plt.savefig(image_path + f"/ms_group_{_donor}_{_region}_{_replicate}.svg", dpi=300, bbox_inches="tight")
# plt.show()
plt.close()

fig, ax = plt.subplots()
cells.plot(ax=ax, color=cells['subclass_color'], edgecolor='none', markersize=1, alpha=1, rasterized=True).axis("off");
ax.set_title(f"{_region_corr} - {_donor}")
plt.savefig(image_path + f"/ms_subclass_{_donor}_{_region}_{_replicate}.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path + f"/ms_subclass_{_donor}_{_region}_{_replicate}.pdf", dpi=300, bbox_inches="tight")
plt.savefig(image_path + f"/ms_subclass_{_donor}_{_region}_{_replicate}.svg", dpi=300, bbox_inches="tight")
plt.show()
plt.close()

gdf_msn = cells.loc[~cells['MSN_Groups'].isna()]
gdf_msn['group_color'] = gdf_msn['MSN_Groups'].map(adata.uns['MSN_Groups_palette']).fillna("gray")

fig, ax = plt.subplots()
gdf_msn.plot(ax=ax, color=gdf_msn['group_color'], edgecolor='none', markersize=2, alpha=1, rasterized=True).axis("off");
sub_geoms[sub_geoms['type'] == "White_Matter"].plot(ax=ax, color='none', edgecolor='black', alpha=1, linewidth=0.5)
sub_geoms[sub_geoms['type'] == "Striosome"].plot(ax=ax, color='none', edgecolor='black', alpha=1, linewidth=0.5)
sub_geoms[sub_geoms['type'] == "Matrix"].plot(ax=ax, color='none', edgecolor='black', alpha=1, linewidth=0.5)
ax.set_title(f"{_region_corr} - {_donor}")
plt.savefig(image_path + f"/ms_msn_group_{_donor}_{_region}_{_replicate}.png", dpi=300, bbox_inches="tight")
plt.savefig(image_path + f"/ms_msn_group_{_donor}_{_region}_{_replicate}.pdf", dpi=300, bbox_inches="tight")
plt.savefig(image_path + f"/ms_msn_group_{_donor}_{_region}_{_replicate}.svg", dpi=300, bbox_inches="tight")
# plt.show()
plt.close()

# cbar_ax.set_yticks([-400, -200, 0, 500, 1000, 1500])
# cbar_ax.set_yticklabels(['-400', '-200', '0', '500', '1000', '1500'])
# cbar_ax.tick_params(labelsize=12)

# sub_geoms[sub_geoms['type'] == "White_Matter"].plot(ax=ax, color=sub_geoms['type_color'], edgecolor='none', alpha=0.1)
# sub_geoms[sub_geoms['type'] == "White_Matter"].plot(ax=ax, color='none', edgecolor='black', alpha=1, linewidth=0.5)

# sub_geoms[sub_geoms['type'] == "Striosome"].plot(ax=ax, color=sub_geoms['type_color'], edgecolor='none', alpha=0.5)
# sub_geoms[sub_geoms['type'] == "Striosome"].plot(ax=ax, color='none', edgecolor='black', alpha=1, linewidth=0.5)

# sub_geoms[sub_geoms['type'] == "Matrix"].plot(ax=ax, color=sub_geoms['type_color'], edgecolor='none', alpha=0.5)
# sub_geoms[sub_geoms['type'] == "Matrix"].plot(ax=ax, color='none', edgecolor='black', alpha=1, linewidth=0.5)

# sub_geoms[sub_geoms['type'] == "Striosome"].plot(ax=ax, color='red', edgecolor='black', alpha=0.1)
# sub_geoms[sub_geoms['type'] == "Matrix"].plot(ax=ax, color='blue', edgecolor='black', alpha=0.1)

# str_geoms.plot(ax=ax, color='red', edgecolor='none', alpha=0.05)
# mat_geoms.plot(ax=ax, color='blue', edgecolor='none', alpha=0.05)
# str_geoms.plot(ax=ax, color='none', edgecolor='black', linewidth=1, alpha=1)
# mat_geoms.plot(ax=ax, color='none', edgecolor='black', linewidth=1, alpha=1)

# ax.legend()
# ax.set_title(f"{_donor} {_region} {_replicate}")
# plt.show()

## Functions for accumulating tests

In [None]:
region_rows = []
for pair, dfg in df.groupby("pair"):
    for region, dfr in dfg.groupby("brain_region"):
        # Need ≥2 experiments to estimate heterogeneity
        if len(dfr) >= 2:
            res = re_meta(dfr["yi"].values, dfr["vi"].values)
            res.update(pair=pair, brain_region=region)
            region_rows.append(res)

per_region = pd.DataFrame(region_rows)
if len(per_region):
    # Benjamini–Hochberg FDR correction within all tests
    per_region["p_fdr"] = sm.stats.multipletests(per_region["p"], method="fdr_bh")[1]
    per_region.to_csv(DIR / "meta_region_pooled.csv", index=False)
print("Wrote meta_region_pooled.csv (pooled z per pair × region)")

In [None]:
def i_squared(effect_sizes, weights):
    k = len(effect_sizes)
    if k <= 1:
        return np.nan
    mean_eff = np.average(effect_sizes, weights=weights)
    Q = np.sum(weights * (effect_sizes - mean_eff) ** 2)
    df = k - 1
    return max(0, (Q - df) / Q) * 100 if Q > 0 else 0

In [None]:
import numpy as np
import pandas as pd
import statsmodels.api as sm

rng = np.random.default_rng(42)

# ------------------------------------------------------------
# 1) Simulate data: 8 regions × 4 donors × 2 reps
#    True variance components (you can change these):
# ------------------------------------------------------------
R   = 8
D   = 4
REP = 2
regions = [f"R{i+1}" for i in range(R)]
donors  = [f"D{j+1}" for j in range(D)]

# True means per region (arbitrary)
mu_region = {r: rng.normal(loc=0.0, scale=0.5) for r in regions}

# True variance components
tau2_donor_true = 0.15   # donor heterogeneity
tau2_rep_true   = 0.05   # replicate heterogeneity
sigma2_within   = 1.0    # z-scale residual variance

rows = []
for r in regions:
    for d in donors:
        donor_uid = f"{r}:{d}"      # make donors unique per region
        b_donor = rng.normal(0, np.sqrt(tau2_donor_true))
        for rep in range(1, REP+1):
            study_id = f"{donor_uid}:rep{rep}"
            b_rep   = rng.normal(0, np.sqrt(tau2_rep_true))
            # one "study" (effect estimate) per replicate
            y = mu_region[r] + b_donor + b_rep + rng.normal(0, np.sqrt(sigma2_within))
            rows.append({"region": r, "donor_uid": donor_uid, "study_id": study_id, "yi": y})

df = pd.DataFrame(rows)

# ------------------------------------------------------------
# 2) Fit pooled REML mixed model:
#    yi ~ C(region) + (1 | donor_uid) + (1 | study_id)
#    In statsmodels: donor as groups, study_id via vc_formula
# ------------------------------------------------------------
# Make sure grouping variables are strings/categoricals
df["region"]    = df["region"].astype(str)
df["donor_uid"] = df["donor_uid"].astype(str)
df["study_id"]  = df["study_id"].astype(str)
df["yi"]        = df["yi"].astype(float)

# MixedLM via formula interface to allow vc_formula
model = sm.MixedLM.from_formula(
    "yi ~ C(region)",
    groups="donor_uid",
    vc_formula={"rep": "0 + C(study_id)"},
    data=df
)

fit = model.fit(reml=True, method="lbfgs", maxiter=1000, disp=False)

# ------------------------------------------------------------
# 3) Extract variance components robustly across versions
# ------------------------------------------------------------
# donor variance:
if hasattr(fit, "cov_re") and getattr(fit.cov_re, "shape", (0,0))[0] > 0:
    tau2_donor_hat = float(fit.cov_re.iloc[0, 0])
else:
    tau2_donor_hat = 0.0  # boundary solution

# replicate variance (vc component); fit.vcomp can be array/Series/float
tau2_rep_hat = 0.0
if hasattr(fit, "vcomp"):
    v = fit.vcomp
    if isinstance(v, (float, int, np.floating)):
        tau2_rep_hat = float(v)
    elif hasattr(v, "__len__") and len(v) > 0:
        tau2_rep_hat = float(v[0])

# ------------------------------------------------------------
# 4) Convert to I^2 on the z-scale
# ------------------------------------------------------------
I2_donor = 100.0 * tau2_donor_hat / (tau2_donor_hat + sigma2_within)
I2_rep   = 100.0 * tau2_rep_hat   / (tau2_rep_hat   + sigma2_within)

print("=== REML pooled across regions ===")
print(f"tau^2_donor (true {tau2_donor_true:.3f})  = {tau2_donor_hat:.3f}  → I^2_donor = {I2_donor:.1f}%")
print(f"tau^2_rep   (true {tau2_rep_true:.3f})    = {tau2_rep_hat:.3f}    → I^2_rep   = {I2_rep:.1f}%")

# Optional: view fixed effects (region means)
print("\nFixed effects (region means, relative to reference):")
print(fit.summary().tables[1])


In [None]:
def _aggregate_tests(DIR, file_format, ): 


## MS ENRICHMENT CODE

In [None]:
# From here on this needs to be iterable. 
# contact_list = []
pbar = tqdm(itertools.product(donors, brain_regions, replicates))
for _i in pbar:
    if _i in skip:
        # print(f"Skipping {_i}")
        continue
    _donor, _brain_region, _replicate, = _i
    pbar.set_description(f"Processing {_donor} | {_brain_region} | {_replicate}")
    adata_sub = adata[ (adata.obs['donor'] == _donor) & 
                       (adata.obs['brain_region'] == _brain_region) & 
                       (adata.obs['replicate'] == _replicate) ].copy()
    geoms_sub = geoms[ (geoms['donor'] == _donor) & 
                       (geoms['brain_region'] == _brain_region) & 
                       (geoms['lab'] == _replicate) ].copy()
    if geoms_sub.shape[0] == 0:
        continue

    gdf = gpd.GeoDataFrame(adata_sub.obs, geometry=gpd.points_from_xy(adata_sub.obs['CENTER_X'], adata_sub.obs['CENTER_Y']), crs=None)

    subclass_cells = gdf['Subclass'].unique().tolist()
    group_cells = gdf['Group'].unique().tolist()

    mat_cells = gpd.sjoin(gdf, geoms[geoms['type'] == 'Matrix'], how="inner", predicate='within')
    mat_cells = mat_cells.loc[~mat_cells.index.duplicated(keep="first")]
    str_cells = gpd.sjoin(gdf, geoms[geoms['type'] == 'Striosome'], how="inner", predicate='within')
    str_cells = str_cells.loc[~str_cells.index.duplicated(keep="first")]
    wm_cells = gpd.sjoin(gdf, geoms[geoms['type'] == 'White_Matter'], how="inner", predicate='within')
    wm_cells = wm_cells.loc[~wm_cells.index.duplicated(keep="first")]

    sub_mat_counts = mat_cells.groupby("Subclass", observed=False).size().to_dict()
    sub_str_counts = str_cells.groupby("Subclass", observed=False).size().to_dict()
    sub_wm_counts = wm_cells.groupby("Subclass", observed=False).size().to_dict()

    gr_mat_counts = mat_cells.groupby("Group", observed=False).size().to_dict()
    gr_str_counts = str_cells.groupby("Group", observed=False).size().to_dict()
    gr_wm_counts = wm_cells.groupby("Group", observed=False).size().to_dict()

    null_sub_mat_counts = {a: [] for a in subclass_cells}
    null_sub_str_counts = {a: [] for a in subclass_cells}
    null_sub_wm_counts = {a: [] for a in subclass_cells}
    null_gr_mat_counts = {a: [] for a in group_cells}
    null_gr_str_counts = {a: [] for a in group_cells}
    null_gr_wm_counts = {a: [] for a in group_cells}

    for i in range(N_permute): 
        gdf.geometry = permute_geometry(gdf.geometry)

        mat_cells = gpd.sjoin(gdf, geoms[geoms['type'] == 'Matrix'], how="inner", predicate='within')
        mat_cells = mat_cells.loc[~mat_cells.index.duplicated(keep="first")]

        str_cells = gpd.sjoin(gdf, geoms[geoms['type'] == 'Striosome'], how="inner", predicate='within')
        str_cells = str_cells.loc[~str_cells.index.duplicated(keep="first")]
        
        wm_cells = gpd.sjoin(gdf, geoms[geoms['type'] == 'White_Matter'], how="inner", predicate='within')
        wm_cells = wm_cells.loc[~wm_cells.index.duplicated(keep="first")]

        
        for a, b in mat_cells.groupby("Subclass", observed=False).size().items(): 
            null_sub_mat_counts[a].append(b)
        
        for a, b in str_cells.groupby("Subclass", observed=False).size().items(): 
            null_sub_str_counts[a].append(b)
        
        for a, b in wm_cells.groupby("Subclass", observed=False).size().items(): 
            null_sub_wm_counts[a].append(b)

        for a, b in mat_cells.groupby("Group", observed=False).size().items(): 
            null_gr_mat_counts[a].append(b)
        
        for a, b in str_cells.groupby("Group", observed=False).size().items(): 
            null_gr_str_counts[a].append(b)
        
        for a, b in wm_cells.groupby("Group", observed=False).size().items(): 
            null_gr_wm_counts[a].append(b)

    naming = ["subclass_matrix", "subclass_striosome", "subclass_white_matter",
            "group_matrix", "group_striosome", "group_white_matter"]
    real_dicts = [sub_mat_counts, sub_str_counts, sub_wm_counts,
                gr_mat_counts, gr_str_counts, gr_wm_counts]
    null_dicts = [null_sub_mat_counts, null_sub_str_counts, null_sub_wm_counts,
                null_gr_mat_counts, null_gr_str_counts, null_gr_wm_counts]

    for i, (_name, _real, _null) in enumerate(zip(naming, real_dicts, null_dicts)):
        ps, adj_ps = one_sided_pval(list(_real.values()), list(_null.values()))
        cell_types = _real.keys()
        result_df = pd.DataFrame({
            "cell_type": cell_types,
            "real_count": [_real[ct] for ct in cell_types],
            "mean_null_count": [np.mean(_null[ct]) for ct in cell_types],
            "std_null_count": [np.std(_null[ct]) for ct in cell_types],
            "p_value": ps,
            "adj_p_value": adj_ps,
            "log_2FC": [np.log2( (_real[ct] + 1) / (np.mean(_null[ct]) + 1) ) for ct in cell_types]
        })
        # output_path = "/home/x-aklein2/projects/aklein/BICAN/BG/spatial_analysis/results/ms_composition/"
        # os.makedirs(output_path, exist_ok=True)
        # result_df.to_csv(Path(output_path) / f"ms_composition_{_name}_{_donor}_{_brain_region}_{_replicate}.csv", index=False)

In [None]:
str_area = geoms_sub[geoms_sub['type'] == "Striosome"].geometry.area.sum()
mtr_area = geoms_sub[geoms_sub['type'] == "Matrix"].geometry.area.sum()
print(f"Total striosome area: {str_area:.2f}, total matrix area: {mtr_area:.2f}")
# adata_sub

In [None]:
cell_types = sub_mat_counts.keys()

In [None]:
ps = one_sided_pval(list(sub_mat_counts.values()), list(null_sub_mat_counts.values()))

### OLD WM INTEGRATION 

In [None]:
# geoms

In [None]:
# geoms_colors = {
#     "White_Matter": "lightgray",
#     "Striosome": "red",
#     "Matrix": "blue"
# }
# geoms['type_color'] = geoms['type'].map(geoms_colors)

In [None]:
# def _combine_ms_wm(geoms): 

#     new_geoms = geoms.copy()
#     str_geoms = geoms[geoms['type'] == "Striosome"].copy()
#     new_geoms.geometry = new_geoms.geometry.difference(str_geoms.unary_union.buffer(str_buffer))
#     new_geoms.loc[str_geoms.index, str_geoms.columns] = str_geoms
#     mat_geoms = new_geoms[new_geoms['type'] == "Matrix"].copy()
#     wm_geoms = new_geoms[new_geoms['type'] == "White_Matter"].copy()

#     wm_geoms['wm_id'] = range(len(wm_geoms))
#     mat_geoms['mat_id'] = range(len(mat_geoms))
#     str_geoms['str_id'] = range(len(str_geoms))

#     ints = gpd.overlay(mat_geoms, wm_geoms, how='intersection')
#     ints['mat_area'] = ints['mat_id'].apply(lambda x: mat_geoms.loc[mat_geoms['mat_id'] == x, 'geometry'].values[0].area)
#     ints['wm_area'] = ints['wm_id'].apply(lambda x: wm_geoms.loc[wm_geoms['wm_id'] == x, 'geometry'].values[0].area)

#     # Remove overlapping regions
#     new_geoms.geometry = new_geoms.geometry.difference(ints.geometry.unary_union.buffer(str_buffer))

#     new_ints = []
#     for index, row in ints.iterrows(): 
#         if row['mat_area'] < row['wm_area']:
#             keep = "1"
#             keep_idx = row['mat_id']
#         else: 
#             keep = "2"
#             keep_idx = row['wm_id']
#         keep_row = [keep_idx]
#         for _col in geoms.columns: 
#             if _col == "geometry": 
#                 keep_row.append(row['geometry'])
#             else: 
#                 keep_row.append(row[f"{_col}_{keep}"])
#         new_ints.append(keep_row)
#     add_ints = gpd.GeoDataFrame(new_ints, columns=["keep_idx"] + list(geoms.columns))

#     new_geoms = pd.concat([new_geoms, add_ints.drop(columns=["keep_idx"])])
#     new_geoms = new_geoms.loc[~new_geoms.is_empty]
#     return new_geoms



In [None]:
# # From here on this needs to be iterable. 
# contact_list = []
# pbar = tqdm(itertools.product(donors, brain_regions, replicates))
# for _i in pbar:
#     if _i in skip:
#         # print(f"Skipping {_i}")
#         continue
#     _donor, _brain_region, _replicate, = _i
#     pbar.set_description(f"Processing {_donor} | {_brain_region} | {_replicate}")
    
#     ### combine
#     sub_geoms = geoms[(geoms['brain_region'] == _brain_region) & (geoms['lab'] == _replicate) & (geoms['donor'] == _donor)]
#     fig, ax = plt.subplots(figsize=(5,5))
#     sub_geoms.plot(ax=ax, color=sub_geoms['type_color'], edgecolor='none', legend=True, alpha=0.1).axis("off");
#     sub_geoms.plot(ax=ax, color="none", edgecolor='black', legend=True).axis("off");
#     plt.show()

#     new_geoms = _combine_ms_wm(sub_geoms)

#     # new_geoms = sub_geoms.copy()
#     # str_geoms = sub_geoms[sub_geoms['type'] == "Striosome"].copy()
#     # new_geoms.geometry = new_geoms.geometry.difference(str_geoms.unary_union.buffer(str_buffer))
#     # new_geoms.loc[str_geoms.index, str_geoms.columns] = str_geoms
#     # mat_geoms = new_geoms[new_geoms['type'] == "Matrix"].copy()
#     # wm_geoms = new_geoms[new_geoms['type'] == "White_Matter"].copy()

#     # wm_geoms['wm_id'] = range(len(wm_geoms))
#     # mat_geoms['mat_id'] = range(len(mat_geoms))
#     # str_geoms['str_id'] = range(len(str_geoms))

#     # ints = gpd.overlay(mat_geoms, wm_geoms, how='intersection')
#     # ints['mat_area'] = ints['mat_id'].apply(lambda x: mat_geoms.loc[mat_geoms['mat_id'] == x, 'geometry'].values[0].area)
#     # ints['wm_area'] = ints['wm_id'].apply(lambda x: wm_geoms.loc[wm_geoms['wm_id'] == x, 'geometry'].values[0].area)

#     # # Remove overlapping regions
#     # new_geoms.geometry = new_geoms.geometry.difference(ints.geometry.unary_union.buffer(str_buffer))

#     # new_ints = []
#     # for index, row in ints.iterrows(): 
#     #     if row['mat_area'] < row['wm_area']:
#     #         keep = "1"
#     #         keep_idx = row['mat_id']
#     #     else: 
#     #         keep = "2"
#     #         keep_idx = row['wm_id']
#     #     keep_row = [keep_idx]
#     #     for _col in sub_geoms.columns: 
#     #         if _col == "geometry": 
#     #             keep_row.append(row['geometry'])
#     #         else: 
#     #             keep_row.append(row[f"{_col}_{keep}"])
#     #     new_ints.append(keep_row)
#     # add_ints = gpd.GeoDataFrame(new_ints, columns=["keep_idx"] + list(sub_geoms.columns))

#     # new_geoms = pd.concat([new_geoms, add_ints.drop(columns=["keep_idx"])])
#     # new_geoms = new_geoms.loc[~new_geoms.is_empty]

#     fig, ax = plt.subplots(figsize=(5,5))
#     new_geoms.plot(ax=ax, color=new_geoms['type_color'], edgecolor='none', legend=True, alpha=0.1).axis("off");
#     new_geoms.plot(ax=ax, color="none", edgecolor='black', legend=True).axis("off");
#     plt.show()

#     contact_list.append(new_geoms)

In [None]:
sub_geoms = geoms[(geoms['brain_region'] == 'PU') & (geoms['lab'] == 'ucsd') & (geoms['donor'] == "UCI5224")]
sub_geoms.head()

In [None]:
fig, ax = plt.subplots(figsize=(5,5))
sub_geoms.plot(ax=ax, color=sub_geoms['type_color'], edgecolor='none', legend=True, alpha=0.1).axis("off");
sub_geoms.plot(ax=ax, color="none", edgecolor='black', legend=True).axis("off");
plt.show()

In [None]:
new_geoms = sub_geoms.copy()
str_geoms = sub_geoms[sub_geoms['type'] == "Striosome"].copy()
new_geoms.geometry = new_geoms.geometry.difference(str_geoms.unary_union.buffer(str_buffer))
new_geoms.loc[str_geoms.index, str_geoms.columns] = str_geoms
mat_geoms = new_geoms[new_geoms['type'] == "Matrix"].copy()
wm_geoms = new_geoms[new_geoms['type'] == "White_Matter"].copy()
new_geoms.plot()

In [None]:
wm_geoms['wm_id'] = range(len(wm_geoms))
mat_geoms['mat_id'] = range(len(mat_geoms))
str_geoms['str_id'] = range(len(str_geoms))
wm_geoms.shape, str_geoms.shape, mat_geoms.shape

In [None]:
ints = gpd.overlay(mat_geoms, wm_geoms, how='intersection')
ints['mat_area'] = ints['mat_id'].apply(lambda x: mat_geoms.loc[mat_geoms['mat_id'] == x, 'geometry'].values[0].area)
ints['wm_area'] = ints['wm_id'].apply(lambda x: wm_geoms.loc[wm_geoms['wm_id'] == x, 'geometry'].values[0].area)
ints.head()

In [None]:
fig, ax = plt.subplots(figsize=(5,5))
new_geoms.plot(ax=ax, color=new_geoms['type_color'], edgecolor='none', legend=True, alpha=0.1).axis("off");
new_geoms.plot(ax=ax, color="none", edgecolor='black', legend=True).axis("off");
ints.plot(ax=ax, color="yellow", edgecolor='black', legend=True, alpha=0.5).axis("off");
plt.show()


In [None]:
# Remove overlapping regions
new_geoms.geometry = new_geoms.geometry.difference(ints.geometry.unary_union.buffer(str_buffer))

In [None]:
new_ints = []
for index, row in ints.iterrows(): 
    if row['mat_area'] < row['wm_area']:
        keep = "1"
        keep_idx = row['mat_id']
    else: 
        keep = "2"
        keep_idx = row['wm_id']
    keep_row = [keep_idx]
    for _col in sub_geoms.columns: 
        if _col == "geometry": 
            keep_row.append(row['geometry'])
        else: 
            keep_row.append(row[f"{_col}_{keep}"])
    new_ints.append(keep_row)
add_ints = gpd.GeoDataFrame(new_ints, columns=["keep_idx"] + list(sub_geoms.columns))

In [None]:
fig, ax = plt.subplots(figsize=(5,5))
new_geoms.plot(ax=ax, color=new_geoms['type_color'], edgecolor='none', legend=True, alpha=0.1).axis("off");
new_geoms.plot(ax=ax, color="none", edgecolor='black', legend=True).axis("off");
add_ints.plot(ax=ax, color="yellow", edgecolor='black', legend=True, alpha=0.5).axis("off");
plt.show()


In [None]:
new_geoms = pd.concat([new_geoms, add_ints.drop(columns=["keep_idx"])])
new_geoms = new_geoms.loc[~new_geoms.is_empty]

In [None]:
fig, ax = plt.subplots(figsize=(5,5))
new_geoms.plot(ax=ax, color=new_geoms['type_color'], edgecolor='none', legend=True, alpha=0.1).axis("off");
new_geoms.plot(ax=ax, color="none", edgecolor='black', legend=True).axis("off");
plt.show()

In [None]:
# for _donor in sub_geoms['donor'].unique():
#     print(f"Donor {_donor}:")
#     # display(sub_geoms[sub_geoms['donor'] == _donor]['type'].value_counts())
#     sub_geomsd = sub_geoms[sub_geoms['donor'] == _donor].copy()
#     fig, ax = plt.subplots(figsize=(5,5))
#     sub_geomsd.plot(ax=ax, column="type", color=sub_geomsd['type_color'], edgecolor='none', legend=True, alpha=0.1).axis("off");
#     sub_geomsd.plot(ax=ax, column="type", color="none", edgecolor='black', legend=True).axis("off");
#     plt.show()