In [None]:
import os
from pathlib import Path

import math
import numpy as np
import pandas as pd
import anndata as ad

# import scanpy as sc
# import scipy.stats
# from statsmodels.stats.multitest import multipletests
# from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
import seaborn as sns
plt.rcParams['pdf.fonttype'] = 42
plt.rcParams['ps.fonttype'] = 42
plt.rcParams['figure.dpi'] = 300

from datetime import datetime 
current_datetime = datetime.now().strftime("%Y-%m-%d_%H:%M")
image_path = "/home/x-aklein2/projects/aklein/BICAN/BG/images"

np.random.seed(13)

In [None]:
adata_rna = ad.read_h5ad("/home/x-aklein2/projects/aklein/BICAN/data/reference/AIT/AIT_PU.h5ad", backed='r')
adata_rna

In [None]:
max_cells = 10000

In [None]:
levels = ["Class", "Subclass", "Group", "Cluster"]

### split

In [None]:
for _level in levels: 
    cell_types = adata_rna.obs[_level].unique().tolist()
    for _cell_type in cell_types:
        adata_sub = adata_rna[adata_rna.obs[_level] == _cell_type, :].copy()
        adata_sub.layers['counts'] = adata_sub.raw.X.copy()
        if adata_sub.n_obs > max_cells:
            adata_sub = adata_sub[adata_sub.obs.sample(max_cells).index, :].copy()
        expr = adata_sub.X.toarray()
        break
    break

In [None]:
X1, X2 = train_test_split(expr, test_size=0.5, random_state=42)

## Pearson Correlation 
Not working so far! 

In [None]:
X1_mean = np.mean(X1, axis=0)
X2_mean = np.mean(X2, axis=0)

In [None]:
thr_list = []
for _thr in np.arange(0, 5, 0.05): 
    mask = (X1_mean >= _thr) & (X2_mean >= _thr)
    thr_list.append((_thr, np.sum(mask)))
    # X1_mean_thr = X1_mean[mask]
    # X2_mean_thr = X2_mean[mask]


In [None]:
thr_list = []
for _thr in np.arange(0, 5, 0.05): 
    mask = (X1_mean >= _thr) & (X2_mean >= _thr)
    thr_list.append((_thr, np.sum(mask)))
    # X1_mean_thr = X1_mean[mask]
    # X2_mean_thr = X2_mean[mask]


fig, ax = plt.subplots(figsize=(4,3), dpi=150)
thr_df = pd.DataFrame(thr_list, columns=["Threshold", "Num_genes"])
sns.lineplot(data=thr_df, x="Threshold", y="Num_genes", ax=ax)
# ax.set_yscale("log")
ax.set_xlabel("Expression threshold")
ax.set_ylabel("Number of genes")
fig.tight_layout()
plt.show()

In [None]:
scipy.stats.pearsonr(X1_mean, X2_mean)

In [None]:
X1_sum = np.sum(X1, axis=0)
X2_sum = np.sum(X2, axis=0)
scipy.stats.pearsonr(X1_sum, X2_sum)

In [None]:
scipy.stats.t.sf()

In [None]:
thr_list = []
corr = []
for _thr in np.arange(0, 1000, 1): 
    mask = (X1_sum >= _thr) & (X2_sum >= _thr)
    thr_list.append((_thr, np.sum(mask)))
    X1_sum_thr = X1_sum[mask]
    X2_sum_thr = X2_sum[mask]
    pcorr, _ = scipy.stats.pearsonr(X1_sum_thr, X2_sum_thr)
    corr.append((pcorr, _thr))
thr_df = pd.DataFrame(thr_list, columns=["Threshold", "Num_genes"])
corr_df = pd.DataFrame(corr, columns=["Pearson_corr", "Threshold"])

fig, axes = plt.subplots(1, 2, figsize=(8,3), dpi=150)
ax = axes[0]
sns.lineplot(data=thr_df, x="Threshold", y="Num_genes", ax=ax)
# ax.set_yscale("log")
ax.set_xlabel("Expression threshold")
ax.set_ylabel("Number of genes")

ax = axes[1]
sns.lineplot(data=corr_df, x="Threshold", y="Pearson_corr", ax=ax)
# ax.set_yscale("log")
ax.set_ylim(0, 1)
ax.set_xlabel("Expression threshold")
ax.set_ylabel("Pearson correlation")
fig.tight_layout()
plt.show()

In [None]:
(X1_sum == 0).sum(), (X2_sum == 0).sum()

In [None]:
X2.shape

In [None]:
X1.shape

## PCA loadings


In [None]:
sc.experimental.pp.highly_variable_genes(
    adata_sub, n_top_genes=5000, layer='counts', subset=False,
)
n_hvgs = int(adata_sub.var['highly_variable'].sum())
print(f"Number of highly variable genes: {n_hvgs}")

In [None]:
adata_pca = adata_sub[:, adata_sub.var['highly_variable']].copy()
sc.pp.scale(adata_pca, max_value=10)
sc.tl.pca(adata_pca, n_comps=100, svd_solver='arpack')

In [None]:
explained = adata_pca.uns['pca']['variance_ratio']
cum_expl = np.cumsum(explained)
n_pc_80 = int((cum_expl < 0.8).sum() + 1)
n_pc_90 = int((cum_expl < 0.9).sum() + 1)
print(f"Number of PCs to explain 80% variance: {n_pc_80}")
print(f"Number of PCs to explain 90% variance: {n_pc_90}")

In [None]:
loadings = adata_pca.varm['PCs']
weights = explained[:loadings.shape[1]]
gene_scores = (loadings**2 * weights).sum(axis=1)

In [None]:
gene_df = pd.DataFrame({
    "gene" : adata_pca.var_names, 
    "gene_score": gene_scores
}).sort_values("gene_score", ascending=False)

In [None]:
top_frac = 0.10
k = max(1, int(top_frac * gene_df.shape[0]))
variable_genes_pca = gene_df.head(k)
print(f"[PCA approach] #genes in top {int(top_frac*100)}% by PCA variance score: {k}")

In [None]:
gene_df

## SCVI NB model

In [None]:
import numpy as np
import pandas as pd
import scvi
from anndata import AnnData
from typing import Optional, Dict, Any

def nb_within_cluster_variability(
    adata: AnnData,
    cluster_key: str,
    cluster_value: str,
    *,
    counts_layer: str = "counts",
    batch_key: Optional[str] = None,
    model: Optional[scvi.model.SCVI] = None,
    max_epochs: int = 100,
    n_latent: int = 10,
    top_frac: float = 0.10,
    robust_call: bool = True,
    seed: int = 0,
    return_residuals: bool = False,
    # New ↓↓↓
    max_cells_per_cluster: Optional[int] = None,
    stratify_by: Optional[str] = None,   # e.g., "batch" to preserve batch proportions
    min_cells_required: int = 20,        # sanity check for tiny clusters after downsampling
) -> Dict[str, Any]:
    """
    Estimate within-cluster gene variability under a Negative Binomial model using scVI,
    with optional downsampling of the target cluster.

    New parameters
    --------------
    max_cells_per_cluster : int or None
        If set and the cluster has more cells than this, downsample without replacement
        to this target size (optionally stratified by `stratify_by`).
    stratify_by : str or None
        Column in `adata.obs` used to stratify the downsampling (preserves proportions).
        Only applies within the target cluster.
    min_cells_required : int
        Raise an error if the final cluster size is below this threshold.
    """
    # -------------------------
    # Ensure counts layer
    # -------------------------
    if counts_layer not in adata.layers:
        adata.layers[counts_layer] = adata.X.copy()

    # -------------------------
    # Train or reuse scVI model
    # -------------------------
    if model is None:
        adata_nb = adata.copy()
        scvi.settings.seed = seed
        scvi.model.SCVI.setup_anndata(adata_nb, layer=counts_layer, batch_key=batch_key)
        model = scvi.model.SCVI(adata_nb, n_layers=2, n_latent=n_latent, gene_likelihood="nb")
        model.train(max_epochs=max_epochs, early_stopping=True, plan_kwargs={"lr": 1e-3})
    else:
        adata_nb = adata  # assume model already aligned to this AnnData

    # -------------------------
    # Subset to the cluster
    # -------------------------
    if cluster_key not in adata_nb.obs.columns:
        raise KeyError(f"cluster_key '{cluster_key}' not found in adata.obs")

    mask = (adata_nb.obs[cluster_key].astype(str) == str(cluster_value)).values
    if mask.sum() == 0:
        raise ValueError(f"No cells found for {cluster_key} == '{cluster_value}'")

    # Potential downsampling indices (within cluster)
    cluster_idx = np.where(mask)[0]
    rng = np.random.default_rng(seed)

    if max_cells_per_cluster is not None and len(cluster_idx) > max_cells_per_cluster:
        if stratify_by is None:
            # Simple random sample without replacement
            sampled_idx = rng.choice(cluster_idx, size=max_cells_per_cluster, replace=False)
        else:
            # Stratified sampling by adata.obs[stratify_by] within the cluster
            if stratify_by not in adata_nb.obs.columns:
                raise KeyError(f"stratify_by '{stratify_by}' not found in adata.obs")

            labels = adata_nb.obs[stratify_by].astype(str).values
            # counts per stratum in the cluster
            unique, counts = np.unique(labels[cluster_idx], return_counts=True)
            props = counts / counts.sum()
            # initial quota per stratum
            quotas = np.floor(props * max_cells_per_cluster).astype(int)
            # distribute any remainder
            remainder = max_cells_per_cluster - quotas.sum()
            if remainder > 0:
                # assign remainder to strata with largest fractional parts
                fracs = (props * max_cells_per_cluster) - np.floor(props * max_cells_per_cluster)
                order = np.argsort(-fracs)
                for j in order[:remainder]:
                    quotas[j] += 1

            # sample within each stratum
            sampled_idx_list = []
            for lab, q in zip(unique, quotas):
                stratum_idx = cluster_idx[labels[cluster_idx] == lab]
                if q > len(stratum_idx):
                    q = len(stratum_idx)  # safety, in case of rounding
                if q > 0:
                    sampled_idx_list.append(rng.choice(stratum_idx, size=q, replace=False))
            sampled_idx = np.concatenate(sampled_idx_list) if len(sampled_idx_list) else np.array([], dtype=int)

        final_idx = np.sort(sampled_idx)
    else:
        final_idx = cluster_idx

    if final_idx.size < min_cells_required:
        raise ValueError(
            f"Final cluster size ({final_idx.size}) is below min_cells_required={min_cells_required}."
        )

    adata_c = adata_nb[final_idx].copy()

    # Observed counts (cells x genes)
    X = adata_c.layers[counts_layer]
    if hasattr(X, "toarray"):
        X = X.toarray()
    X = X.astype(np.float32, copy=False)

    # -------------------------
    # NB parameters from scVI
    # -------------------------
    lik = model.get_likelihood_parameters(adata=adata_c)
    mu = lik["mu"]          # (cells x genes)
    theta = lik["theta"]    # (genes,)

    # -------------------------
    # Pearson residuals & per-gene variance
    # -------------------------
    den = np.sqrt(mu + (mu ** 2) / theta[None, :]) + 1e-8
    residuals = (X - mu) / den
    res_var = residuals.var(axis=0, ddof=1)

    gene_table = pd.DataFrame({
        "gene": adata_c.var_names,
        "residual_variance": res_var
    }).sort_values("residual_variance", ascending=False).reset_index(drop=True)

    # Top-fraction call
    n_genes = gene_table.shape[0]
    k = max(1, int(round(top_frac * n_genes)))
    top_fraction_genes = gene_table.head(k).copy()

    result: Dict[str, Any] = {
        "cluster_id": str(cluster_value),
        "gene_table": gene_table,
        "n_genes": n_genes,
        "top_fraction_genes": top_fraction_genes,
        "n_top_fraction": k,
        "final_n_cells": int(adata_c.n_obs),
        "downsampled": bool(final_idx.size < cluster_idx.size),
        "sampled_cell_indices": final_idx,  # indices into the original AnnData
        "stratify_by": stratify_by,
        "max_cells_per_cluster": max_cells_per_cluster,
    }

    if robust_call:
        med = float(np.median(res_var))
        mad = float(np.median(np.abs(res_var - med)) + 1e-8)
        thr = med + 3 * 1.4826 * mad
        thr = float(max(1.0, thr))  # NB baseline variance ~1
        robust_mask = res_var > thr
        robust_genes = gene_table.loc[robust_mask].copy()
        result.update({
            "robust_genes": robust_genes,
            "n_robust": robust_genes.shape[0],
            "robust_threshold": thr,
            "median_res_var": med,
            "mad_res_var": mad,
        })

    if return_residuals:
        result["residuals"] = residuals  # cells x genes (after any downsampling)

    return result


In [None]:
adata_rna.layers['counts'] = adata_rna.raw.X.copy()

In [None]:
adata_rna.obs[levels[0]].unique()[0]

In [None]:
# One-time scVI training (reuse model across clusters if you want)
# If you already have raw counts in adata.layers["counts"], you're set.
res = nb_within_cluster_variability(
    adata_rna,
    cluster_key=levels[0],
    cluster_value=adata_rna.obs[levels[0]].unique()[0],
    counts_layer="counts",   # or whatever your raw counts layer is
    batch_key=None,          # e.g., "batch" if present
    max_epochs=100,
    max_cells_per_cluster=2000,
    top_frac=0.10,
    robust_call=True,
)

print("Cluster:", res["cluster_id"])
print("# genes (top 10%):", res["n_top_fraction"])
print("# genes (robust):", res.get("n_robust", "n/a"))
res["gene_table"].head()


In [None]:
adata_sub.var['highly_variable'].sum()

In [None]:
adata_sub.raw.X

## Dispersion based approach

In [None]:
output_path = Path("/home/x-aklein2/projects/aklein/BICAN/BG/data/cluster_dispersion")
output_path.mkdir(parents=True, exist_ok=True)

In [None]:
from sklearn.svm import SVR
def calculate_hvf_svr(adata, max_cells = 50000, min_cells = 20): 
    if adata.n_obs > max_cells:
        adata = adata[adata.obs.sample(max_cells).index, :].copy()
    if adata.n_obs < min_cells:
        raise ValueError(f"Not enough cells ({adata.n_obs}) to calculate HVF.")
    expr = adata.X.toarray().copy()
    expr_mean = np.mean(expr, axis=0)
    expr_var = np.var(expr, axis=0, ddof=1)
    dispersion = expr_var / (expr_mean + 1e-8)
    log2_disp = np.log2(dispersion + 1e-8)
    log2_expr_mean = np.log2(expr_mean + 1e-8)
    X = np.vstack([log2_expr_mean]).T

    svr_gamma = 1000 / X.shape[0]
    svr = SVR(kernel='rbf', C=1.0, gamma=svr_gamma)
    svr.fit(X, log2_disp)

    score = log2_disp - svr.predict(X)

    hvf_df = pd.DataFrame({
        "gene": adata.var_names,
        "svr_score": score,
        "dispersion": dispersion
    }).sort_values("svr_score", ascending=False)

    return hvf_df

In [None]:
# hvf_df = calculate_hvf_svr(adata_rna)
# fout = output_path / f"hvf_svr_all.csv"
# hvf_df.to_csv(fout)

In [None]:
# for _level in levels:
#     for _cell_type in adata_rna.obs[_level].unique().tolist():
#         print(f"Level: {_level}, Cell type: {_cell_type}, #cells: {(adata_rna.obs[_level] == _cell_type).sum()}")
#         try: 
#             hvf_df = calculate_hvf_svr(adata_rna[adata_rna.obs[_level] == _cell_type, :])
#         except ValueError as e:
#             print(f"  Skipping {_cell_type} due to error: {e}")
#             continue
#         out_ct = _cell_type.replace("/", "_").replace(" ", "_")
#         fout = output_path / f"hvf_svr_{_level}_{out_ct}.csv"
#         hvf_df.to_csv(fout) 

In [None]:
### TODO: 
# - Load at the class, subclass, group, and cluster level 
# - For each gene look at the ratio of both its dispersion and its score to the dispersion / score from entire dataset calculations 
# - For each gene classify the difference between the new and original as its value on a min-max scaling with the min is 0 and the max is the original value 
# - (account for side cases where there was 0 dispersion in old and new!) 
# - For varying thresholds np.arange(0, 2, 0.05) get the count per cell type for genes that are above that said threshold
# - plot the histogram colored by level for those counts (different histogram for the different thresholds!). # 

In [None]:
# hvf_df_all = pd.read_csv(output_path / "hvf_svr_all.csv", index_col=0)
# hvf_df_all = hvf_df_all.set_index("gene")
# for _level in levels:
#     for _cell_type in adata_rna.obs[_level].unique().tolist():
#         out_ct = _cell_type.replace("/", "_").replace(" ", "_")
#         try: 
#             hvf_df = pd.read_csv(output_path / f"hvf_svr_{_level}_{out_ct}.csv", index_col=0)
#             hvf_df = hvf_df.set_index("gene")
#         except FileNotFoundError as e:
#             print(f"  Skipping {_cell_type} due to error: {e}")
#             continue
#         hvf_df = hvf_df.join(hvf_df_all, lsuffix="_sub", rsuffix="_all", how="inner")
#         hvf_df['dispersion_ratio'] = hvf_df['dispersion_sub'] / (hvf_df['dispersion_all'] + 1e-8)
#         hvf_df['score_ratio'] = hvf_df['svr_score_sub'] / (hvf_df['svr_score_all'] + 1e-8)
#         # min-max scaling of the ratios 
#         # hvf_df['dispersion_ratio_mm'] = (hvf_df['dispersion_ratio'] - hvf_df['dispersion_ratio'].min()) / (hvf_df['dispersion_ratio'].max() - hvf_df['dispersion_ratio'].min() + 1e-8)
#         # hvf_df['score_ratio_mm'] = (hvf_df['score_ratio'] - hvf_df['score_ratio'].min()) / (hvf_df['score_ratio'].max() - hvf_df['score_ratio'].min() + 1e-8)
#         out_ct = _cell_type.replace("/", "_").replace(" ", "_")
#         fout = output_path / f"hvf_svr_ratios_{_level}_{out_ct}.csv"
#         hvf_df.to_csv(fout)

In [None]:
all_levels = {}
for _level in levels:
    level_specific = {}
    for _cell_type in adata_rna.obs[_level].unique().tolist():
        out_ct = _cell_type.replace("/", "_").replace(" ", "_")
        try: 
            hvf_df = pd.read_csv(output_path / f"hvf_svr_ratios_{_level}_{out_ct}.csv", index_col=0)
        except FileNotFoundError as e:
            print(f"  Skipping {_cell_type} due to error: {e}")
            continue
        thr_list = []
        for thr in np.arange(0.05, 2.05, 0.05):
            count_disp = (hvf_df['dispersion_ratio'] > thr).sum()
            count_score = (hvf_df['score_ratio'] > thr).sum()
            thr_list.append((thr, count_disp, count_score))
        level_specific[_cell_type] = thr_list
    all_levels[_level] = level_specific

In [None]:
def plot_level_histograms_grid(data, mode='D'):
    """
    For each threshold, plot a subplot with 4 histograms (one for each level).
    mode: 'D' for dispersion, 'S' for score
    Plots density instead of frequency.
    """
    levels = list(data.keys())
    # Assume all levels have the same thresholds
    thresholds = set()
    for level in levels:
        for vals in data[level].values():
            thresholds.update([t for t, _, _ in vals])
    thresholds = sorted(thresholds)
    n = len(thresholds)
    ncols = math.ceil(math.sqrt(n))
    nrows = math.ceil(n / ncols)

    fig, axes = plt.subplots(nrows, ncols, figsize=(4*ncols, 4*nrows), squeeze=False)
    for idx, threshold in enumerate(thresholds):
        row, col = divmod(idx, ncols)
        ax = axes[row][col]
        for level in levels:
            cell_types = list(data[level].keys())
            values = []
            for ct in cell_types:
                for t, count_d, count_s in data[level][ct]:
                    if t == threshold:
                        values.append(count_d if mode == 'D' else count_s)
                        break
            ax.hist(values, bins=20, alpha=0.5, label=level, density=True)
        ax.set_title(f'Threshold {threshold:.2f}')
        ax.set_xlabel('Count')
        ax.set_ylabel('Density')
        ax.legend()
    # Hide unused subplots
    for idx in range(n, nrows * ncols):
        row, col = divmod(idx, ncols)
        fig.delaxes(axes[row][col])
    plt.tight_layout()
    plt.show()

# Example usage:
# plot_level_histograms_grid(all_levels, mode='D')  # For dispersion
# plot_level_histograms_grid(all_levels, mode='S')  # For score


In [None]:
def calculate_median_counts_per_threshold(data, mode='D'):
    """
    For each threshold, calculate the median count across cell types for each level.
    mode: 'D' for dispersion, 'S' for score
    Returns a DataFrame with thresholds as index and levels as columns.
    """
    levels = list(data.keys())
    # Assume all levels have the same thresholds
    thresholds = set()
    for level in levels:
        for vals in data[level].values():
            thresholds.update([t for t, _, _ in vals])
    thresholds = sorted(thresholds)
    n = len(thresholds)
    ncols = math.ceil(math.sqrt(n))
    nrows = math.ceil(n / ncols)

    thr_lists = []
    for idx, threshold in enumerate(thresholds):
        row, col = divmod(idx, ncols)
        medians = []
        for level in levels:
            cell_types = list(data[level].keys())
            values = []
            for ct in cell_types:
                for t, count_d, count_s in data[level][ct]:
                    if t == threshold:
                        values.append(count_d if mode == 'D' else count_s)
                        break
            medians.append(np.median(values))
        thr_lists.append((threshold, *medians))
    df_meds = pd.DataFrame(thr_lists, columns=["Threshold"] + levels).set_index("Threshold")
    return df_meds
        

In [None]:
df_meds_D = calculate_median_counts_per_threshold(all_levels, mode='D')  # For dispersion
df_meds_S = calculate_median_counts_per_threshold(all_levels, mode='S')  # For score

In [None]:
fig, ax = plt.subplots(figsize=(6,4), dpi=150)
sns.lineplot(data=df_meds_D, ax=ax)
ax.set_xlabel("Thresholded Dispersion Ratio")
ax.set_ylabel("Median count")
ax.set_title("Median Gene Counts Above Dispersion Ratio Threshold by Level")
fig.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(6,4), dpi=150)
sns.lineplot(data=df_meds_S, ax=ax)
ax.set_xlabel("Thresholded Score Ratio")
ax.set_ylabel("Median count")
ax.set_title("Median Gene Counts Above Score Ratio Threshold by Level")
fig.tight_layout()
plt.show()

### Specific Gene Examples

In [None]:
from rich import print as rprint

In [None]:
# Pick a lineage: 
class_vc = adata_rna.obs.loc[adata_rna.obs['Class'] == "CN LGE GABA", 'Subclass'].value_counts()
rprint(class_vc[class_vc > 0].index.tolist())
subclass_vc = adata_rna.obs.loc[adata_rna.obs['Subclass'] == "STR D2 MSN", 'Group'].value_counts()
rprint(subclass_vc[subclass_vc > 0].index.tolist())
group_vc = adata_rna.obs.loc[adata_rna.obs['Group'] == "STRd D2 Matrix MSN", 'Cluster'].value_counts()
rprint(group_vc[group_vc > 0])

In [None]:
def get_gene_expr_for_celltype(
    adata : ad.AnnData, 
    level : str | list[str],
    cell_type : str | list[str],
    gene : str | list[str] = None,
    layer : str = None,
    max_cells : int = 10000,
): 
    if isinstance(gene, str):
        gene = [gene]
    if isinstance(level, str):
        level = [level]
    if isinstance(cell_type, str):
        cell_type = [cell_type]
    if len(level) != len(cell_type):
        raise ValueError("Length of level and cell_type must be the same.")
    expr_list = []
    for l, ct in zip(level, cell_type):
        if l not in adata.obs.columns:
            raise KeyError(f"Level '{l}' not found in adata.obs")
        adata_sub = adata[adata.obs[l] == ct].to_memory().copy()
        print(f"Level: {l}, Cell type: {ct}, #cells: {adata_sub.shape[0]}")
        if adata_sub.shape[0] == 0:
            raise ValueError(f"No cells found for the specified level(s) and cell_type(s).")
        elif adata_sub.shape[0] > max_cells and gene is not None:
            adata_sub = adata_sub[adata_sub.obs.sample(max_cells).index, adata_sub.var_names.isin(gene)].copy()
        elif adata_sub.shape[0] > max_cells:
            adata_sub = adata_sub[adata_sub.obs.sample(max_cells).index, :].copy()
        elif gene is not None:
            adata_sub = adata_sub[:, adata_sub.var_names.isin(gene)].copy()
        
        if layer is None:
            expr = adata_sub.X.toarray()
        else: 
            expr = adata_sub.layers[layer]
        
        if hasattr(expr, "toarray"):
            expr = expr.toarray()
        expr_list.append(expr)
        del expr
    gene_list = adata_sub.var_names if gene is None else gene
    del adata_sub
    return expr_list, gene_list
    # expr_mean = np.mean(expr, axis=0)
    # expr_var = np.var(expr, axis=0, ddof=1)
    # dispersion = expr_var / (expr_mean + 1e-8)
    # df = pd.DataFrame({
    #     "gene": adata_sub.var_names,
    #     "mean_expr": expr_mean,
    #     "variance": expr_var,
    #     "dispersion": dispersion
    # }).set_index("gene")
    # return df

In [None]:
# expr = get_gene_expr_for_celltype(
#     adata_rna, 
#     level = "Subclass",
#     cell_type = "STR D2 MSN",
#     gene = "DRD2",
#     max_cells=10000,
# )

In [None]:
def get_topn_dispersion_genes(
    level: str | list[str], 
    cell_type: str | list[str],
    n: int = 100,
    data_path = Path("/home/x-aklein2/projects/aklein/BICAN/BG/data/cluster_dispersion")
):
    if isinstance(level, str):
        level = [level]
    if isinstance(cell_type, str):
        cell_type = [cell_type]
    if len(level) != len(cell_type):
        raise ValueError("Length of level and cell_type must be the same.")
    dfs = []
    for l, ct in zip(level, cell_type):
        out_ct = ct.replace("/", "_").replace(" ", "_")
        try: 
            hvf_df = pd.read_csv(data_path / f"hvf_svr_ratios_{l}_{out_ct}.csv", index_col=0)
            # hvf_df = hvf_df.set_index("gene")
        except FileNotFoundError as e:
            print(f"  Skipping {ct} due to error: {e}")
            continue
        hvf_df_top = hvf_df.sort_values("svr_score_sub", ascending=False).head(n)
        hvf_df_top['category'] = 'top'
        hvf_df_bottom = hvf_df.sort_values("svr_score_sub", ascending=True).head(n)
        hvf_df_bottom['category'] = 'bottom'
        hvf_df = pd.concat([hvf_df_top, hvf_df_bottom], axis=0)
        hvf_df = hvf_df.reset_index().copy()
        hvf_df['level'] = l
        hvf_df['cell_type'] = ct
        dfs.append(hvf_df)
    if len(dfs) == 0:
        raise ValueError("No dataframes to concatenate.")
    df_all = pd.concat(dfs, axis=0)
    return df_all

In [None]:
df = get_topn_dispersion_genes(
    level = ["Class", "Subclass", "Group"],
    cell_type = ["CN LGE GABA", "STR D2 MSN", "STRd D2 Matrix MSN"],
    n = 100,
)

In [None]:
df.loc[df['level'] == 'Class'].sort_values("svr_score_sub", ascending=False)

In [None]:
expr_list, var_names = get_gene_expr_for_celltype(
    adata_rna, 
    level = ["Class", "Subclass", "Group", "Cluster"],
    cell_type = ["CN LGE GABA", "STR D2 MSN", "STRd D2 Matrix MSN", "Human-84"],
    max_cells=50000, 
    gene=["GPC5", "LRRC7", "XIST", "CHRM3", "FTX"]
)

In [None]:
class_expr = expr_list[0]
subclass_expr = expr_list[1]
group_expr = expr_list[2]
cluster_expr = expr_list[3]

In [None]:
class_expr.max(axis=0), subclass_expr.max(axis=0), group_expr.max(axis=0), cluster_expr.max(axis=0)

In [None]:
for _id, _gene in enumerate(var_names):
    fig, ax = plt.subplots(figsize=(6,4), dpi=150)
    sns.histplot(class_expr[:, _id].flatten(), label="Class: CN LGE GABA", ax=ax, stat="probability", alpha=0.5, binwidth=0.1)
    sns.histplot(subclass_expr[:, _id].flatten(), label="Subclass: STR D2 MSN", ax=ax, stat="probability", alpha=0.5, binwidth=0.1)
    sns.histplot(group_expr[:, _id].flatten(), label="Group: STRd D2 Matrix MSN", ax=ax, stat="probability", alpha=0.5, binwidth=0.1)
    sns.histplot(cluster_expr[:, _id].flatten(), label="Cluster: Human-84", ax=ax, stat="probability", alpha=0.5, binwidth=0.1)
    ax.set_ylim((0, 1))
    ax.set_xlabel("Expression")
    ax.set_ylabel("Density")
    ax.set_title(f"Expression Distribution of {_gene} Across Cell Types")
    ax.legend()
    plt.show()

In [None]:
expr

## Call DEGs

In [None]:
indices = np.arange(adata_sub.shape[0])
x1_indices, x2_indices = train_test_split(indices, test_size=0.5, random_state=13)
x1_index = adata_sub.obs.iloc[x1_indices].index
x2_index = adata_sub.obs.iloc[x2_indices].index
adata_sub.obs.loc[x1_index, "split"] = "X1"
adata_sub.obs.loc[x2_index, "split"] = "X2"

In [None]:
from spida.utilities._degs import call_degs_by_celltype, summarize_deg_results, plot_deg_summary, plot_volcano
# def call_degs_by_celltype(
#     adata: ad.AnnData,
#     celltype_col: str,
#     layer: str | None = None,
#     min_cells: int = 10,
#     max_cells: int = 50000,
#     min_genes: int = 100,
#     logfc_threshold: float = 0.25,
#     pval_threshold: float = 0.05,
#     method: str = 'wilcoxon',
#     correction_method: str = 'benjamini-hochberg',
#     n_genes: int | None = None,
#     save_results: bool = True,
#     output_dir: str | None = None,
#     verbose: bool = False
# ) -> dict[str, pd.DataFrame]:


In [None]:
adata_sub = adata_sub[:, adata_sub.var['highly_variable']].copy()

In [None]:
results = call_degs_by_celltype(
    adata_sub, 
    celltype_col='split',
    layer=None,
    min_cells=10,
    max_cells=50000,
    min_genes=100,
    logfc_threshold=0.25,
    pval_threshold=0.05,
    method='wilcoxon',
    correction_method='benjamini-hochberg',
    n_genes=None,
    save_results=False,
    output_dir=None,
    verbose=True
)

In [None]:
results['X1']['significant'].sum()

In [None]:
# Show summary
summary_subclass = summarize_deg_results(results, top_n=10)
print("\nSubclass-level DEG Summary:")
print(summary_subclass)

In [None]:
sc.tl.rank_genes_groups(
    adata_sub, 
    groupby='split',
    groups=[celltype],
    reference='others',
    method=method,
    use_raw=False,
    layer=layer,
    n_genes=adata_ds.n_vars if n_genes is None else n_genes,
    tie_correct=True
)

In [None]:
sc.tl.rank_genes_groups(
    adata_ds,
    groupby='temp_groups',
    groups=[celltype],
    reference='others',
    method=method,
    use_raw=False,
    layer=layer,
    n_genes=adata_ds.n_vars if n_genes is None else n_genes,
    tie_correct=True
)

# Extract results
result = sc.get.rank_genes_groups_df(
    adata_ds, 
    group=celltype,
    pval_cutoff=1.0,  # Get all genes, filter later
    log2fc_min=None   # Get all genes, filter later
)

# Add multiple testing correction
if len(result) > 0:
    if correction_method == 'benjamini-hochberg':
        rejected, pvals_corrected, _, _ = multipletests(
            result['pvals'], 
            method='fdr_bh'
        )
        result['pvals_adj'] = pvals_corrected
    elif correction_method == 'bonferroni':
        rejected, pvals_corrected, _, _ = multipletests(
            result['pvals'], 
            method='bonferroni'
        )
        result['pvals_adj'] = pvals_corrected
    else:
        result['pvals_adj'] = result['pvals']
    
    # Add significance flags
    result['significant'] = (
        (result['pvals_adj'] < pval_threshold) & 
        (abs(result['logfoldchanges']) > logfc_threshold)
    )
    
    # Sort by adjusted p-value and log fold change
    result = result.sort_values(['significant', 'pvals_adj', 'logfoldchanges'], 
                                ascending=[False, True, False])
    
    # Add cell type information
    result['cell_type'] = celltype
    result['comparison'] = f"{celltype}_vs_others"
    
    deg_results[celltype] = result
    
    n_sig = result['significant'].sum()

In [None]:
adata_sub