This is a notebook to plot gene expression patterns across brain regions for known markers. i.e. DRD1 / DRD2 expression by cell types for all brain regions. 

markers include: 
- BCAS1 - Oligos
- S100B - Astro
- PENK - MSNs 
- GAD1 / GAD2 - GABA neurons
- PVALB
- TAC3 
- ST18
- LAMP5
- LSAMP
- GATA1/3

- P2RY12, CX3CR1 - Microglia

In [None]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
import anndata as ad
from spida.utilities._ad_utils import normalize_adata

import matplotlib.pyplot as plt
import seaborn as sns


plt.rcParams['figure.dpi'] = 300
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['figure.figsize'] = (4, 4)
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['font.family'] = 'sans-serif'
plt.rcParams['font.sans-serif'] = 'Arial'
plt.rcParams['axes.facecolor'] = 'white'

In [None]:
# ad_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_PFV8_annotated_v5.h5ad"
ad_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/BICAN_BG_CPS.h5ad"
adata = ad.read_h5ad(ad_path)
adata

In [None]:
def grouped_boxplot(data: pd.DataFrame, 
                   x: str, 
                   y: str, 
                   hue: str | None = None,
                   ax: plt.Axes | None = None,
                   colors: list[str] | None = None,
                   palette: str | dict | None = None,
                   labels: list[str] | None = None,
                   title: str | None = None,
                   xlabel: str | None = None,
                   ylabel: str | None = None,
                   legend: bool = True,
                   box_width: float = 0.6,
                   spacing: float = 0.8,
                   **boxplot_kwargs) -> plt.Axes:
    """
    Create a grouped boxplot with optional hue variable using matplotlib's boxplot.
    
    Parameters:
    -----------
    data : pd.DataFrame
        The input dataframe containing the data
    x : str
        Column name for the groupby variable (x-axis groups)
    y : str
        Column name for the values to plot (y-axis)
    hue : str, optional
        Column name for the hue variable (subgroups within each x group)
    ax : plt.Axes, optional
        Matplotlib axes object. If None, creates a new figure
    colors : list, optional
        List of colors for the hue groups. If None, uses default colors
    palette : str or dict, optional
        Color palette for the hue groups. If None, uses default seaborn palette
    labels : list, optional
        Custom labels for x-axis groups. If None, uses unique values from x column
    title : str, optional
        Plot title
    xlabel : str, optional
        X-axis label. If None, uses the x column name
    ylabel : str, optional
        Y-axis label. If None, uses the y column name
    legend : bool, default True
        Whether to show legend (only applies when hue is used)
    box_width : float, default 0.6
        Width of individual boxes
    spacing : float, default 0.8
        Spacing between different x groups
    **boxplot_kwargs
        Additional keyword arguments passed to matplotlib's boxplot
    
    Returns:
    --------
    plt.Axes
        The matplotlib axes object
    """
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 6))
    
    # Get unique groups
    x_groups = data[x].unique()
    
    if labels is None:
        labels = [str(group) for group in x_groups]
    
    if hue is None:
        # Simple boxplot without hue
        box_data = [data[data[x] == group][y].dropna().values for group in x_groups]
        
        bp = ax.boxplot(box_data, 
                       labels=labels,
                       widths=box_width,
                       patch_artist=True,
                       **boxplot_kwargs)
        
        # Color the boxes
        if colors is not None and len(colors) >= len(x_groups):
            for patch, color in zip(bp['boxes'], colors):
                patch.set_facecolor(color)
        else:
            # Use default color
            for patch in bp['boxes']:
                patch.set_facecolor('lightblue')

        # Set x-axis labels
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
                
    else:
        # Grouped boxplot with hue
        hue_groups = data[hue].unique()
        n_hue = len(hue_groups)
        
        if colors is None:
            if palette is not None:
                if isinstance(palette, str):
                    colors = sns.color_palette(palette, n_hue)
                elif isinstance(palette, dict):
                    colors = [palette.get(hue_group, 'gray') for hue_group in hue_groups]
                else:
                    raise ValueError("Palette must be a string or a dictionary.")
            else: 
                # Generate default colors
                colors = plt.cm.Set2(np.linspace(0, 1, n_hue))
        
        # Calculate positions for boxes
        positions = []
        all_box_data = []
        all_colors = []
        
        for i, x_group in enumerate(x_groups):
            x_data = data[data[x] == x_group]
            
            # Calculate positions for this x group
            group_positions = []
            if n_hue == 1:
                group_positions = [i * spacing + 1]
            else:
                # Center the hue groups around the main x position
                center = i * spacing + 1
                hue_spacing = box_width * 1.2  # Slight spacing between hue boxes
                total_width = (n_hue - 1) * hue_spacing
                start_pos = center - total_width / 2
                group_positions = [start_pos + j * hue_spacing for j in range(n_hue)]
            
            for j, hue_group in enumerate(hue_groups):
                hue_data = x_data[x_data[hue] == hue_group][y].dropna().values
                if len(hue_data) > 0:  # Only add if there's data
                    all_box_data.append(hue_data)
                    positions.append(group_positions[j])
                    all_colors.append(colors[j])
        
        # Create the boxplot
        bp = ax.boxplot(all_box_data,
                       positions=positions,
                       widths=box_width,
                       patch_artist=True,
                       **boxplot_kwargs)
        
        # Color the boxes
        for patch, color in zip(bp['boxes'], all_colors):
            patch.set_facecolor(color)
        
        # Set x-axis labels
        x_positions = [i * spacing + 1 for i in range(len(x_groups))]
        ax.set_xticks(x_positions)
        ax.set_xticklabels(labels, rotation=45, ha='right')
        
        # Add legend
        if legend:
            legend_elements = [plt.Rectangle((0, 0), 1, 1, facecolor=colors[i], 
                                          label=str(hue_group)) 
                             for i, hue_group in enumerate(hue_groups)]
            ax.legend(handles=legend_elements, title=hue, loc='upper left', bbox_to_anchor=(1, 1))
    
    # Set labels and title
    if xlabel is None:
        xlabel = x
    if ylabel is None:
        ylabel = y
        
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    
    if title is not None:
        ax.set_title(title)
    
    # Improve layout
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    
    return ax

In [None]:
def _plot_boxplot(adata, _gene, groupby='brain_region', hue='Subclass', **kwargs): 
    df = pd.DataFrame({
        'group': adata.obs[groupby] if groupby is not None else None,
        'hue' : adata.obs[hue] if hue is not None else None,
        'value': adata[:, _gene].X.toarray().flatten()
    })
    h = "hue" if hue is not None else None
    x = "group" if groupby is not None else None
    ax = grouped_boxplot(df, x=x, y='value', hue=h, title=_gene, xlabel=groupby, ylabel='Expression', **kwargs)

In [None]:
"EBF1" in adata.var_names

In [None]:
_plot_boxplot(adata, "EBF1", groupby="Subclass", hue=None, palette=adata.uns['Subclass_palette'], showfliers=False)

In [None]:
_plot_boxplot(adata[adata.obs['brain_region'] == "MGM1"], "CX3CR1", palette=adata.uns['Subclass_palette'], showfliers=False)

In [None]:
_plot_boxplot(adata, "N", groupby="Subclass", hue=None, palette=adata.uns['Subclass_palette'], showfliers=False)