The point of this notebook is to plot the cell-cell proximity analysis results for both the supplementary and the spatial main figure. (Fig. 6)

Author: Amit Klein 
Email: a3klein@ucsd.edu

In [None]:
import os
from pathlib import Path
import itertools 
from tqdm import tqdm

import numpy as np
import pandas as pd
import anndata as ad
from scipy.stats import norm 
from statsmodels.stats.multitest import multipletests

import multiprocessing as mp
from functools import partial

import matplotlib.pyplot as plt
import seaborn as sns
from adjustText import adjust_text  # pip install adjustText

In [None]:
plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['figure.figsize'] = (4, 4)
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rcParams['font.size'] = 8
plt.rcParams['axes.facecolor'] = 'white'
    
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['savefig.transparent'] = True
plt.rcParams['savefig.bbox'] = 'tight'
plt.rcParams['savefig.pad_inches'] = 0.01

RASTERIZED = False

In [None]:
image_path = Path("/home/x-aklein2/projects/aklein/BICAN/BG/images/figures/ccpa_v2")
image_path.mkdir(parents=True, exist_ok=True)

In [None]:
ad_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad"
adata = ad.read_h5ad(ad_path)
adata

In [None]:
br_to_brc_map = adata.obs[['brain_region', 'brain_region_corr']].set_index('brain_region').drop_duplicates().to_dict()['brain_region_corr']

# Summary Contact Plots (Fig. 6 I-K)

## Helper functions

In [None]:
# Chat GPT
def combine_results(meta_df, cell_types_to_combine, new_category_name, comparison_ct):
    """
    Correctly combine multiple cell type meta-analysis results.
    This computes the AVERAGE effect across different cell type pairs,
    not a meta-analysis of the same effect.
    
    Parameters:
    -----------
    meta_df : pd.DataFrame
        DataFrame with columns: pair, brain_region, mu, se, etc.
    cell_types_to_combine : list
        List of cell types to combine
    new_category_name : str
        Name for the combined category
    
    Returns:
    --------
    pd.DataFrame with combined results
    """
    combined_results = []
    
    # Group by brain region
    for region, region_df in meta_df.groupby('brain_region'):
        # Filter for the cell types we want to combine
        subset = region_df[
            (region_df['ct1'].isin(cell_types_to_combine) & (region_df['ct2'] == comparison_ct)) | 
            (region_df['ct2'].isin(cell_types_to_combine) & (region_df['ct1'] == comparison_ct))
        ].copy()
        
        if subset.empty:
            continue
            
        # Extract means and standard errors
        mus = subset['mu'].values
        ses = subset['se'].values
        n = len(subset)
        
        # Simple average of the means (unweighted)
        combined_mu = np.mean(mus)
        
        # Standard error of the mean across different effects
        # This is: sqrt(sum(SE_i^2)) / n
        # This accounts for the uncertainty in each individual estimate
        combined_se = np.sqrt(np.sum(ses**2)) / n
        
        # Alternative: if you want to weight by precision
        # weights = 1 / (ses ** 2)
        # combined_mu = np.sum(weights * mus) / np.sum(weights)
        # combined_se = 1 / np.sqrt(np.sum(weights))
        
        combined_results.append({
            'brain_region': region,
            'category': new_category_name,
            'mu': combined_mu,
            'se': combined_se,
            'n_combined': n,
            'original_pairs': subset['pair'].tolist(),
            'individual_mus': mus.tolist(),
            'individual_ses': ses.tolist()
        })
    
    return pd.DataFrame(combined_results)

def combine_regions(df_meta, col_pair='pair', regions=None): 
    if isinstance(regions, str):
        regions = [regions]
    results = []
    for pair, pair_df in df_meta.groupby(col_pair): 
        if regions is not None:
            pair_df = pair_df[pair_df['brain_region'].isin(regions)]
            if pair_df.empty:
                continue
        mus = pair_df['mu'].values
        ses = pair_df['se'].values
        regs = pair_df['brain_region'].values
        n = len(pair_df)

        # Simple average of the means (unweighted)
        combined_mu = np.mean(mus)
        
        # Standard error of the mean across different effects
        # This is: sqrt(sum(SE_i^2)) / n
        # This accounts for the uncertainty in each individual estimate
        combined_se = np.sqrt(np.sum(ses**2)) / n
        
        # Alternative: if you want to weight by precision
        # weights = 1 / (ses ** 2)
        # combined_mu = np.sum(weights * mus) / np.sum(weights)
        # combined_se = 1 / np.sqrt(np.sum(weights))
        
        results.append({
            col_pair: pair,
            'mu': combined_mu,
            'se': combined_se,
            'n_combined': n,
            'original_pairs': pair_df[col_pair].tolist(),
            'individual_mus': mus.tolist(),
            'individual_ses': ses.tolist(),
            'regions': regs.tolist()
        })

    return pd.DataFrame(results)

### plots

In [None]:
def plot_pair(
    meta,
    pair,
    pair_col = 'pair',
    order=None,
    ax=None,
    color='blue',
    label=None,
    rasterized=False,
    opacity=1.0,
    region_col = "brain_region",
):
    
    meta_pair = meta[meta[pair_col] == pair]
    if order is not None:
        order = [region for region in order if region in meta_pair[region_col].values]
        meta_pair = meta_pair.set_index(region_col).loc[order].reset_index()
    else:
        meta_pair = meta_pair.sort_values(by='mu')
    # meta_pair = meta_pair if order else meta_pair.sort_values(by='mu')

    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 4))
    ax.errorbar(x=meta_pair[region_col], y=meta_pair['mu'], yerr=meta_pair['se'], fmt='o', color=color, capsize=5, label=label, alpha=opacity, rasterized=rasterized)
    ax.plot(meta_pair[region_col], meta_pair['mu'], color=color, linestyle='--', alpha=opacity, rasterized=rasterized)
    ax.axhline(0, color='black', linestyle='--', linewidth=1, rasterized=rasterized)
    ax.set_xlabel('Brain Region')
    ax.set_ylabel('Pooled Z-score')
    ax.set_title(f'Pooled Contact Enrichment Z-scores for {pair} Across Brain Regions')
    # ax.set_xticks(order)
    # ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
    return ax

## Run

### Astrocytes

In [None]:
astro_msn = []
rs = [15, 30, 50]
for r in rs: 
    DIR = Path(f'/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/meta_contacts/{r}um_2')
    df = pd.read_csv(DIR / "contacts_meta_input.csv")
    meta = pd.read_csv(DIR / "meta_region_pooled.csv")

    meta['ct1'] = meta['pair'].astype(str).str.split('|').str[0]
    meta['ct2'] = meta['pair'].astype(str).str.split('|').str[1]

    MSN_types = ["STR D1 MSN", "STR D2 MSN", "STR Hybrid MSN"]
    combined_MSN = combine_results(meta, MSN_types, "All_MSN", "Astrocyte")
    rc_MSN = combine_regions(combined_MSN, col_pair='category')
    rc_MSN['radius_um'] = r
    astro_msn.append(rc_MSN)

astro_msn = pd.concat(astro_msn)
astro_msn['id'] = astro_msn['radius_um'].astype(str) + "um, MSN"
astro_msn['rid'] = astro_msn['radius_um'].astype(str) + "um"


astro_gp = []
rs = [15, 30, 50]
for r in rs: 
    DIR = Path(f'/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/meta_contacts/{r}um_2')
    df = pd.read_csv(DIR / "contacts_meta_input.csv")
    meta = pd.read_csv(DIR / "meta_region_pooled.csv")

    meta['ct1'] = meta['pair'].astype(str).str.split('|').str[0]
    meta['ct2'] = meta['pair'].astype(str).str.split('|').str[1]

    GP_types = ["CN Cholinergic GABA", "CN LHX8 GABA", "CN MEIS2 GABA", "CN ONECUT1 GABA", "CN GABA-Glut"]
    combined_GP = combine_results(meta, GP_types, "All_GP", "Astrocyte")
    rc_GP = combine_regions(combined_GP, col_pair='category', regions=['GP'])
    rc_GP['radius_um'] = r
    astro_gp.append(rc_GP)

astro_gp = pd.concat(astro_gp)
astro_gp['id'] = astro_gp['radius_um'].astype(str) + "um, GP"
astro_gp['rid'] = astro_gp['radius_um'].astype(str) + "um"

astro_it = []
rs = [15, 30, 50]
for r in rs: 
    DIR = Path(f'/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/meta_contacts/{r}um_2')
    df = pd.read_csv(DIR / "contacts_meta_input.csv")
    meta = pd.read_csv(DIR / "meta_region_pooled.csv")

    meta['ct1'] = meta['pair'].astype(str).str.split('|').str[0]
    meta['ct2'] = meta['pair'].astype(str).str.split('|').str[1]

    IT_types = ["CN ST18 GABA", "STR RSPO2 GABA", "CN LAMP5-CXCL14 GABA", "CN LAMP5-LHX6 GABA", "CN VIP GABA"]
    combined_IT = combine_results(meta, IT_types, "All_IT", "Astrocyte")

    rc_IT = combine_regions(combined_IT, col_pair='category', regions=['CAH', 'CAB', 'PU', 'NAC', 'CAT', 'GP', 'MGM1'])
    rc_IT['radius_um'] = r
    astro_it.append(rc_IT)

astro_it = pd.concat(astro_it)
astro_it['id'] = astro_it['radius_um'].astype(str) + "um, INT"
astro_it['rid'] = astro_it['radius_um'].astype(str) + "um"

In [None]:
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)

xpos = np.arange(len(astro_msn) + len(astro_gp) + len(astro_it) + 2)
xmsn = xpos[:len(astro_msn)]
xgp = xpos[len(astro_msn)+1:len(astro_msn)+len(astro_gp)+1]
xit = xpos[len(astro_msn)+len(astro_gp)+2:len(astro_msn)+len(astro_gp)+len(astro_it)+2]
# xlabels = astro_msn['id'].tolist() + astro_gp['id'].tolist() + astro_it['id'].tolist()
xlabels = astro_msn['radius_um'].tolist() + astro_gp['radius_um'].tolist() + astro_it['radius_um'].tolist()

ax.bar(x=xmsn, height=astro_msn['mu'], yerr=astro_msn['se'], capsize=5, color='purple', alpha=0.7, width=0.9, edgecolor='black', rasterized=RASTERIZED)
for _i, _reg in enumerate(astro_msn['regions'].iloc[0]): 
    ax.plot(
        xmsn,
        [astro_msn['individual_mus'].iloc[i][_i] for i in range(len(astro_msn))],
        marker='o',
        linestyle='--',
        linewidth=1,
        markersize=5,
        color=adata.uns['brain_region_palette'].get(_reg, 'gray'),
        rasterized=RASTERIZED
    )

ax.bar(x=xgp, height=astro_gp['mu'], yerr=astro_gp['se'], capsize=5, color='green', alpha=0.7, width=0.9, edgecolor='black', rasterized=RASTERIZED)
for _i, _reg in enumerate(astro_gp['regions'].iloc[0]): 
    ax.plot(
        xgp,
        [astro_gp['individual_mus'].iloc[i][_i] for i in range(len(astro_gp))],
        marker='o',
        linestyle='--',
        linewidth=1,
        markersize=5,
        color=adata.uns['brain_region_palette'].get(_reg, 'gray'),
        rasterized=RASTERIZED
    )

ax.bar(x=xit, height=astro_it['mu'], yerr=astro_it['se'], capsize=5, color='blue', alpha=0.7, width=0.9, edgecolor='black', rasterized=RASTERIZED)
for _i, _reg in enumerate(astro_it['regions'].iloc[0]): 
    ax.plot(
        xit,
        [astro_it['individual_mus'].iloc[i][_i] for i in range(len(astro_it))],
        marker='o',
        linestyle='--',
            linewidth=1,
        markersize=5,
        color=adata.uns['brain_region_palette'].get(_reg, 'gray'),
        rasterized=RASTERIZED
    )

ax.axhline(0, color='black', linestyle='--', linewidth=1, rasterized=RASTERIZED)

# making legend: 
handles = []
for _reg in astro_it['regions'].iloc[0]: 
    handles.append(ax.plot([], [], color=adata.uns['brain_region_palette'].get(_reg, 'gray'), marker='o', linestyle='--', label=br_to_brc_map[_reg]))
leg = ax.legend(handles=[h[0] for h in handles], loc='upper left', bbox_to_anchor=(1, 1), title='Brain Region', fontsize=12, title_fontsize=14)
ax.legend_ = None

ax.set_xticks(np.concatenate((xpos[xmsn], xpos[xgp], xpos[xit])))
# ax.set_xticklabels(xlabels, rotation=45, ha='right')
ax.set_xticklabels(xlabels, fontsize=14, rasterized=RASTERIZED)
sec = ax.secondary_xaxis(location=0)
sec.set_xticks([xmsn[1], xgp[1], xit[1]])
sec.set_xticklabels(['\nMSN', '\nGP', '\nINT'], fontsize=14, rasterized=RASTERIZED)

sec2 = ax.secondary_xaxis(location=0)
sec2.set_xticks([xmsn[0]-0.5, xmsn[-1]+0.5, xgp[0]-0.5, xgp[-1]+0.5, xit[0]-0.5, xit[-1]+0.5], labels=[])
sec2.tick_params('x', length=40, width=1.5)

# ax.set_xlabel('Radius and Cell Type Combination', y=-1)
ax.set_ylim((-4, 4))
ax.set_yticks(np.arange(-4, 5, 2))
ax.set_yticklabels(np.arange(-4, 5, 2), fontsize=14, rasterized=RASTERIZED)
ax.set_ylabel('Combined Pooled Z-score', fontsize=14, rasterized=RASTERIZED)

ax.set_title('Astrocyte - Neuron Pairs, By Radius', y=1.05, fontsize=24, rasterized=RASTERIZED)
# ax.set_axis_off()
# plt.show()

plt.tight_layout()
plt.savefig(image_path / f'astrocyte_combined_contact_enrichment_allradii.png', bbox_inches='tight', dpi=300)
plt.savefig(image_path / f'astrocyte_combined_contact_enrichment_allradii.pdf', bbox_inches='tight', dpi=300)
plt.show()
plt.close()

fig, ax = plt.subplots(figsize=(2, 1), dpi=300)
labels = [text.get_text() for text in leg.get_texts()]
fig.legend(
    handles=[h[0] for h in handles],
    labels=labels,
    bbox_to_anchor=(0.5, 0.5),
    loc='center',
    title="Brain Region",
    fontsize=12,
    title_fontsize=14,
    frameon=True
)
ax.set_axis_off()
plt.tight_layout()
plt.savefig(image_path / f'astrocyte_combined_contact_enrichment_allradii_legend.png', bbox_inches='tight', dpi=300)
plt.savefig(image_path / f'astrocyte_combined_contact_enrichment_allradii_legend.pdf', bbox_inches='tight', dpi=300)
# plt.show()
plt.close()


### Microglia

In [None]:
ct_msn = []
rs = [15, 30, 50]
for r in rs: 
    DIR = Path(f'/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/meta_contacts/{r}um_2')
    df = pd.read_csv(DIR / "contacts_meta_input.csv")
    meta = pd.read_csv(DIR / "meta_region_pooled.csv")

    meta['ct1'] = meta['pair'].astype(str).str.split('|').str[0]
    meta['ct2'] = meta['pair'].astype(str).str.split('|').str[1]

    MSN_types = ["STR D1 MSN", "STR D2 MSN", "STR Hybrid MSN"]
    combined_MSN = combine_results(meta, MSN_types, "All_MSN", "Microglia")
    rc_MSN = combine_regions(combined_MSN, col_pair='category')
    rc_MSN['radius_um'] = r
    ct_msn.append(rc_MSN)

ct_msn = pd.concat(ct_msn)
ct_msn['id'] = ct_msn['radius_um'].astype(str) + "um, MSN"


ct_gp = []
rs = [15, 30, 50]
for r in rs: 
    DIR = Path(f'/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/meta_contacts/{r}um_2')
    df = pd.read_csv(DIR / "contacts_meta_input.csv")
    meta = pd.read_csv(DIR / "meta_region_pooled.csv")

    meta['ct1'] = meta['pair'].astype(str).str.split('|').str[0]
    meta['ct2'] = meta['pair'].astype(str).str.split('|').str[1]

    GP_types = ["CN Cholinergic GABA", "CN LHX8 GABA", "CN MEIS2 GABA", "CN ONECUT1 GABA", "CN GABA-Glut"]
    combined_GP = combine_results(meta, GP_types, "All_GP", "Microglia")
    rc_GP = combine_regions(combined_GP, col_pair='category', regions=['GP'])
    rc_GP['radius_um'] = r
    ct_gp.append(rc_GP)

ct_gp = pd.concat(ct_gp)
ct_gp['id'] = ct_gp['radius_um'].astype(str) + "um, GP"


ct_it = []
rs = [15, 30, 50]
for r in rs: 
    DIR = Path(f'/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/meta_contacts/{r}um_2')
    df = pd.read_csv(DIR / "contacts_meta_input.csv")
    meta = pd.read_csv(DIR / "meta_region_pooled.csv")

    meta['ct1'] = meta['pair'].astype(str).str.split('|').str[0]
    meta['ct2'] = meta['pair'].astype(str).str.split('|').str[1]

    IT_types = ["CN ST18 GABA", "STR RSPO2 GABA", "CN LAMP5-CXCL14 GABA", "CN LAMP5-LHX6 GABA", "CN VIP GABA"]
    combined_IT = combine_results(meta, IT_types, "All_IT", "Microglia")

    rc_IT = combine_regions(combined_IT, col_pair='category', regions=['CAH', 'CAB', 'PU', 'NAC', 'CAT', "GP", 'MGM1'])
    rc_IT['radius_um'] = r
    ct_it.append(rc_IT)

ct_it = pd.concat(ct_it)
ct_it['id'] = ct_it['radius_um'].astype(str) + "um, INT"

In [None]:
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)

xpos = np.arange(len(ct_msn) + len(ct_gp) + len(ct_it) + 2)
xmsn = xpos[:len(ct_msn)]
xgp = xpos[len(ct_msn)+1:len(ct_msn)+len(ct_gp)+1]
xit = xpos[len(ct_msn)+len(ct_gp)+2:len(ct_msn)+len(ct_gp)+len(ct_it)+2]
# xlabels = ct_msn['id'].tolist() + ct_gp['id'].tolist() + ct_it['id'].tolist()
xlabels = ct_msn['radius_um'].tolist() + ct_gp['radius_um'].tolist() + ct_it['radius_um'].tolist()

ax.bar(x=xmsn, height=ct_msn['mu'], yerr=ct_msn['se'], capsize=5, color='purple', alpha=0.7, width=0.9, edgecolor='black', rasterized=RASTERIZED)
for _i, _reg in enumerate(ct_msn['regions'].iloc[0]): 
    ax.plot(
        xmsn,
        [ct_msn['individual_mus'].iloc[i][_i] for i in range(len(ct_msn))],
        marker='o',
        linestyle='--',
        linewidth=1,
        markersize=5,
        color=adata.uns['brain_region_palette'].get(_reg, 'gray'),
        rasterized=RASTERIZED
    )

ax.bar(x=xgp, height=ct_gp['mu'], yerr=ct_gp['se'], capsize=5, color='green', alpha=0.7, width=0.9, edgecolor='black', rasterized=RASTERIZED)
for _i, _reg in enumerate(ct_gp['regions'].iloc[0]): 
    ax.plot(
        xgp,
        [ct_gp['individual_mus'].iloc[i][_i] for i in range(len(ct_gp))],
        marker='o',
        linestyle='--',
        linewidth=1,
        markersize=5,
        color=adata.uns['brain_region_palette'].get(_reg, 'gray'),
        rasterized=RASTERIZED
    )

ax.bar(x=xit, height=ct_it['mu'], yerr=ct_it['se'], capsize=5, color='blue', alpha=0.7, width=0.9, edgecolor='black', rasterized=RASTERIZED)
for _i, _reg in enumerate(ct_it['regions'].iloc[0]): 
    ax.plot(
        xit,
        [ct_it['individual_mus'].iloc[i][_i] for i in range(len(ct_it))],
        marker='o',
        linestyle='--',
        linewidth=1,
        markersize=5,
        color=adata.uns['brain_region_palette'].get(_reg, 'gray'),
        rasterized=RASTERIZED
    )

ax.axhline(0, color='black', linestyle='--', linewidth=1, rasterized=RASTERIZED)

# making legend: 
handles = []
for _reg in ct_it['regions'].iloc[0]: 
    handles.append(ax.plot([], [], color=adata.uns['brain_region_palette'].get(_reg, 'gray'), marker='o', linestyle='--', label=br_to_brc_map[_reg], rasterized=RASTERIZED))
ax.legend(handles=[h[0] for h in handles], loc='upper left', bbox_to_anchor=(1, 1), title='Brain Region', fontsize=12, title_fontsize=14)
ax.legend_ = None

ax.set_xticks(np.concatenate((xpos[xmsn], xpos[xgp], xpos[xit])))
# ax.set_xticklabels(xlabels, rotation=45, ha='right')
ax.set_xticklabels(xlabels, fontsize=14, rasterized=RASTERIZED)
sec = ax.secondary_xaxis(location=0)
sec.set_xticks([xmsn[1], xgp[1], xit[1]])
sec.set_xticklabels(['\nMSN', '\nGP', '\nINT'], fontsize=14, rasterized=RASTERIZED)

sec2 = ax.secondary_xaxis(location=0)
sec2.set_xticks([xmsn[0]-0.5, xmsn[-1]+0.5, xgp[0]-0.5, xgp[-1]+0.5, xit[0]-0.5, xit[-1]+0.5], labels=[])
sec2.tick_params('x', length=40, width=1.5)

# ax.set_xlabel('Radius and Cell Type Combination', y=-1)
ax.set_ylim((-2, 3))
ax.set_yticks(np.arange(-2, 4, 1))
ax.set_yticklabels(np.arange(-2, 4, 1), fontsize=14, rasterized=RASTERIZED)
ax.set_ylabel('Combined Pooled Z-score', fontsize=14, rasterized=RASTERIZED)

ax.set_title('Microglia - Neuron Pairs, By Radius', y=1.05, fontsize=24, rasterized=RASTERIZED)

plt.tight_layout()
plt.savefig(image_path / f'microglia_combined_contact_enrichment_allradii.png', bbox_inches='tight', dpi=300)
plt.savefig(image_path / f'microglia_combined_contact_enrichment_allradii.pdf', bbox_inches='tight', dpi=300)
plt.show()

### Oligo

In [None]:
ct_msn = []
rs = [15, 30, 50]
for r in rs: 
    DIR = Path(f'/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/meta_contacts/{r}um_2')
    df = pd.read_csv(DIR / "contacts_meta_input.csv")
    meta = pd.read_csv(DIR / "meta_region_pooled.csv")

    meta['ct1'] = meta['pair'].astype(str).str.split('|').str[0]
    meta['ct2'] = meta['pair'].astype(str).str.split('|').str[1]

    MSN_types = ["STR D1 MSN", "STR D2 MSN", "STR Hybrid MSN"]
    combined_MSN = combine_results(meta, MSN_types, "All_MSN", "Oligodendrocyte")
    rc_MSN = combine_regions(combined_MSN, col_pair='category')
    rc_MSN['radius_um'] = r
    ct_msn.append(rc_MSN)

ct_msn = pd.concat(ct_msn)
ct_msn['id'] = ct_msn['radius_um'].astype(str) + "um, MSN"

ct_gp = []
rs = [15, 30, 50]
for r in rs: 
    DIR = Path(f'/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/meta_contacts/{r}um_2')
    df = pd.read_csv(DIR / "contacts_meta_input.csv")
    meta = pd.read_csv(DIR / "meta_region_pooled.csv")

    meta['ct1'] = meta['pair'].astype(str).str.split('|').str[0]
    meta['ct2'] = meta['pair'].astype(str).str.split('|').str[1]

    GP_types = ["CN Cholinergic GABA", "CN LHX8 GABA", "CN MEIS2 GABA", "CN ONECUT1 GABA", "CN GABA-Glut"]
    combined_GP = combine_results(meta, GP_types, "All_GP", "Oligodendrocyte")
    rc_GP = combine_regions(combined_GP, col_pair='category', regions=['GP'])
    rc_GP['radius_um'] = r
    ct_gp.append(rc_GP)

ct_gp = pd.concat(ct_gp)
ct_gp['id'] = ct_gp['radius_um'].astype(str) + "um, GP"


ct_it = []
rs = [15, 30, 50]
for r in rs: 
    DIR = Path(f'/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/meta_contacts/{r}um_2')
    df = pd.read_csv(DIR / "contacts_meta_input.csv")
    meta = pd.read_csv(DIR / "meta_region_pooled.csv")

    meta['ct1'] = meta['pair'].astype(str).str.split('|').str[0]
    meta['ct2'] = meta['pair'].astype(str).str.split('|').str[1]

    IT_types = ["CN ST18 GABA", "STR RSPO2 GABA", "CN LAMP5-CXCL14 GABA", "CN LAMP5-LHX6 GABA", "CN VIP GABA"]
    combined_IT = combine_results(meta, IT_types, "All_IT", "Oligodendrocyte")

    rc_IT = combine_regions(combined_IT, col_pair='category', regions=['CAH', 'CAB', 'PU', 'NAC', 'CAT', 'GP', 'MGM1'])
    rc_IT['radius_um'] = r
    ct_it.append(rc_IT)

ct_it = pd.concat(ct_it)
ct_it['id'] = ct_it['radius_um'].astype(str) + "um, IT"

In [None]:
fig, ax = plt.subplots(figsize=(8, 5), dpi=100)

xpos = np.arange(len(ct_msn) + len(ct_gp) + len(ct_it) + 2)
xmsn = xpos[:len(ct_msn)]
xgp = xpos[len(ct_msn)+1:len(ct_msn)+len(ct_gp)+1]
xit = xpos[len(ct_msn)+len(ct_gp)+2:len(ct_msn)+len(ct_gp)+len(ct_it)+2]
# xlabels = ct_msn['id'].tolist() + ct_gp['id'].tolist() + ct_it['id'].tolist()
xlabels = ct_msn['radius_um'].tolist() + ct_gp['radius_um'].tolist() + ct_it['radius_um'].tolist()

ax.bar(x=xmsn, height=ct_msn['mu'], yerr=ct_msn['se'], capsize=5, color='purple', alpha=0.7, width=0.9, edgecolor='black', rasterized=RASTERIZED)
for _i, _reg in enumerate(ct_msn['regions'].iloc[0]): 
    ax.plot(
        xmsn,
        [ct_msn['individual_mus'].iloc[i][_i] for i in range(len(ct_msn))],
        marker='o',
        linestyle='--',
        linewidth=1,
        markersize=5,
        color=adata.uns['brain_region_palette'].get(_reg, 'gray'),
        rasterized=RASTERIZED
    )

ax.bar(x=xgp, height=ct_gp['mu'], yerr=ct_gp['se'], capsize=5, color='green', alpha=0.7, width=0.9, edgecolor='black', rasterized=RASTERIZED)
for _i, _reg in enumerate(ct_gp['regions'].iloc[0]): 
    ax.plot(
        xgp,
        [ct_gp['individual_mus'].iloc[i][_i] for i in range(len(ct_gp))],
        marker='o',
        linestyle='--',
        linewidth=1,
        markersize=5,
        color=adata.uns['brain_region_palette'].get(_reg, 'gray'),
        rasterized=RASTERIZED
    )

ax.bar(x=xit, height=ct_it['mu'], yerr=ct_it['se'], capsize=5, color='blue', alpha=0.7, width=0.9, edgecolor='black', rasterized=RASTERIZED)
for _i, _reg in enumerate(ct_it['regions'].iloc[0]): 
    ax.plot(
        xit,
        [ct_it['individual_mus'].iloc[i][_i] for i in range(len(ct_it))],
        marker='o',
        linestyle='--',
        linewidth=1,
        markersize=5,
        color=adata.uns['brain_region_palette'].get(_reg, 'gray'),
        rasterized=RASTERIZED
    )

ax.axhline(0, color='black', linestyle='--', linewidth=1, rasterized=RASTERIZED)

# making legend: 
handles = []
for _reg in ct_it['regions'].iloc[0]: 
    handles.append(ax.plot([], [], color=adata.uns['brain_region_palette'].get(_reg, 'gray'), marker='o', linestyle='--', label=br_to_brc_map[_reg], rasterized=RASTERIZED))
ax.legend(handles=[h[0] for h in handles], loc='upper left', bbox_to_anchor=(1, 1), title='Brain Region', fontsize=12, title_fontsize=14)
ax.legend_ = None

ax.set_xticks(np.concatenate((xpos[xmsn], xpos[xgp], xpos[xit])))
# ax.set_xticklabels(xlabels, rotation=45, ha='right')
ax.set_xticklabels(xlabels, fontsize=14, rasterized=RASTERIZED)
sec = ax.secondary_xaxis(location=0)
sec.set_xticks([xmsn[1], xgp[1], xit[1]])
sec.set_xticklabels(['\nMSN', '\nGP', '\nINT'], fontsize=14, rasterized=RASTERIZED)

sec2 = ax.secondary_xaxis(location=0)
sec2.set_xticks([xmsn[0]-0.5, xmsn[-1]+0.5, xgp[0]-0.5, xgp[-1]+0.5, xit[0]-0.5, xit[-1]+0.5], labels=[])
sec2.tick_params('x', length=40, width=1.5)

# ax.set_xlabel('Radius and Cell Type Combination', y=-1)
ax.set_ylim((-10, 5))
ax.set_yticks(np.arange(-10, 6, 2))
ax.set_yticklabels(np.arange(-10, 6, 2), fontsize=14, rasterized=RASTERIZED)
ax.set_ylabel('Combined Pooled Z-score', fontsize=14, rasterized=RASTERIZED)

ax.set_title('Oligodendrocyte - Neuron Pairs, By Radius', y=1.05, fontsize=24, rasterized=RASTERIZED)

plt.tight_layout()
plt.savefig(image_path / f'oligo_combined_contact_enrichment_allradii.png', bbox_inches='tight', dpi=300)
plt.savefig(image_path / f'oligo_combined_contact_enrichment_allradii.pdf', bbox_inches='tight', dpi=300)
plt.show()

# Split by Brain Region + Interaction Type (Fig 6 S. 10)

## Helper Functions

In [None]:
# Chat GPT
def combine_results(meta_df, cell_types_to_combine, new_category_name, comparison_ct):
    """
    Correctly combine multiple cell type meta-analysis results.
    This computes the AVERAGE effect across different cell type pairs,
    not a meta-analysis of the same effect.
    
    Parameters:
    -----------
    meta_df : pd.DataFrame
        DataFrame with columns: pair, brain_region, mu, se, etc.
    cell_types_to_combine : list
        List of cell types to combine
    new_category_name : str
        Name for the combined category
    
    Returns:
    --------
    pd.DataFrame with combined results
    """
    combined_results = []
    
    # Group by brain region
    for region, region_df in meta_df.groupby('brain_region'):
        # Filter for the cell types we want to combine
        subset = region_df[
            (region_df['ct1'].isin(cell_types_to_combine) & (region_df['ct2'] == comparison_ct)) | 
            (region_df['ct2'].isin(cell_types_to_combine) & (region_df['ct1'] == comparison_ct))
        ].copy()
        
        if subset.empty:
            continue
            
        # Extract means and standard errors
        mus = subset['mu'].values
        ses = subset['se'].values
        n = len(subset)
        
        # Simple average of the means (unweighted)
        combined_mu = np.mean(mus)
        
        # Standard error of the mean across different effects
        # This is: sqrt(sum(SE_i^2)) / n
        # This accounts for the uncertainty in each individual estimate
        combined_se = np.sqrt(np.sum(ses**2)) / n
        
        # Alternative: if you want to weight by precision
        # weights = 1 / (ses ** 2)
        # combined_mu = np.sum(weights * mus) / np.sum(weights)
        # combined_se = 1 / np.sqrt(np.sum(weights))
        
        combined_results.append({
            'brain_region': region,
            'category': new_category_name,
            'mu': combined_mu,
            'se': combined_se,
            'n_combined': n,
            'original_pairs': subset['pair'].tolist(),
            'individual_mus': mus.tolist(),
            'individual_ses': ses.tolist()
        })
    
    return pd.DataFrame(combined_results)

# Test the function
# MSN_types = ["STR D1 MSN", "STR D2 MSN", "STR Hybrid MSN"]
# combined_MSN_corrected = combine_results(meta, MSN_types, "All_MSN_corrected")

def plot_pair(
    meta,
    pair,
    pair_col = 'pair',
    order=None,
    ax=None,
    color='blue',
    label=None,
    rasterized=False,
    opacity=1.0,
    region_col = "brain_region",
    xlabel="Brain Region",
    ylabel="Pooled Z-score",
    title=None,
    axis_fontsize=14,
    ticklabel_fontsize=12, 
    title_fontsize=18,
    show_ticks=True
):
    
    meta_pair = meta[meta[pair_col] == pair]
    if order is not None:
        order = [region for region in order if region in meta_pair[region_col].values]
        meta_pair = meta_pair.set_index(region_col).loc[order].reset_index()
    else:
        meta_pair = meta_pair.sort_values(by='mu')
    # meta_pair = meta_pair if order else meta_pair.sort_values(by='mu')

    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 4))
    ax.errorbar(x=meta_pair[region_col], y=meta_pair['mu'], yerr=meta_pair['se'], fmt='o', color=color, capsize=5, label=label, alpha=opacity, rasterized=rasterized)
    ax.plot(meta_pair[region_col], meta_pair['mu'], color=color, linestyle='--', alpha=opacity, rasterized=rasterized)
    ax.axhline(0, color='black', linestyle='--', linewidth=1, rasterized=rasterized)
    ax.set_xlabel(xlabel, fontsize=axis_fontsize)
    ax.set_ylabel(ylabel, fontsize=axis_fontsize)
    
    if title is not None: 
        ax.set_title(title, fontsize=title_fontsize)
    
    if show_ticks: 
        ax.set_xticks(np.arange(len(meta_pair[region_col])), meta_pair[region_col], fontsize=ticklabel_fontsize)
    else: 
        ax.set_xticks([])

    # ax.set_yticks(fontsize=ticklabel_fontsize)

    return ax

## Run

In [None]:
title_fontsize=18
axis_fontsize=14
legend_fontsize=12
ticklabel_fontsize=14
RASTERIZED=False
rs = [15, 30, 50]

In [None]:
MSN_types = ["STR D1 MSN", "STR D2 MSN", "STR Hybrid MSN"]
GP_types = ["CN LHX8 GABA", "CN MEIS2 GABA", "CN ONECUT1 GABA", "CN GABA-Glut"]
INT_types = ["CN ST18 GABA", "STR RSPO2 GABA", "CN LAMP5-CXCL14 GABA", "CN LAMP5-LHX6 GABA", "CN VIP GABA"]

### Astrocyte interactions

In [None]:
types = MSN_types
combined_type_name = "ALL"
nn_comparison = "Astrocyte"
ORDER = ['CaH', 'CaB', 'CaT', 'Pu', 'NAC']
plot_label = "Combined MSNs"
neu_plot_label = "MSNs"

for r in rs: 
    DIR = Path(f'/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/meta_contacts/{r}um_2')
    df = pd.read_csv(DIR / "contacts_meta_input.csv")
    meta = pd.read_csv(DIR / "meta_region_pooled.csv")
    meta['brain_region_corr'] = meta['brain_region'].map(br_to_brc_map)

    meta['ct1'] = meta['pair'].astype(str).str.split('|').str[0]
    meta['ct2'] = meta['pair'].astype(str).str.split('|').str[1]

    combined = combine_results(meta, types, combined_type_name, nn_comparison)
    combined['brain_region_corr'] = combined['brain_region'].map(br_to_brc_map)

    fig, axs = plt.subplots(1, 1, figsize=(8, 5), dpi=100)
    for _type in types:
        pair = nn_comparison + "|" + _type
        sub = meta[(meta['ct1'] == nn_comparison) & (meta['ct2'] == _type)]
        if sub.empty: 
            pair = _type + "|" + nn_comparison  
            sub = meta[(meta['ct2'] == nn_comparison) & (meta['ct1'] == _type)]
        if sub.empty: 
            print(f"No data for {nn_comparison} and {_type}, skipping.")
            continue
        
        plot_pair(
            sub, pair, order=ORDER, ax=axs, 
            color=adata.uns['Subclass_palette'].get(_type, 'blue'),
            label=_type, opacity=0.5, rasterized=RASTERIZED,
            region_col="brain_region_corr",
            xlabel=None, ylabel=None, show_ticks=False,)

    plot_pair(
        combined, combined_type_name, pair_col="category", order=ORDER, ax=axs,
        color='red', label=plot_label, rasterized=RASTERIZED, region_col="brain_region_corr",
        xlabel=None, ylabel=None, ticklabel_fontsize=ticklabel_fontsize, show_ticks=True
    )

    axs.set_ylabel("Pooled Z-score", fontsize=axis_fontsize)
    yticks = axs.get_yticks()[::2]
    axs.set_yticks(yticks, yticks.astype(int), fontsize=ticklabel_fontsize)
    axs.set_title(f'Pooled Contact Z-scores for {neu_plot_label} - {nn_comparison} for {r}um', fontsize=title_fontsize)
    # axs.legend(loc='upper left', bbox_to_anchor=(1, 1), fontsize=legend_fontsize)

    plt.tight_layout()
    plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_{r}um.png', bbox_inches='tight', dpi=300)
    plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_{r}um.pdf', bbox_inches='tight', dpi=300)
    plt.show()
    plt.close()


# plot legend here: 
fig, ax = plt.subplots(figsize=(2, 1), dpi=300)
handles = []
for _type in types:
    handles.append(ax.plot([], [], color=adata.uns['Subclass_palette'].get(_type, 'blue'), marker='o', linestyle='--', label=_type, rasterized=RASTERIZED))
handles.append(ax.plot([], [], color='red', marker='o', linestyle='--', label=plot_label, rasterized=RASTERIZED))
labels = [_type for _type in types] + [plot_label]
fig.legend(
    handles=[h[0] for h in handles],
    labels=labels,
    bbox_to_anchor=(0.5, 0.5),
    loc='center',
    title=f"{neu_plot_label} Subtypes",
    fontsize=12,
    title_fontsize=14,
    frameon=True
)
ax.set_axis_off()
plt.tight_layout()
plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_legend.png', bbox_inches='tight', dpi=300)
plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_legend.pdf', bbox_inches='tight', dpi=300)
# plt.show()
plt.close()


In [None]:
types = INT_types
combined_type_name = "ALL"
nn_comparison = "Astrocyte"
ORDER = ['GP']
plot_label = "Combined GP"
neu_plot_label = "GP"

for r in rs: 
    DIR = Path(f'/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/meta_contacts/{r}um_2')
    df = pd.read_csv(DIR / "contacts_meta_input.csv")
    meta = pd.read_csv(DIR / "meta_region_pooled.csv")
    meta['brain_region_corr'] = meta['brain_region'].map(br_to_brc_map)

    meta['ct1'] = meta['pair'].astype(str).str.split('|').str[0]
    meta['ct2'] = meta['pair'].astype(str).str.split('|').str[1]

    combined = combine_results(meta, types, combined_type_name, nn_comparison)
    combined['brain_region_corr'] = combined['brain_region'].map(br_to_brc_map)

    fig, axs = plt.subplots(1, 1, figsize=(8, 5), dpi=100)
    for _type in types:
        pair = nn_comparison + "|" + _type
        sub = meta[(meta['ct1'] == nn_comparison) & (meta['ct2'] == _type)]
        if sub.empty: 
            pair = _type + "|" + nn_comparison  
            sub = meta[(meta['ct2'] == nn_comparison) & (meta['ct1'] == _type)]
        if sub.empty: 
            print(f"No data for {nn_comparison} and {_type}, skipping.")
            continue
        
        plot_pair(
            sub, pair, order=ORDER, ax=axs, 
            color=adata.uns['Subclass_palette'].get(_type, 'blue'),
            label=_type, opacity=0.5, rasterized=RASTERIZED,
            region_col="brain_region_corr",
            xlabel=None, ylabel=None, show_ticks=False,)

    plot_pair(
        combined, combined_type_name, pair_col="category", order=ORDER, ax=axs,
        color='red', label=plot_label, rasterized=RASTERIZED, region_col="brain_region_corr",
        xlabel=None, ylabel=None, ticklabel_fontsize=ticklabel_fontsize, show_ticks=True
    )

    axs.set_ylabel("Pooled Z-score", fontsize=axis_fontsize)
    yticks = axs.get_yticks()[::2]
    axs.set_yticks(yticks, yticks.astype(int), fontsize=ticklabel_fontsize)
    axs.set_title(f'Pooled Contact Z-scores for {neu_plot_label} - {nn_comparison} for {r}um', fontsize=title_fontsize)
    # axs.legend(loc='upper left', bbox_to_anchor=(1, 1), fontsize=legend_fontsize)

    plt.tight_layout()
    plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_{r}um.png', bbox_inches='tight', dpi=300)
    plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_{r}um.pdf', bbox_inches='tight', dpi=300)
    plt.show()
    plt.close()


# plot legend here: 
fig, ax = plt.subplots(figsize=(2, 1), dpi=300)
handles = []
for _type in types:
    handles.append(ax.plot([], [], color=adata.uns['Subclass_palette'].get(_type, 'blue'), marker='o', linestyle='--', label=_type, rasterized=RASTERIZED))
handles.append(ax.plot([], [], color='red', marker='o', linestyle='--', label=plot_label, rasterized=RASTERIZED))
labels = [_type for _type in types] + [plot_label]
fig.legend(
    handles=[h[0] for h in handles],
    labels=labels,
    bbox_to_anchor=(0.5, 0.5),
    loc='center',
    title=f"{neu_plot_label} Subtypes",
    fontsize=12,
    title_fontsize=14,
    frameon=True
)
ax.set_axis_off()
plt.tight_layout()
plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_legend.png', bbox_inches='tight', dpi=300)
plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_legend.pdf', bbox_inches='tight', dpi=300)
plt.show()
plt.close()


In [None]:
types = INT_types
combined_type_name = "ALL"
nn_comparison = "Astrocyte"
ORDER = ['CaH', 'CaB', 'CaT', 'Pu', 'GP', 'NAC', 'MGM1']
plot_label = "Combined INTs"
neu_plot_label = "INTs"

for r in rs: 
    DIR = Path(f'/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/meta_contacts/{r}um_2')
    df = pd.read_csv(DIR / "contacts_meta_input.csv")
    meta = pd.read_csv(DIR / "meta_region_pooled.csv")
    meta['brain_region_corr'] = meta['brain_region'].map(br_to_brc_map)

    meta['ct1'] = meta['pair'].astype(str).str.split('|').str[0]
    meta['ct2'] = meta['pair'].astype(str).str.split('|').str[1]

    combined = combine_results(meta, types, combined_type_name, nn_comparison)
    combined['brain_region_corr'] = combined['brain_region'].map(br_to_brc_map)

    fig, axs = plt.subplots(1, 1, figsize=(8, 5), dpi=100)
    for _type in types:
        pair = nn_comparison + "|" + _type
        sub = meta[(meta['ct1'] == nn_comparison) & (meta['ct2'] == _type)]
        if sub.empty: 
            pair = _type + "|" + nn_comparison  
            sub = meta[(meta['ct2'] == nn_comparison) & (meta['ct1'] == _type)]
        if sub.empty: 
            print(f"No data for {nn_comparison} and {_type}, skipping.")
            continue
        
        plot_pair(
            sub, pair, order=ORDER, ax=axs, 
            color=adata.uns['Subclass_palette'].get(_type, 'blue'),
            label=_type, opacity=0.5, rasterized=RASTERIZED,
            region_col="brain_region_corr",
            xlabel=None, ylabel=None, show_ticks=False,)

    plot_pair(
        combined, combined_type_name, pair_col="category", order=ORDER, ax=axs,
        color='red', label=plot_label, rasterized=RASTERIZED, region_col="brain_region_corr",
        xlabel=None, ylabel=None, ticklabel_fontsize=ticklabel_fontsize, show_ticks=True
    )

    axs.set_ylabel("Pooled Z-score", fontsize=axis_fontsize)
    yticks = axs.get_yticks()[::2]
    axs.set_yticks(yticks, yticks.astype(int), fontsize=ticklabel_fontsize)
    axs.set_title(f'Pooled Contact Z-scores for {neu_plot_label} - {nn_comparison} for {r}um', fontsize=title_fontsize)
    # axs.legend(loc='upper left', bbox_to_anchor=(1, 1), fontsize=legend_fontsize)

    plt.tight_layout()
    plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_{r}um.png', bbox_inches='tight', dpi=300)
    plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_{r}um.pdf', bbox_inches='tight', dpi=300)
    plt.show()
    plt.close()


# plot legend here: 
fig, ax = plt.subplots(figsize=(2, 1), dpi=300)
handles = []
for _type in types:
    handles.append(ax.plot([], [], color=adata.uns['Subclass_palette'].get(_type, 'blue'), marker='o', linestyle='--', label=_type, rasterized=RASTERIZED))
handles.append(ax.plot([], [], color='red', marker='o', linestyle='--', label=plot_label, rasterized=RASTERIZED))
labels = [_type for _type in types] + [plot_label]
fig.legend(
    handles=[h[0] for h in handles],
    labels=labels,
    bbox_to_anchor=(0.5, 0.5),
    loc='center',
    title=f"{neu_plot_label} Subtypes",
    fontsize=12,
    title_fontsize=14,
    frameon=True
)
ax.set_axis_off()
plt.tight_layout()
plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_legend.png', bbox_inches='tight', dpi=300)
plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_legend.pdf', bbox_inches='tight', dpi=300)
plt.show()
plt.close()


### Microglia Interactions

In [None]:
types = MSN_types
combined_type_name = "ALL"
nn_comparison = "Microglia"
ORDER = ['CaH', 'CaB', 'CaT', 'Pu', 'NAC']
plot_label = "Combined MSNs"
neu_plot_label = "MSNs"

for r in rs: 
    DIR = Path(f'/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/meta_contacts/{r}um_2')
    df = pd.read_csv(DIR / "contacts_meta_input.csv")
    meta = pd.read_csv(DIR / "meta_region_pooled.csv")
    meta['brain_region_corr'] = meta['brain_region'].map(br_to_brc_map)

    meta['ct1'] = meta['pair'].astype(str).str.split('|').str[0]
    meta['ct2'] = meta['pair'].astype(str).str.split('|').str[1]

    combined = combine_results(meta, types, combined_type_name, nn_comparison)
    combined['brain_region_corr'] = combined['brain_region'].map(br_to_brc_map)

    fig, axs = plt.subplots(1, 1, figsize=(8, 5), dpi=100)
    for _type in types:
        pair = nn_comparison + "|" + _type
        sub = meta[(meta['ct1'] == nn_comparison) & (meta['ct2'] == _type)]
        if sub.empty: 
            pair = _type + "|" + nn_comparison  
            sub = meta[(meta['ct2'] == nn_comparison) & (meta['ct1'] == _type)]
        if sub.empty: 
            print(f"No data for {nn_comparison} and {_type}, skipping.")
            continue
        
        plot_pair(
            sub, pair, order=ORDER, ax=axs, 
            color=adata.uns['Subclass_palette'].get(_type, 'blue'),
            label=_type, opacity=0.5, rasterized=RASTERIZED,
            region_col="brain_region_corr",
            xlabel=None, ylabel=None, show_ticks=False,)

    plot_pair(
        combined, combined_type_name, pair_col="category", order=ORDER, ax=axs,
        color='red', label=plot_label, rasterized=RASTERIZED, region_col="brain_region_corr",
        xlabel=None, ylabel=None, ticklabel_fontsize=ticklabel_fontsize, show_ticks=True
    )

    axs.set_ylabel("Pooled Z-score", fontsize=axis_fontsize)
    yticks = axs.get_yticks()[::2]
    axs.set_yticks(yticks, yticks.astype(int), fontsize=ticklabel_fontsize)
    axs.set_title(f'Pooled Contact Z-scores for {neu_plot_label} - {nn_comparison} for {r}um', fontsize=title_fontsize)
    # axs.legend(loc='upper left', bbox_to_anchor=(1, 1), fontsize=legend_fontsize)

    plt.tight_layout()
    plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_{r}um.png', bbox_inches='tight', dpi=300)
    plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_{r}um.pdf', bbox_inches='tight', dpi=300)
    plt.show()
    plt.close()


# plot legend here: 
fig, ax = plt.subplots(figsize=(2, 1), dpi=300)
handles = []
for _type in types:
    handles.append(ax.plot([], [], color=adata.uns['Subclass_palette'].get(_type, 'blue'), marker='o', linestyle='--', label=_type, rasterized=RASTERIZED))
handles.append(ax.plot([], [], color='red', marker='o', linestyle='--', label=plot_label, rasterized=RASTERIZED))
labels = [_type for _type in types] + [plot_label]
fig.legend(
    handles=[h[0] for h in handles],
    labels=labels,
    bbox_to_anchor=(0.5, 0.5),
    loc='center',
    title=f"{neu_plot_label} Subtypes",
    fontsize=12,
    title_fontsize=14,
    frameon=True
)
ax.set_axis_off()
plt.tight_layout()
# plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_legend.png', bbox_inches='tight', dpi=300)
# plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_legend.pdf', bbox_inches='tight', dpi=300)
plt.show()
plt.close()


In [None]:
types = INT_types
combined_type_name = "ALL"
nn_comparison = "Microglia"
ORDER = ['GP']
plot_label = "Combined GP"
neu_plot_label = "GP"

for r in rs: 
    DIR = Path(f'/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/meta_contacts/{r}um_2')
    df = pd.read_csv(DIR / "contacts_meta_input.csv")
    meta = pd.read_csv(DIR / "meta_region_pooled.csv")
    meta['brain_region_corr'] = meta['brain_region'].map(br_to_brc_map)

    meta['ct1'] = meta['pair'].astype(str).str.split('|').str[0]
    meta['ct2'] = meta['pair'].astype(str).str.split('|').str[1]

    combined = combine_results(meta, types, combined_type_name, nn_comparison)
    combined['brain_region_corr'] = combined['brain_region'].map(br_to_brc_map)

    fig, axs = plt.subplots(1, 1, figsize=(8, 5), dpi=100)
    for _type in types:
        pair = nn_comparison + "|" + _type
        sub = meta[(meta['ct1'] == nn_comparison) & (meta['ct2'] == _type)]
        if sub.empty: 
            pair = _type + "|" + nn_comparison  
            sub = meta[(meta['ct2'] == nn_comparison) & (meta['ct1'] == _type)]
        if sub.empty: 
            print(f"No data for {nn_comparison} and {_type}, skipping.")
            continue
        
        plot_pair(
            sub, pair, order=ORDER, ax=axs, 
            color=adata.uns['Subclass_palette'].get(_type, 'blue'),
            label=_type, opacity=0.5, rasterized=RASTERIZED,
            region_col="brain_region_corr",
            xlabel=None, ylabel=None, show_ticks=False,)

    plot_pair(
        combined, combined_type_name, pair_col="category", order=ORDER, ax=axs,
        color='red', label=plot_label, rasterized=RASTERIZED, region_col="brain_region_corr",
        xlabel=None, ylabel=None, ticklabel_fontsize=ticklabel_fontsize, show_ticks=True
    )

    axs.set_ylabel("Pooled Z-score", fontsize=axis_fontsize)
    yticks = axs.get_yticks()[::2]
    axs.set_yticks(yticks, yticks.astype(int), fontsize=ticklabel_fontsize)
    axs.set_title(f'Pooled Contact Z-scores for {neu_plot_label} - {nn_comparison} for {r}um', fontsize=title_fontsize)
    # axs.legend(loc='upper left', bbox_to_anchor=(1, 1), fontsize=legend_fontsize)

    plt.tight_layout()
    plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_{r}um.png', bbox_inches='tight', dpi=300)
    plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_{r}um.pdf', bbox_inches='tight', dpi=300)
    plt.show()
    plt.close()


# plot legend here: 
fig, ax = plt.subplots(figsize=(2, 1), dpi=300)
handles = []
for _type in types:
    handles.append(ax.plot([], [], color=adata.uns['Subclass_palette'].get(_type, 'blue'), marker='o', linestyle='--', label=_type, rasterized=RASTERIZED))
handles.append(ax.plot([], [], color='red', marker='o', linestyle='--', label=plot_label, rasterized=RASTERIZED))
labels = [_type for _type in types] + [plot_label]
fig.legend(
    handles=[h[0] for h in handles],
    labels=labels,
    bbox_to_anchor=(0.5, 0.5),
    loc='center',
    title=f"{neu_plot_label} Subtypes",
    fontsize=12,
    title_fontsize=14,
    frameon=True
)
ax.set_axis_off()
plt.tight_layout()
# plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_legend.png', bbox_inches='tight', dpi=300)
# plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_legend.pdf', bbox_inches='tight', dpi=300)
plt.show()
plt.close()


In [None]:
types = INT_types
combined_type_name = "ALL"
nn_comparison = "Microglia"
ORDER = ['CaH', 'CaB', 'CaT', 'Pu', 'GP', 'NAC', 'MGM1']
plot_label = "Combined INTs"
neu_plot_label = "INTs"

for r in rs: 
    DIR = Path(f'/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/meta_contacts/{r}um_2')
    df = pd.read_csv(DIR / "contacts_meta_input.csv")
    meta = pd.read_csv(DIR / "meta_region_pooled.csv")
    meta['brain_region_corr'] = meta['brain_region'].map(br_to_brc_map)

    meta['ct1'] = meta['pair'].astype(str).str.split('|').str[0]
    meta['ct2'] = meta['pair'].astype(str).str.split('|').str[1]

    combined = combine_results(meta, types, combined_type_name, nn_comparison)
    combined['brain_region_corr'] = combined['brain_region'].map(br_to_brc_map)

    fig, axs = plt.subplots(1, 1, figsize=(8, 5), dpi=100)
    for _type in types:
        pair = nn_comparison + "|" + _type
        sub = meta[(meta['ct1'] == nn_comparison) & (meta['ct2'] == _type)]
        if sub.empty: 
            pair = _type + "|" + nn_comparison  
            sub = meta[(meta['ct2'] == nn_comparison) & (meta['ct1'] == _type)]
        if sub.empty: 
            print(f"No data for {nn_comparison} and {_type}, skipping.")
            continue
        
        plot_pair(
            sub, pair, order=ORDER, ax=axs, 
            color=adata.uns['Subclass_palette'].get(_type, 'blue'),
            label=_type, opacity=0.5, rasterized=RASTERIZED,
            region_col="brain_region_corr",
            xlabel=None, ylabel=None, show_ticks=False,)

    plot_pair(
        combined, combined_type_name, pair_col="category", order=ORDER, ax=axs,
        color='red', label=plot_label, rasterized=RASTERIZED, region_col="brain_region_corr",
        xlabel=None, ylabel=None, ticklabel_fontsize=ticklabel_fontsize, show_ticks=True
    )

    axs.set_ylabel("Pooled Z-score", fontsize=axis_fontsize)
    yticks = axs.get_yticks()[::2]
    axs.set_yticks(yticks, yticks.astype(int), fontsize=ticklabel_fontsize)
    axs.set_title(f'Pooled Contact Z-scores for {neu_plot_label} - {nn_comparison} for {r}um', fontsize=title_fontsize)
    # axs.legend(loc='upper left', bbox_to_anchor=(1, 1), fontsize=legend_fontsize)

    plt.tight_layout()
    plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_{r}um.png', bbox_inches='tight', dpi=300)
    plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_{r}um.pdf', bbox_inches='tight', dpi=300)
    plt.show()
    plt.close()


# plot legend here: 
fig, ax = plt.subplots(figsize=(2, 1), dpi=300)
handles = []
for _type in types:
    handles.append(ax.plot([], [], color=adata.uns['Subclass_palette'].get(_type, 'blue'), marker='o', linestyle='--', label=_type, rasterized=RASTERIZED))
handles.append(ax.plot([], [], color='red', marker='o', linestyle='--', label=plot_label, rasterized=RASTERIZED))
labels = [_type for _type in types] + [plot_label]
fig.legend(
    handles=[h[0] for h in handles],
    labels=labels,
    bbox_to_anchor=(0.5, 0.5),
    loc='center',
    title=f"{neu_plot_label} Subtypes",
    fontsize=12,
    title_fontsize=14,
    frameon=True
)
ax.set_axis_off()
plt.tight_layout()
# plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_legend.png', bbox_inches='tight', dpi=300)
# plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_legend.pdf', bbox_inches='tight', dpi=300)
plt.show()
plt.close()


### Oligodendrocyte Interactions

In [None]:
types = MSN_types
combined_type_name = "ALL"
nn_comparison = "Oligodendrocyte"
ORDER = ['CaH', 'CaB', 'CaT', 'Pu', 'NAC']
plot_label = "Combined MSNs"
neu_plot_label = "MSNs"

for r in rs: 
    DIR = Path(f'/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/meta_contacts/{r}um_2')
    df = pd.read_csv(DIR / "contacts_meta_input.csv")
    meta = pd.read_csv(DIR / "meta_region_pooled.csv")
    meta['brain_region_corr'] = meta['brain_region'].map(br_to_brc_map)

    meta['ct1'] = meta['pair'].astype(str).str.split('|').str[0]
    meta['ct2'] = meta['pair'].astype(str).str.split('|').str[1]

    combined = combine_results(meta, types, combined_type_name, nn_comparison)
    combined['brain_region_corr'] = combined['brain_region'].map(br_to_brc_map)

    fig, axs = plt.subplots(1, 1, figsize=(8, 5), dpi=100)
    for _type in types:
        pair = nn_comparison + "|" + _type
        sub = meta[(meta['ct1'] == nn_comparison) & (meta['ct2'] == _type)]
        if sub.empty: 
            pair = _type + "|" + nn_comparison  
            sub = meta[(meta['ct2'] == nn_comparison) & (meta['ct1'] == _type)]
        if sub.empty: 
            print(f"No data for {nn_comparison} and {_type}, skipping.")
            continue
        
        plot_pair(
            sub, pair, order=ORDER, ax=axs, 
            color=adata.uns['Subclass_palette'].get(_type, 'blue'),
            label=_type, opacity=0.5, rasterized=RASTERIZED,
            region_col="brain_region_corr",
            xlabel=None, ylabel=None, show_ticks=False,)

    plot_pair(
        combined, combined_type_name, pair_col="category", order=ORDER, ax=axs,
        color='red', label=plot_label, rasterized=RASTERIZED, region_col="brain_region_corr",
        xlabel=None, ylabel=None, ticklabel_fontsize=ticklabel_fontsize, show_ticks=True
    )

    axs.set_ylabel("Pooled Z-score", fontsize=axis_fontsize)
    yticks = axs.get_yticks()[::2]
    axs.set_yticks(yticks, yticks.astype(int), fontsize=ticklabel_fontsize)
    axs.set_title(f'Pooled Contact Z-scores for {neu_plot_label} - {nn_comparison} for {r}um', fontsize=title_fontsize)
    # axs.legend(loc='upper left', bbox_to_anchor=(1, 1), fontsize=legend_fontsize)

    plt.tight_layout()
    plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_{r}um.png', bbox_inches='tight', dpi=300)
    plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_{r}um.pdf', bbox_inches='tight', dpi=300)
    plt.show()
    plt.close()


# plot legend here: 
fig, ax = plt.subplots(figsize=(2, 1), dpi=300)
handles = []
for _type in types:
    handles.append(ax.plot([], [], color=adata.uns['Subclass_palette'].get(_type, 'blue'), marker='o', linestyle='--', label=_type, rasterized=RASTERIZED))
handles.append(ax.plot([], [], color='red', marker='o', linestyle='--', label=plot_label, rasterized=RASTERIZED))
labels = [_type for _type in types] + [plot_label]
fig.legend(
    handles=[h[0] for h in handles],
    labels=labels,
    bbox_to_anchor=(0.5, 0.5),
    loc='center',
    title=f"{neu_plot_label} Subtypes",
    fontsize=12,
    title_fontsize=14,
    frameon=True
)
ax.set_axis_off()
plt.tight_layout()
# plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_legend.png', bbox_inches='tight', dpi=300)
# plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_legend.pdf', bbox_inches='tight', dpi=300)
# plt.show()
plt.close()


In [None]:
types = INT_types
combined_type_name = "ALL"
nn_comparison = "Oligodendrocyte"
ORDER = ['GP']
plot_label = "Combined GP"
neu_plot_label = "GP"

for r in rs: 
    DIR = Path(f'/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/meta_contacts/{r}um_2')
    df = pd.read_csv(DIR / "contacts_meta_input.csv")
    meta = pd.read_csv(DIR / "meta_region_pooled.csv")
    meta['brain_region_corr'] = meta['brain_region'].map(br_to_brc_map)

    meta['ct1'] = meta['pair'].astype(str).str.split('|').str[0]
    meta['ct2'] = meta['pair'].astype(str).str.split('|').str[1]

    combined = combine_results(meta, types, combined_type_name, nn_comparison)
    combined['brain_region_corr'] = combined['brain_region'].map(br_to_brc_map)

    fig, axs = plt.subplots(1, 1, figsize=(8, 5), dpi=100)
    for _type in types:
        pair = nn_comparison + "|" + _type
        sub = meta[(meta['ct1'] == nn_comparison) & (meta['ct2'] == _type)]
        if sub.empty: 
            pair = _type + "|" + nn_comparison  
            sub = meta[(meta['ct2'] == nn_comparison) & (meta['ct1'] == _type)]
        if sub.empty: 
            print(f"No data for {nn_comparison} and {_type}, skipping.")
            continue
        
        plot_pair(
            sub, pair, order=ORDER, ax=axs, 
            color=adata.uns['Subclass_palette'].get(_type, 'blue'),
            label=_type, opacity=0.5, rasterized=RASTERIZED,
            region_col="brain_region_corr",
            xlabel=None, ylabel=None, show_ticks=False,)

    plot_pair(
        combined, combined_type_name, pair_col="category", order=ORDER, ax=axs,
        color='red', label=plot_label, rasterized=RASTERIZED, region_col="brain_region_corr",
        xlabel=None, ylabel=None, ticklabel_fontsize=ticklabel_fontsize, show_ticks=True
    )

    axs.set_ylabel("Pooled Z-score", fontsize=axis_fontsize)
    yticks = axs.get_yticks()[::2]
    axs.set_yticks(yticks, yticks.astype(int), fontsize=ticklabel_fontsize)
    axs.set_title(f'Pooled Contact Z-scores for {neu_plot_label} - {nn_comparison} for {r}um', fontsize=title_fontsize)
    # axs.legend(loc='upper left', bbox_to_anchor=(1, 1), fontsize=legend_fontsize)

    plt.tight_layout()
    plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_{r}um.png', bbox_inches='tight', dpi=300)
    plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_{r}um.pdf', bbox_inches='tight', dpi=300)
    plt.show()
    plt.close()


# plot legend here: 
fig, ax = plt.subplots(figsize=(2, 1), dpi=300)
handles = []
for _type in types:
    handles.append(ax.plot([], [], color=adata.uns['Subclass_palette'].get(_type, 'blue'), marker='o', linestyle='--', label=_type, rasterized=RASTERIZED))
handles.append(ax.plot([], [], color='red', marker='o', linestyle='--', label=plot_label, rasterized=RASTERIZED))
labels = [_type for _type in types] + [plot_label]
fig.legend(
    handles=[h[0] for h in handles],
    labels=labels,
    bbox_to_anchor=(0.5, 0.5),
    loc='center',
    title=f"{neu_plot_label} Subtypes",
    fontsize=12,
    title_fontsize=14,
    frameon=True
)
ax.set_axis_off()
plt.tight_layout()
# plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_legend.png', bbox_inches='tight', dpi=300)
# plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_legend.pdf', bbox_inches='tight', dpi=300)
plt.show()
plt.close()


In [None]:
types = INT_types
combined_type_name = "ALL"
nn_comparison = "Oligodendrocyte"
ORDER = ['CaH', 'CaB', 'CaT', 'Pu', 'GP', 'NAC', 'MGM1']
plot_label = "Combined INTs"
neu_plot_label = "INTs"

for r in rs: 
    DIR = Path(f'/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/meta_contacts/{r}um_2')
    df = pd.read_csv(DIR / "contacts_meta_input.csv")
    meta = pd.read_csv(DIR / "meta_region_pooled.csv")
    meta['brain_region_corr'] = meta['brain_region'].map(br_to_brc_map)

    meta['ct1'] = meta['pair'].astype(str).str.split('|').str[0]
    meta['ct2'] = meta['pair'].astype(str).str.split('|').str[1]

    combined = combine_results(meta, types, combined_type_name, nn_comparison)
    combined['brain_region_corr'] = combined['brain_region'].map(br_to_brc_map)

    fig, axs = plt.subplots(1, 1, figsize=(8, 5), dpi=100)
    for _type in types:
        pair = nn_comparison + "|" + _type
        sub = meta[(meta['ct1'] == nn_comparison) & (meta['ct2'] == _type)]
        if sub.empty: 
            pair = _type + "|" + nn_comparison  
            sub = meta[(meta['ct2'] == nn_comparison) & (meta['ct1'] == _type)]
        if sub.empty: 
            print(f"No data for {nn_comparison} and {_type}, skipping.")
            continue
        
        plot_pair(
            sub, pair, order=ORDER, ax=axs, 
            color=adata.uns['Subclass_palette'].get(_type, 'blue'),
            label=_type, opacity=0.5, rasterized=RASTERIZED,
            region_col="brain_region_corr",
            xlabel=None, ylabel=None, show_ticks=False,)

    plot_pair(
        combined, combined_type_name, pair_col="category", order=ORDER, ax=axs,
        color='red', label=plot_label, rasterized=RASTERIZED, region_col="brain_region_corr",
        xlabel=None, ylabel=None, ticklabel_fontsize=ticklabel_fontsize, show_ticks=True
    )

    axs.set_ylabel("Pooled Z-score", fontsize=axis_fontsize)
    yticks = axs.get_yticks()[::2]
    axs.set_yticks(yticks, yticks.astype(int), fontsize=ticklabel_fontsize)
    axs.set_title(f'Pooled Contact Z-scores for {neu_plot_label} - {nn_comparison} for {r}um', fontsize=title_fontsize)
    # axs.legend(loc='upper left', bbox_to_anchor=(1, 1), fontsize=legend_fontsize)

    plt.tight_layout()
    plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_{r}um.png', bbox_inches='tight', dpi=300)
    plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_{r}um.pdf', bbox_inches='tight', dpi=300)
    plt.show()
    plt.close()


# plot legend here: 
fig, ax = plt.subplots(figsize=(2, 1), dpi=300)
handles = []
for _type in types:
    handles.append(ax.plot([], [], color=adata.uns['Subclass_palette'].get(_type, 'blue'), marker='o', linestyle='--', label=_type, rasterized=RASTERIZED))
handles.append(ax.plot([], [], color='red', marker='o', linestyle='--', label=plot_label, rasterized=RASTERIZED))
labels = [_type for _type in types] + [plot_label]
fig.legend(
    handles=[h[0] for h in handles],
    labels=labels,
    bbox_to_anchor=(0.5, 0.5),
    loc='center',
    title=f"{neu_plot_label} Subtypes",
    fontsize=12,
    title_fontsize=14,
    frameon=True
)
ax.set_axis_off()
plt.tight_layout()
# plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_legend.png', bbox_inches='tight', dpi=300)
# plt.savefig(image_path / f'{nn_comparison.lower()}_{neu_plot_label.lower()}_contact_enrichment_legend.pdf', bbox_inches='tight', dpi=300)
plt.show()
plt.close()


# Heterogeneity Plots

## Helper Functions

### Ternary plotting functions
from ChatGPT

In [None]:
def ternary_to_cartesian(A, B, C):
    """Convert ternary coordinates (A, B, C) into 2D Cartesian."""
    x = 0.5 * (2*B + C)
    y = (np.sqrt(3)/2) * C
    return x, y


def normalize_ternary(A, B, C):
    """Ensure A+B+C = 1."""
    A = np.array(A)
    B = np.array(B)
    C = np.array(C)
    total = A + B + C
    return A/total, B/total, C/total


def draw_ternary_grid(ax, n=5, color="lightgray", lw=0.8, alpha=0.8):
    """
    Draw triangular grid lines on the ternary plot.
    n : number of subdivisions on each side.
    """
    ticks = np.linspace(0, 1, n+2)[1:-1]  # interior ticks

    # iso-A lines: A = const → B + C = 1-A
    for a in ticks:
        B = np.linspace(0, 1-a, 200)
        C = (1 - a) - B
        x, y = ternary_to_cartesian(a*np.ones_like(B), B, C)
        ax.plot(x, y, color=color, lw=lw, alpha=alpha)

    # iso-B lines: B = const → A + C = 1-B
    for b in ticks:
        A = np.linspace(0, 1-b, 200)
        C = (1 - b) - A
        x, y = ternary_to_cartesian(A, b*np.ones_like(A), C)
        ax.plot(x, y, color=color, lw=lw, alpha=alpha)

    # iso-C lines: C = const → A + B = 1-C
    for c in ticks:
        A = np.linspace(0, 1-c, 200)
        B = (1 - c) - A
        x, y = ternary_to_cartesian(A, B, c*np.ones_like(A))
        ax.plot(x, y, color=color, lw=lw, alpha=alpha)



def ternary_contour(ax, func, levels=10, cmap="viridis", alpha=0.8, resolution=150):
    """
    Draw contour map of func(A,B,C) over the ternary triangle.

    func : callable
        func(A, B, C) must accept arrays and return an array.
    """
    # Generate triangular domain mesh
    A_vals = np.linspace(0, 1, resolution)
    B_vals = np.linspace(0, 1, resolution)

    A_mesh, B_mesh = np.meshgrid(A_vals, B_vals)
    C_mesh = 1 - A_mesh - B_mesh

    # Mask outside the triangle
    mask = (C_mesh >= 0)
    A_dom = A_mesh[mask]
    B_dom = B_mesh[mask]
    C_dom = C_mesh[mask]

    # Evaluate function
    Z = np.full_like(A_mesh, np.nan, dtype=float)
    Z_masked = func(A_dom, B_dom, C_dom)
    Z[mask] = Z_masked

    # Convert to cartesian coordinates
    X, Y = ternary_to_cartesian(A_mesh, B_mesh, C_mesh)

    # Plot contour
    cs = ax.tricontourf(
        X[mask], Y[mask], Z_masked,
        levels=levels, cmap=cmap, alpha=alpha
    )
    return cs

    
def barycentric_color(A, B, C, colA, colB, colC):
    """
    Blend three RGB colors using barycentric coordinates A, B, C.
    Inputs:
        A, B, C : arrays summing to 1
        colA, colB, colC : tuples (R,G,B)
    Returns:
        Nx3 RGB array
    """
    A = np.asarray(A)[:,None]
    B = np.asarray(B)[:,None]
    C = np.asarray(C)[:,None]

    colA = np.asarray(colA)[None,:]
    colB = np.asarray(colB)[None,:]
    colC = np.asarray(colC)[None,:]

    return A*colA + B*colB + C*colC


In [None]:
def ternary_plot_colored(
    ax,
    A, B, C,
    axis_colors=((1,0,0), (0,1,0), (0,0,1)),
    labels=("A", "B", "C"),
    point_labels=None,        # <-- NEW: list of labels for points
    label_top_n=0,            # <-- NEW: how many top one-sided points to label
    label_kws=None,           # <-- NEW: styling for text labels
    grid=True,
    grid_n=5,
    contour_func=None,
    contour_levels=10,
    contour_cmap="viridis",
    point_size=40,
    axis_fontsize=12, 
    title_fontsize=14, 
    title=None, 
):
    """
    Ternary plot with barycentric coloring + optional labeling of top-N
    most one-sided points.
    """

    # Normalize
    A, B, C = normalize_ternary(A, B, C)

    # Convert to Cartesian
    x, y = ternary_to_cartesian(A, B, C)

    # Barycentric color mixing
    colA, colB, colC = axis_colors
    colors = barycentric_color(A, B, C, colA, colB, colC)

    # --- Optional contour ---
    if contour_func is not None:
        ternary_contour(ax, contour_func,
                        levels=contour_levels,
                        cmap=contour_cmap)

    # Draw triangle
    tri_x = [0, 1, 0.5, 0]
    tri_y = [0, 0, np.sqrt(3)/2, 0]
    ax.plot(tri_x, tri_y, color="black", lw=1.5)

    # Grid
    if grid:
        draw_ternary_grid(ax, n=grid_n)

    # --- Draw points ---
    ax.scatter(
        x, y,
        c=colors,
        s=point_size,
        edgecolor="black",
        linewidth=0.3,
        zorder=4
    )

    # --- Label top-N most one-sided points ---
    if label_top_n > 0:
        scores = np.max(np.vstack([A, B, C]), axis=0)
        top_idx = np.argsort(scores)[-label_top_n:]  # highest scores

        # Default point labels = index
        if point_labels is None:
            point_labels = [str(i) for i in range(len(A))]

        if label_kws is None:
            label_kws = dict(color="black", fontsize=10, ha="center", va="bottom")

        for idx in top_idx:
            ax.text(x[idx], y[idx], point_labels[idx], zorder=6, **label_kws)

    # Remove ticks, fix scaling
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect("equal")

    # Corner labels
    ax.text(0, 0, labels[0], ha="right", va="top", color=axis_colors[0], fontsize=axis_fontsize)
    ax.text(1, 0, labels[1], ha="left", va="top", color=axis_colors[1], fontsize=axis_fontsize)
    ax.text(0.5, np.sqrt(3)/2, labels[2], ha="center", va="bottom", color=axis_colors[2], fontsize=axis_fontsize)

    ax.set_title(title, fontsize=title_fontsize)

    return ax


### Stats

## Run

In [None]:
RADIUS = 50
DIR = Path(f"/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/meta_contacts/{RADIUS}um_2")
df = pd.read_csv(DIR / "contacts_meta_input.csv")
meta = pd.read_csv(DIR / "meta_region_pooled.csv")

ct1, ct2 = 'cell_type1', 'cell_type2'
df['pair'] = df[ct1].astype(str) + '|' + df[ct2].astype(str)

In [None]:
df_pair_vc = df.groupby("pair", observed=True).size()
df_pair_vc = df_pair_vc[df_pair_vc >= 8]
keep_pairs = df_pair_vc.index.tolist()
df = df[df['pair'].isin(keep_pairs)].copy()

### Heterogeneity Comparison

In [None]:
results = []
for _ct in df['pair'].unique(): 
    sub = df[df['pair'] == _ct]
    sub = sub.sort_values(by='brain_region')
    rep_het = sub.groupby(["brain_region", "donor"], observed=True)['z_score'].var().median()
    don_het = sub.groupby(["brain_region", "replicate"], observed=True)['z_score'].var().median()
    brc_het = sub.groupby(["donor", "replicate"], observed=True)['z_score'].var().dropna().median()
    results.append([_ct, rep_het, don_het, brc_het])

df_het = pd.DataFrame(results, columns=['pair', 'replicate_heterogeneity', 'donor_heterogeneity', 'brain_region_heterogeneity'])
df_het= df_het.loc[~df_het.isna().sum(axis=1).astype(bool)]
df_het.shape

In [None]:
label_A = "Replicate"
label_B = "Donor"
label_C = "Brain Region"

A = df_het.iloc[:, 1]
B = df_het.iloc[:, 2]
C = df_het.iloc[:, 3]

fig, ax = plt.subplots(figsize=(6,6))

ternary_plot_colored(
    ax,
    A, B, C,
    point_size=20,
    labels=(label_A, label_B, label_C),
    axis_colors = [(0.18, 0.18, 0.65),(0.62, 0.92, 0.61),(1.00, 0.50, 0.42),],
    point_labels=df_het['pair'].tolist(),
    label_top_n=0, 
    label_kws=dict(fontsize=4, color='black'),
    axis_fontsize=14, 
    title=f"Heterogeneity Ratios (Radius {RADIUS}um)",
    title_fontsize=18
)
ax.axis("off")

plt.tight_layout()
plt.savefig(image_path / f'contact_enrichment_heterogeneity_ternary_r{RADIUS}.png', bbox_inches='tight', dpi=300)
plt.savefig(image_path / f'contact_enrichment_heterogeneity_ternary_r{RADIUS}.pdf', bbox_inches='tight', dpi=300)
plt.show()
plt.close()

### Total Heterogeneity Comparison by cell types

In [None]:
vasculature_cells = ['VLMC', "SMC", "Pericyte", "Endo", "Ependymal", "Lymphocyte"]
pairs = list(itertools.product(vasculature_cells, repeat=2))
pairs = [f"{a}|{b}" for a,b in pairs]
pairs = [_pair for _pair in pairs if _pair in df['pair'].unique() ]

In [None]:
vv_all = df.groupby('pair', observed=True)['z_score'].var().to_frame()
vv_all['condition'] = "all"
vv_all.loc[pairs, 'condition'] = "vasc"

In [None]:
adata.obs['Neighborhood']

### Example Results

In [None]:
df[df[ct2] == _ct2][ct1].unique()   

In [None]:
df[df[ct1] == _ct1][ct2].unique()   

In [None]:
_ct2 = "Pericyte"
_ct1 = "Pericyte"
pair = f"{_ct1}|{_ct2}"
for _region in adata.obs['brain_region'].unique(): 
    # subset for that pair/region
    d = df[(df["cell_type1"] == _ct1) & (df["cell_type2"] == _ct2) & (df["brain_region"] == _region)]
    if d.empty: 
        d = df[(df["cell_type1"] == _ct2) & (df["cell_type2"] == _ct1) & (df["brain_region"] == _region)]
    if d.empty:
        continue
    m = meta[(meta["pair"] == pair) & (meta["brain_region"] == _region)]

    fig, ax = plt.subplots(figsize=(4, 2))
    ax.scatter(d["z_score"], d["donor"], color="purple", s=50, marker='x')
    ax.axvline(0, color="k", lw=1)
    ax.axvline(m["mu"].values[0], color="red", lw=2, label="pooled z")
    ax.set_title(f"{pair} in {br_to_brc_map[_region]}")
    ax.set_xlabel("z-score vs null")
    ax.legend()
    plt.tight_layout()
    plt.savefig(image_path / f'vasculature_pair_{_ct1}_{_ct2}_{_region}.png', bbox_inches='tight', dpi=300)
    plt.savefig(image_path / f'vasculature_pair_{_ct1}_{_ct2}_{_region}.pdf', bbox_inches='tight', dpi=300)
    plt.show()
    plt.close()

In [None]:
# _ct2 = "VLMC"
# _ct1 = "Endo"
# pair = f"{_ct1}|{_ct2}"
# region = "CAH"

# # subset for that pair/region
# d = df[(df["cell_type1"] == _ct1) & (df["cell_type2"] == _ct2) & (df["brain_region"] == region)]
# m = meta[(meta["pair"] == pair) & (meta["brain_region"] == region)]

# fig, ax = plt.subplots(figsize=(4, 2))
# ax.scatter(d["z_score"], d["donor"], color="purple", s=50, marker='x')
# ax.axvline(0, color="k", lw=1)
# ax.axvline(m["mu"].values[0], color="red", lw=2, label="pooled z")
# ax.set_title(f"{pair} in {region}")
# ax.set_xlabel("z-score vs null")
# ax.legend()
# plt.show()
# plt.close()


# Plot Example Real Vs. Null Distributions Fig 6S9

## Helpers

In [None]:
def get_cells_in_radius(adata, center, radius, cols=['CENTER_X', 'CENTER_Y']):
    return adata[adata.obs[cols].apply(lambda x: (x[cols[0]] - center[0])**2 + (x[cols[1]] - center[1])**2 < radius**2, axis=1)].copy()


def get_cell_by_cell_contacts(
    data : ad.AnnData | pd.DataFrame, 
    cell_type_col = "Subclass", 
    spatial_keys = ['center_x', 'center_y'],
    cell_type_list = None,
    radius = 50,
): 
    if isinstance(data, ad.AnnData):
        data = data.obs.copy()

    if cell_type_list is None:
        cell_type_list = np.unique(data[cell_type_col])
    N_cell_types = len(cell_type_list)
    contact_counts = np.zeros((N_cell_types, N_cell_types), dtype=int)

    coords = data[spatial_keys].values
    cell_types = data[cell_type_col].values
    cell_type_to_idx = {ct: i for i, ct in enumerate(cell_type_list)}

    for i in range(data.shape[0]):
        ct_i = cell_types[i]
        idx_i = cell_type_to_idx[ct_i]
        coord_i = coords[i]
        
        dists = np.linalg.norm(coords - coord_i, axis=1)
        neighbors = np.where((dists > 1e-4) & (dists <= radius))[0]
        
        for j in neighbors:
            ct_j = cell_types[j]
            idx_j = cell_type_to_idx[ct_j]
            contact_counts[idx_i, idx_j] += 1
    
    # Normalize by number of cells per cell type to make cell type specific
    N_cells = data.groupby(cell_type_col, observed=True).size().to_dict()
    divide = np.asarray([N_cells[i] for i in cell_type_list])
    norm_counts = contact_counts / divide

    return norm_counts, cell_type_list
    
# functions from xingjiepan 2023 mouse atlas paper
def adjust_p_value_matrix_by_BH(p_val_mtx):
    '''Adjust the p-values in a matrix by the Benjamini/Hochberg method.
    The matrix should be symmetric.
    '''
    p_val_sequential = []
    N = p_val_mtx.shape[0]
    
    for i in range(N):
        for j in range(i, N):
            p_val_sequential.append(p_val_mtx[i, j])

    p_val_sequential_bh = multipletests(p_val_sequential, method='fdr_bh')[1]
    
    adjusted_p_val_mtx = np.zeros((N, N))
    
    counter = 0
    for i in range(N):
        for j in range(i, N):
            adjusted_p_val_mtx[i, j] = p_val_sequential_bh[counter]
            adjusted_p_val_mtx[j, i] = p_val_sequential_bh[counter]
            counter += 1
            
    return adjusted_p_val_mtx

def get_data_frame_from_metrices(cell_types, mtx_dict):
    N = len(cell_types)
    
    serials_dict = {'cell_type1':[], 'cell_type2':[]}
    for k in mtx_dict.keys():
        serials_dict[k] = []
        
    for i in range(N):
        for j in range(N):
            serials_dict['cell_type1'].append(cell_types[i])
            serials_dict['cell_type2'].append(cell_types[j])
            for k in mtx_dict.keys():
                serials_dict[k].append(mtx_dict[k][i, j])
                
    return pd.DataFrame(serials_dict)
    

def sort_cell_type_contact_p_values(p_val_mtx, cell_types):
    '''Return a list of (cell_type1, cell_type2, p_value) sorted by p_values.'''
    p_val_list = []
    N = p_val_mtx.shape[0]
    for i in range(N):
        for j in range(N):
            p_val_list.append((cell_types[i], cell_types[j], p_val_mtx[i, j]))
    return sorted(p_val_list, key=lambda x:x[2])

In [None]:
def _permute_and_get_contacts(i, adata_sub, r_permute, r_test, cell_types=None):
    np.random.seed()
    df_slide = adata_sub.obs.copy()
    r = r_permute * np.sqrt(np.random.uniform(size=df_slide.shape[0]))
    theta = np.random.uniform(size=df_slide.shape[0]) * 2 * np.pi
    df_slide['CENTER_X'] = df_slide['CENTER_X'] + r * np.cos(theta)
    df_slide['CENTER_Y'] = df_slide['CENTER_Y'] + r * np.sin(theta)
    contacts = get_cell_by_cell_contacts(df_slide, cell_type_col=cell_type_col_subclass, spatial_keys=['CENTER_X', 'CENTER_Y'], cell_type_list=cell_types, radius=r_test)
    return contacts[0]

def _get_contacts_for_slice(
    _slice,
    adata,
    Np = 1000,
    r_permute = 100,
    r_test = 15,
    alpha = 0.05,
    min_contacts = 50,
): 
    # parameters: 
    # adata: AnnData object containing all data
    # _slice: tuple of (donor, brain_region, replicate)
    # radius: radius for contact calculation
    # Np: number of permutations
    # r_permute: max radius for permutation
    #

    donor, brain_region, replicate = _slice
    adata_sub = adata[(adata.obs['donor'] == donor) &
                      (adata.obs['brain_region'] == brain_region) &
                      (adata.obs['replicate'] == replicate)].copy()
    adata_sub = adata_sub[(adata_sub.obs[cell_type_col_subclass] != "unknown")].copy()
    cell_types = np.unique(adata_sub.obs[cell_type_col_subclass])
    N_cell_types = len(cell_types)
    real_contacts, _ = get_cell_by_cell_contacts(adata_sub, cell_type_col=cell_type_col_subclass, spatial_keys=['CENTER_X', 'CENTER_Y'], cell_type_list=cell_types, radius=r_test)
    # merged_contact_counts = np.zeros((Np, N_cell_types, N_cell_types), dtype=int)
    
    with mp.Pool(mp.cpu_count()) as pool: 
        ret_list = pool.map(partial(_permute_and_get_contacts,
                                    adata_sub=adata_sub,
                                    r_permute=r_permute,
                                    r_test=r_test,
                                    cell_types=cell_types),
                             range(Np))
    merged_contact_counts = np.array(ret_list)
    # for i in range(Np): 
    #     contact = _permute_and_get_contacts(i, adata_sub, r_permute, r_test)
    #     merged_contact_counts[i, :, :] = contact

    null_contacts_mean = np.mean(merged_contact_counts, axis=0)
    null_contacts_std = np.std(merged_contact_counts, axis=0)

    np.save(Path(output_path) / f"contact_counts_real_{donor}_{brain_region}_{replicate}_{r_test}um.npy", real_contacts)
    np.save(Path(output_path) / f"contact_counts_permuted_{donor}_{brain_region}_{replicate}_{r_test}um.npy", merged_contact_counts)
    np.save(Path(output_path) / f"contact_counts_permuted_mean_{donor}_{brain_region}_{replicate}_{r_test}um.npy", null_contacts_mean)
    np.save(Path(output_path) / f"contact_counts_permuted_std_{donor}_{brain_region}_{replicate}_{r_test}um.npy", null_contacts_std)

    null_contacts_std = np.maximum(null_contacts_std, np.sqrt(1/1000))
    permuted_z_score = (real_contacts - null_contacts_mean) / null_contacts_std
    local_p_values = norm.sf(np.abs(permuted_z_score))
    adjusted_local_p_value = adjust_p_value_matrix_by_BH(local_p_values)
    fold_changes = real_contacts / (null_contacts_mean + 1e-6)
    # Gather all results into a data frame
    contact_result_df = get_data_frame_from_metrices(cell_types, 
                                                    {'pval-adjusted': adjusted_local_p_value,
                                                    'pval': local_p_values,
                                                    'z_score': permuted_z_score,
                                                    'contact_count': real_contacts,
                                                    'permutation_mean': null_contacts_mean,
                                                    'permutation_std': null_contacts_std,
                                                    'fold-change' : fold_changes,
                                            }).sort_values('z_score', ascending=False)
    contact_result_df['id'] = f"{donor}_{brain_region}_{replicate}"
    # contact_result_df = contact_result_df[contact_result_df['pval-adjusted'] < alpha]
    # contact_result_df = contact_result_df[contact_result_df['contact_count'] > min_contacts]
    contact_result_df.to_csv(Path(output_path) / f"cell_contacts_{donor}_{brain_region}_{replicate}_{r_test}um.csv", index=False)

    return (contact_result_df, real_contacts, merged_contact_counts, cell_types, donor, brain_region, replicate)
    # return contact_result_df

## Run

In [None]:
Np = 1000
r_permute = 200
r_test = 30
alpha = 0.05
min_contacts = 50
cell_type_col_subclass="Subclass"
cell_type_col_group="Group"

In [None]:
donors = adata.obs['donor'].unique().tolist()
brain_regions = adata.obs['brain_region'].unique().tolist()
replicates = adata.obs['replicate'].unique().tolist()
slices = list(itertools.product(donors, brain_regions, replicates))
slices.remove(("UWA7648", "CAT", "salk"))
slices.remove(("UWA7648", "CAT", "ucsd"))

In [None]:
_slice = slices[2]

In [None]:
donor, brain_region, replicate = _slice
adata_sub = adata[(adata.obs['donor'] == donor) &
                    (adata.obs['brain_region'] == brain_region) &
                    (adata.obs['replicate'] == replicate)].copy()
adata_sub = adata_sub[(adata_sub.obs[cell_type_col_subclass] != "unknown")].copy()
cell_types = np.unique(adata_sub.obs[cell_type_col_subclass])
N_cell_types = len(cell_types)

In [None]:
real_contacts, _ = get_cell_by_cell_contacts(adata_sub, cell_type_col=cell_type_col_subclass, spatial_keys=['CENTER_X', 'CENTER_Y'], cell_type_list=cell_types, radius=r_test)

In [None]:
with mp.Pool(mp.cpu_count()) as pool: 
    ret_list = pool.map(partial(_permute_and_get_contacts,
                                adata_sub=adata_sub,
                                r_permute=r_permute,
                                r_test=r_test,
                                cell_types=cell_types),
                            range(Np))
merged_contact_counts = np.array(ret_list)

In [None]:
null_contacts = merged_contact_counts
null_contacts.shape

In [None]:
output_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/sandbox"
function_args = {}
function_args['Np'] = Np
function_args['r_permute'] = r_permute
function_args['r_test'] = r_test
function_args['alpha'] = alpha
function_args['min_contacts'] = min_contacts
function_args['adata'] = adata

res = _get_contacts_for_slice(_slice, **function_args)
(contact_result_df, real_contacts, merged_contact_counts, cell_types, donor, brain_region, replicate) = res

# parallel_func = partial(_get_contacts_for_slice, **function_args)
# contacts_list = []
# for _slice in tqdm(slices): 
#     contacts_list.append(_get_contacts_for_slice(_slice, **function_args))

In [None]:
null_contacts = merged_contact_counts
null_contacts.shape

In [None]:
null_contacts.mean(axis=0).shape

In [None]:
_ct1 = "STR D1 MSN"
_ct2 = "CN ST18 GABA"
i = np.where(cell_types == _ct1)[0][0]
j = np.where(cell_types == _ct2)[0][0]
null_dist = null_contacts[:, i, j]
real_count = real_contacts[i, j]

fig, ax = plt.subplots(figsize=(4, 3))
hist = ax.hist(null_dist, bins=10, color='lightgrey', density=True)
ylim = int(min(hist[0])), int(max(hist[0])+1)
yrange = ylim[1] - ylim[0]
yticks = np.linspace(ylim[0], ylim[1], 5).round(0).astype(int)
xlim = (np.append(hist[1], real_count).min()), (np.append(hist[1], real_count).max())
xrange = xlim[1] - xlim[0]
xticks = np.linspace(xlim[0] - 0.1*xrange, xlim[1] + 0.1*xrange, 5).round(2)
xticks

ax.axvline(real_count, color='red', linestyle='--', label='Real Count')
ax.set_xlabel(f'Contact Counts', fontsize=14)
ax.set_ylabel('Density', fontsize=14)
ax.set_title(f'Example Contact Counts Distribution', fontsize=18)
ax.set_xticks(xticks, xticks, fontsize=12)
ax.set_yticks(yticks, yticks, fontsize=12)
# ax.legend(loc='upper right', bbox_to_anchor=(1.6, 1), fontsize=12)

plt.tight_layout()
plt.savefig(image_path / f'z_ccd_{_ct1}_{_ct2}_{donor}_{brain_region}_{replicate}_{r_test}um.png', bbox_inches='tight', dpi=300)
plt.savefig(image_path / f'z_ccd_{_ct1}_{_ct2}_{donor}_{brain_region}_{replicate}_{r_test}um.pdf', bbox_inches='tight', dpi=300)
plt.show()
plt.close()

In [None]:
# _donor, _brain_region, _replicate = "UCI5224", "PU", "salk"
# adata_sub = adata[(adata.obs['donor'] == _donor) &
#                   (adata.obs['brain_region'] == _brain_region) &
#                   (adata.obs['replicate'] == _replicate)].copy()

In [None]:
# level = "Subclass"
# adata_sub.obs[level] = adata_sub.obs[level].fillna("unknown")
# cell_types = np.unique(adata_sub.obs[level])
# cell_contacts = get_cell_by_cell_contacts(adata_sub, cell_type_col=level, spatial_keys=['CENTER_X', 'CENTER_Y'], cell_type_list=cell_types, radius=30)

In [None]:
# Np = 1000
# r_permute = 200
# N_cell_types = len(cell_types)
# merged_contact_counts = np.zeros((Np, N_cell_types, N_cell_types), dtype=int)
# for i in range(Np): 
#     df_slide = adata_sub.obs.copy()
#     r = r_permute * np.sqrt(np.random.uniform(size=df_slide.shape[0]))
#     theta = np.random.uniform(size=df_slide.shape[0]) * 2 * np.pi
#     df_slide['CENTER_X'] = df_slide['CENTER_X'] + r * np.cos(theta)
#     df_slide['CENTER_Y'] = df_slide['CENTER_Y'] + r * np.sin(theta)
#     contacts = get_cell_by_cell_contacts(df_slide, cell_type_col='Subclass', spatial_keys=['CENTER_X', 'CENTER_Y'], cell_type_list=cell_types, radius=30)
#     merged_contact_counts[i] = contacts[0]

In [None]:
# merged_contact_counts

In [None]:
null_contacts_std = np.maximum(null_contacts_std, np.sqrt(1/1000))
permuted_z_score = (real_contacts - null_contacts_mean) / null_contacts_std
local_p_values = norm.sf(np.abs(permuted_z_score))
adjusted_local_p_value = adjust_p_value_matrix_by_BH(local_p_values)
fold_changes = real_contacts / (null_contacts_mean + 1e-6)
# Gather all results into a data frame
contact_result_df = get_data_frame_from_metrices(cell_types, 
                                                 {'pval-adjusted': adjusted_local_p_value,
                                                  'pval': local_p_values,
                                                  'z_score': permuted_z_score,
                                                  'contact_count': real_contacts,
                                                  'permutation_mean': null_contacts_mean,
                                                  'permutation_std': null_contacts_std,
                                                  'fold-change' : fold_changes,
                                        }).sort_values('z_score', ascending=False)

In [None]:
contact_result_df = contact_result_df[contact_result_df['pval-adjusted'] < 0.05]
contact_result_df = contact_result_df[contact_result_df['contact_count'] > 50]
contact_result_df

In [None]:
_ct1 = "Astrocyte"
_ct2 = "STR D2 MSN"
i = np.where(cell_types == _ct1)[0][0]
j = np.where(cell_types == _ct2)[0][0]
null_dist = null_contacts[:, i, j]
real_count = real_contacts[i, j]

In [None]:
fig, ax = plt.subplots(figsize=(4, 4))
ax.hist(null_dist, bins=30, color='lightgrey', density=True)
ax.axvline(real_count, color='red', linestyle='--', label='Real Count')
ax.set_xlabel(f'Contact Counts between {_ct1} and {_ct2}')
ax.set_ylabel('Density')
ax.set_title(f'Contact Counts Distribution\n{_donor}, {_brain_region}, {_replicate}')
ax.legend(loc='upper right', bbox_to_anchor=(1.6, 1), fontsize=12)
plt.show()

In [None]:
r = 30
DIR = Path(f"/home/x-aklein2/projects/aklein/BICAN/BG/data/CPS/cell_contacts_{r}um")
_donor = "UWA7648"
_brain_region = "CAH"
_replicate = "ucsd"


real_contacts = np.load(DIR / f"contact_counts_real_{_donor}_{_brain_region}_{_replicate}_{r}um.npy")
null_contacts = np.load(DIR / f"contact_counts_permuted_{_donor}_{_brain_region}_{_replicate}_{r}um.npy")
null_contacts_mean = np.load(DIR / f"contact_counts_permuted_mean_{_donor}_{_brain_region}_{_replicate}_{r}um.npy")
null_contacts_std = np.load(DIR / f"contact_counts_permuted_std_{_donor}_{_brain_region}_{_replicate}_{r}um.npy")

In [None]:
# _ct1 = "Astrocyte"
# _ct2 = "STR D2 MSN"
# i = np.where(cell_types == _ct1)[0][0]
# j = np.where(cell_types == _ct2)[0][0]
i = 5
j = 15
null_dist = null_contacts[:, i, j]
real_count = real_contacts[i, j]

In [None]:
fig, ax = plt.subplots(figsize=(4, 4))
ax.hist(null_dist, bins=30, color='lightgrey', density=True)
ax.axvline(real_count, color='red', linestyle='--', label='Real Count')
ax.set_xlabel(f'Contact Counts between')# {_ct1} and {_ct2}')
ax.set_ylabel('Density')
ax.set_title(f'Contact Counts Distribution\n{_donor}, {_brain_region}, {_replicate}')
ax.legend(loc='upper right', bbox_to_anchor=(1.6, 1), fontsize=12)
plt.show()