### Notebook for the label transfer of Healthy Human PBMCs to mouse ACM heart (merged Pkp2+Ttn dataset) using `scANVI`

- **Developed by:** Carlos Talavera-López Ph.D
- **Modified by:** Alexandra Cirnu
- **Würzburg Institute for Systems Immunology & Julius-Maximilian-Universität Würzburg**
- **Date of creation:** 230918
- **Date of modification:** 240220

### Import required modules

In [1]:
import scvi
import torch
import anndata
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import plotnine as p
from pywaffle import Waffle
import matplotlib.pyplot as plt
#from scib_metrics.benchmark import Benchmarker

### Set up working environment

In [2]:
sc.settings.verbosity = 3
sc.logging.print_versions()
sc.settings.set_figure_params(dpi = 180, color_map = 'magma_r', dpi_save = 300, vector_friendly = True, format = 'svg')

-----
anndata     0.10.5.post1
scanpy      1.9.8
-----
PIL                 9.4.0
absl                NA
asttokens           NA
attr                23.2.0
chex                0.1.85
comm                0.2.1
contextlib2         NA
cycler              0.12.1
cython_runtime      NA
dateutil            2.8.2
debugpy             1.8.1
decorator           5.1.1
docrep              0.3.2
etils               1.7.0
exceptiongroup      1.2.0
executing           2.0.1
flax                0.8.1
fsspec              2024.2.0
h5py                3.10.0
importlib_resources NA
ipykernel           6.29.2
ipywidgets          8.1.2
jax                 0.4.24
jaxlib              0.4.24
jedi                0.19.1
joblib              1.3.2
kiwisolver          1.4.5
lightning           2.1.4
lightning_utilities 0.10.1
llvmlite            0.42.0
matplotlib          3.8.3
mizani              0.9.3
ml_collections      NA
ml_dtypes           0.3.2
mpl_toolkits        NA
msgpack             1.0.7
mudata           

In [3]:
 def X_is_raw(adata): return np.array_equal(adata.X.sum(axis=0).astype(int), adata.X.sum(axis=0))

In [4]:
torch.set_float32_matmul_precision('high')

In [5]:
warnings.simplefilter(action = 'ignore')
scvi.settings.seed = 1712
%config InlineBackend.print_figure_kwargs = {'facecolor' : "w"}
%config InlineBackend.figure_format = 'retina'

Seed set to 1712


In [6]:
arches_params = dict(
    use_layer_norm = "both",
    use_batch_norm = "none",
    encode_covariates = True,
    dropout_rate = 0.2,
    n_layers = 2,
)

### Read in Healthy data

In [7]:
reference = sc.read_h5ad('/home/acirnu/data/ACM_cardiac_leuco/Annotated_PBMC/meyer_nikolic_covid_pbmc_raw.h5ad')
reference

AnnData object with n_obs × n_vars = 422220 × 33559
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'nCount_ADT', 'nFeature_ADT', 'Age_group', 'BMI', 'COVID_severity', 'COVID_status', 'Ethnicity', 'Group', 'Sex', 'Smoker', 'annotation_broad', 'annotation_detailed', 'annotation_detailed_fullNames', 'patient_id', 'sample_id', 'sequencing_library', 'Protein_modality_weight'
    var: 'name'

In [8]:
X_is_raw(reference)

True

In [9]:
reference.obs['annotation_detailed_fullNames'].value_counts()

annotation_detailed_fullNames
T CD4 naive                           77439
Classical monocyte                    47520
NK                                    43630
B naive                               40193
T CD4 helper                          37922
T CD8 CTL                             36217
T CD8 naive                           33590
T CD8 central mem                     12830
T regulatory                           8567
Non-classical monocyte                 8342
T CD4 naive IFN stim                   8134
Classical monocyte IFN stim            7124
T CD8 effector mem CD45RA+             6443
B non-switched mem                     5574
T CD4 CTL                              5294
T gamma/delta                          5144
NK CD56 bright                         4723
B switched mem                         4060
MAIT                                   3961
Cycling                                3593
cDC2                                   2963
T CD8 effector mem                     2486
B 

- Remove annotations with less than 10 cells

In [10]:
reference.obs['seed_labels'] = reference.obs['annotation_detailed_fullNames'].copy()
reference.obs['seed_labels'].value_counts()

seed_labels
T CD4 naive                           77439
Classical monocyte                    47520
NK                                    43630
B naive                               40193
T CD4 helper                          37922
T CD8 CTL                             36217
T CD8 naive                           33590
T CD8 central mem                     12830
T regulatory                           8567
Non-classical monocyte                 8342
T CD4 naive IFN stim                   8134
Classical monocyte IFN stim            7124
T CD8 effector mem CD45RA+             6443
B non-switched mem                     5574
T CD4 CTL                              5294
T gamma/delta                          5144
NK CD56 bright                         4723
B switched mem                         4060
MAIT                                   3961
Cycling                                3593
cDC2                                   2963
T CD8 effector mem                     2486
B naive IFN stim    

In [11]:
reference.obs['seed_labels'].cat.categories

Index(['Hematopoietic progenitors IFN stim', 'B non-switched mem IFN stim',
       'B naive IFN stim', 'Non-classical monocyte IFN stim',
       'Classical monocyte IFN stim', 'NK IFN stim', 'T CD8 CTL IFN stim',
       'T CD4 naive IFN stim', 'Red blood cells', 'Platelets', 'Cycling',
       'Basophils & Eosinophils', 'Hematopoietic progenitors', 'Plasmablasts',
       'Plasma cells', 'B invariant', 'B switched mem', 'B non-switched mem',
       'B naive', 'cDC2', 'cDC1', 'AS-DC', 'pDC',
       'Non-classical monocyte complement+', 'Non-classical monocyte',
       'Classical monocyte IL6+', 'Classical monocyte', 'ILC',
       'NK CD56 bright', 'NK', 'NKT', 'MAIT', 'T regulatory', 'T gamma/delta',
       'T CD8 CTL', 'T CD8 effector mem CD45RA+', 'T CD8 effector mem',
       'T CD8 central mem', 'T CD8 naive', 'T CD4 CTL', 'T CD4 helper',
       'T CD4 naive'],
      dtype='object')

### Subset populations of interest

In [12]:
reference = reference[~reference.obs['seed_labels'].isin(['nan'])]              # = filter out not assigned cells?
reference.obs['seed_labels'].value_counts()

seed_labels
T CD4 naive                           77439
Classical monocyte                    47520
NK                                    43630
B naive                               40193
T CD4 helper                          37922
T CD8 CTL                             36217
T CD8 naive                           33590
T CD8 central mem                     12830
T regulatory                           8567
Non-classical monocyte                 8342
T CD4 naive IFN stim                   8134
Classical monocyte IFN stim            7124
T CD8 effector mem CD45RA+             6443
B non-switched mem                     5574
T CD4 CTL                              5294
T gamma/delta                          5144
NK CD56 bright                         4723
B switched mem                         4060
MAIT                                   3961
Cycling                                3593
cDC2                                   2963
T CD8 effector mem                     2486
B naive IFN stim    

In [13]:
reference.obs['donor'] = reference.obs['sample_id'].copy()

In [14]:
sc.pp.filter_cells(reference, min_genes = 200)              #to remove empty cells from table
sc.pp.filter_cells(reference, min_counts = 100)

In [15]:
cell_source_reference = "Yoshida"
reference.obs["cell_source"] = cell_source_reference
reference.obs['cell_source'].value_counts()

cell_source
Yoshida    422220
Name: count, dtype: int64

### Read in other query

In [16]:
query = sc.read_h5ad('/home/acirnu/data/ACM_cardiac_leuco/processed_merged/Merge_demux_QCed_ac240220.raw.h5ad')
query.obs['cell_source'] = 'AG_Gerull'
query.obs['seed_labels'] = 'Unknown'
query.obs['donor'] = query.obs['sample'].copy()
query

AnnData object with n_obs × n_vars = 44593 × 32285
    obs: 'sample', 'condition', 'genotype', 'infection', 'library', 'n_genes', 'doublet_scores', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'percent_mt', 'n_counts', 'percent_chrY', 'XIST-counts', 'S_score', 'G2M_score', 'cell_source', 'seed_labels', 'donor'
    var: 'gene_ids', 'feature_types', 'mt', 'ribo', 'n_cells_by_counts-A1', 'mean_counts-A1', 'pct_dropout_by_counts-A1', 'total_counts-A1', 'n_cells_by_counts-A2', 'mean_counts-A2', 'pct_dropout_by_counts-A2', 'total_counts-A2', 'n_cells_by_counts-A3', 'mean_counts-A3', 'pct_dropout_by_counts-A3', 'total_counts-A3', 'n_cells_by_counts-A4', 'mean_counts-A4', 'pct_dropout_by_counts-A4', 'total_counts-A4', 'n_cells_by_counts-B1', 'mean_counts-B1', 'pct_dropout_by_counts-B1', 'total_counts-B1', 'n_cells_by_counts-B2', 'mean_counts-B2', 'pct_dropout_by_counts-B2', 'total_counts-B2'
    layers: 'counts', 'sqrt_norm'

In [17]:
X_is_raw(query)

False

In [18]:
sc.pp.filter_cells(query, min_genes = 200)
sc.pp.filter_cells(query, min_counts = 100)

### Change gene symbols for label transfer

In [19]:
query.var_names = [gene_name.upper() for gene_name in query.var_names]
query.var_names

Index(['XKR4', 'GM1992', 'GM19938', 'GM37381', 'RP1', 'SOX17', 'GM37587',
       'GM37323', 'MRPL15', 'LYPLA1',
       ...
       'GM16367', 'AC163611.1', 'AC163611.2', 'AC140365.1', 'AC124606.2',
       'AC124606.1', 'AC133095.2', 'AC133095.1', 'AC234645.1', 'AC149090.1'],
      dtype='object', length=32285)

In [20]:
adata = reference.concatenate(query, batch_key = 'batch', batch_categories = ['reference', 'query'], join = 'inner')                #merge to one adata object
adata

AnnData object with n_obs × n_vars = 466813 × 16209
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'nCount_ADT', 'nFeature_ADT', 'Age_group', 'BMI', 'COVID_severity', 'COVID_status', 'Ethnicity', 'Group', 'Sex', 'Smoker', 'annotation_broad', 'annotation_detailed', 'annotation_detailed_fullNames', 'patient_id', 'sample_id', 'sequencing_library', 'Protein_modality_weight', 'seed_labels', 'donor', 'n_genes', 'n_counts', 'cell_source', 'sample', 'condition', 'genotype', 'infection', 'library', 'doublet_scores', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'percent_mt', 'percent_chrY', 'XIST-counts', 'S_score', 'G2M_score', 'batch'
    var: 'gene_ids-query', 'feature_types-query', 'mt-query', 'ribo-query', 'n_cells_by_counts-A1-query', 'mean_counts-A1-query', 'pct_dropout_by_counts-A1-query', 'total_counts-A1-query', 'n_cells_by_counts-A2-query', 'mean_counts-A2-query', 'pct_dropout_by_counts-A2-query', 'total_counts-

In [21]:
adata.obs['seed_labels'] = adata.obs['seed_labels'].astype('category')
adata.obs['seed_labels'].value_counts()

seed_labels
T CD4 naive                           77439
Classical monocyte                    47520
Unknown                               44593
NK                                    43630
B naive                               40193
T CD4 helper                          37922
T CD8 CTL                             36217
T CD8 naive                           33590
T CD8 central mem                     12830
T regulatory                           8567
Non-classical monocyte                 8342
T CD4 naive IFN stim                   8134
Classical monocyte IFN stim            7124
T CD8 effector mem CD45RA+             6443
B non-switched mem                     5574
T CD4 CTL                              5294
T gamma/delta                          5144
NK CD56 bright                         4723
B switched mem                         4060
MAIT                                   3961
Cycling                                3593
cDC2                                   2963
T CD8 effector mem  

In [22]:
adata.obs['batch'].value_counts()

batch
reference    422220
query         44593
Name: count, dtype: int64

### Select HVGs

In [23]:
adata.obs['donor'] = adata.obs['donor'].astype('str') 
adata.obs['donor'].value_counts()

donor
AP6                   15700
AP11                  12207
PP3                   10293
AP5                   10097
AP12                  10033
                      ...  
Pkp2_Ctr_MCMV_4         698
Pkp2_Ctr_PBS_1          684
Ttn_HetKO_noninf_1      563
PC27                    431
PP11                     72
Name: count, Length: 110, dtype: int64

In [24]:
adata_raw = adata.copy()
adata.layers['counts'] = adata.X.copy()

sc.pp.highly_variable_genes(
    adata,
    flavor = "seurat_v3",
    n_top_genes = 3000,
    layer = "counts",
    batch_key = "donor",
    subset = True,
    span = 1
    )

adata

If you pass `n_top_genes`, all cutoffs are ignored.
extracting highly variable genes
--> added
    'highly_variable', boolean vector (adata.var)
    'highly_variable_rank', float vector (adata.var)
    'means', float vector (adata.var)
    'variances', float vector (adata.var)
    'variances_norm', float vector (adata.var)


AnnData object with n_obs × n_vars = 466813 × 3000
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'nCount_ADT', 'nFeature_ADT', 'Age_group', 'BMI', 'COVID_severity', 'COVID_status', 'Ethnicity', 'Group', 'Sex', 'Smoker', 'annotation_broad', 'annotation_detailed', 'annotation_detailed_fullNames', 'patient_id', 'sample_id', 'sequencing_library', 'Protein_modality_weight', 'seed_labels', 'donor', 'n_genes', 'n_counts', 'cell_source', 'sample', 'condition', 'genotype', 'infection', 'library', 'doublet_scores', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'percent_mt', 'percent_chrY', 'XIST-counts', 'S_score', 'G2M_score', 'batch'
    var: 'gene_ids-query', 'feature_types-query', 'mt-query', 'ribo-query', 'n_cells_by_counts-A1-query', 'mean_counts-A1-query', 'pct_dropout_by_counts-A1-query', 'total_counts-A1-query', 'n_cells_by_counts-A2-query', 'mean_counts-A2-query', 'pct_dropout_by_counts-A2-query', 'total_counts-A

### Transfer of annotation with scANVI

In [25]:
scvi.model.SCVI.setup_anndata(adata, 
                              batch_key = 'donor', 
                              labels_key = 'seed_labels',
                              layer = 'counts',
                              categorical_covariate_keys = ['donor', 'cell_source'])

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [26]:
scvi_model = scvi.model.SCVI(adata, 
                             n_latent = 50, 
                             n_layers = 3, 
                             dispersion = 'gene-batch', 
                             gene_likelihood = 'nb')

In [27]:
scvi_model.train(20, 
                 check_val_every_n_epoch = 1, 
                 enable_progress_bar = True,
                 accelerator = "gpu",
                 devices = -1                                        
                 )

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[rank: 0] Seed set to 1712
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
[rank: 1] Seed set to 1712
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 2 processes
----------------------------------------------------------------------------------------------------

LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [0,1]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Epoch 1/20:   0%|          | 0/20 [00:00<?, ?it/s]

ProcessRaisedException: 

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 68, in _wrap
    fn(i, *args)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/multiprocessing.py", line 170, in _wrapping_function
    results = function(*args, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 989, in _run
    results = self._run_stage()
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1035, in _run_stage
    self.fit_loop.run()
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 202, in run
    self.advance()
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 359, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 136, in run
    self.advance(data_fetcher)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 240, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 187, in run
    self._optimizer_step(batch_idx, closure)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 265, in _optimizer_step
    call._call_lightning_module_hook(
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 157, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 1291, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/core/optimizer.py", line 151, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/strategies/ddp.py", line 265, in optimizer_step
    optimizer_output = super().optimizer_step(optimizer, closure, model, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 230, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision.py", line 117, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/torch/optim/optimizer.py", line 385, in wrapper
    out = func(*args, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/torch/optim/optimizer.py", line 76, in _use_grad
    ret = func(self, *args, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/torch/optim/adam.py", line 146, in step
    loss = closure()
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/plugins/precision/precision.py", line 104, in _wrap_closure
    closure_result = closure()
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 140, in __call__
    self._result = self.closure(*args, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 126, in closure
    step_output = self._step_fn()
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 315, in _training_step
    training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 381, in training_step
    return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 633, in __call__
    wrapper_output = wrapper_module(*args, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1523, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 626, in wrapped_forward
    out = method(*_args, **_kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/scvi/train/_trainingplans.py", line 344, in training_step
    _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/scvi/train/_trainingplans.py", line 278, in forward
    return self.module(*args, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/scvi/module/base/_decorators.py", line 32, in auto_transfer_args
    return fn(self, *args, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/scvi/module/base/_base_module.py", line 203, in forward
    return _generic_forward(
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/scvi/module/base/_base_module.py", line 743, in _generic_forward
    generative_outputs = module.generative(**generative_inputs, **generative_kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/scvi/module/base/_decorators.py", line 32, in auto_transfer_args
    return fn(self, *args, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/scvi/module/_vae.py", line 393, in generative
    px_scale, px_r, px_rate, px_dropout = self.decoder(
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/scvi/nn/_base_components.py", line 402, in forward
    px = self.px_decoder(z, *cat_list)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/scvi/nn/_base_components.py", line 159, in forward
    one_hot_cat = one_hot(cat, n_cat)
  File "/home/acirnu/miniforge3/envs/scANVI/lib/python3.10/site-packages/scvi/nn/_utils.py", line 7, in one_hot
    onehot.scatter_(1, index.type(torch.long), 1)
RuntimeError: Index tensor must have the same number of dimensions as self tensor


### Evaluate model performance a la _Sevensson_

In [None]:
history_df = (
    scvi_model.history['elbo_train'].astype(float)
    .join(scvi_model.history['elbo_validation'].astype(float))
    .reset_index()
    .melt(id_vars = ['epoch'])
)

p.options.figure_size = 12, 6

p_ = (
    p.ggplot(p.aes(x = 'epoch', y = 'value', color = 'variable'), history_df.query('epoch > 0'))
    + p.geom_line()
    + p.geom_point()
    + p.scale_color_manual({'elbo_train': 'black', 'elbo_validation': 'red'})
    + p.theme_minimal()
)

p_.save('fig2.png', dpi = 300)

print(p_)

### Label transfer with `scANVI` 

In [None]:
scanvi_model = scvi.model.SCANVI.from_scvi_model(scvi_model, 'Unknown')

In [None]:
scanvi_model.train(80, 
                   check_val_every_n_epoch = 1, 
                   enable_progress_bar = True,              # use_gpu = 1)
)

### Evaluate model performance a la Svensson

In [None]:
history_df = (
    scanvi_model.history['elbo_train'].astype(float)
    .join(scanvi_model.history['elbo_validation'].astype(float))
    .reset_index()
    .melt(id_vars = ['epoch'])
)

p.options.figure_size = 12, 6

p_ = (
    p.ggplot(p.aes(x = 'epoch', y = 'value', color = 'variable'), history_df.query('epoch > 0'))
    + p.geom_line()
    + p.geom_point()
    + p.scale_color_manual({'elbo_train': 'black', 'elbo_validation': 'red'})
    + p.theme_minimal()
)

p_.save('fig1.png', dpi = 300)

print(p_)

In [None]:
adata.obs["C_scANVI"] = scanvi_model.predict(adata)

- Extract latent representation

In [None]:
adata.obsm["X_scANVI"] = scanvi_model.get_latent_representation(adata)

- Visualise corrected dataset

In [None]:
sc.pp.neighbors(adata, use_rep = "X_scANVI", n_neighbors = 50, metric = 'minkowski')
sc.tl.umap(adata, min_dist = 0.3, spread = 4, random_state = 1712)
sc.pl.umap(adata, frameon = False, color = ['C_scANVI', 'seed_labels', 'sample', 'batch'], size = 1, legend_fontsize = 5, ncols = 3)

### Export annotated object

In [None]:
adata_export_merged = anndata.AnnData(X = adata_raw.X, var = adata_raw.var, obs = adata.obs)
adata_export_merged

In [None]:
adata_export = adata_export_merged[adata_export_merged.obs['batch'].isin(['query'])]
adata_export

### Add new gene symbols

In [None]:
query_export = anndata.AnnData(X = query.X, var = query.var, obs = adata_export.obs)
query_export

### Revert gene symbols

In [None]:
query_export.var_names = [gene_name.capitalize() for gene_name in query_export.var_names]
query_export.var_names

In [None]:
query_export.obs['C_scANVI'].value_counts()

In [None]:
query_export

In [None]:
query_export.write('../data/heart_mm_nuclei-23-0092_scANVI-CellStates_ctl231123.raw.h5ad')