In [None]:
import scanpy as sc
import spapros as sp
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
from sklearn.linear_model import LogisticRegression
import anndata as ann
from pathlib import Path
import iss_analysis.io as io
import iss_analysis.pick_genes as pick
from scipy.sparse import issparse, csr_matrix
from scipy.stats import entropy
import utils as ut

from abc_atlas_access.abc_atlas_cache.abc_project_cache import AbcProjectCache


In [None]:
AREA = ['VISp', 'ALM']

subclasses = [
    "L4/5 IT CTX",
    "Sst",
    "Vip",
    "Pvalb",
    "Lamp5",
    "L2/3 IT CTX",
    "L5 IT CTX",
    "L5 PT CTX",
    "Sst Chodl",
]

prior_marker = [
    "Chodl", 
    "Vip", 
    "Gad1", 
    "Lamp5", 
    "Sst", 
    "Slc17a7"
]

datapath = Path('/nemo/lab/znamenskiyp/home/shared/projects/colasa_MOs_panel/yao_2021')

reference_list = []
for area in AREA:
    ref_area = io.load_yao_2021_to_anndata(datapath, area, subclasses)
    reference_list.append(ref_area)
reference = ann.concat(reference_list, join='inner', axis=0)

counts_reference = reference.copy()

In [None]:
sc.pp.filter_genes(reference, min_cells=1) # it crashes with a lot of genes expressed in no cells
sc.pp.normalize_total(reference)
sc.pp.log1p(reference)
sc.pp.highly_variable_genes(reference,flavor="cell_ranger",n_top_genes=10000) #for a good panel

sc.pp.filter_genes(counts_reference, min_cells=1) # it crashes with a lot of genes expressed in no cells
sc.pp.highly_variable_genes(counts_reference,flavor="cell_ranger",n_top_genes=10000) #for a good panel


In [None]:
savepath = "/nemo/lab/znamenskiyp/home/shared/projects/colasa_MOs_panel/panels/spapros_Rob_joint_25"

prior_marker = [
    "Chodl",
    "Cux2",
    "Fezf2",
    "Foxp2",
    "Rorb",
    "Pvalb",
    "Lamp5",
    "Adamts2",
    "Slco2a1",
]

selector = sp.se.ProbesetSelector(reference, n=25, 
                                  celltype_key="subclass_label", 
                                  verbosity=1, 
                                  n_jobs=-1, 
                                  n_pca_genes=0,
                                  save_dir=savepath, 
                                  preselected_genes=prior_marker)

selected = selector.probeset[selector.probeset["selection"]].copy()
spapros_panel = list(selected.index)

In [None]:
panel_reference = reference[:,reference.var_names.isin(spapros_panel)]
panel_counts_reference = counts_reference[:,reference.var_names.isin(spapros_panel)]

In [None]:
def calculate_expression_per_taxa(panel_reference, taxa_label):

    '''
    Generates a matrix of taxa x genes, where each entry corresponds to the trimmed (25-75) mean expression per taxa. 
    '''

    #initialise the dataframe
    genes = panel_reference.var_names

    clusters = panel_reference.obs[taxa_label].to_numpy()
    uniq_clusters, inverse = np.unique(clusters, return_inverse=True) #inverse shows which elements belong to which cluster in the original matrix

    expression_percluster = pd.DataFrame(index=uniq_clusters, columns=genes)

    n_genes = panel_reference.n_vars
    print(f'There are {n_genes} genes in this panel')

    for k, cl in enumerate(uniq_clusters):
        # rows for this cluster
        rows = np.nonzero(inverse == k)[0]
        if rows.size == 0:
            continue

        # submatrix (cells_in_cluster x genes), dense for vectorized percentile ops
        Xc = panel_reference.X[rows, :].toarray() if issparse(panel_reference.X) else np.asarray(panel_reference.X[rows, :])

        # per-gene quartiles (vectorized across genes)
        p25, p75 = np.percentile(Xc, [25, 75], axis=0)

        # trim per gene
        mask = (Xc > p25) & (Xc < p75)            # shape: (cells_in_cluster, n_genes)
        count_trim = mask.sum(axis=0)             # per-gene counts after trimming
        sum_trim = (Xc * mask).sum(axis=0)        # per-gene sums after trimming

        # avoid division-by-zero when all values fall outside (rare but possible)
        trimmed_mean = np.divide(
            sum_trim, count_trim,
            out=np.zeros_like(sum_trim, dtype=float),
            where=count_trim > 0
        )

        # add to our matrix
        expression_percluster.loc[cl] = trimmed_mean

    return expression_percluster



In [None]:
gene_priority = [
    "Chodl",
    "Vip",
    "Sst",
    "Slc17a7",
    "Lamp5",
    "Gad1",
    "Stard8",
    "Cox6a2",
    "Cplx3",
    "Fezf2",
    "Rorb",
    "Arhgap25",
    "Synpr",
    "Pvalb",
    "Tmem163",
    "Cryab",
    "Meis2",
    "Adarb2",
    "Car4",
    "Crhbp",
    "Hs3st2",
    "Myl4",
    "Myh7",
    "Grin3a",
    "Gad2"
]

In [None]:
expression_persubclass = ut.calculate_expression_per_taxa(panel_reference, 'subclass_label')


maxmean = expression_persubclass.max() - expression_persubclass.mean()

#calculate entropy of each gene
gene_H = entropy(expression_persubclass.to_numpy(dtype = float), base = 2, axis = 0)

h_thresh = 3.0        # entropy threshold (bits)
mm_thresh = 2.0       # max-mean threshold (same units as your data)
mask = (gene_H > h_thresh) & (maxmean < mm_thresh)
almost_mask = (gene_H > h_thresh) ^ (maxmean < mm_thresh)
colors_1 = np.where(mask, 1, 0)
colors_2 = np.where(almost_mask, 2, 0)
colors_sum = colors_1+colors_2
colist = ['black', 'red', 'orange']
colors = [colist[i] for i in colors_sum]

plt.scatter(maxmean, gene_H, c=colors, s=20)
plt.ylabel('per gene Shannon entropy (bits)')
plt.xlabel('per gene max-mean (log1p(norm(counts))')

plt.show()

In [None]:

# --- Build robust cluster â†’ subclass mapping from obs ---


# ---- Explicit color allocation (edit as you wish) ----
subclass_colors = {
    "Vip":           "#1f77b4",
    "Lamp5":         "#ff7f0e",
    "Scng":          "#2ca02c",
    "Sst Chodl":     "#d62728",
    "Sst":           "#9467bd",
    "Pvalb":         "#8c564b",
    "L2/3 IT CTX":   "#e377c2",
    "L4/5 IT CTX":   "#7f7f7f",
    "L5 IT CTX":     "#bcbd22",
    "L6 IT CTX":     "#17becf",
    "L5 PT CTX":     "#aec7e8",
    "L4 RSP-ACA":    "#ffbb78",
    "L5/6 NP CTX":   "#98df8a",
    "L6 CT CTX":     "#ff9896",
    "L6b CTX":       "#c5b0d5",
}

expr_ord = expression_persubclass
# ---- Plot: one subplot (row) per gene; shaded subclass bands ----
genes_order =np.argsort((expr_ord.max() - expr_ord.mean()).values)

genes = list(expr_ord.columns[genes_order]) #sort genes according to max-mean quality metric

plot_maxmean = maxmean[np.argsort(maxmean)]#sort metrics according to the same metric to preserve correspondence
plot_entropy = gene_H[np.argsort(maxmean)]
plot_colors = np.array(colors)[np.argsort(maxmean)]

n_genes = len(genes)

# ---- Layout knobs ----
height_per_subplot = 4.4   # bump this up if still too squished
fig_w = 12
legend_right_margin = 0.80 # reserve space for legend (0.80 = 80% of width for plots)

fig, axes = plt.subplots(
    n_genes, 1,
    figsize=(fig_w, max(4, height_per_subplot * n_genes)),
)


if n_genes == 1:
    axes = [axes]

x = np.arange(len(reference.obs['subclass_label'].unique()))

for i, gene in enumerate(genes):
    ax = axes[i]
    # gene trace
    y = expr_ord[gene].astype(float).to_numpy()
    ax.plot(x, y, marker="o", linestyle="none", markersize=20, linewidth=1)
    ax.text(0.99, 0.95,
        f'Entropy: {plot_entropy[i]:.3f}\nMax-mean: {plot_maxmean[i]:.3f}\n{gene}',
        transform=ax.transAxes,  # <--- relative to subplot axes
        ha="right", va="top", 
        color = plot_colors[i])

    ax.set_ylabel(f'log1p(norm({gene}))')
    ax.set_xticks(np.arange(len(expr_ord.index)), labels=expr_ord.index)
    ax.grid(True, alpha=0.3)
    ax.set_ylim(-0.5, 10)



plt.tight_layout()
plt.show()




In [None]:
reference

In [None]:
spapros_panel

In [None]:
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import numpy as np
import pandas as pd
from sklearn import tree

# Extract data for genes in spapros_panel
X = reference[:, spapros_panel].X
if hasattr(X, 'toarray'):  # If sparse matrix
    X = X.toarray()

y = reference.obs['subclass_label'].values

print("="*80)
print("ONE-VS-REST DECISION TREES (depth=2) FOR EACH SUBCLASS")
print("="*80)
print(f"Number of cells: {X.shape[0]}")
print(f"Number of genes (spapros_panel): {X.shape[1]}")
print(f"Number of subclasses: {len(np.unique(y))}")

# Store results
results = []

# Train one tree per subclass
for subclass in np.unique(y):
    print(f"\n{'='*80}")
    print(f"Subclass: {subclass}")
    print('='*80)
    
    # Create binary target: 1 if this subclass, 0 otherwise
    y_binary = (y == subclass).astype(int)
    
    # Train decision tree with max_depth=2
    clf = DecisionTreeClassifier(max_depth=2, random_state=42)
    
    # 5-fold cross-validation
    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    cv_scores = cross_val_score(clf, X, y_binary, cv=cv, scoring='accuracy')
    
    print(f"CV Accuracy: {cv_scores.mean():.3f} (+/- {cv_scores.std()*2:.3f})")
    print(f"Per fold: {[f'{s:.3f}' for s in cv_scores]}")
    
    # Train on full data to get tree structure and metrics
    clf.fit(X, y_binary)
    y_pred = clf.predict(X)
    
    # Calculate metrics
    precision, recall, f1, support = precision_recall_fscore_support(
        y_binary, y_pred, average='binary', pos_label=1
    )
    
    print(f"\nFull dataset metrics:")
    print(f"  Precision: {precision:.3f}")
    print(f"  Recall: {recall:.3f}")
    print(f"  F1-score: {f1:.3f}")
    print(f"  True class size: {support}")
    
    # Get feature importance
    feature_importance = pd.DataFrame({
        'gene': spapros_panel,
        'importance': clf.feature_importances_
    }).sort_values('importance', ascending=False)
    
    important_genes = feature_importance[feature_importance['importance'] > 0]
    print(f"\nImportant genes ({len(important_genes)}):")
    for _, row in important_genes.iterrows():
        print(f"  {row['gene']}: {row['importance']:.3f}")
    
    # Get tree structure
    print(f"\nDecision tree structure:")
    tree_rules = tree.export_text(clf, feature_names=spapros_panel)
    print(tree_rules)
    
    # Store results
    results.append({
        'subclass': subclass,
        'cv_mean_accuracy': cv_scores.mean(),
        'cv_std_accuracy': cv_scores.std(),
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'n_samples': support,
        'top_gene': important_genes.iloc[0]['gene'] if len(important_genes) > 0 else None,
        'top_gene_importance': important_genes.iloc[0]['importance'] if len(important_genes) > 0 else 0
    })

# Summary table
print("\n" + "="*80)
print("SUMMARY: ALL SUBCLASSES")
print("="*80)
results_df = pd.DataFrame(results)
print(results_df.to_string(index=False))

print("\n" + "="*80)
print("KEY GENES PER SUBCLASS")
print("="*80)
for _, row in results_df.iterrows():
    print(f"{row['subclass']:20s}: {row['top_gene']:15s} (importance={row['top_gene_importance']:.3f}, F1={row['f1']:.3f})")