# scANVI Label Transfer

This notebook performs cell type label transfer from the uninfected reference mouse gut Xenium dataset (Reina-Campos et al., 2025, Nature) onto the current Xenium spatial dataset using scANVI. This integration is used as one layer of information to inform labeling.

**Pinned Environment:** [`envs/sc-charter.yaml`](../../envs/sc-charter.yaml)  

In [None]:
import os
import sys
from pathlib import Path
import scanpy as sc
import scvi
import pandas as pd
import matplotlib.pyplot as plt
import anndata as ad
from lightning.pytorch import seed_everything
import random
import sys
import session_info

In [None]:
random.seed(0)
seed_everything(0)

scvi.settings.seed = 0
scvi.settings.num_workers = 32

### Set paths

In [None]:
sys.path.append(str(Path.cwd().resolve().parents[1]))

from config.paths import BASE_DIR

adata_dir = BASE_DIR / "data/h5ad/export_03"
ref_data_dir = BASE_DIR / "data/h5ad/max-data"

output_dir = BASE_DIR / "data/h5ad/export_04"
scvi_dir = BASE_DIR / "scvi"
scanvi_dir = BASE_DIR / "scanvi"

output_dir.mkdir(parents=True, exist_ok=True)
scanvi_dir.mkdir(parents=True, exist_ok=True)

### Read data

In [None]:
adata = sc.read_h5ad(os.path.join(adata_dir, "artis-naive-scvi-leiden.h5ad"))
refdata = sc.read_h5ad(
    os.path.join(ref_data_dir, "uninfected.h5ad")
)

In [None]:
refdata.obs_names_make_unique()

In [None]:
refdata.obs["batch"].value_counts()

### Prepare adatas

In [None]:
# Raw counts
refdata.layers["counts"] = refdata.layers["raw"].copy() # consistent nomenclature between samples
refdata.X = refdata.layers["counts"].copy()

In [None]:
adata.X = adata.layers["counts"].copy()

In [None]:
# add sample_id for refdata
refdata.obs["sample_id"] = refdata.obs["batch"]

# add column for easier filtering
refdata.obs["ref_data"] = "Yes"
adata.obs["ref_data"] = "No"

## Gene alignment and concatenation

In [None]:
# Find the intersection of genes for integration
common_genes = list(set(adata.var_names) & set(refdata.var_names))
len(common_genes)

### Create subsets

In [None]:
# Create copies of the datasets, keeping only common genes
adata_subset = adata[:, common_genes].copy()
refdata_subset = refdata[:, common_genes].copy()

In [None]:
print(f"adata before subsetting: {adata.shape[1]}")
print(f"adata after subsetting: {adata_subset.shape[1]}")

print(f"refdata before subsetting: {refdata.shape[1]}")
print(f"refdata after subsetting: {refdata_subset.shape[1]}")

In [None]:
index_match = (adata_subset.var_names == refdata_subset.var_names).all()
print(f"adata.var_names match: {index_match}")

# Concatenate

In [None]:
# Concatenate datasets
adata_list = [adata_subset, refdata_subset]

adata_concat = ad.concat(
    adata_list,
    join="outer",  # make sure it's outer
    label="scanvi_batch",
    index_unique="-",
)

### Perform filtering on concatenated adata

In [None]:
sc.pp.calculate_qc_metrics(
    adata_concat, percent_top=(10, 20, 50, 150), inplace=True
)  # adds updated total_counts metric at intersection of adata.var between two objects for filtering

This step below filters cells in the combined dataset:
- All study (query) cells are kept.
- Reference cells are only kept if their total transcript counts are between 20 and 800. These are the parameters used in the reference dataset's published manuscript.

In [None]:
adata_concat = adata_concat[
    (adata_concat.obs["ref_data"] == "No")  # Keep all non-reference samples
    | (
        (
            adata_concat.obs["ref_data"] == "Yes"
        )  # Apply total_counts filter only to reference samples
        & (adata_concat.obs["total_counts"] > 20)
        & (adata_concat.obs["total_counts"] < 800)
    )
].copy()

## Train scVI model

In [None]:
# Check batch key
adata_concat.obs["sample_id"].value_counts()

In [None]:
scvi.model.SCVI.setup_anndata(
    adata_concat,
    layer="counts", 
    batch_key="ref_data", # yes/no column for whether the cells are from reference or query
    categorical_covariate_keys=["sample_id"],  # inter-run variability as a covariate
)

model = scvi.model.SCVI(adata_concat, n_layers=2, n_latent=30) # Same parameters selected from Reina-Campos et al., 2025, Nature
print("starting model training")

model.train(early_stopping=True, enable_progress_bar=True, accelerator="gpu")

In [None]:
model.save(os.path.join(scvi_dir, "02_model"), prefix="02_label_transfer_")

In [None]:
SCVI_LATENT_KEY = "X_scVI_refalign"
adata_concat.obsm[SCVI_LATENT_KEY] = model.get_latent_representation()

### Compute neighbors graph and UMAP

In [None]:
sc.pp.neighbors(
    adata_concat, use_rep="X_scVI_refalign", key_added="neighbors_scvi_refalign"
)
sc.tl.umap(adata_concat, neighbors_key="neighbors_scvi_refalign")

### Export refaligned adata

In [None]:
filename = os.path.join(output_dir, "adata_concat_scvi_refalign.h5ad")
os.makedirs(os.path.dirname(filename), exist_ok=True)

adata_concat.write_h5ad(filename, compression="gzip")

# Checkpoint

In [None]:
filename = os.path.join(output_dir, "adata_concat_scvi_refalign.h5ad")
adata_concat = sc.read_h5ad(filename)
adata_concat

### Prep data for scANVI

In [None]:
# Subtype annotations are the most granular in the labeling hierarchy from the reference dataset
adata_concat.obs["Subtype"].value_counts()

In [None]:
# Add unknown category for query (Artis) samples
adata_concat.obs["Subtype"] = adata_concat.obs["Subtype"].cat.add_categories("Unknown")
adata_concat.obs = adata_concat.obs.fillna(value={"Subtype": "Unknown"})

Below is important for mapping labels in the unintegrated original adata

In [None]:
# Subset only query dataset (unlabeled before scANVI training)
query_cells = adata_concat.obs["Subtype"] == "Unknown"
query_cells.to_csv(os.path.join(scanvi_dir, "query_cells-subtype.csv"))
query_cells

In [None]:
# Prepare barcodes; suffix was added during concat
query_cells_copy = query_cells.copy()
query_cells_copy.index = query_cells_copy.index.map(lambda x: x[:-2])
query_cells_copy

## Run scANVI

In [None]:
# scanvi label transfer
scanvi_model = scvi.model.SCANVI.from_scvi_model(
    model,
    adata=adata_concat,
    unlabeled_category="Unknown",  # Entries in labels_key to label
    labels_key="Subtype",
)  # Column to transfer labels from

scanvi_model.train(
    max_epochs=20, accelerator="gpu", n_samples_per_label=None
)  

In [None]:
# save scANVI model
scanvi_model.save(
    os.path.join(scvi_dir, "03_model_scanvi"), prefix="03_max-artis_scanvi_"
)

### Reference mapping step

In [None]:
SCANVI_LATENT_KEY = "X_scANVI"
SCANVI_PREDICTION_KEY = "scanvi_labels_xenium"

adata_concat.obsm[SCANVI_LATENT_KEY] = scanvi_model.get_latent_representation(
    adata_concat
)
adata_concat.obs[SCANVI_PREDICTION_KEY] = scanvi_model.predict(
    adata_concat
)  # this fills out the labels

In [None]:
filename = os.path.join(output_dir, "adata_concat-scvi-scanvi-predictions.h5ad")
os.makedirs(os.path.dirname(filename), exist_ok=True)

adata_concat.write_h5ad(filename, compression="gzip")

## Map query labels onto original adata

In [None]:
# re-read fresh adata with all the genes, this is the query
adata_query = sc.read_h5ad(os.path.join(adata_dir, "artis-naive-scvi-leiden.h5ad"))
adata_query.obs.head()

In [None]:
# create adata_labeled from adata_concat which has the labels
adata_labeled = adata_concat[
    adata_concat.obs["scanvi_batch"] == "0"
].copy()  # batch is stored as string

print(adata_labeled.obs["scanvi_batch"].value_counts())
print("")
print(adata_labeled.obs["sample_id"].value_counts())

In [None]:
# Remove suffix from labeled adata subset from adata_concat
adata_labeled.obs.index = adata_labeled.obs.index.astype(str)
adata_labeled.obs.index = adata_labeled.obs.index.str[:-2]
adata_labeled.obs.head()

In [None]:
query_cells_copy = query_cells_copy[query_cells_copy].index.tolist()
print(
    set(query_cells_copy).issubset(set(adata_labeled.obs.index))
)  # Should return True
print(set(query_cells_copy).issubset(set(adata_query.obs.index)))  # Should return True

In [None]:
# Transfer predictions back to the original query dataset
adata_query.obs["scanvi_labels_xenium"] = adata_labeled.obs.loc[
    query_cells_copy, "scanvi_labels_xenium"
]
adata_query.obs["scanvi_labels_xenium"]

In [None]:
adata_query.obs["scanvi_labels_xenium"].value_counts()

## Export labeled adata

In [None]:
h5ad_path = os.path.join(output_dir, "adata-scanvi-labels.h5ad")
adata_query.write_h5ad(h5ad_path, compression="gzip")

print(h5ad_path)