# Train a totalVI model on CITE-seq and impute expression

In [49]:
import scvi
import scanpy as sc
import anndata as ad
import pandas as pd
import matplotlib.pyplot as plt
import mudata as md

In [None]:
citeseq_data_dir = '/home/projects/amit/floriani/Lab/PROJECTS/FlowVI/data/raw/CITE_seq/BNHL/'
adata_citeseq_rna = ad.read_h5ad(citeseq_data_dir + '2024-11-06_BNHL_CITEseq_Tcells_RNA.h5ad')
adata_citeseq_prot = ad.read_h5ad(citeseq_data_dir + '2024-12-05_BNHL_CITEseq_Tcells_protein_cleaned_integrated_imputed_arcsinh_BB.h5ad')

In [None]:
# combine into mudata
mdata = md.MuData({"rna": adata_citeseq_rna, "protein": adata_citeseq_prot})
mdata

  self._update_attr("var", axis=0, join_common=join_common)
  self._update_attr("obs", axis=1, join_common=join_common)


In [None]:
# compute HVG
sc.pp.highly_variable_genes(
    mdata.mod["rna"],
    n_top_genes=4000,
    flavor="seurat_v3",
    batch_key="Run",
    layer="counts",
)
# Place subsetted counts in a new modality
mdata.mod["rna_subset"] = mdata.mod["rna"][:, mdata.mod["rna"].var["highly_variable"]].copy()

mdata.update()

In [None]:
# train totalVI model
scvi.model.TOTALVI.setup_mudata(
    mdata,
    rna_layer="counts",
    protein_layer='raw',
    batch_key="Run",
    modalities={
        "rna_layer": "rna_subset",
        "protein_layer": "protein",
        "batch_key": "rna_subset",
    },
)
model = scvi.model.TOTALVI(mdata)
model.train(max_epochs = 50)

model.save('/home/projects/amit/floriani/Lab/PROJECTS/FlowVI/models/BNHL/2025-01-17_CITEseq_TotalVI', overwrite=True)

[34mINFO    [0m Computing empirical prior initialization for protein background.                                          


  _verify_and_correct_data_format(adata, self.attr_name, self.attr_key)


In [None]:
# get latent
rna = mdata.mod["rna_subset"]
protein = mdata.mod["protein"]

# arbitrarily store latent in rna modality
TOTALVI_LATENT_KEY = "X_totalVI"
rna.obsm[TOTALVI_LATENT_KEY] = model.get_latent_representation()

In [None]:
# get batch corrected expression estimates
all_batches = [*adata_citeseq_rna.obs['Run'].drop_duplicates().values]

rna_denoised, protein_denoised = model.get_normalized_expression(
    n_samples=10, return_mean=True, transform_batch=all_batches
)
rna.layers["denoised_rna"] = rna_denoised
protein.layers["denoised_protein"] = protein_denoised

protein.layers["protein_foreground_prob"] = 100 * model.get_protein_foreground_probability(
    n_samples=10, return_mean=True, transform_batch=all_batches
)


mdata.update()

  self._update_attr("var", axis=0, join_common=join_common)
  self._update_attr("obs", axis=1, join_common=join_common)


In [43]:
# save h5ad to disk
rna.write(f'{citeseq_data_dir}2025-01-17_BNHL_CITEseq_combined_RNA_TotalVI_imputed.h5ad')
protein.write(f'{citeseq_data_dir}2025-01-17_BNHL_CITEseq_Tcells_protein_cleaned_integrated_imputed_arcsinh_BB_TotalVI_imputed.h5ad')