## Notebook for transferring labels from healthy epithelial reference to cancer epithelial cells using `scarches`

- **Developed by**: Anna Maguza
- **Institute of Computational Biology - Computational Health Centre - Helmholtz Munich**
- 23th May 2023

### Load required modules

In [1]:
import scanpy as sc
import torch
import scarches as sca
from scarches.dataset.trvae.data_handling import remove_sparsity
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import anndata as an

 captum (see https://github.com/pytorch/captum).
INFO:lightning_fabric.utilities.seed:Global seed set to 0
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
os.chdir('../')
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

In [3]:
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

### Load required modules

In [4]:
input_healthy = '/Users/anna.maguza/Desktop/Data/Processed_datasets/Cancer_dataset_integration/input_files/Epithelial_cells/Geosketch_subset/Healthy_epithelial_cells_Geosketch_subset_2000_HVGs.h5ad'
Healthy_adata = sc.read(input_healthy)

In [5]:
input_cancer = '/Users/anna.maguza/Desktop/Data/Processed_datasets/Cancer_dataset_integration/input_files/Epithelial_cells/Geosketch_subset/Cancer_epithelial_cells_2000_HVGs.h5ad'
Cancer_adata = sc.read(input_cancer)

### Combine Datasets

In [6]:
Healthy_adata.obs['seed_labels'] = Healthy_adata.obs['Unified Cell States']
Cancer_adata.obs['seed_labels'] = 'Unknown'

In [7]:
# Combine datasets
adata = Healthy_adata.concatenate(Cancer_adata, batch_key='dataset', batch_categories=['Healthy', 'Cancer'])

In [8]:
adata.obs['dataset'].value_counts()

dataset
Cancer     113593
Healthy     38776
Name: count, dtype: int64

In [9]:
adata.obs['seed_labels'].value_counts()

seed_labels
Unknown                          113593
TA                                 5000
Enterocyte                         5000
Colonocyte                         5000
Stem cells OLFM4                   5000
Stem cells OLFM4 LGR5              4000
Goblet cells                       4000
Paneth cells                       3468
BEST2+ Goblet cell                 2924
Tuft cells                         1204
Stem cells OLFM4 PCNA               700
Epithelial HBB HBA                  670
Epithelial cells METTL12 MAFB       471
Stem cells OLFM4 GSTA1              341
Microfold cell                      340
Enteroendocrine cells               311
L cells                             228
Enterochromaffin cells              119
Name: count, dtype: int64

In [10]:
del Healthy_adata, Cancer_adata

### Set relevant anndata.obs labels

In [11]:
condition_key = 'dataset'
cell_type_key = 'seed_labels'
target_conditions = ['Cancer']

In [12]:
adata = remove_sparsity(adata)
source_adata = adata[~adata.obs[condition_key].isin(target_conditions)].copy()
target_adata = adata[adata.obs[condition_key].isin(target_conditions)].copy()

In [13]:
sca.models.SCVI.setup_anndata(source_adata, batch_key= "Sample_ID", labels_key='seed_labels')

In [14]:
vae = sca.models.SCVI(
    source_adata,
    n_layers=3,
    encode_covariates=True,
    deeply_inject_covariates=False,
    use_layer_norm="both",
    use_batch_norm="none",
    dispersion = 'gene-batch', 
    gene_likelihood = 'nb', 
    n_latent = 50
)

In [15]:
vae.train(early_stopping=True)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (mps), used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Epoch 206/206: 100%|██████████| 206/206 [27:18<00:00,  8.83s/it, loss=316, v_num=1]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=206` reached.


Epoch 206/206: 100%|██████████| 206/206 [27:18<00:00,  7.95s/it, loss=316, v_num=1]


In [16]:
source_adata.obsm["X_SCVI"] = vae.get_latent_representation()

In [17]:
scanvae = sca.models.SCANVI.from_scvi_model(vae, unlabeled_category = "Unknown")

In [18]:
print("Labelled Indices: ", len(scanvae._labeled_indices))
print("Unlabelled Indices: ", len(scanvae._unlabeled_indices))

Labelled Indices:  38776
Unlabelled Indices:  0


In [19]:
scanvae.train(early_stopping=True)

[34mINFO    [0m Training for [1;36m10[0m epochs.                                                                                   


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (mps), used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Epoch 10/10: 100%|██████████| 10/10 [02:44<00:00, 15.89s/it, loss=449, v_num=1]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 10/10: 100%|██████████| 10/10 [02:44<00:00, 16.43s/it, loss=449, v_num=1]


In [20]:
source_adata.obsm["X_scanvi"] = scanvae.get_latent_representation()

### Create anndata file of latent representation

In [21]:
source_adata.obs['predictions'] = scanvae.predict()

In [22]:
print("Acc: {}".format(np.mean(source_adata.obs.predictions == source_adata.obs.seed_labels)))

Acc: 0.280714875180524


In [23]:
source_adata.obs['predictions'].value_counts()

predictions
Enterocyte               11371
Colonocyte                9440
TA                        5115
Stem cells OLFM4          4179
Paneth cells              4153
BEST2+ Goblet cell        2499
Stem cells OLFM4 LGR5     2019
Name: count, dtype: int64

### Perform surgery on reference model and train on query dataset without cell type labels

In [24]:
model = sca.models.SCANVI.load_query_data(
    target_adata,
    scanvae,
    freeze_dropout = True,
)
model._unlabeled_indices = np.arange(target_adata.n_obs)
model._labeled_indices = []
print("Labelled Indices: ", len(model._labeled_indices))
print("Unlabelled Indices: ", len(model._unlabeled_indices))

Labelled Indices:  0
Unlabelled Indices:  113593


In [25]:
model.train(
    plan_kwargs=dict(weight_decay=0.0),
    check_val_every_n_epoch=10,
)

[34mINFO    [0m Training for [1;36m70[0m epochs.                                                                                   


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (mps), used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Epoch 70/70: 100%|██████████| 70/70 [38:48<00:00, 34.78s/it, loss=392, v_num=1]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=70` reached.


Epoch 70/70: 100%|██████████| 70/70 [38:48<00:00, 33.26s/it, loss=392, v_num=1]


In [26]:
target_adata.obsm["X_scanvi"] = model.get_latent_representation()

In [27]:
target_adata.obsm["X_SCVI"] = model.get_latent_representation()

In [28]:
target_adata.obs['predictions'] = model.predict()

### Concatenate objects and save

In [29]:
adata_full = source_adata.concatenate(target_adata)
adata_full.obsm["X_SCVI"] = np.concatenate([source_adata.obsm["X_SCVI"], target_adata.obsm["X_SCVI"]], axis=0)

In [30]:
# Save the output
adata_full.write('/Users/anna.maguza/Desktop/Data/Processed_datasets/Cancer_dataset_integration/output/Epithelial/Integrated_epithelial_cells_with_scArches_scANVI.h5ad')

In [31]:
target_adata.obs['predictions'].value_counts()

predictions
Enterocyte               40252
Stem cells OLFM4         25137
TA                       14513
Paneth cells             13934
Stem cells OLFM4 LGR5    10754
BEST2+ Goblet cell        7294
Colonocyte                1709
Name: count, dtype: int64