In [None]:
"""
Example Code for Running scVI

Code was developed for scvi-tools version 0.9.3,
but should work for versions up to 0.14.6. From
versions >0.15.0, the setup_anndata function has
been slightly altered (see scvi-tools tutorials).
"""

In [None]:
#%% scVI: load important modules

import torch
import scvi
import scanpy
import sys
import csv
import numpy as np
import pandas as pd

In [None]:
#%% scVI: define paths to data files

"""
variable_genes_file = path_to_variable_genes

A .csv file with two columns corresponding to gene names and
their indices within the dataset (see variable_genes.csv files 
for example). We recommend using variable genes defined by
Seurat's VST algorithm. NOTE: If managing data in both R and
Python, be sure to account for differences in indexing between
the languages (e.g., Gene 1 in R would map to Gene 0 in Python).

dataset_file = path_to_dataset

A .csv file of the expression matrix, with genes as rows and
cells as columns (standard format in Seurat). Gene names and
cell barcodes should be included. If your dataset contains cells
as rows and genes as columns (standard format in Scanpy), then
do not tranpose the adata object in the "scVI: create AnnData"
cell. Alternatively, if your data is in the .h5 or .mtx formats
(common outputs from CellRanger), you should use the appropriate
Scanpy data-reading functions (e.g., scanpy.read_10x_h5 or
scanpy.read_10x_mtx) in the "scVI: create AnnData" cell.

batch_id_file = path_to_batch_file

A .csv file with a single column corresponding to the batch
variable for each cell. This file should not have headers nor
row / column names.
"""

In [None]:
#%% scVI: variable genes

#retrieve highly variable genes
var_gene_index = []
var_gene_name = []

with open(variable_genes_file) as csvfile:
    readCSV = csv.reader(csvfile, delimiter = ",")
    for row in readCSV:
        var_gene_index.append(int(row[1]))
        var_gene_name.append(row[0])
        
print("Number of variable genes: " + str(len(var_gene_name)))

In [None]:
#%% scVI: batch identities

#retrieve batch identities
batch_identities = np.loadtxt(batch_id_file, delimiter=",", dtype = np.float64)
batch_identities = np.reshape(batch_identities, (len(batch_identities), 1))
print("Batch Identities: " + str(np.unique(batch_identities)))

In [None]:
#%% scVI: output files

directory_path = "./" #folder location where you want to save scVI outputs

#define the output files
model_save_file = directory_path + "scVI_model"
adata_save_file = directory_path + "scVI_adata.h5ad"
latent_save_file = directory_path + "scVI_latent.csv"
normalized_save_file = directory_path + "scVI_normalized.csv"
imputed_save_file = directory_path + "scVI_imputed.csv"

print("Model Save Path: " + model_save_file,
      "AnnData Save Path: " + adata_save_file,
      "Latent Save Path: " + latent_save_file,
      "Normalized Save Path: " + normalized_save_file,
      "Imputed Save Path: " + imputed_save_file,
      sep="\n\n")

In [None]:
#%% scVI: create AnnData

#create Scanpy AnnData object

#we use the Scanpy's read_csv function to create the AnnData object,
#but scanpy.read_10x_h5 or scanpy.read_10x_mtx can also work here.
#writing and loading larger datasets as .csv files can be very slow,
#so these alternatives functions will certainly be faster.

adata = scanpy.read_csv(dataset_file, first_column_names=True)
adata = adata.transpose() #do not transpose if cells are rows and genes are columns
adata = adata[:, var_gene_name].copy() #subset the dataset to only include variable genes

#add batch identities to AnnData object
adata.obs["batch"] = batch_identities

print(adata)

print("Batch count:")
print(*[sum(adata.obs["batch"] == i) for i in set(adata.obs["batch"])])
for i in set(adata.obs["batch"]):
    print(i, sum(adata.obs["batch"] == i))

In [None]:
#%% scVI: specify model parameters

#parameters used in the Worley, Everetts, et al. paper
#can be altered to user's preference
scvi_params = {"use_cuda" : torch.cuda.is_available(),
               "n_layers" : 2,
               "n_latent" : 15,
               "gene_likelihood" : "nb",
               "n_epochs" : 400,
               "train_size" : 0.8,
               "python" : sys.executable,
               "scvi_version:" : scvi.__version__}
#gene_likelihood: use "zinb" for zero-inflated negative binomial, "nb" for negative binomial

for key, val in scvi_params.items():
    print(key, val, sep="\n")

In [None]:
#%% scVI: create the model

#the following setup_anndata function should work for scvi-tools versions <0.14
#for scvi-tools versions >0.15, scvi.model.SCVI.setup_anndata should suffice
#instead of scvi.data.setup_anndata

scvi.data.setup_anndata(adata,
                        batch_key = "batch")
model = scvi.model.SCVI(adata,
                        n_latent = scvi_params["n_latent"],
                        n_layers = scvi_params["n_layers"],
                        gene_likelihood = scvi_params["gene_likelihood"])

In [None]:
#%% scVI: train the model

model.train(max_epochs = scvi_params["n_epochs"], train_size = scvi_params["train_size"])

model_latent = model.get_latent_representation()
model_normalized = model.get_normalized_expression()
model_imputed = model.get_normalized_expression(library_size = "latent")

#save all of the output to file
model.save(dir_path = model_save_file)
adata.write(filename = adata_save_file)
np.savetxt(latent_save_file, model_latent, fmt='%s', delimiter = ",")
np.savetxt(normalized_save_file, model_normalized, fmt='%s', delimiter = ",")
np.savetxt(imputed_save_file, model_imputed, fmt='%s', delimiter = ",")
#Outputs:
#model_save_file: folder containing trained scANVI model
#adata_save_file: AnnData used for model training, in .h5ad format
#latent_save_file: CSV file containing the latent representation (matrix) of the data
#normalized_save_file: CSV file containing the denoised expression matrix, scaled to 1
#imputed_save_file: CSV file containing the denoised expression matrix, scaled to library size

#save a log of parameters used
output_log = open(directory_path + "scVI_param_log.txt", "w")
for key, val in scvi_params.items():
    output_log.write(key + "\t" + str(val) + "\n")
output_log.close()