## scVI-tools based integration

In [None]:
import os
import tempfile

import anndata
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import scvi
import seaborn as sns
import torch

In [None]:
scvi.settings.seed = 0
print("Last run with scvi-tools version:", scvi.__version__)

In [None]:
sc.set_figure_params(figsize=(6, 6), frameon=False)
sns.set_theme()
torch.set_float32_matmul_precision("high")
save_dir = tempfile.TemporaryDirectory()

%config InlineBackend.print_figure_kwargs={"facecolor": "w"}
%config InlineBackend.figure_format="retina"

# Load datasets

In [None]:
# file paths for query and reference
query_path = "/ocean/projects/cis240075p/asachan/datasets/TA_muscle/ERCC1_KO_mice/samples_2025/objects/query_norm_v2.h5ad"
ref_path = "/ocean/projects/cis240075p/asachan/datasets/TA_muscle/ERCC1_KO_mice/aging_all_2024/objects/ref_norm_v3.h5ad"

# load data
query_data = sc.read_h5ad(query_path)
ref_data = sc.read_h5ad(ref_path)

In [None]:
query_data

In [None]:
ref_data

In [None]:
# update the obs column name of ref_data from orig.ident to sample_id
ref_data.obs.rename(columns={"orig.ident": "sample_id"}, inplace=True)
# use the first letter of the sample_id to create a new column sex
ref_data.obs["sex"] = ref_data.obs["sample_id"].str[0]
# make the sex column a categorical variable
ref_data.obs["sex"] = ref_data.obs["sex"].astype("category")


#### Filter cells from reference using the same thresholds as query cells

In [None]:
# cell bender probability of being cell and not soup
ref_data = ref_data[ref_data.obs['cell_probability'] > 0.99]
# Basic filtering based on number of genes
sc.pp.filter_cells(ref_data, min_genes=200) 

In [None]:
ref_data

In [None]:
#write ref_data to file
ref_data.write_h5ad("/ocean/projects/cis240075p/asachan/datasets/TA_muscle/ERCC1_KO_mice/aging_all_2024/objects/ref_norm_v3.h5ad")


### Concat the datasets

In [None]:
# add batch obs to both datasets
ref_data.obs["batch"] = "2024"
query_data.obs["batch"] = "2025"

In [None]:
query_data.obs["cell_type"] = "Unknown"

In [None]:
ref_data.X = ref_data.layers["cellbender"]
query_data.X = query_data.layers["cellbender"]

In [None]:
# Find common genes
common_genes = ref_data.var_names.intersection(query_data.var_names)
print(f"Common genes: {len(common_genes)}")

In [None]:
# Subset both datasets to common genes
ref_data = ref_data[:, common_genes].copy()
query_data = query_data[:, common_genes].copy()

In [None]:
adata = anndata.concat([ref_data, query_data])
adata.layers["cellbender"] = adata.X.copy()
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
adata.raw = adata  # keep full dimension
sc.pp.highly_variable_genes(
    adata,
    flavor="seurat_v3",
    n_top_genes=2000,
    layer="cellbender",
    batch_key="batch",
    subset=False,
)

In [None]:
adata

### Train scVI model on the concatenated dataset

In [None]:
scvi.model.SCVI.setup_anndata(adata, layer="cellbender", batch_key="batch")
model = scvi.model.SCVI(adata, n_layers=2, n_latent=32, n_hidden=128)
model.train()

In [None]:
### Save the model
model.save("/ocean/projects/cis240075p/asachan/datasets/TA_muscle/models/scvi_model_v3.pt")

In [None]:
SCVI_LATENT_KEY = "X_scVI"
adata.obsm[SCVI_LATENT_KEY] = model.get_latent_representation()

In [None]:
sc.pp.neighbors(adata, use_rep=SCVI_LATENT_KEY)
sc.tl.umap(adata, min_dist=0.3)

In [None]:
sc.pl.umap(
    adata,
    color=["batch"],
    frameon=False,
    ncols=1,
)
sc.pl.umap(
    adata,
    color=["cell_type"],
    frameon=False,
    ncols=1,
)

## scANVI for label transfer

In [None]:
SCANVI_CELLTYPE_KEY = "cell_type"
np.unique(adata.obs[SCANVI_CELLTYPE_KEY], return_counts=True)

In [None]:
# use pretrained scvi model to get the latent space
scanvi_model = scvi.model.SCANVI.from_scvi_model(
    model,
    adata=adata,
    unlabeled_category="Unknown",
    labels_key=SCANVI_CELLTYPE_KEY,
)

In [None]:
# train the model
scanvi_model.train(max_epochs=20, n_samples_per_label=100)

In [None]:
SCANVI_LATENT_KEY = "X_scANVI"
SCANVI_PREDICTION_KEY = "C_scANVI"

adata.obsm[SCANVI_LATENT_KEY] = scanvi_model.get_latent_representation(adata)
adata.obs[SCANVI_PREDICTION_KEY] = scanvi_model.predict(adata)

In [None]:
sc.pp.neighbors(adata, use_rep=SCANVI_LATENT_KEY)
sc.tl.umap(adata, min_dist=0.3)

In [None]:
# in order to make colors matchup
adata.obs.C_scANVI = pd.Categorical(
    adata.obs.C_scANVI.values, categories=adata.obs.cell_type.cat.categories
)

In [None]:
sc.pl.umap(
    adata,
    color=["cell_type", SCANVI_PREDICTION_KEY, "batch", "sex", "condition"],
    frameon=False,
    ncols=2,
    palette=adata.uns["cell_type_colors"],
)

In [None]:
# write the integrated adata to file
adata.write_h5ad("/ocean/projects/cis240075p/asachan/datasets/TA_muscle/ERCC1_KO_mice/integrated_samples/scANVI_v1.h5ad")