In [None]:
import scANVI_prediction_utils as spu
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import scanpy as sc
import anndata as ad
import scvi
import torch


## 1. Check GPU status

In [None]:
spu.gpu_status()

## 2. Load reference and query

In [None]:
adata_ref=ad.read_h5ad("/tscc/lustre/ddn/scratch/aopatel/adata_ref.h5ad")
adata_query=ad.read_h5ad("/tscc/lustre/ddn/scratch/aopatel/fin_adata_mtg.h5ad")

In [None]:
adata_query

## 3. Prepare query

In [None]:
adata_query=spu.check_donor_overlap(ref=adata_ref, query=adata_query, filter=True)

In [None]:
adata_query=spu.check_donor_overlap(ref=adata_ref, query=adata_query, filter=False)

In [None]:
adata_query

In [None]:
symbols = adata_query.var["gene_symbols"].astype(str)

print("Missing symbols:", symbols.isna().sum())
print("Duplicated symbols:", symbols.duplicated().sum())


### Change index column back to gene symbol, keep highest expressed if gene symbol is repeated

In [None]:
# 1. Safely compute total expression (works with dense, scipy sparse, sparse.GCXS, backed, etc.)
adata_query.var['total_expr'] = np.asarray(adata_query.X.sum(axis=0)).ravel()

# 2. Subset to highest-expressed version of each gene symbol
adata_query = adata_query[:, 
    adata_query.var.sort_values(['gene_symbols', 'total_expr'], 
                                ascending=[True, False])
                  .drop_duplicates(subset='gene_symbols', keep='first')
                  .index
].copy()

# 3. Set clean gene symbols as var_names
adata_query.var_names = adata_query.var['gene_symbols'].astype(str)

# 4. Check to ensure var names are absolutely unique 
assert adata_query.var_names.is_unique, "Oh no! var_names still have duplicates!"
print("All good – var_names are unique:", adata_query.var_names.is_unique)

In [None]:
# Finds overlap of gene symbols in ref and quer
adata_ref, adata_query = spu.overlapper(ref=adata_ref, query=adata_query, filter=True)

In [None]:
"REST" in adata_query.var_names

In [None]:
"REST" in adata_ref.var_names

## 4. Start the hierarchical scANVI pipeline

In [None]:
# 1. HVG calculation on ref only
sc.pp.highly_variable_genes(adata_ref, flavor="seurat_v3", n_top_genes=3000, subset=True)
hvg_genes = adata_ref.var_names.tolist()
adata_query = adata_query[:, hvg_genes].copy()

In [None]:
"REST" in adata_ref.var_names

In [None]:
# 2. Concatenate
adata = adata_ref.concatenate(adata_query, batch_key='source', batch_categories=['ref', 'query'])

In [None]:
# 3. Global scVI → SCANVI
#### Setup scVI model and train
#layer="counts"
scvi.model.SCVI.setup_anndata(adata, batch_key="libraryBatch",
                             categorical_covariate_keys=["individualID", "sex"],
                             continuous_covariate_keys=["age_numeric"])   
vae=scvi.model.SCVI(adata, n_latent=30)

# training #1
vae.train(max_epochs=200, early_stopping=True,
         early_stopping_patience=20,
         early_stopping_monitor="elbo_validation",
         batch_size=2048, precision="16-mixed")

In [None]:
plt.figure(figsize=(6,4))
plt.plot(vae.history['elbo_train'], label='elbo_train')
plt.plot(vae.history['elbo_validation'], label='elbo_validation')
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("ELBO")
plt.title("scVI (Unsupervised) Training History")
plt.show()

In [None]:
#### Add "Unknown" slot to 'subclass_label' which is the cell type annotation
adata.obs['class_label']=adata.obs['class_label'].cat.add_categories('Unknown')
adata.obs=adata.obs.fillna(value={'class_label':'Unknown'})

#### Add "Unknown" slot to 'subclass_label' which is the cell type annotation
adata.obs['subclass_label']=adata.obs['subclass_label'].cat.add_categories('Unknown')
adata.obs=adata.obs.fillna(value={'subclass_label':'Unknown'})
adata.obs

In [None]:
lvae=scvi.model.SCANVI.from_scvi_model(vae,adata=adata,
                                       unlabeled_category='Unknown', 
                                       labels_key='class_label')

lvae.train(max_epochs=25, early_stopping=True, n_samples_per_label=100,
          early_stopping_patience=20, early_stopping_monitor="elbo_validation",
          batch_size=2048, precision="16-mixed")

In [None]:
plt.figure(figsize=(6,4))
plt.plot(lvae.history['elbo_train'], label='elbo_train')
plt.plot(lvae.history['elbo_validation'], label='elbo_validation')
plt.legend()
plt.xlabel("Epoch")
plt.ylabel("ELBO")
plt.title("scANVI (Semi-Supervised) Training History")
plt.show()

In [None]:
# Get predicted labels AND confidence scores
pred_labels = lvae.predict(adata)

# Save predicted labels
adata.obs['class_predicted'] = pred_labels


soft_preds = lvae.predict(adata, soft=True)
adata.obs['class_prob_max'] = np.max(soft_preds, axis=1)

In [None]:
norm_expr = lvae.get_normalized_expression(
    library_size=1e4,
    transform_batch=None,
    return_numpy=True
)
adata.layers['scanvi_norm'] = norm_expr

In [None]:
adata

In [None]:
# Filter low quality cells
adata = adata[adata.obs['class_prob_max'] > 0.85].copy()
adata

In [None]:
# 4. Per-class subclass refinement (final fixed version)

latent_sub = np.zeros((adata.n_obs, 20))
pred_subclass = np.full(adata.n_obs, "Unknown", dtype=object)

# Global subclass list based on full data
subclass_names = np.sort(adata.obs['subclass_label'].unique())
n_subclasses_total = len(subclass_names)
soft_sub_preds = np.zeros((adata.n_obs, n_subclasses_total))
subclass_to_idx = {k: i for i, k in enumerate(subclass_names)}

for c in adata.obs['class_predicted'].unique():
    print(f"Starting: {c}")
    idx = adata.obs['class_predicted'] == c
    if idx.sum() < 200:
        continue
        
    sub = adata[idx].copy()
    
    scvi.model.SCVI.setup_anndata(
        sub,
        batch_key="libraryBatch",
        categorical_covariate_keys=["individualID", "sex"],
        continuous_covariate_keys=["age_numeric"]
    )

    vae_c = scvi.model.SCVI(sub, n_latent=20)
    vae_c.train(
        max_epochs=200,
        early_stopping=True,
        batch_size=2048,
        precision="16-mixed"
    )

    scanvi_c = scvi.model.SCANVI.from_scvi_model(
        vae_c, adata=sub,
        unlabeled_category="Unknown",
        labels_key='subclass_label'
    )
    scanvi_c.train(
        max_epochs=20,
        batch_size=2048,
        precision="16-mixed"
    )

    # embeddings + hard predictions
    latent_sub[idx] = scanvi_c.get_latent_representation()
    pred_subclass[idx] = scanvi_c.predict()

    # soft predictions
    soft_local_df = scanvi_c.predict(soft=True)
    soft_local = soft_local_df.to_numpy()
    
    local_subclass_names = soft_local_df.columns.values  # correct order

    for j, local_name in enumerate(local_subclass_names):
        global_idx = subclass_to_idx[local_name]
        soft_sub_preds[idx, global_idx] = soft_local[:, j]


adata.obsm['X_scANVI_subclass'] = latent_sub
adata.obs['predicted_subclass'] = pred_subclass
adata.obs['subclass_prob_max'] = np.max(soft_sub_preds, axis=1)
adata.obsm['X_scANVI_subclass_soft'] = soft_sub_preds
adata.uns['subclass_names'] = subclass_names


In [None]:
# 5. Metacells + QC
sc.pp.neighbors(adata, use_rep='X_scANVI_subclass')
sc.tl.leiden(adata, key_added='leiden')  #resolution=5 for metacells

In [None]:
sc.tl.umap(adata,random_state=11, min_dist=0.15)
sc.pl.umap(adata, color=["class_predicted", "predicted_subclass", "leiden"])

In [None]:
sc.pl.umap(adata[adata.obs["source"] == "ref"], color=["source"],size=2)

In [None]:
sc.pl.umap(adata, color="predicted_subclass",size=0.5)

In [None]:
sc.pl.umap(adata[adata.obs["source"] == "ref"], color=["predicted_subclass"],size=2)

In [None]:
#sc.pp.normalize_total(adata)
#sc.pp.log1p(adata)


In [None]:
# Astrocyte Markers
markers=['GFAP', 'AQP4', 'ALDH1L1']

sc.pl.umap(adata[adata.obs["source"] == "query"], color= markers ,size=2, layer='scanvi_norm')

In [None]:
# Neuronal markers
markers = ['RBFOX3', 'DCX', 'ELAVL4']

fig = sc.pl.umap(
    adata[adata.obs["source"] == "query"],
    color=markers,
    size=2,
    layer='scanvi_norm',
    return_fig=True
)

fig.suptitle("Neuronal Marker Expression (scanvi_norm)", fontsize=18, y=1.02)

plt.show()

In [None]:
# Inhibitory neuronal markers
markers=['GAD1','GAD2','ADARB2']

fig = sc.pl.umap(
    adata[adata.obs["source"] == "query"],
    color=markers,
    size=2,
    layer='scanvi_norm',
    return_fig=True
)

fig.suptitle("Inhibitory Neuronal Marker Expression (scanvi_norm)", fontsize=18, y=1.02)

plt.show()

In [None]:
# Ecitatory neuronal markers
markers=['SLC17A6','RORB']

sc.pl.umap(adata[adata.obs["source"] == "query"], color= markers ,size=2, layer='scanvi_norm')