In [None]:
import os
from pathlib import Path

import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad

import matplotlib.pyplot as plt
import seaborn as sns
from spida.pl import plot_categorical, plot_continuous, categorical_scatter
plt.rcParams['axes.facecolor'] = 'white'

In [None]:
adata_path = "/home/x-aklein2/projects/aklein/BICAN/BG/data/annotation/BICAN_BG_ALL/BG_pfv8_all.h5ad"
adata = ad.read_h5ad(adata_path)
adata

In [None]:
def plot_counts(
    adata, 
    groupby, 
    col
): 

    for _group in adata.obs[groupby].cat.categories: 
        print(_group)
        sub_adata = adata[adata.obs[groupby] == _group]
        print(sub_adata.obs[col].value_counts())
        print()


In [None]:
col = "nCount_RNA"
labs = adata.obs['replicate'].unique()
lab_palette = adata.uns['replicate_palette']
nrows = adata.obs["donor"].nunique()
ncols = adata.obs["brain_region"].nunique()

fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(4*ncols, 4*nrows), dpi=200, squeeze=False)
for i, _donor in enumerate(sorted(adata.obs["donor"].unique())):
    for j, _region in enumerate(sorted(adata.obs["brain_region"].unique())):
        ax = axes[i, j]
        for _lab in labs: 
            sub_adata = adata[(adata.obs["brain_region"] == _region) & (adata.obs["donor"] == _donor) & (adata.obs["replicate"] == _lab)]
            if sub_adata.n_obs == 0:
                ax.axis('off')
                continue
            sns.histplot(np.log1p(sub_adata.obs[col].values), color=lab_palette[_lab], stat="density", fill=True, element="step", alpha=0.5, edgecolor='k', ax=ax)
            ax.set_title(f"{_region} - {_donor}\n(n={sub_adata.n_obs})")
            ax.set_xlabel(col)
            ax.set_ylabel("Frequency")
            ax.axvline(x=np.median(np.log1p(sub_adata.obs[col].values)), color=lab_palette[_lab], linestyle='--', label=f"{_lab} Median")
            ax.legend()
        ax.set_xlim((2, 8))
    plt.tight_layout()

plt.show()

In [None]:
groupby = "dataset_id"
col = "nCount_RNA"
# palette = adata.uns[f"{groupby}_palette"]

In [None]:
    
fig, ax = plt.subplots(figsize=(8, 4), dpi=200, constrained_layout=True)

for _group in adata.obs[groupby].cat.categories: 
    print(_group)
    sub_adata = adata[adata.obs[groupby] == _group]
    sns.histplot(
        data=sub_adata.obs, 
        x=col, 
        color=palette[_group], 
        label=_group, 
        stat="density", 
        element="step", 
        log_scale=True,
        fill=True, 
        alpha=0.5,
        ax=ax
    )
ax.legend()
plt.show()