# Train scVI Model

**Pinned Environment:** [`envs/sc-charter.yaml`](../../envs/sc-charter.yaml)  

In [None]:
import os
from pathlib import Path
import scanpy as sc
from scipy.sparse import issparse
import scvi
import numpy as np
import matplotlib.pyplot as plt
from lightning.pytorch import seed_everything
import random
import sys

In [None]:
random.seed(0)
seed_everything(0)

scvi.settings.seed = 0
scvi.settings.num_workers = 32

## Set paths, import data

In [None]:
sys.path.append(str(Path.cwd().resolve().parents[1]))

from config.paths import BASE_DIR

base_dir = BASE_DIR / "data/h5ad/export_01"
data_dir = base_dir / "05_filtered"

scvi_dir = BASE_DIR / "scvi/01_model"
output_dir = BASE_DIR / "data/h5ad/export_02"

output_dir.mkdir(parents=True, exist_ok=True)

In [None]:
adata = sc.read_h5ad(os.path.join(data_dir, "artis-naive-pp.h5ad"))

In [None]:
# Use raw counts layer for scVI
adata.X = adata.layers["counts"].copy()

# verify
print("adata.X is sparse:", issparse(adata.X))
print(
    "adata.X has only whole numbers:", np.all(adata.X.data == np.round(adata.X.data))
)

#### Train model

In [None]:
# Set up AnnData for SCVI
scvi.model.SCVI.setup_anndata(
    adata, layer="counts", batch_key="sample_id"
)

In [None]:
# Initialize model with parameters from Reina-Campos et al., 2025
model = scvi.model.SCVI(
    adata,
    n_layers=2,
    n_latent=30,
    gene_likelihood="nb",
)

In [None]:
# Train SCVI model
model.train(
    early_stopping=True,
    accelerator="gpu",
    early_stopping_patience=3, 
    early_stopping_min_delta=0.001,  
    max_epochs=10,
    enable_progress_bar=True,
)

In [None]:
model.save(scvi_dir, prefix="01_artis_scvi_")

In [None]:
# Visualize training loss
plt.plot(model.history["elbo_train"], label="Train ELBO Loss")
plt.plot(model.history["elbo_validation"], label="Validation ELBO Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("SCVI Training Loss Curve")
plt.show()

Extract latent representation to `adata.obsm`

In [None]:
adata.obsm["X_scVI"] = model.get_latent_representation(adata).astype(np.float32)
adata.obsm["X_scVI"].shape

## Export

In [None]:
filename = os.path.join(
    output_dir, "artis-naive-scvi.h5ad"
)

os.makedirs(os.path.dirname(filename), exist_ok=True)

adata.write_h5ad(filename, compression="gzip")