# Using the data from prev scRNAseq datasets to do probe selection

Start by loading prev scRNAseq datasets into the spapros framework using [this tutorial](https://spapros.readthedocs.io/en/latest/_tutorials/spapros_tutorial_basic_selection.html)

## Imports

In [None]:
import scanpy as sc
import spapros as sp
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from sklearn.linear_model import LogisticRegression
from pathlib import Path
import iss_analysis.io as io
import iss_analysis.pick_genes as pick
from scipy.sparse import issparse, csr_matrix
from scipy.stats import entropy
import utils as ut

from abc_atlas_access.abc_atlas_cache.abc_project_cache import AbcProjectCache


## Load reference train dataset

We take the raw count data that we generated in the MOs panel notebook. For this, we use the tools we built to use Petr's panel (a bit bad that we need to import the whole iss_analysis for using these simple functions). We need to filter for the main classes of neurons we care about, plus region of interest and neurotransmitters. 

Without the filtering by class, we get:

| Code | Cell type            |
|------|----------------------|
| 01   | IT-ET Glut           |
| 02   | NP-CT-L6b Glut       |
| 03   | OB-CR Glut           |
| 05   | OB-IMN GABA          |
| 06   | CTX-CGE GABA         |
| 07   | CTX-MGE GABA         |
| 08   | CNU-MGE GABA         |
| 09   | CNU-LGE GABA         |
| 30   | Astro-Epen           |
| 32   | OEC                  |


In [None]:
download_base = Path('/nemo/lab/znamenskiyp/home/shared/projects/colasa_MOs_panel')
abc_cache = AbcProjectCache.from_cache_dir(download_base)

abc_cache.current_manifest

In [None]:


reference = io.main_yao_2023(abc_cache, 
                  'WMB-10Xv3', 
                  ['WMB-10Xv3-Isocortex-2', 'WMB-10Xv3-Isocortex-1'], 
                  download_base,
                  taxa_subclass = [
                      
                    "001 CLA-EPd-CTX Car3 Glut",
                    "002 IT EP-CLA Glut",
                    "003 L5/6 IT TPE-ENT Glut",
                    "004 L6 IT CTX Glut",
                    "005 L5 IT CTX Glut",
                    "006 L4/5 IT CTX Glut",
                    "007 L2/3 IT CTX Glut",
                    "020 L2/3 IT RSP Glut",
                    "022 L5 ET CTX Glut",
                    "029 L6b CTX Glut",
                    "030 L6 CT CTX Glut",
                    "032 L5 NP CTX Glut",
                    "046 Vip Gaba",
                    "049 Lamp5 Gaba",
                    "050 Lamp5 Lhx6 Gaba",
                    "047 Sncg Gaba",
                    "053 Sst Gaba",
                    "056 Sst Chodl Gaba",
                    "051 Pvalb chandelier Gaba",
                    "052 Pvalb Gaba"
                    ],
                  region_of_interest = 'MO-FRP', 
                  neurotransmitters = ['GABA', 'Glut'], 
                  extract_csv = False)

And now we attempt to run the software with predicting subclass as a target. There is very little CNU-LGE GABA in the isocortex as per Yao 2023, but hey

In [None]:
set(reference.obs['class'])

In [None]:
set(reference.obs['subclass'])

In [None]:
# Count observations per subclass
subclass_counts = reference.obs['subclass'].value_counts().sort_values(ascending=False)

# Make a dataframe
df_subclass = pd.DataFrame({
    "subclass": subclass_counts.index,
    "count": subclass_counts.values
})

# Add a row for the total
df_subclass.loc[len(df_subclass)] = ["TOTAL", df_subclass["count"].sum()]

# Show the dataframe
print(df_subclass)

In [None]:
# total reads per transcript
reads_per_transcript = np.array(reference.X.sum(axis=0)).flatten()

df_transcripts = pd.DataFrame({
    "gene": reference.var_names,
    "total_reads": reads_per_transcript
}).sort_values("total_reads", ascending=False)

In [None]:
plt.hist(df_transcripts["total_reads"], bins=200, log=True)

In [None]:
# total reads per cell
reads_per_cell = np.array(reference.X.sum(axis=1))

reference.obs["total_reads"] = reads_per_cell



In [None]:
reference.obs.total_reads

In [None]:
plt.hist(reference.obs.total_reads, bins=200, log=True)
plt.xlabel("Reads per cell")
plt.ylabel("Number of cells (log scale)")
plt.title("Distribution of reads per cell")

# Force scientific offset notation (1, 2, 3, ... × 1e6)
formatter = plt.ScalarFormatter(useMathText=True)
formatter.set_scientific(True)
formatter.set_powerlimits((0, 0))  # always show offset
plt.gca().xaxis.set_major_formatter(formatter)

plt.show()


And we process it like they do. Select only the first 100 genes for efficiency

In [None]:
sc.pp.normalize_total(reference)
sc.pp.log1p(reference)
sc.pp.highly_variable_genes(reference,flavor="cell_ranger",n_top_genes=1000)


## Run selector

Run vanilla selection

In [None]:
selector = sp.se.ProbesetSelector(reference, n=100, celltype_key="subclass", verbosity=1, n_jobs=-1, save_dir="/nemo/lab/znamenskiyp/home/shared/projects/colasa_MOs_panel/panels/spapros")


In [None]:
selector.select_probeset()

In [None]:
selector.probeset[selector.probeset["selection"]]


In [None]:

lookup = pd.DataFrame({
    "gene_symbol": reference.var['gene_symbol'],
    "feature_id": reference.var.index
})

selected = selector.probeset[selector.probeset["selection"]].copy()

selected = selected.merge(lookup, 
                          left_index=True,
                          right_on="feature_id",
                          how="left")

selected

In [None]:
pd.options.display.max_rows = 100  # control rows
pd.options.display.max_columns = 10  # control cols

## Compare to the greedy algorithm

We need to take this probeset and compare it to the greedy algorithm using the spapros metrics and, if need be, the Znamenskiy metrics. Let's run it again with the new filters. 

df, gene_names = io.main_yao_2023(abc_cache, 
                  'WMB-10Xv3', 
                  ['WMB-10Xv3-Isocortex-2', 'WMB-10Xv3-Isocortex-1'], 
                  download_base,
                  taxa_class = ['01 IT-ET Glut', '06 CTX-CGE GABA', '07 CTX-MGE GABA',  '08 CNU-MGE GABA', '09 CNU-LGE GABA'],
                  region_of_interest = 'MO-FRP', 
                  neurotransmitters = ['GABA', 'Glut'], 
                  extract_csv = True)

In [None]:
pick.main(
    '/nemo/lab/znamenskiyp/home/shared/projects/colasa_MOs_panel/panels/greedy_panels',
    efficiency=0.01,
    datapath='/nemo/lab/znamenskiyp/home/shared/projects/colasa_MOs_panel',
    subsample=1,
    classify="subclass",
    dataset = 'yao_2023'
)

## Evaluate

We now have two alternatives. Now, we need to load them both as probesets in the spapros framework and compare. Ideally, we would also build the DE-only pipeline. 

In order to do this, we need:

- The training dataset loaded (10xv3)
- The two probesets loaded

In [None]:

selector = sp.se.ProbesetSelector(reference, n=100, celltype_key="subclass", verbosity=1, n_jobs=-1, save_dir="/nemo/lab/znamenskiyp/home/shared/projects/colasa_MOs_panel/panels/spapros")
lookup = pd.DataFrame({
    "gene_symbol": reference.var['gene_symbol'],
    "feature_id": reference.var.index
})

selected = selector.probeset[selector.probeset["selection"]].copy()

selected = selected.merge(lookup, 
                          left_index=True,
                          right_on="feature_id",
                          how="left")

spapros_panel = list(selected.index)

spapros_panel

In [None]:
znam = np.load('/nemo/lab/znamenskiyp/home/shared/projects/colasa_MOs_panel/panels/greedy_panels/greedy_panelsgenes_subclass_e0.01_s1_20250919_165916.npz', allow_pickle=True)
panel = list(znam['gene_names'][znam['include_genes'] == True])
panel = pd.DataFrame(panel)

lookup = pd.DataFrame({
    "gene_symbol": reference.var['gene_symbol'],
    "feature_id": reference.var.index
})

# Build dict: gene_symbol → list of feature_ids
#SOME GENE SYMBOLS HAVE MORE THAN ONE EMSEBBL ENTRY BUT NOT OURS
multi_mapping = lookup.groupby("gene_symbol")["feature_id"].apply(list).to_dict()

# Map panel values, may return lists
panel_feature_ids = panel[0].map(multi_mapping)

petr_panel = list(panel_feature_ids)



In [None]:
petr_str = [str(x[0]) for x in petr_panel]


In [None]:
spapros_panel

And compare:

In [None]:
reference_sets = sp.se.select_reference_probesets(reference, n = 100, obs_key = "subclass")

In [None]:
evaluator = sp.ev.ProbesetEvaluator(reference, verbosity=2, celltype_key = 'subclass', results_dir="/nemo/lab/znamenskiyp/home/shared/projects/colasa_MOs_panel/panels/evaluation")


This step is actually kinda long, maybe because I have no gpu?

In [None]:
for set_id, df in reference_sets.items():
    gene_set = df[df["selection"]].index.to_list()
    evaluator.evaluate_probeset(gene_set, set_id=set_id)


In [None]:
set_id = "greedy_panel"
evaluator.evaluate_probeset(petr_panel, set_id=set_id)
set_id = "spapros_panel"
evaluator.evaluate_probeset(spapros_panel, set_id=set_id)

## Evaluate in test data

In [None]:
reference_test = io.main_yao_2023(abc_cache, 
                  'WMB-10Xv2', 
                  ['WMB-10Xv2-Isocortex-3', 'WMB-10Xv2-Isocortex-1', 'WMB-10Xv2-Isocortex-4', 'WMB-10Xv2-Isocortex-2'], 
                  download_base,
                  taxa_class = ['01 IT-ET Glut', '06 CTX-CGE GABA', '07 CTX-MGE GABA',  '08 CNU-MGE GABA', '09 CNU-LGE GABA'],
                  region_of_interest = 'MO-FRP', 
                  neurotransmitters = ['GABA', 'Glut'], 
                  extract_csv = False)

sc.pp.filter_genes(reference_test, min_cells=1)
sc.pp.normalize_total(reference_test)
sc.pp.log1p(reference_test)
sc.pp.highly_variable_genes(reference_test,flavor="cell_ranger",n_top_genes=1000)

## Running Petr evaluation

In [None]:
exons_df, gene_names = io.read_yao_2023("/nemo/lab/znamenskiyp/home/shared/projects/colasa_MOs_panel/")
train_set, tesset, cluster_labels = pick.train_test_split(exons_df, classify_by = "subclass", gene_filter=r"\d", efficiency=0.01)

In [None]:
znam = np.load('/nemo/lab/znamenskiyp/home/shared/projects/colasa_MOs_panel/panels/greedy_panels/greedy_panelsgenes_subclass_e0.01_s1_20250919_165916.npz', allow_pickle=True)
gene_set = list(znam['gene_names'][znam['include_genes'] == True])
gene_set

In [None]:
accuracy_train, accuracy_test = pick.evaluate_gene_set(train_set, 
    tesset, 
    gene_set, 
    gene_names)

## Running expression constraints

In [None]:
# loading the yao2021 reference

ALM_subclasses = [
    "Vip",
    "Lamp5",
    "Scng",
    "Sst Chodl",
    "Sst",
    "Pvalb",
    "L2/3 IT CTX",
    "L4/5 IT CTX",
    "L5 IT CTX",
    "L6 IT CTX",
    "L5 PT CTX",
    "L4 RSP-ACA",
    "L5/6 NP CTX",
    "L6 CT CTX",
    "L6b CTX",
]


datapath = Path('/nemo/lab/znamenskiyp/home/shared/projects/colasa_MOs_panel/yao_2021')

reference = io.load_yao_2021_to_anndata(datapath, 'ALM', ALM_subclasses)

To compare the expression thresholds on our expression matrix to the expression thresholds on the Allen matrix, we need to have them on the same units. The average expression on the Allen dataset is the trimmed mean (percentiles 25-75) of the log2(CPM+1), where CPM is a normalised counts per cell *10e6. In the training set for spapros, data is normalised per cell to add up to the median expression level in the dataset, and then taken as loge(data), where loge is the natural logarithm.

$ allen_{i, j} = \log_2(\frac{read_{i, j}}{read_j}*10^6+1) $

Therefore:

$ \frac{read_{i, j}}{read_j} = \frac{2^{\text{allen}_{i,j}} - 1}{10^6} $


Where $read_{i, j}$ are the reads for gene i in cell j, and $read_j$ are the total reads in cell j

We want to choose two expression constraints in this space, so we will chose 4 as a low constrain (Petr said it once) and 12 as a high constraint (Sst, which we can barely see)

In [None]:
low_threshold_allen = 4
high_threshold_allen = 11

And we make them into thresholds for spapros

$ spapros_{i, j} = \log_{e}(\frac{read_{i, j}}{read_j} * median +1) $

In [None]:
def allen_to_spapros(allen, spapros_median):
    # get the ratio per cell
    ratio = (np.power( 2, allen) - 1)/10e6

    #turn it into spapros
    spapros = np.log1p(ratio*spapros_median) #what a funny little function

    return spapros

def get_median_expression(adata):
    # Add counts up
    cell_sums = np.array(adata.X.sum(axis=0)).flatten()
    plt.hist(cell_sums, bins = 200, log=True)
    plt.title('Total gene counts per cell')
    plt.xlabel('Total gene counts')
    plt.ylabel('#cells')
    median = np.median(cell_sums)

    return median

In [None]:
median = get_median_expression(reference)
print(f'median cell expression is {median}')

In [None]:
low_threshold_spapros = allen_to_spapros(low_threshold_allen, median)
high_threshold_spapros = allen_to_spapros(high_threshold_allen, median)

print(f'{low_threshold_spapros}, {high_threshold_spapros}')

In [None]:
reference_filtered = reference.copy()

sc.pp.filter_genes(reference_filtered, min_cells=1)
sc.pp.normalize_total(reference_filtered)
sc.pp.log1p(reference_filtered)


In [None]:
median = get_median_expression(reference_filtered)
print(f'median log1p expression is {median}')

The percentiles are probably screwing w us. Maybe just look at the number of counts of somatostatin in the somatostatin cluster?

In [None]:
gene_idx = reference.var_names.get_loc('Sst')
print(gene_idx)
gene_expression = reference.X[:,gene_idx].toarray().flatten()

sst_mask = reference.obs['subclass_label'] == 'Sst'
print(sum(sst_mask))
gene_expression = gene_expression[sst_mask]

In [None]:
gene_idx = reference.var_names.get_loc('Sst')
print(gene_idx)
gene_expression = reference.X[:,gene_idx].toarray().flatten()
print(len(gene_expression))

#sst_mask = reference.obs['subclass_label'] == 'Sst'
#print(sum(sst_mask))
#gene_expression = gene_expression[sst_mask]

plt.hist(gene_expression, bins = 200, log = True)
plt.title('Expression levels of Sst globally (counts)')
print(np.percentile(gene_expression, 99))
print(len(gene_expression[gene_expression>np.percentile(gene_expression, 99)]))
[p_25, p_75] = [np.percentile(gene_expression, q) for q in [25, 75]]
print([p_25, p_75])
for q in [p_25, p_75]:
    plt.axvline(q, color = 'red') 
plt.axvline(np.percentile(gene_expression, 99), color='green')

plt.show()


Maybe helpful?

> EDIT: I was able to get closer to the values in `trimmed_means.csv` by taking the `log2(CPM + 1)` of all cells, then excluding the top and bottom 25%, and taking the mean of the remaining values. As an example, for a single cluster and a single gene, the `trimmed_means.csv` value is **6.651905**, and I’m getting **6.652420762158425**.


In [None]:
gene_expression

In [None]:
trimmed_expression = gene_expression[(gene_expression>p_25)&(gene_expression<p_75)]
trimmed_mean = np.mean(trimmed_expression)

In [None]:
trimmed_mean

In [None]:
# After we normalise, the trimmed mean expression of Sst in Sst neurons: 

gene_idx = reference_filtered.var_names.get_loc('Sst')
print(gene_idx)
gene_expression = reference_filtered.X[:,gene_idx].toarray().flatten()

sst_mask = reference_filtered.obs['subclass_label'] == 'Sst'
print(sum(sst_mask))
gene_expression = gene_expression[sst_mask]

plt.hist(gene_expression, bins = 200, log = True)
plt.title('Expression levels of Sst in Sst cells (counts)')
[p_25, p_75] = [np.percentile(gene_expression, q) for q in [25, 75]]
print([p_25, p_75])
for q in [p_25, p_75]:
    plt.axvline(q, color = 'red') 

plt.show()

trimmed_expression = gene_expression[(gene_expression>p_25)&(gene_expression<p_75)]
trimmed_mean = np.mean(trimmed_expression)
trimmed_mean

Okay, so: 

- For each gene, find the expression distribution at the cluster level, find the cluster with the highest trimmed avg expression, have that for sst as the threshold, define for the rest

In [None]:

X = reference.X
if issparse(X) and not isinstance(X, csr_matrix):
    X = X.tocsr()

clusters = reference.obs['cluster_label'].to_numpy()
uniq_clusters, inverse = np.unique(clusters, return_inverse=True)

n_genes = reference.n_vars
max_expression = np.full(n_genes, -np.inf, dtype=float)

for k, cl in enumerate(uniq_clusters):
    # rows for this cluster
    rows = np.nonzero(inverse == k)[0]
    if rows.size == 0:
        continue

    # submatrix (cells_in_cluster x genes), dense for vectorized percentile ops
    Xc = X[rows, :].toarray() if issparse(X) else np.asarray(X[rows, :])

    # per-gene quartiles (vectorized across genes)
    p25, p75 = np.percentile(Xc, [25, 75], axis=0)

    # trim per gene
    mask = (Xc > p25) & (Xc < p75)            # shape: (cells_in_cluster, n_genes)
    count_trim = mask.sum(axis=0)             # per-gene counts after trimming
    sum_trim = (Xc * mask).sum(axis=0)        # per-gene sums after trimming

    # avoid division-by-zero when all values fall outside (rare but possible)
    trimmed_mean = np.divide(
        sum_trim, count_trim,
        out=np.zeros_like(sum_trim, dtype=float),
        where=count_trim > 0
    )

    # update running max across clusters
    np.maximum(max_expression, trimmed_mean, out=max_expression)

reference.var['expression_max'] = max_expression

In [None]:
plt.hist(reference.var['expression_max'], bins = 200, log = True)

In [None]:
reference.var[reference.var_names=='Sst']['expression_max']

In [None]:
sst_mask = reference.obs['subclass_label'] == 'Sst'
sst_clusters = reference.obs['cluster_label'][sst_mask]

print(sst_clusters.unique())

expression = np.ones_like(sst_clusters.unique())
ncells = np.ones_like(sst_clusters.unique())

for i, cluster in enumerate(sst_clusters.unique()):

    gene_idx = reference.var_names.get_loc('Sst')
    print(gene_idx)
    gene_expression = reference.X[:,gene_idx].toarray().flatten()

    sst_mask = reference.obs['cluster_label'] == cluster
    print(sum(sst_mask))
    gene_expression = gene_expression[sst_mask]

    #plt.hist(gene_expression, bins = 200, log = True)
    [p_25, p_75] = [np.percentile(gene_expression, q) for q in [25, 75]]
    print([p_25, p_75])
    trimmed_expression = gene_expression[(gene_expression>p_25)&(gene_expression<p_75)]
    trimmed_mean = np.mean(trimmed_expression)

    expression[i] = trimmed_mean
    ncells[i] = len(gene_expression)
    #plt.title(f'Expression levels of Sst in {cluster} cells (counts) TM:{trimmed_mean}')

    #for q in [p_25, p_75]:
    #    plt.axvline(q, color = 'red') 

    #plt.show()

plt.scatter(expression, ncells)
plt.title('Expression of Sst per Sst cluster')
plt.xlabel('Trimmed mean(counts)')
plt.ylabel('ncells per cluster')

In [None]:
# Convert to dense
X = reference_filtered.X.toarray() if hasattr(reference_filtered.X, "toarray") else np.asarray(reference_filtered.X)

# Flatten into 1D
values = X.ravel()

plt.hist(values, bins=200, log=True)
plt.xlabel("Log-norm expression")
plt.ylabel("Frequency (log scale)")
plt.title("Histogram of all values in reference.X")
plt.show()

Let's try to define a 1% expression level per gene, a la spapros

In [None]:
max_threshold = np.percentile(reference.X.toarray(), 99, axis = 0)
plt.hist(max_threshold, bins = 200, log = True)

In [None]:
reference.var['max_threshold_spapros']= max_threshold
reference.var[reference.var_names=='Sst']

In [None]:
print(reference.var.)

## Sparseness constraints

We want a gene to be expressed "sparsely". For that, we can either have an interesting metric, or we can have a boxplot. Let's start with the boxplot. 

In [None]:
# loading the yao2021 reference

ALM_subclasses = [
    "Vip",
    "Lamp5",
    "Scng",
    "Sst Chodl",
    "Sst",
    "Pvalb",
    "L2/3 IT CTX",
    "L4/5 IT CTX",
    "L5 IT CTX",
    "L6 IT CTX",
    "L5 PT CTX",
    "L4 RSP-ACA",
    "L5/6 NP CTX",
    "L6 CT CTX",
    "L6b CTX",
]


datapath = Path('/nemo/lab/znamenskiyp/home/shared/projects/colasa_MOs_panel/yao_2021')

reference = io.load_yao_2021_to_anndata(datapath, 'ALM', ALM_subclasses)
counts_reference = reference.copy()

In [None]:
sc.pp.filter_genes(reference, min_cells=1) # it crashes with a lot of genes expressed in no cells
sc.pp.normalize_total(reference)
sc.pp.log1p(reference)
sc.pp.highly_variable_genes(reference,flavor="cell_ranger",n_top_genes=10000) #for a good panel

sc.pp.filter_genes(counts_reference, min_cells=1) # it crashes with a lot of genes expressed in no cells
sc.pp.highly_variable_genes(counts_reference,flavor="cell_ranger",n_top_genes=10000) #for a good panel


In [None]:
savepath = "/nemo/lab/znamenskiyp/home/shared/projects/colasa_MOs_panel/panels/spapros_10K_taylored_glia_2021"

prior_marker = [
    "Chodl",
    "Cux2",
    "Fezf2",
    "Foxp2",
    "Rorb",
    "Pvalb",
    "Lamp5",
    "Adamts2",
    "Slco2a1",
]

selector = sp.se.ProbesetSelector(reference, n=100, 
                                  celltype_key="subclass_label", 
                                  verbosity=1, 
                                  n_jobs=-1, 
                                  save_dir=savepath, 
                                  preselected_genes=prior_marker)

selected = selector.probeset[selector.probeset["selection"]].copy()
spapros_panel = list(selected.index)

Subset the reference dataset by the genes we care about

In [None]:
panel_reference = reference[:,reference.var_names.isin(spapros_panel)]
panel_counts_reference = counts_reference[:,reference.var_names.isin(spapros_panel)]

Calculate the trimmed mean per cluster of each gene. In raw counts or in CPM? Do it in their lognorm way. 

We want to have a matrix of clusters x genes, or add it to the reference.var. Let's add it to the var if it's less than 100 clusters

In [None]:
len(panel_reference.obs['cluster_label'].unique())

In [None]:
def calculate_expression_per_taxa(panel_reference, taxa_label):

    '''
    Generates a matrix of taxa x genes, where each entry corresponds to the trimmed (25-75) mean expression per taxa. 
    '''

    #initialise the dataframe
    genes = panel_reference.var_names

    clusters = panel_reference.obs[taxa_label].to_numpy()
    uniq_clusters, inverse = np.unique(clusters, return_inverse=True) #inverse shows which elements belong to which cluster in the original matrix

    expression_percluster = pd.DataFrame(index=uniq_clusters, columns=genes)

    n_genes = panel_reference.n_vars
    print(f'There are {n_genes} genes in this panel')

    for k, cl in enumerate(uniq_clusters):
        # rows for this cluster
        rows = np.nonzero(inverse == k)[0]
        if rows.size == 0:
            continue

        # submatrix (cells_in_cluster x genes), dense for vectorized percentile ops
        Xc = panel_reference.X[rows, :].toarray() if issparse(panel_reference.X) else np.asarray(panel_reference.X[rows, :])

        # per-gene quartiles (vectorized across genes)
        p25, p75 = np.percentile(Xc, [25, 75], axis=0)

        # trim per gene
        mask = (Xc > p25) & (Xc < p75)            # shape: (cells_in_cluster, n_genes)
        count_trim = mask.sum(axis=0)             # per-gene counts after trimming
        sum_trim = (Xc * mask).sum(axis=0)        # per-gene sums after trimming

        # avoid division-by-zero when all values fall outside (rare but possible)
        trimmed_mean = np.divide(
            sum_trim, count_trim,
            out=np.zeros_like(sum_trim, dtype=float),
            where=count_trim > 0
        )

        # add to our matrix
        expression_percluster.loc[cl] = trimmed_mean

    return expression_percluster



This is how they're ordered in the panel, for easy comparison:

In [None]:
#This is for the panel with no sparseness constraints

gene_priority = [
    "Chodl","Vip","Rorb","Pvalb","Sst","Fezf2","Lamp5","Cux2","Rrad","Adamts2","Foxp2","Npsr1",
    "Slco2a1","Olfm1","Calb1","Zfpm2","Cdh13","Cryab","Etv1","Cdh8","Rgs4","Grin3a","Cplx3","Gpc6",
    "Synpr","Kcnab1","Cacna2d3","Vwc2l","Rprm","Tox","Erbb4","Crh","Ptn","Lypd1","Zfp804b","Pcp4l1",
    "Il1rapl2","Lingo2","Pcp4","Nxph1","Dcc","Pcsk5","Caln1","Man1a","Nov","Prr16","Grik1","Cck",
    "Cpne4","S100a10","Sorcs1","Cdh9","Gng4","Pde1a","Penk","Alcam","Reln","Ptprt","Trp53i11","Marcksl1",
    "Fstl5","Igfbp4","Cbln4","Crhbp","Necab1","Crtac1","Unc5d","Col25a1","Ptchd4","Fxyd6","Stxbp6",
    "Thsd7a","Parm1","Sema3c","Medag","Rbp4","Ctgf","Chst8","Fam19a1","Adarb2","Dner","Ak5","Cacng3",
    "Cygb","Nxph4","Cplx1","Arpp19","Kcnk2","Hs6st3","Sparcl1","Gpx3","Car4","Gucy1a3","Spon1","Prss23",
    "Tac2","Ptprm","Tmem91","Sema5a","Cntnap5c"
]

In [None]:
#This is for box_sparseness

gene_priority = [
    "Chodl",
    "Vip",
    "Rorb",
    "Sst",
    "Pvalb",
    "Lamp5",
    "Fezf2",
    "Cux2",
    "Rrad",
    "Adamts2",
    "Foxp2",
    "Npsr1",
    "Slco2a1",
    "Tshz2",
    "Fam19a1",
    "Cplx3",
    "Calb1",
    "Hs3st4",
    "Blnk",
    "Igfbp6",
    "Synpr",
    "Adarb2",
    "Ctgf",
    "Syt6",
    "Tnnc1",
    "Erbb4",
    "Sstr2",
    "Cdh9",
    "Ptprk",
    "Stxbp6",
    "Rprm",
    "Slc30a3",
    "Crh",
    "Lypd1",
    "Man1a",
    "Rgs12",
    "Npas1",
    "Il1rapl2",
    "Ptprm",
    "Crym",
    "Igfbp4",
    "Kit",
    "Cryab",
    "Prss23",
    "Etv1",
    "Grin3a",
    "Nxph1",
    "Tcap",
    "Meis2",
    "Calb2",
    "Pcp4l1",
    "Egfr",
    "Nr2f2",
    "Grik1",
    "Thsd7a",
    "Cd63",
    "Kcnip1",
    "Tac2",
    "Cbln4",
    "Elfn1",
    "Unc13c",
    "Car4",
    "Moxd1",
    "Trp53i11",
    "Cpne7",
    "Rmst",
    "Npy",
    "Cartpt",
    "Ndnf",
    "Pthlh",
    "S100b",
    "Vwc2l",
    "Crhbp",
    "Arhgap25",
    "Cox6a2",
    "Gad1",
    "Adamts3",
    "Prex1",
    "Npr3",
    "Adcyap1",
    "Sema3c",
    "Sema5a",
    "Pamr1",
    "Npy2r",
    "Penk",
    "Fam46a",
    "Kctd8",
    "Pnoc",
    "Cidea",
    "Ackr3",
    "Lgals1",
    "Ptprz1",
    "Cxcl14",
    "Tcerg1l",
    "Lypd6",
    "Rxfp1",
    "Serpine2",
    "Reln",
    "Cdh20",
    "Timp3",
]


In [None]:
#this is for the entropy constraint
gene_priority = [
    "Chodl",
    "Vip",
    "Sst",
    "Pvalb",
    "Rorb",
    "Lamp5",
    "Cux2",
    "Fezf2",
    "Rrad",
    "Adamts2",
    "Foxp2",
    "Npsr1",
    "Slco2a1",
    "Tshz2",
    "Fam19a1",
    "Calb1",
    "Syt6",
    "Blnk",
    "Cplx3",
    "Adarb2",
    "Tcap",
    "Rprm",
    "Cpne9",
    "Tac1",
    "Ctgf",
    "Erbb4",
    "Slc30a3",
    "Synpr",
    "Cdh9",
    "Tnnc1",
    "Il1rapl2",
    "Meis2",
    "Crh",
    "Adcyap1",
    "Rgs12",
    "Lypd1",
    "Igfbp6",
    "Sema5a",
    "Cpne7",
    "S100a6",
    "Lypd6",
    "Calb2",
    "Pamr1",
    "Cidea",
    "Medag",
    "Ptprk",
    "Pcp4l1",
    "Nxph1",
    "Crym",
    "Ptprm",
    "Cbln4",
    "Igfbp4",
    "Pnoc",
    "Penk",
    "Npy",
    "Egfr",
    "Cryab",
    "Grik1",
    "Kctd8",
    "Car4",
    "S100b",
    "Nr2f2",
    "Pthlh",
    "Prss23",
    "Cdh20",
    "Ptprz1",
    "Vwc2l",
    "Elfn1",
    "Crhbp",
    "Gad1",
    "Angpt1",
    "Adamts3",
    "Nxph4",
    "Npr3",
    "Sema5b",
    "Ddit4l",
    "Kit",
    "Serpine2",
    "Unc13c",
    "Cd63",
    "Cartpt",
    "Thsd7a",
    "Tac2",
    "Moxd1",
    "Rmst",
    "Sema3c",
    "Rspo1",
    "Sox6",
    "Ppapdc1a",
    "Lgals1",
    "Ackr3",
    "Timp3",
    "Npas1",
    "Reln",
    "Vwc2",
    "Pcdh8",
    "Hgf",
    "Pcdh20",
    "Cp",
    "Ndnf",
]


In [None]:
expression_percluster = ut.calculate_expression_per_taxa(panel_reference, 'cluster_label')

# --- Build robust cluster → subclass mapping from obs ---
meta = (
    reference.obs[["cluster_label", "subclass_label"]]
    .dropna()
    .drop_duplicates()
)

ALM_subclasses = reference.obs['subclass_label'].unique()

# ---- Explicit color allocation (edit as you wish) ----
subclass_colors = {
    "Vip":           "#1f77b4",
    "Lamp5":         "#ff7f0e",
    "Scng":          "#2ca02c",
    "Sst Chodl":     "#d62728",
    "Sst":           "#9467bd",
    "Pvalb":         "#8c564b",
    "L2/3 IT CTX":   "#e377c2",
    "L4/5 IT CTX":   "#7f7f7f",
    "L5 IT CTX":     "#bcbd22",
    "L6 IT CTX":     "#17becf",
    "L5 PT CTX":     "#aec7e8",
    "L4 RSP-ACA":    "#ffbb78",
    "L5/6 NP CTX":   "#98df8a",
    "L6 CT CTX":     "#ff9896",
    "L6b CTX":       "#c5b0d5",
}

# Default color for any subclass not listed
default_color = "#dddddd"

# ---- Order clusters by subclass group, then numeric id within group ----
ordered_clusters = []
spans = []  # list of (subclass, start_idx, end_idx) for shading bands

start = 0
for sub in ALM_subclasses:
    members = meta.loc[meta["subclass_label"] == sub].sort_values("cluster_label", kind="mergesort")
    if not members.empty:
        ordered_clusters.extend(members["cluster_label"].tolist())
        end = start + len(members) - 1
        spans.append((sub, start, end))
        start = end + 1

# Reindex the dataframe to the new cluster and gene order
# 2) Compute the new column order:
cols = list(expression_percluster.columns)
prio_set = set(gene_priority)

# first: the genes you specified, in exactly that order (and present in df)
ordered_cols = [g for g in gene_priority if g in cols]
# then: any remaining df columns not in your list, preserving their original order
ordered_cols += [g for g in cols if g not in prio_set]

# 3) Reindex the dataframe columns accordingly (this returns a new df)
expression_percluster_ordered = expression_percluster.loc[:, ordered_cols]

#and now, per clusters
expr_ord = expression_percluster_ordered.loc[ordered_clusters]

# ---- Plot: one subplot (row) per gene; shaded subclass bands ----
genes = list(expr_ord.columns)
expr_ord.max()
n_genes = len(genes)
fig_h = max(2 * n_genes, 4)

# ---- Layout knobs ----
height_per_subplot = 4.4   # bump this up if still too squished
fig_w = 12
legend_right_margin = 0.80 # reserve space for legend (0.80 = 80% of width for plots)

fig, axes = plt.subplots(
    n_genes, 1,
    figsize=(fig_w, max(4, height_per_subplot * n_genes)),
    sharex=True
)


if n_genes == 1:
    axes = [axes]

x = np.arange(len(ordered_clusters))

for i, gene in enumerate(genes):
    ax = axes[i]

    # background subclass bands
    for sub, s, e in spans:
        color = subclass_colors.get(sub, default_color)
        ax.axvspan(s - 0.5, e + 0.5, facecolor=color, alpha=0.12, linewidth=0)

    # gene trace
    y = expr_ord[gene].astype(float).to_numpy()
    ax.plot(x, y, marker="o", linestyle="-", linewidth=1)
    ax.set_ylabel(f'log1p(norm({gene}))')
    ax.grid(True, alpha=0.3)

# X ticks
axes[-1].set_xticks(x)
axes[-1].set_xticklabels(ordered_clusters, rotation=90)
axes[-1].set_xlabel("Clusters (grouped by subclass)")

# Legend: one entry per subclass band that actually appears
legend_patches = []
for sub, _, _ in spans:
    if sub not in [p.get_label() for p in legend_patches]:
        legend_patches.append(Patch(facecolor=subclass_colors.get(sub, default_color), alpha=0.4, label=sub))
# Put legend above or to the side to avoid overlapping plots
for ax in axes:
    ax.legend(handles=legend_patches, loc="upper left", bbox_to_anchor=(1.01, 1.0), frameon=False, title="Subclass bands")

plt.tight_layout()
plt.show()

# If you also want to keep the explicit order for later use:
ordered_clusters_list = ordered_clusters
subclass_band_spans = spans  # [(subclass, start_idx, end_idx), ...]


In [None]:
expression_percluster

In [None]:
pd.options.display.max_rows = 100  # control rows
pd.options.display.max_columns = 10  # control cols

In [None]:
expression_persubclass = ut.calculate_expression_per_taxa(panel_reference, 'subclass_label')


In [None]:
assert(list(expression_persubclass.columns)==sorted(list(expression_persubclass.columns)))

In [None]:
len(expression_persubclass.values[:,3])

In [None]:
expression_persubclass = ut.calculate_expression_per_taxa(panel_reference, 'subclass_label')


maxmean = expression_persubclass.max() - expression_persubclass.mean()

#calculate entropy of each gene
gene_H = entropy(expression_persubclass.to_numpy(dtype = float), base = 2, axis = 0)

h_thresh = 3.0        # entropy threshold (bits)
mm_thresh = 2.0       # max-mean threshold (same units as your data)
mask = (gene_H > h_thresh) & (maxmean < mm_thresh)
almost_mask = (gene_H > h_thresh) ^ (maxmean < mm_thresh)
colors_1 = np.where(mask, 1, 0)
colors_2 = np.where(almost_mask, 2, 0)
colors_sum = colors_1+colors_2
colist = ['black', 'red', 'orange']
colors = [colist[i] for i in colors_sum]

plt.scatter(maxmean, gene_H, c=colors, s=20)
plt.ylabel('per gene Shannon entropy (bits)')
plt.xlabel('per gene max-mean (log1p(norm(counts))')

plt.show()

In [None]:
def logistic(x, k, x0, plot = False):
    y = [1/(1+np.power(np.e, -k*(value-x0))) for value in x]
    if plot:
        plt.scatter(x, y)
        plt.title('Logistic for midpoint {x0} and k {k}')
        plt.show()
    return y

In [None]:
from matplotlib.cm import get_cmap

values = np.linspace(1, 5, 100)
print(values)
ks = np.linspace(0, 5, 10)
print(ks)
cmap = get_cmap("viridis", 10)           # 10 discrete steps of 'viridis'

for i, k in enumerate(ks):
    y = np.array(logistic(values, k, 2))
    plt.scatter(values, 1-y, color = cmap(i))


plt.show()

In [None]:
#assert that gene_priority is just a resorting of genes

for gene in genes:
    assert gene in gene_priority, f'{gene}'
assert len(genes)==len(gene_priority)

In [None]:

# --- Build robust cluster → subclass mapping from obs ---


# ---- Explicit color allocation (edit as you wish) ----
subclass_colors = {
    "Vip":           "#1f77b4",
    "Lamp5":         "#ff7f0e",
    "Scng":          "#2ca02c",
    "Sst Chodl":     "#d62728",
    "Sst":           "#9467bd",
    "Pvalb":         "#8c564b",
    "L2/3 IT CTX":   "#e377c2",
    "L4/5 IT CTX":   "#7f7f7f",
    "L5 IT CTX":     "#bcbd22",
    "L6 IT CTX":     "#17becf",
    "L5 PT CTX":     "#aec7e8",
    "L4 RSP-ACA":    "#ffbb78",
    "L5/6 NP CTX":   "#98df8a",
    "L6 CT CTX":     "#ff9896",
    "L6b CTX":       "#c5b0d5",
}

expr_ord = expression_persubclass
# ---- Plot: one subplot (row) per gene; shaded subclass bands ----
genes_order =np.argsort((expr_ord.max() - expr_ord.mean()).values)

genes = list(expr_ord.columns[genes_order]) #sort genes according to max-mean quality metric

plot_maxmean = maxmean[np.argsort(maxmean)]#sort metrics according to the same metric to preserve correspondence
plot_entropy = gene_H[np.argsort(maxmean)]
plot_colors = np.array(colors)[np.argsort(maxmean)]

n_genes = len(genes)

# ---- Layout knobs ----
height_per_subplot = 4.4   # bump this up if still too squished
fig_w = 12
legend_right_margin = 0.80 # reserve space for legend (0.80 = 80% of width for plots)

fig, axes = plt.subplots(
    n_genes, 1,
    figsize=(fig_w, max(4, height_per_subplot * n_genes)),
)


if n_genes == 1:
    axes = [axes]

x = np.arange(len(reference.obs['subclass_label'].unique()))

for i, gene in enumerate(genes):
    ax = axes[i]
    # gene trace
    y = expr_ord[gene].astype(float).to_numpy()
    ax.plot(x, y, marker="o", linestyle="none", markersize=20, linewidth=1)
    ax.text(0.99, 0.95,
        f'Entropy: {plot_entropy[i]:.3f}\nMax-mean: {plot_maxmean[i]:.3f}\n{gene}',
        transform=ax.transAxes,  # <--- relative to subplot axes
        ha="right", va="top", 
        color = plot_colors[i])

    ax.set_ylabel(f'log1p(norm({gene}))')
    ax.set_xticks(np.arange(len(expr_ord.index)), labels=expr_ord.index)
    ax.grid(True, alpha=0.3)
    ax.set_ylim(-0.5, 10)



plt.tight_layout()
plt.show()




We're going to train on hand-picked genes to see how well we capture it. 

In [None]:
train = pd.read_csv('sparseness_train.csv')
result = [0, 1, 0] #translate 1, 2 into binary mask
train['target'] = [result[i] for i in train['Sparseness']]
train

In [None]:
#sort to match the metric. Metrics are in reference order, which is alphabetical
train = train.sort_values(by='Gene')
assert list(train['Gene'])==list(expression_persubclass.columns), 'sorting went wrong somewhere'
train['entropy'] = gene_H
train['maxmean'] = maxmean.values
train = train.dropna(subset=["entropy", "maxmean"])
train.head


In [None]:
color = np.where(train['target']==0, "red", "green")
plt.scatter(train['maxmean'], train['entropy'], color = color, s = 20)
plt.ylabel('per gene Shannon entropy (bits)')
plt.xlabel('per gene max-mean (log1p(norm(counts))')

plt.axhline(3.28, color = 'lightblue')
plt.axvline(3.38, color = 'lightgray')
plt.axvline(2, color = 'lightblue')

In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report

# Features and target
X = train[["entropy", "maxmean"]]
y = train["target"]

# Split into train/test sets (80/20 for example)
#X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Initialize logistic regression
full_model = LogisticRegression(max_iter=1000)

# Fit model
full_model.fit(X, y)

# Predict on test set
#y_pred = clf.predict(X_test)

# Evaluate
#print("Accuracy:", accuracy_score(y_test, y_pred))
#print(classification_report(y_test, y_pred))

In [None]:
print("Intercept:", full_model.intercept_)
print("Coefficients:", dict(zip(X.columns, full_model.coef_[0])))


In [None]:
test = pd.DataFrame(np.linspace(0, 7, 500))
score = full_model.predict_proba(test)
print(score.shape)
plt.plot(test, score[:,1])
plt.ylabel('Gene score')
plt.xlabel('Gene entropy')

In [None]:
score

In [None]:
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.linear_model import LogisticRegression
import numpy as np

def cv_score(features, k=5):
    X_sub = train[features]
    y = train["target"]

    clf = LogisticRegression(max_iter=1000)

    # StratifiedKFold preserves class balance in splits
    cv = StratifiedKFold(n_splits=k, shuffle=True, random_state=42)

    scores = cross_val_score(clf, X_sub, y, cv=cv, scoring="roc_auc")
    return scores.mean(), scores.std()

print("Both:", cv_score(["entropy", "maxmean"]))
print("Entropy only:", cv_score(["entropy"]))
print("Maxmean only:", cv_score(["maxmean"]))



In [None]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.tree import plot_tree


# Drop NaNs in predictors + target
df = train.dropna(subset=["entropy", "maxmean", "target"])

X = df[["entropy", "maxmean"]]
y = df["target"]

# Split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train decision tree
tree_clf = DecisionTreeClassifier(max_depth=2, random_state=42)  # tune max_depth if needed
tree_clf.fit(X_train, y_train)

# Evaluate
y_pred = tree_clf.predict(X_test)
print("Accuracy:", accuracy_score(y_test, y_pred))

# Feature importances
importances = tree_clf.feature_importances_
for feat, imp in zip(X.columns, importances):
    print(f"{feat}: {imp:.3f}")


plt.figure(figsize=(12, 6))
plot_tree(
    tree_clf,
    feature_names=X.columns,
    class_names=[str(c) for c in tree_clf.classes_],
    filled=True,
    rounded=True,
    fontsize=10
)
plt.show()


In [None]:
from sklearn.model_selection import GridSearchCV
from sklearn.tree import DecisionTreeClassifier

df = train.dropna(subset=["entropy", "maxmean", "target"])
X = df[["maxmean"]]
y = df["target"]

# Define the parameter grid
param_grid = {
    "max_depth": [1],
    "min_samples_split": [2, 5, 10],
    "min_samples_leaf": [1, 2, 4]
}

clf = DecisionTreeClassifier(random_state=42)

grid_search = GridSearchCV(
    clf,
    param_grid,
    cv=5,               # 5-fold cross-validation
    scoring="roc_auc", # or "roc_auc"
    n_jobs=-1
)

grid_search.fit(X, y)

print("Best params:", grid_search.best_params_)
print("Best CV score:", grid_search.best_score_)


In [None]:
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree

best_tree = grid_search.best_estimator_

plt.figure(figsize=(12, 6))
plot_tree(
    best_tree,
    feature_names=X.columns,
    class_names=[str(c) for c in best_tree.classes_],
    filled=True,
    rounded=True,
    fontsize=10
)
plt.show()


### Applying final constraint

In [None]:

expression_persubclass = ut.calculate_expression_per_taxa(panel_reference, 'subclass_label')

## Calculating metrics
maxmean = expression_persubclass.max() - expression_persubclass.mean()
#calculate entropy of each gene
gene_H = entropy(expression_persubclass.to_numpy(dtype = float), base = 2, axis = 0)

## Building train dataset
train = pd.read_csv('sparseness_train.csv')
result = [0, 1, 0] #translate 1, 2 into binary mask
train['target'] = [result[i] for i in train['Sparseness']]

#sort to match the metric. Metrics are in reference order, which is alphabetical
train = train.sort_values(by='Gene')
assert list(train['Gene'])==list(expression_persubclass.columns), 'sorting went wrong somewhere'
train['entropy'] = gene_H
train['maxmean'] = maxmean.values
train = train.dropna(subset=["entropy", "maxmean"])

##Training final model
# Features and target
X = train[["entropy"]]
y = train["target"]
# Initialize logistic regression
full_model = LogisticRegression(max_iter=1000)
# Fit model
full_model.fit(X, y)
print("Intercept:", full_model.intercept_)
print("Coefficients:", dict(zip(X.columns, full_model.coef_[0])))

#Evaluate all genes
all_genes_subclass = ut.calculate_expression_per_taxa(reference, 'subclass_label')
all_genes_H = entropy(all_genes_subclass.to_numpy(dtype = float), base = 2, axis = 0)
assert list(reference.var_names)==list(all_genes_subclass.columns), 'Something is wrong with the sorting'
all_genes_entropies = pd.DataFrame(all_genes_H, index=reference.var_names)
all_genes_entropies = all_genes_entropies.fillna(100) 
score = full_model.predict_proba(all_genes_entropies)
reference.var['score'] = score[:,0]





Save parameters:

In [None]:
params = {
    "intercept": full_model.intercept_.tolist(),
    "coef": full_model.coef_.tolist(),
    "features": X.columns.tolist()
}
pd.Series(params).to_json("logreg_params.json")

In [None]:
all_genes_maxmean = all_genes_subclass.max() - all_genes_subclass.mean()


In [None]:
def box_penalty(entropy, maxmean, k=5, entropy_threshold = 3.28, maxmean_threshold = 2):
    
    z_entropy = 1-np.array([1/(1+np.power(np.e, -k*(value-entropy_threshold))) for value in entropy])
    z_maxmean= [1/(1+np.power(np.e, -k*(value-maxmean_threshold))) for value in maxmean]
    penalty = np.minimum(z_entropy, z_maxmean)
    return penalty




In [None]:
penalty = box_penalty(all_genes_H, all_genes_maxmean)

In [None]:
all_genes_entropies = all_genes_entropies.rename(columns={0:'entropy'})
all_genes_entropies['maxmean'] = all_genes_maxmean
X = all_genes_entropies[["entropy", 'maxmean']]
score = full_model.predict_proba(X)


In [None]:
score = full_model.predict_proba(all_genes_entropies)


In [None]:

# Assume you already have:
# all_genes_maxmean   -> array-like of shape (n_genes,)
# all_genes_H         -> array-like of shape (n_genes,)
# score               -> output of predict_proba (n_genes x 2)

#probs = score[:, 1]  # probability of class=1
probs = penalty

plt.figure(figsize=(7, 6))
sc = plt.scatter(
    all_genes_maxmean,
    all_genes_H,
    c=probs,
    cmap='Spectral',      # red=low, blue=high
    s=20,
    edgecolor="none"
)

for e in entropies:
    plt.axhline(e, color = "lightgray")

plt.colorbar(sc, label="Penalty")
plt.xlabel("per gene max-mean (log1p(norm(counts)))")
plt.ylabel("per gene Shannon entropy (bits)")
plt.title("Gene distribution colored by box penalty score")
plt.grid(alpha=0.3)
plt.show()


In [None]:
plt.hist(all_genes_H, log = True)

In [None]:
# Mask genes
mask = (all_genes_H < 1) & (all_genes_maxmean < 1) & (all_genes_H>0.1)
interesting_subset = all_genes_subclass.loc[:, mask]
print("Subset shape:", interesting_subset.shape)

# Reproducible random sample
np.random.seed(42)
sampled_genes = np.random.choice(
    interesting_subset.columns,
    size=min(50, interesting_subset.shape[1]),
    replace=False
)

# Subset the expression matrix
sampled_subset = interesting_subset.loc[:, sampled_genes]

# Subset entropy and maxmean consistently
sampled_H = pd.Series(all_genes_H, index=all_genes_subclass.columns).loc[sampled_genes]
sampled_maxmean = pd.Series(all_genes_maxmean, index=all_genes_subclass.columns).loc[sampled_genes]

print("Sampled genes:", list(sampled_genes)[:5], "...")
print("Entropy subset shape:", sampled_H.shape)
print("Maxmean subset shape:", sampled_maxmean.shape)



In [None]:





expr_ord = sampled_subset
# ---- Plot: one subplot (row) per gene; shaded subclass bands ----
genes_order =np.argsort((expr_ord.max() - expr_ord.mean()).values)

genes = list(expr_ord.columns[genes_order]) #sort genes according to max-mean quality metric

plot_maxmean = sampled_maxmean[np.argsort(sampled_maxmean)]#sort metrics according to the same metric to preserve correspondence
plot_entropy = sampled_H[np.argsort(sampled_maxmean)]
#plot_colors = np.array(colors)[np.argsort(maxmean)]

n_genes = len(genes)

# ---- Layout knobs ----
height_per_subplot = 4.4   # bump this up if still too squished
fig_w = 12
legend_right_margin = 0.80 # reserve space for legend (0.80 = 80% of width for plots)

fig, axes = plt.subplots(
    n_genes, 1,
    figsize=(fig_w, max(4, height_per_subplot * n_genes)),
)


if n_genes == 1:
    axes = [axes]

x = np.arange(len(reference.obs['subclass_label'].unique()))

for i, gene in enumerate(genes):
    ax = axes[i]
    # gene trace
    y = expr_ord[gene].astype(float).to_numpy()
    ax.plot(x, y, marker="o", linestyle="none", markersize=20, linewidth=1)
    ax.text(0.99, 0.95,
        f'Entropy: {plot_entropy[i]:.3f}\nMax-mean: {plot_maxmean[i]:.3f}\n{gene}',
        transform=ax.transAxes,  # <--- relative to subplot axes
        ha="right", va="top")

    ax.set_ylabel(f'log1p(norm({gene}))')
    ax.set_xticks(np.arange(len(expr_ord.index)), labels=expr_ord.index)
    ax.grid(True, alpha=0.3)
    ax.set_ylim(-0.5, 10)



plt.tight_layout()
plt.show()




In [None]:
n_subclasses = np.arange(1, 14)  # 1 to 13 inclusive

sequences = []
for n in n_subclasses:
    seq = np.zeros(13)
    seq[:n] = 1.0 / n     # put n values of 1/n at the start
    sequences.append(seq)

# If you want it as a NumPy array (13x13 matrix):
sequences_array = np.vstack(sequences)

entropies = np.zeros(13)

for i, sequence in enumerate(sequences_array):
    H1 = entropy(sequence, base = 2)
    entropies[i] = H1

In [None]:
entropies

## Plot cummulative expression

Find the average total read count per cluster. 

In [None]:
expr_ord.iloc[1,0]

In [None]:



expression_percluster = ut.calculate_expression_per_taxa(panel_reference, 'cluster_label')

# --- Build robust cluster → subclass mapping from obs ---
meta = (
    reference.obs[["cluster_label", "subclass_label"]]
    .dropna()
    .drop_duplicates()
)

ALM_subclasses = reference.obs['subclass_label'].unique()

# ---- Explicit color allocation (edit as you wish) ----
subclass_colors = {
    "Vip":           "#1f77b4",
    "Lamp5":         "#ff7f0e",
    "Scng":          "#2ca02c",
    "Sst Chodl":     "#d62728",
    "Sst":           "#9467bd",
    "Pvalb":         "#8c564b",
    "L2/3 IT CTX":   "#e377c2",
    "L4/5 IT CTX":   "#7f7f7f",
    "L5 IT CTX":     "#bcbd22",
    "L6 IT CTX":     "#17becf",
    "L5 PT CTX":     "#aec7e8",
    "L4 RSP-ACA":    "#ffbb78",
    "L5/6 NP CTX":   "#98df8a",
    "L6 CT CTX":     "#ff9896",
    "L6b CTX":       "#c5b0d5",
}

# Default color for any subclass not listed
default_color = "#dddddd"

# ---- Order clusters by subclass group, then numeric id within group ----
ordered_clusters = []
spans = []  # list of (subclass, start_idx, end_idx) for shading bands

start = 0
for sub in ALM_subclasses:
    members = meta.loc[meta["subclass_label"] == sub].sort_values("cluster_label", kind="mergesort")
    if not members.empty:
        ordered_clusters.extend(members["cluster_label"].tolist())
        end = start + len(members) - 1
        spans.append((sub, start, end))
        start = end + 1

# Reindex the dataframe to the new cluster order
expr_ord = expression_percluster.loc[ordered_clusters]


# ---- Layout knobs ----
height_per_subplot = 4.4   # bump this up if still too squished
fig_w = 12
legend_right_margin = 0.80 # reserve space for legend (0.80 = 80% of width for plots)

fig, axes = plt.subplots(
    1, 1,
    figsize=(fig_w, max(4, height_per_subplot ))
)


x = np.arange(len(ordered_clusters))


ax = axes

# background subclass bands
for sub, s, e in spans:
    color = subclass_colors.get(sub, default_color)
    ax.axvspan(s - 0.5, e + 0.5, facecolor=color, alpha=0.12, linewidth=0)

# expression_trace 
expr_ord = expr_ord.astype(float)
y = np.log1p(np.expm1(expr_ord).sum(axis=1)).to_numpy()



ax.set_ylabel('log1p(sum of normalized counts)')
ax.plot(x, y, marker="o", linestyle="-", linewidth=1)
ax.grid(True, alpha=0.3)

# X ticks
ax.set_xticks(x)
ax.set_xticklabels(ordered_clusters, rotation=90)
ax.set_xlabel("Clusters (grouped by subclass)")

# Legend: one entry per subclass band that actually appears
legend_patches = []
for sub, _, _ in spans:
    if sub not in [p.get_label() for p in legend_patches]:
        legend_patches.append(Patch(facecolor=subclass_colors.get(sub, default_color), alpha=0.4, label=sub))
# Put legend above or to the side to avoid overlapping plots

ax.legend(handles=legend_patches, loc="upper left", bbox_to_anchor=(1.01, 1.0), frameon=False, title="Subclass bands")
ax.set_title('Total expression per cluster')

plt.tight_layout()
plt.show()

# If you also want to keep the explicit order for later use:
ordered_clusters_list = ordered_clusters
subclass_band_spans = spans  # [(subclass, start_idx, end_idx), ...]


In [None]:
sst_exp = panel_reference.X[:,panel_reference.var.index=='Sst'][panel_reference.obs.subclass_label=='Sst', :].toarray().flatten()

In [None]:
plt.hist(sst_exp, bins = 100, log = True)
plt.axvline(np.median(sst_exp), color = 'red')
plt.text(np.median(sst_exp), 10, s = f'Median: {np.median(sst_exp):.1f}', color = 'red')
plt.title('Sst expression in Sst cells')
plt.xlabel('log1p(norm(counts))')
plt.ylabel('Count (log scale)')
plt.show()

In [None]:
reference

## Confusion matrices

In [None]:
confusion_glia = pd.read_csv("/nemo/lab/znamenskiyp/home/shared/projects/colasa_MOs_panel/panels/evaluation_2023/forest_clfs/forest_clfs_adata1_spapros_10K_2023_taylored_glia_panel.csv")

In [None]:
confusion_glia.index = confusion_glia['Unnamed: 0']
confusion_glia = confusion_glia.drop(columns = ['Unnamed: 0'])

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt


plt.figure(figsize=(10, 8))
sns.heatmap(
    confusion_glia,
    cmap="viridis",     # or "magma", "coolwarm", etc.
    annot=True,         # show values
    fmt=".3f",          # 3 decimal places
    cbar_kws={'label': 'Score'}
)
plt.title("Subclass similarity heatmap")
plt.ylabel("True subclass")
plt.xlabel("Predicted subclass")
plt.tight_layout()
plt.show()


## Testing expression outside our target subclasses

In [None]:
# loading the yao2021 reference

ALM_neuron_subclasses = [
    "Vip",
    "Lamp5",
    "Scng",
    "Sst Chodl",
    "Sst",
    "Pvalb",
    "L2/3 IT CTX",
    "L4/5 IT CTX",
    "L5 IT CTX",
    "L6 IT CTX",
    "L5 PT CTX",
    "L4 RSP-ACA",
    "L5/6 NP CTX",
    "L6 CT CTX",
    "L6b CTX",
]

ALM_glia_subclasses = ['Oligo', 'Astro', 'SMC-Peri', 'VLMC', 'Micro-PVM', 'Endo', 'CR', 'Meis2']




datapath = Path('/nemo/lab/znamenskiyp/home/shared/projects/colasa_MOs_panel/yao_2021')

glia_reference = io.load_yao_2021_to_anndata(datapath, 'ALM', ALM_glia_subclasses)


In [None]:
sc.pp.filter_genes(glia_reference, min_cells=1) # it crashes with a lot of genes expressed in no cells
sc.pp.normalize_total(glia_reference)
sc.pp.log1p(glia_reference)
sc.pp.highly_variable_genes(glia_reference,flavor="cell_ranger",n_top_genes=10000) #for a good panel

In [None]:
savepath = "/nemo/lab/znamenskiyp/home/shared/projects/colasa_MOs_panel/panels/spapros_10K_taylored_glia_2021"

prior_marker = [
    "Chodl",
    "Cux2",
    "Fezf2",
    "Foxp2",
    "Rorb",
    "Pvalb",
    "Lamp5",
    "Adamts2",
    "Slco2a1",
]

selector = sp.se.ProbesetSelector(glia_reference, n=100, 
                                  celltype_key="subclass_label", 
                                  verbosity=1, 
                                  n_jobs=-1, 
                                  save_dir=savepath, 
                                  preselected_genes=prior_marker)

selected = selector.probeset[selector.probeset["selection"]].copy()
spapros_panel = list(selected.index)

panel_reference = glia_reference[:,glia_reference.var_names.isin(spapros_panel)]

#this is for the entropy constraint
gene_priority = [
    "Chodl",
    "Vip",
    "Sst",
    "Pvalb",
    "Rorb",
    "Lamp5",
    "Cux2",
    "Fezf2",
    "Rrad",
    "Adamts2",
    "Foxp2",
    "Npsr1",
    "Slco2a1",
    "Tshz2",
    "Fam19a1",
    "Calb1",
    "Syt6",
    "Blnk",
    "Cplx3",
    "Adarb2",
    "Tcap",
    "Rprm",
    "Cpne9",
    "Tac1",
    "Ctgf",
    "Erbb4",
    "Slc30a3",
    "Synpr",
    "Cdh9",
    "Tnnc1",
    "Il1rapl2",
    "Meis2",
    "Crh",
    "Adcyap1",
    "Rgs12",
    "Lypd1",
    "Igfbp6",
    "Sema5a",
    "Cpne7",
    "S100a6",
    "Lypd6",
    "Calb2",
    "Pamr1",
    "Cidea",
    "Medag",
    "Ptprk",
    "Pcp4l1",
    "Nxph1",
    "Crym",
    "Ptprm",
    "Cbln4",
    "Igfbp4",
    "Pnoc",
    "Penk",
    "Npy",
    "Egfr",
    "Cryab",
    "Grik1",
    "Kctd8",
    "Car4",
    "S100b",
    "Nr2f2",
    "Pthlh",
    "Prss23",
    "Cdh20",
    "Ptprz1",
    "Vwc2l",
    "Elfn1",
    "Crhbp",
    "Gad1",
    "Angpt1",
    "Adamts3",
    "Nxph4",
    "Npr3",
    "Sema5b",
    "Ddit4l",
    "Kit",
    "Serpine2",
    "Unc13c",
    "Cd63",
    "Cartpt",
    "Thsd7a",
    "Tac2",
    "Moxd1",
    "Rmst",
    "Sema3c",
    "Rspo1",
    "Sox6",
    "Ppapdc1a",
    "Lgals1",
    "Ackr3",
    "Timp3",
    "Npas1",
    "Reln",
    "Vwc2",
    "Pcdh8",
    "Hgf",
    "Pcdh20",
    "Cp",
    "Ndnf",
]

In [None]:
expression_persubclass = ut.calculate_expression_per_taxa(panel_reference, 'subclass_label')


maxmean = expression_persubclass.max() - expression_persubclass.mean()

#calculate entropy of each gene
gene_H = entropy(expression_persubclass.to_numpy(dtype = float), base = 2, axis = 0)

h_thresh = 3.0        # entropy threshold (bits)
mm_thresh = 2.0       # max-mean threshold (same units as your data)
mask = (gene_H > h_thresh) & (maxmean < mm_thresh)
almost_mask = (gene_H > h_thresh) ^ (maxmean < mm_thresh)
colors_1 = np.where(mask, 1, 0)
colors_2 = np.where(almost_mask, 2, 0)
colors_sum = colors_1+colors_2
colist = ['black', 'red', 'orange']
colors = [colist[i] for i in colors_sum]

plt.scatter(maxmean, gene_H, c=colors, s=20)
plt.ylabel('per gene Shannon entropy (bits)')
plt.xlabel('per gene max-mean (log1p(norm(counts))')

plt.show()

In [None]:


expr_ord = expression_persubclass
# ---- Plot: one subplot (row) per gene; shaded subclass bands ----
genes_order =np.argsort((expr_ord.max() - expr_ord.mean()).values)

genes = list(expr_ord.columns[genes_order]) #sort genes according to max-mean quality metric

plot_maxmean = maxmean[np.argsort(maxmean)]#sort metrics according to the same metric to preserve correspondence
plot_entropy = gene_H[np.argsort(maxmean)]
plot_colors = np.array(colors)[np.argsort(maxmean)]

n_genes = len(genes)

# ---- Layout knobs ----
height_per_subplot = 4.4   # bump this up if still too squished
fig_w = 12
legend_right_margin = 0.80 # reserve space for legend (0.80 = 80% of width for plots)

fig, axes = plt.subplots(
    n_genes, 1,
    figsize=(fig_w, max(4, height_per_subplot * n_genes)),
)


if n_genes == 1:
    axes = [axes]

x = np.arange(len(glia_reference.obs['subclass_label'].unique()))

for i, gene in enumerate(genes):
    ax = axes[i]
    # gene trace
    y = expr_ord[gene].astype(float).to_numpy()
    ax.plot(x, y, marker="o", linestyle="none", markersize=20, linewidth=1)
    ax.text(0.99, 0.95,
        f'Entropy: {plot_entropy[i]:.3f}\nMax-mean: {plot_maxmean[i]:.3f}\n{gene}',
        transform=ax.transAxes,  # <--- relative to subplot axes
        ha="right", va="top", 
        color = plot_colors[i])

    ax.set_ylabel(f'log1p(norm({gene}))')
    ax.set_xticks(np.arange(len(expr_ord.index)), labels=expr_ord.index)
    ax.grid(True, alpha=0.3)
    ax.set_ylim(-0.5, 10)



plt.tight_layout()
plt.show()




In [None]:

# 1) Total expression per subclass (undo log1p, sum, log1p again)
expr_values = expression_persubclass.to_numpy(dtype=float)
total_expr = np.log1p(np.expm1(expr_values).sum(axis=1))
total_expr = pd.Series(total_expr, index=expression_persubclass.index, name="total_expr")

# 2) Cell counts per subclass
cells_per_subclass = (
    glia_reference.obs["subclass_label"]
    .value_counts()
    .rename_axis("subclass")
    .to_frame("n_cells")
    .sort_index()
)

# 3) Align both by subclass index
df = pd.concat([total_expr, cells_per_subclass["n_cells"]], axis=1, join="inner")

# 4) Scatter plot
fig, ax = plt.subplots(figsize=(7, 5))
ax.scatter(df["n_cells"], df["total_expr"], s=60)

# Add labels
for subclass, row in df.iterrows():
    ax.text(row["n_cells"], row["total_expr"], subclass,
            fontsize=8, ha="center", va="bottom")

ax.set_xlabel("Number of cells per subclass")
ax.set_ylabel("log1p(sum of normalized counts)")
ax.set_title("Total expression vs. cell count per subclass")
ax.grid(alpha=0.3)

plt.tight_layout()
plt.show()


In [None]:
flatvalues = expression_persubclass.values.flatten()
for value in flatvalues:
    assert is_type(value, float), f'Negative value {value} found somewhere'

Omit Meis2, CR, SMC Peri and Micro PVM, penalise for max expression among the rest, sigmoid at 7?

In [None]:
threshold = 7


ALM_glia_subclasses = ['Oligo', 'Astro', 'VLMC', 'Endo']
datapath = Path('/nemo/lab/znamenskiyp/home/shared/projects/colasa_MOs_panel/yao_2021')
glia_penalty_reference = io.load_yao_2021_to_anndata(datapath, 'ALM', ALM_glia_subclasses)

sc.pp.filter_genes(glia_penalty_reference, min_cells=1) # it crashes with a lot of genes expressed in no cells
sc.pp.normalize_total(glia_penalty_reference)
sc.pp.log1p(glia_penalty_reference)
sc.pp.highly_variable_genes(glia_penalty_reference,flavor="cell_ranger",n_top_genes=10000) #for a good panel

expression_persubclass = ut.calculate_expression_per_taxa(glia_penalty_reference, 'subclass_label')

max_values = expression_persubclass.max()

penalty = [1/(1+np.power(np.e, -k*(value-threshold))) for value in max_values]


In [None]:
max_values = expression_persubclass.max()

k = 6
threshold = 7

penalty = [1/(1+np.power(np.e, -k*(value-threshold))) for value in max_values]

plt.hist(penalty)
plt.title(f'Glial expression penalty for k={k} and threshold={threshold}')
plt.xlabel('Penalty')
plt.ylabel('# genes')

In [None]:
ordered_genes = np.array(expression_persubclass.columns)[np.argsort(penalty)]
pd.DataFrame.from_dict({'gene': ordered_genes, 'penalty': np.sort(penalty)}).to_csv('glia_penalty.csv', index = False)

## VIsp panel

In [None]:
ALM_neuron_subclasses = [
    "Vip",
    "Lamp5",
    "Scng",
    "Sst Chodl",
    "Sst",
    "Pvalb",
    "L2/3 IT CTX",
    "L4/5 IT CTX",
    "L5 IT CTX",
    "L6 IT CTX",
    "L5 PT CTX",
    "L4 RSP-ACA",
    "L5/6 NP CTX",
    "L6 CT CTX",
    "L6b CTX",
]

datapath = Path('/nemo/lab/znamenskiyp/home/shared/projects/colasa_MOs_panel/yao_2021')

reference = io.load_yao_2021_to_anndata(datapath, 'VISp', None)

In [None]:
reference

In [None]:
# Get value counts of subclass_label
subclass_counts = reference.obs['subclass_label'].value_counts()

# Convert to DataFrame for better formatting
df_counts = pd.DataFrame({
    'Subclass': subclass_counts.index,
    'Number of cells': subclass_counts.values
})

# Add total row
total_cells = df_counts['Number of cells'].sum()
df_counts.loc[len(df_counts)] = ['TOTAL', total_cells]

print(df_counts.to_string(index=False))

L4/5 IT CTX
Sst
L6 IT CTX
Vip
Pvalb
Lamp5
L2/3 IT CTX
L6 CT CTX
L5 IT CTX
L5 PT CTX
L5/6 NP CTX
L6b CTX
Sst Chodl
Car3
Sncg
