# Implement of scANVI method

We save the scANVI results as csv table, then load the results in R for following analysis.

In [None]:
import scvi
import torch
import anndata as ad
import numpy as np
import pandas as pd
import math
from scipy.sparse import csr_matrix

In [None]:
# load count
count = pd.read_csv("count",header = 0, index_col=0)
count = count.to_numpy()
count_csr = csr_matrix(count.T, dtype=np.float32)

# load meta
meta = pd.read_csv("meta",header = 0, index_col=0)

# batch variable
batch = meta["batch"].tolist()

# convert format
adata = ad.AnnData(count_csr)
adata.obs_names = [f"Sample_{i:d}" for i in range(adata.n_obs)]
adata.var_names = [f"Taxon_{i:d}" for i in range(adata.n_vars)]
adata = ad.AnnData(adata.X, obs=meta)
adata.layers["count"] = adata.X

# sample covariate used in data integration
adata.obs["Y"] = adata.obs["Y"].astype(str)

# scVI to remove batch effects
scvi.model.SCVI.setup_anndata(adata, layer="count", batch_key="batch")
model = scvi.model.SCVI(adata)
model.train()

# scANVI to incorporate sample covariate information
scanvi_model = scvi.model.SCANVI.from_scvi_model(
    model,
    adata=adata,
    labels_key="Y",
    unlabeled_category="Unknown",
)
scanvi_model.train(max_epochs=20, n_samples_per_label=100)
res = scanvi_model.get_latent_representation(adata)
df2 = pd.DataFrame(res)
df2.to_csv("scanvi_res.csv")