### Notebook for the label transfer of annotated gut stem cells to predictec stem cells in the GCA with `scANVI`

- **Developed by:** Carlos Talavera-López Ph.D
- **Institute of Computational Biology - Computational Health Department - Helmholtz Munich**
- v230419

### Import required modules

In [1]:
import torch
import scvi
import anndata
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt

Global seed set to 0
  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)


### Set up working environment

In [2]:
sc.settings.verbosity = 3
sc.logging.print_versions()
sc.settings.set_figure_params(dpi = 180, color_map = 'magma_r', dpi_save = 300, vector_friendly = True, format = 'svg')

-----
anndata     0.8.0
scanpy      1.9.3
-----
PIL                 9.4.0
absl                NA
appnope             0.1.3
asttokens           NA
backcall            0.2.0
beta_ufunc          NA
binom_ufunc         NA
brotli              NA
certifi             2022.12.07
cffi                1.15.1
charset_normalizer  2.1.1
chex                0.1.6
colorama            0.4.6
contextlib2         NA
cycler              0.10.0
cython_runtime      NA
dateutil            2.8.2
debugpy             1.6.6
decorator           5.1.1
docrep              0.3.2
executing           1.2.0
flax                0.5.0
fsspec              2023.3.0
h5py                3.8.0
hypergeom_ufunc     NA
idna                3.4
igraph              0.10.4
importlib_resources NA
invgauss_ufunc      NA
ipykernel           6.15.0
ipywidgets          8.0.6
jax                 0.4.6
jaxlib              0.4.6
jedi                0.18.2
joblib              1.2.0
kiwisolver          1.4.4
leidenalg           0.9.1
lightning

In [3]:
warnings.simplefilter(action = 'ignore')
scvi.settings.seed = 1712
%config InlineBackend.print_figure_kwargs = {'facecolor' : "w"}
%config InlineBackend.figure_format = 'retina'
torch.set_float32_matmul_precision('medium')

Global seed set to 1712


In [4]:
arches_params = dict(
    use_layer_norm = "both",
    use_batch_norm = "none",
    encode_covariates = True,
    dropout_rate = 0.2,
    n_layers = 3,
)

### Read in Healthy data

In [5]:
kong_gca = sc.read_h5ad('../../data/Kong_2023_and_predicted_GCA_stem_cells_unprocessed.h5ad')
kong_gca

AnnData object with n_obs × n_vars = 31298 × 23616
    obs: 'cell_type', 'Location', 'batch', 'Sample_ID', 'n_genes', 'n_counts', 'Chem', 'Site', 'Type', 'Donor_ID', 'Layer', 'Celltype', 'sex', 'species', 'species__ontology_label', 'library_preparation_protocol', 'library_preparation_protocol__ontology_label', 'organ', 'organ__ontology_label', 'disease', 'disease__ontology_label', 'seed_labels', 'Study_name', 'UniqueCell_ID', 'CellType', 'Diagnosis', 'Age', 'Region code', 'Fraction', 'Gender', 'n_genes_by_counts', 'total_counts_mt', 'doublet_scores', 'predicted_doublets', 'Age_group', 'total_counts_ribo', 'percent_mito', 'percent_ribo', 'Cell States', 'Cell Label', 'dataset'
    var: 'gene_id-reference', 'gene_name-reference'

### Check if an object has raw or log-transformed counts

In [6]:
def contains_raw_counts(adata: anndata.AnnData) -> bool:
    data = adata.X.toarray() if hasattr(adata.X, 'toarray') else adata.X
    has_negative_values = np.any(data < 0)
    has_non_integer_values = not np.all(np.equal(np.mod(data, 1), 0))
    return not (has_negative_values or has_non_integer_values)


In [7]:
is_raw_counts = contains_raw_counts(kong_gca)
print(f"AnnData contains raw counts: {is_raw_counts}")

AnnData contains raw counts: True


In [8]:
kong_gca.obs['Study_name'].value_counts()

Kong 2023         16360
Gut Cell Atlas    12471
Smilie             1828
Wang                639
Name: Study_name, dtype: int64

- Remove cells with 'None' annotation

In [9]:
min_cells = 10
cell_counts_per_sample = kong_gca.obs.groupby('Sample_ID').size()
samples_to_keep = cell_counts_per_sample[cell_counts_per_sample >= min_cells].index
kong_gca_filtered = kong_gca[kong_gca.obs['Sample_ID'].isin(samples_to_keep)].copy()
kong_gca_filtered

AnnData object with n_obs × n_vars = 31005 × 23616
    obs: 'cell_type', 'Location', 'batch', 'Sample_ID', 'n_genes', 'n_counts', 'Chem', 'Site', 'Type', 'Donor_ID', 'Layer', 'Celltype', 'sex', 'species', 'species__ontology_label', 'library_preparation_protocol', 'library_preparation_protocol__ontology_label', 'organ', 'organ__ontology_label', 'disease', 'disease__ontology_label', 'seed_labels', 'Study_name', 'UniqueCell_ID', 'CellType', 'Diagnosis', 'Age', 'Region code', 'Fraction', 'Gender', 'n_genes_by_counts', 'total_counts_mt', 'doublet_scores', 'predicted_doublets', 'Age_group', 'total_counts_ribo', 'percent_mito', 'percent_ribo', 'Cell States', 'Cell Label', 'dataset'
    var: 'gene_id-reference', 'gene_name-reference'

In [10]:
kong_gca_filtered.obs['seed_labels'].value_counts()

Unknown                   14674
Stem cells OLFM4           7048
Stem cells OLFM4 LGR5      4954
Stem cells OLFM4 PCNA      2927
Stem cells OLFM4 GSTA1     1402
Name: seed_labels, dtype: int64

### Select HVGs

In [11]:
adata_raw = kong_gca_filtered.copy()
kong_gca_filtered.layers['counts'] = kong_gca_filtered.X.copy()

In [12]:
sc.pp.highly_variable_genes(
    kong_gca_filtered,
    flavor = "seurat_v3",
    n_top_genes = 3000,
    layer = "counts",
    batch_key = "Sample_ID",
    subset = True,
    span = 1
)
kong_gca_filtered

If you pass `n_top_genes`, all cutoffs are ignored.
extracting highly variable genes
--> added
    'highly_variable', boolean vector (adata.var)
    'highly_variable_rank', float vector (adata.var)
    'means', float vector (adata.var)
    'variances', float vector (adata.var)
    'variances_norm', float vector (adata.var)


AnnData object with n_obs × n_vars = 31005 × 3000
    obs: 'cell_type', 'Location', 'batch', 'Sample_ID', 'n_genes', 'n_counts', 'Chem', 'Site', 'Type', 'Donor_ID', 'Layer', 'Celltype', 'sex', 'species', 'species__ontology_label', 'library_preparation_protocol', 'library_preparation_protocol__ontology_label', 'organ', 'organ__ontology_label', 'disease', 'disease__ontology_label', 'seed_labels', 'Study_name', 'UniqueCell_ID', 'CellType', 'Diagnosis', 'Age', 'Region code', 'Fraction', 'Gender', 'n_genes_by_counts', 'total_counts_mt', 'doublet_scores', 'predicted_doublets', 'Age_group', 'total_counts_ribo', 'percent_mito', 'percent_ribo', 'Cell States', 'Cell Label', 'dataset'
    var: 'gene_id-reference', 'gene_name-reference', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'highly_variable_nbatches'
    uns: 'hvg'
    layers: 'counts'

### Transfer of annotation with scANVI

In [13]:
scvi.model.SCVI.setup_anndata(kong_gca_filtered,
                              layer = "counts",
                              labels_key = "seed_labels",
                              categorical_covariate_keys = ["Sample_ID"],
                              continuous_covariate_keys = ["n_genes", "n_counts"]
                              )

In [14]:
scvi_model = scvi.model.SCVI(kong_gca_filtered, n_latent = 50, n_layers = 3, dispersion = 'gene-batch', gene_likelihood = 'nb')

In [15]:
scvi_model.train()

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Epoch 1/258:   0%|          | 0/258 [00:00<?, ?it/s]

ValueError: Expected parameter loc (Tensor of shape (128, 50)) of distribution Normal(loc: torch.Size([128, 50]), scale: torch.Size([128, 50])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], grad_fn=<AddmmBackward0>)

### Label transfer with `scANVI` 

In [None]:
scanvi_model = scvi.model.SCANVI.from_scvi_model(scvi_model, 'Unknown')

In [None]:
scanvi_model.train(25)


In [None]:
kong_gca_filtered.obs["C_scANVI"] = scanvi_model.predict(kong_gca_filtered)

- Extract latent representation

In [None]:
kong_gca_filtered.obsm["X_scANVI"] = scanvi_model.get_latent_representation(kong_gca_filtered)

- Visualise corrected dataset

In [None]:
sc.pp.neighbors(kong_gca_filtered, use_rep = "X_scANVI", n_neighbors = 50, metric = 'minkowski')
sc.tl.umap(kong_gca_filtered, min_dist = 0.2, spread = 2, random_state = 1712)
sc.pl.umap(kong_gca_filtered, frameon = False, color = ['group', 'disease', 'infection', 'C_scANVI', 'cell_states'], size = 0.6, legend_fontsize = 5, ncols = 3)

### Export annotated sample object 

In [None]:
query_annotated = adata[adata.obs['group'].isin(['healthy_ctrl'])]
query_annotated.obs_names

In [None]:
query_annotated.obs.index = pd.Index(['-'.join(idx.split('-')[:3]) for idx in query_annotated.obs.index])
query_annotated.obs.index

In [None]:
healthy_ctrl.obs['C_scANVI'] = query_annotated.obs['C_scANVI'].copy()
healthy_ctrl.obs_names

In [None]:
healthy_ctrl.obs['C_scANVI'].cat.categories

In [None]:
healthy_ctrl.obs['C_scANVI'].value_counts()

In [None]:
def X_is_raw(adata):
    return np.array_equal(adata.X.sum(axis = 0).astype(int), adata.X.sum(axis = 0))


In [None]:
X_is_raw(healthy_ctrl)

In [None]:
healthy_ctrl.obs['C_scANVI'].value_counts()

### Export annotated Helathy-CTRL object with raw counts

In [None]:
healthy_ctrl.write('/home/cartalop/data/carlos/single_cell/COPD_IAV/grch38-iav/scanvi_annotated/Marburg_Healthy_CTRL_ctl230315_scANVI_annot.raw.h5ad')