In [None]:
import scANVI_prediction_utils_03 as spu
import pandas as pd
import numpy as np
import scanpy as sc
import anndata as ad
import scvi
from scipy.stats import entropy



## 1. Check GPU status

In [None]:
spu.gpu_status()

## 2. Load reference and query

In [None]:
adata_ref=ad.read_h5ad("/tscc/lustre/ddn/scratch/aopatel/adata_ref.h5ad")

In [None]:
adata_query=ad.read_h5ad("/tscc/lustre/ddn/scratch/aopatel/preprocessed_adata_mtg.h5ad")

## 3. Prepare query

#### Keep overlapping genes between query and ref only

In [None]:
# Finds overlap of gene symbols in ref and quer
adata_ref, adata_query = spu.overlapper(ref=adata_ref, query=adata_query, do_filter=True)

## 4. scANVI query to ref mapping protocol

### A. HVG Selection

In [None]:
# 1. HVG calculation on ref only
sc.pp.highly_variable_genes(adata_ref, flavor="seurat_v3", n_top_genes=3000, batch_key="sample",subset=True)
hvg_genes = adata_ref.var_names.tolist()
adata_query = adata_query[:, hvg_genes].copy()

### B. Train ref with scVI  

In [None]:
# Training ref

arches_params = dict(
    use_layer_norm="both",
    use_batch_norm="none",
    encode_covariates=True,
    dropout_rate=0.2,
    n_layers=2,
)


scvi.model.SCVI.setup_anndata(adata_ref, batch_key="sample")
vae=scvi.model.SCVI(adata_ref, **arches_params)
vae.train(max_epochs=200, early_stopping=True,
         batch_size=1024, precision="16-mixed")

In [None]:
#### View elbow plot for training (scVI-ref only)
spu.disp_elbow(vae)

### C. Train Ref with scANVI

<div class="alert alert-block alert-info">
<b> scANVI needs a column in the query to know wich cells to label. Here we are tell scANVI to create this column for us by already creating in the ref (so we don't have to do this later). This works because scvi.model.SCANVI.load_query_data() will inherit this column and pass it on to the query object. *Note that we are filling this column with known cell types for the ref because we do not want to predict cell types for the ref. 

</div>


In [None]:
adata_ref.obs["labels_scanvi"] = adata_ref.obs["subclass_label"].values

In [None]:
vae_ref_scan = scvi.model.SCANVI.from_scvi_model(
    vae,
    unlabeled_category="Unknown",
    labels_key="labels_scanvi",
)

In [None]:
vae_ref_scan.train(max_epochs=20, n_samples_per_label=100, batch_size=1024, precision="16-mixed")

<div class="alert alert-block alert-info">
<b> Make a umap of the reference alone, guage the seperation
</div>

In [None]:
adata_ref.obsm["X_scANVI"] = vae_ref_scan.get_latent_representation()
sc.pp.neighbors(adata_ref, use_rep="X_scANVI")
sc.tl.leiden(adata_ref)
sc.tl.umap(adata_ref)

In [None]:
sc.pl.umap(
    adata_ref,
    color=["sex", "subclass_label"],
    frameon=False,
    ncols=1,
)

### D. Train query, using ref model to map query unto reference and predict cell types

In [None]:
adata_query.obs["sample"] = adata_query.obs["10X_ID"].astype(str)

In [None]:
vae_q = scvi.model.SCANVI.load_query_data(
    adata_query,
    vae_ref_scan)

In [None]:
vae_q.train(
    max_epochs=100,
    plan_kwargs={"weight_decay": 0.0},
    check_val_every_n_epoch=10,
    batch_size=1024,
    precision="16-mixed"
)

In [None]:
#### View elbow plot for mapping (scANVI-query only)
spu.disp_elbow(vae_q)

In [None]:
adata_query.obsm["X_scANVI"] = vae_q.get_latent_representation()
adata_query.obs["predictions"] = vae_q.predict()

<div class="alert alert-block alert-info">
<b> Get query specific gene expression profiles (we'll use this to check if the reference expresses correct marker genes for the various cell type). Remember transform_batch=None because we do not want to normalize our gene expression with the ref. This would be counterfactual analysis!

</div>

In [None]:
# Normalized layer 
norm_expr = vae_q.get_normalized_expression(
    library_size=1e4,
    transform_batch=None,
    return_numpy=True
)

# probabilities
soft_preds = vae_q.predict(adata_query, soft=True)
adata_query.obs['sub_class_prob_max'] = np.max(soft_preds, axis=1)


adata_query.layers['scanvi_norm'] = norm_expr

### E. Plot both ref and query together in latent space 

In [None]:
## Together

In [None]:
adata_ref

In [None]:
# Creates adata.obs column called 'batch' by default '0' is adata_query and '1' is adata_ref is this context
adata_full = adata_query.concatenate(adata_ref)

In [None]:
adata_full.obs['batch']

In [None]:
adata_full.obs["batch"] = adata_full.obs["batch"].cat.rename_categories(
    ["Query", "Reference"]
)

In [None]:
sc.pp.neighbors(adata_full, use_rep="X_scANVI")
sc.tl.leiden(adata_full)
sc.tl.umap(adata_full)

In [None]:
sc.pl.umap(
    adata_full[adata_full.obs["batch"]=="Query"],
    color="predictions",
    frameon=False,
)

In [None]:
sc.pl.umap(
    adata_full[adata_full.obs["batch"]=="Reference"],
    color="subclass_label",
    frameon=False,
)

In [None]:
sc.pl.umap(
    adata_full[adata_full.obs['batch']=="Query"],
    color="individualID",
    frameon=False,
    vmin=7
)

In [None]:
sc.pl.umap(
    adata_full[adata_full.obs['batch']=="Query"],
    color="sub_class_prob_max",
    frameon=False
)

## 5. Donor Entropy Calculations

In [None]:
def compute_donor_entropy(adata, cluster_key='leiden', donor_key='individualID'):
    entropies = {}
    for clust in adata.obs[cluster_key].unique():
        subset = adata[adata.obs[cluster_key] == clust]
        donor_counts = subset.obs[donor_key].value_counts(normalize=True)
        ent = entropy(donor_counts, base=2)  # Shannon entropy in bits
        entropies[clust] = ent
    return pd.Series(entropies)


q=adata_full[adata_full.obs['batch']=="Query"]

donor_entropy = compute_donor_entropy(q)
print(donor_entropy.sort_values())  # Low values = donor-biased clusters

In [None]:
exclude_clusters = ["40", "42", "15", "19", "31"]

mask = (
    (adata_full.obs["batch"] == "Query") &
    (~adata_full.obs["leiden"].isin(exclude_clusters))
)

sc.pl.umap(
    adata_full[mask],
    color="individualID",
    frameon=False,
)

In [None]:
exclude_clusters = ["40", "42", "15", "19", "31"]

query_mask = adata_full.obs["batch"] == "Query"
exclude_mask = adata_full.obs["leiden"].isin(exclude_clusters)

mask = query_mask & exclude_mask

n_removed = mask.sum()
n_total_query = query_mask.sum()
n_remaining = n_total_query - n_removed

print(f"Query cells total:     {n_total_query:,}")
print(f"Query cells removed:   {n_removed:,} ({n_removed/n_total_query:.2%})")
print(f"Query cells remaining: {n_remaining:,}")

In [None]:
#### Extract the relevant obs columns for the cells we're removing
removed_obs = adata_full.obs.loc[mask, ["individualID", "sex", "Consensus clinical diagnosis"]].copy()


#### Overall breakdown (across all removed clusters)
print("=== Overall breakdown of removed Query cells ===")
print("\nBy Consensus clinical diagnosis:")
print(removed_obs["Consensus clinical diagnosis"].value_counts(dropna=False))

print("\nBy sex:")
print(removed_obs["sex"].value_counts(dropna=False))


## 6. Final save and transfer of essential data to final query object

<div class="alert alert-block alert-info">
<b> Save the essential files. 1. adata_full object (ref+query, 3000 HVGs, modelled), 2. adata_query (query, now just 3000 HVGs, modelled) 3. adata_query_X (original adata_query you started with + cell type predictions for all cells, + prob_max for all cells + X_scANVI for all cells + leiden info ONLY FOR CLUSTERS THAT passed our donor entropy filter)

</div>

### A. Save essential files now!

In [None]:
#### Save adata_full
adata_full.write_h5ad("/tscc/lustre/ddn/scratch/aopatel/adata_full_mtg.h5ad")

In [None]:
#### In case of emergency save the query up to this point
adata_query.write_h5ad("/tscc/lustre/ddn/scratch/aopatel/adata_query_mtg_partial.h5ad")

### B. Get the correct objects in the correct notation for transfer

In [None]:
adata_full=ad.read_h5ad("/tscc/lustre/ddn/scratch/aopatel/adata_full_mtg.h5ad")
adata_query=ad.read_h5ad("/tscc/lustre/ddn/scratch/aopatel/adata_query_mtg_partial.h5ad")

In [None]:
leiden_q_adata=adata_full[adata_full.obs["batch"]=="Query"].copy()

In [None]:
#### Optional cluster flagging that I thought of
clusters_to_flag = {"40", "42", "15", "19", "31"}

# initialize column (optional but recommended)
leiden_q_adata.obs["cluster_flagged"] = "N"

# flag selected clusters
mask = leiden_q_adata.obs["leiden"].isin(clusters_to_flag)
leiden_q_adata.obs.loc[mask, "cluster_flagged"] = "Y"

In [None]:
leiden_q_adata.obs["leiden"]

In [None]:
#### remove trailing "-0" that adata.concatenate() leaves behind
leiden_q_adata.obs_names = leiden_q_adata.obs_names.str.replace(r"-0$", "", regex=True)

In [None]:
leiden_q_adata.obs["leiden"]

In [None]:
#### Make copy of adata_query to work with 
q_data=adata_query.copy()

### C. Prepare transfer of cell type predictions and other info to adata_query_X (the file we will use moving forward)

In [None]:
#### Load original adata_query that we started this notebook with to transfer our hard earned information!
adata_query_X=ad.read_h5ad("/tscc/lustre/ddn/scratch/aopatel/preprocessed_adata_mtg.h5ad")

In [None]:
#### leiden info is ONLY in adata_query_leiden_filtered 

spu.transfer_predictions_by_barcode(
    source_adata=leiden_q_adata,
    target_adata=adata_query_X,
    column="leiden",
)

In [None]:
#### leiden info is ONLY in adata_query_leiden_filtered 

spu.transfer_predictions_by_barcode(
    source_adata=leiden_q_adata,
    target_adata=adata_query_X,
    column="cluster_flagged",
)

In [None]:
#### predictions info is in q_data

spu.transfer_predictions_by_barcode(
    source_adata=q_data,
    target_adata=adata_query_X,
    column="predictions",
)

In [None]:
#### sub_class_prob_max is in q_data

spu.transfer_predictions_by_barcode(
    source_adata=q_data,
    target_adata=adata_query_X,
    column="sub_class_prob_max",
)

In [None]:
adata_query_X

In [None]:
#### Transfer X_scANVI from q_data (our copy of adata_query with 3000 HVGs)
#### so we can feed it to neighbors and UMAP when we start afresh 

#### Safety checks

assert "X_scANVI" in q_data.obsm, "X_scANVI not found in q_data.obsm"
assert q_data.obs_names.is_unique
assert adata_query_X.obs_names.is_unique

overlap = q_data.obs_names.intersection(adata_query_X.obs_names)
assert overlap.size > 0, "No overlapping barcodes"

#### Transfer latent embedding
X_scANVI_df = pd.DataFrame(
    q_data.obsm["X_scANVI"],
    index=q_data.obs_names,
)

adata_query_X.obsm["X_scANVI"] = (
    X_scANVI_df
        .reindex(adata_query_X.obs_names)
        .to_numpy()
)

##### Post-check

n_missing = np.isnan(adata_query_X.obsm["X_scANVI"]).any(axis=1).sum()
print("Cells with missing X_scANVI:", n_missing)


In [None]:
adata_query_X

In [None]:
#### The final save!
adata_query_X.write_h5ad("/tscc/lustre/ddn/scratch/aopatel/adata_scANVI_mtg.h5ad")