In [None]:
import os
from pathlib import Path
import scanpy as sc
from scipy.sparse import issparse
import scvi
import numpy as np
import matplotlib.pyplot as plt
from lightning.pytorch import seed_everything
import random
import sys

In [None]:
random.seed(0)
seed_everything(0)

scvi.settings.seed = 0
scvi.settings.num_workers = 32

## Set paths, import data

In [None]:
# Base directory
base_dir = Path("/path/to/tbi-seq")

# Subdirectories
## Input
data_dir = base_dir / "data/h5ad"

## Output
scvi_dir = base_dir / "scvi/01_model"
output_dir = data_dir
os.makedirs(output_dir, exist_ok=True)

In [None]:
adata = sc.read_h5ad(os.path.join(data_dir, "03_neurons-clean.h5ad"))

In [None]:
print("adata.X is sparse:", issparse(adata.X))
print(
    "adata.X has only whole numbers:", np.all(adata.X.data == np.round(adata.X.data))
)

In [None]:
adata.var['mt'] = adata.var_names.str.startswith('mt-')  # for mouse
adata.var['ribosomal'] = adata.var_names.str.match(r'^(Rpl|Rps)\d+')

In [None]:
sc.pp.calculate_qc_metrics(
    adata,
    qc_vars=['mt', 'ribosomal'],
    percent_top=None,
    log1p=False,
    inplace=True
)

In [None]:
adata.obs.columns

## scVI

#### Train model

In [None]:
# Set up AnnData for SCVI
scvi.model.SCVI.setup_anndata(
    adata, batch_key="group"
)

In [None]:
model = scvi.model.SCVI(
    adata,
    gene_likelihood="nb",
)

In [None]:
adata.obs.columns

In [None]:
scvi.model.SCVI.setup_anndata(
    adata,
    categorical_covariate_keys=['group'],
    continuous_covariate_keys=['total_counts', 'pct_counts_mt', 'pct_counts_ribosomal'],
)

In [None]:
scvi.train.Trainer(accelerator='gpu', devices=1)
model.train()

In [None]:
model.save(scvi_dir, prefix="01_tbi-seq_")

In [None]:
# Visualize training loss
plt.plot(model.history["elbo_train"], label="Train ELBO Loss")
plt.plot(model.history["elbo_validation"], label="Validation ELBO Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("SCVI Training Loss Curve")
plt.show()

Extract latent representation to `adata.obsm`

In [None]:
adata.obsm["X_scVI"] = model.get_latent_representation(adata).astype(np.float32)
adata.obsm["X_scVI"].shape

## Export

In [None]:
filename = os.path.join(
    output_dir, "03_neurons-clean-scvi.h5ad"
)

os.makedirs(os.path.dirname(filename), exist_ok=True)

adata.write_h5ad(filename, compression="gzip")