# Import modules

In [1]:
import scanpy as sc
import torch
import scarches as sca
import matplotlib.pyplot as plt
import numpy as np
import gdown

 captum (see https://github.com/pytorch/captum).
INFO:pytorch_lightning.utilities.seed:Global seed set to 0
  return new_rank_zero_deprecation(*args, **kwargs)


In [2]:
import session_info
session_info.show()

In [3]:
sc.settings.set_figure_params(dpi=200, frameon=False)
sc.set_figure_params(dpi=200)
sc.set_figure_params(figsize=(4, 4))
torch.set_printoptions(precision=3, sci_mode=False, edgeitems=7)

# Set relevant anndata.obs labels and training length

In [4]:
cell_type_key = 'cell_type'

# Read in reference adata

In [5]:
adata_ref = sc.read_h5ad('/nfs/team205/heart/anndata_objects/8regions/scArches/RNA_adult-8reg-ref_pp.h5ad')
adata_ref

AnnData object with n_obs × n_vars = 629041 × 3155
    obs: 'sangerID', 'combinedID', 'donor', 'donor_type', 'region', 'region_finest', 'age', 'gender', 'facility', 'cell_or_nuclei', 'modality', 'kit_10x', 'flushed', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'scrublet_score', 'scrublet_leiden', 'cluster_scrublet_score', 'doublet_pval', 'doublet_bh_pval', 'batch_key', 'leiden_scVI', 'cell_type', 'cell_state_HCAv1', 'cell_state_scNym', 'cell_state_scNym_confidence', 'cell_state'
    var: 'gene_name_scRNA-0', 'gene_name_snRNA-1', 'gene_name_multiome-2', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'highly_variable_nbatches', 'highly_variable_intersection'
    uns: 'cell_or_nuclei_colors', 'cell_state_HCAv1_colors', 'cell_state_colors', 'cell_state_scNym_colors', 'cell_type_colors', 'donor_colors', 'hvg', 'kit_10x_colors', 'leiden_scVI_colors', 'region_colors'
    obsm: 'X_scVI', 

In [6]:
adata_ref.X.data[:10]

array([1.34, 2.24, 1.65, 1.34, 3.73, 1.34, 0.88, 0.88, 1.65, 1.65],
      dtype=float32)

In [7]:
adata_ref.layers["counts"].data[:10]

array([1., 1., 1., 5., 1., 8., 1., 2., 1., 3.], dtype=float32)

In [8]:
adata_ref.obs['batch_key'].value_counts()

AH1_Nuclei_Multiome-v1    51034
D11_Cell_3prime-v3        45964
D2_Nuclei_3prime-v2       45050
H5_Nuclei_3prime-v3       38177
D6_Cell_3prime-v2         36677
A61_Nuclei_Multiome-v1    34607
D8_Nuclei_Multiome-v1     33554
H3_Nuclei_3prime-v3       32266
H7_Nuclei_3prime-v3       31676
H6_Nuclei_3prime-v3       27431
D4_Nuclei_3prime-v2       26773
H4_Nuclei_3prime-v3       25707
D5_Nuclei_3prime-v2       22564
D3_Nuclei_3prime-v2       22271
H2_Nuclei_3prime-v3       21396
D6_Cell_3prime-v3         20568
D7_Cell_3prime-v2         18528
D6_Nuclei_3prime-v2       17946
D7_Nuclei_3prime-v2       16904
D1_Nuclei_3prime-v2       15533
D7_Nuclei_Multiome-v1     13925
D11_Nuclei_3prime-v3      13844
D5_Cell_3prime-v2          6647
D3_Cell_3prime-v2          5364
D3_Nuclei_Multiome-v1      2685
D4_Cell_3prime-v2          1784
D1_Cell_3prime-v2           166
Name: batch_key, dtype: int64

# Read in query data 

In [29]:
# query data
adata_que = sc.read_h5ad('/nfs/team205/heart/anndata_objects/8regions/multiome_RNA_adult_new-SAN-AVN_raw_rmdblcls.h5ad')
adata_que.var = adata_que.var[['gene_name']]
adata_que

AnnData object with n_obs × n_vars = 75255 × 36601
    obs: 'latent_RT_efficiency', 'latent_cell_probability', 'latent_scale', 'sangerID', 'combinedID', 'donor', 'donor_type', 'region', 'region_finest', 'age', 'gender', 'facility', 'cell_or_nuclei', 'modality', 'kit_10x', 'flushed', 'scrublet_score', 'scrublet_leiden', 'cluster_scrublet_score', 'doublet_pval', 'doublet_bh_pval', 'n_genes', 'n_counts', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'batch_key', '_scvi_batch', '_scvi_labels', 'leiden_scVI', 'clus20', 'doublet_cls'
    var: 'gene_name'
    obsm: 'X_scVI', 'X_umap', '_scvi_extra_continuous', 'latent_gene_encoding'

In [30]:
# subset HVGs of reference data
adata_que.layers["counts"] = adata_que.X.copy()
adata_que.raw = adata_que
adata_que = adata_que[:,adata_ref.var_names]
print(adata_que.X.data[:10])
adata_que

[4. 2. 2. 1. 3. 3. 1. 6. 1. 1.]


View of AnnData object with n_obs × n_vars = 75255 × 3155
    obs: 'latent_RT_efficiency', 'latent_cell_probability', 'latent_scale', 'sangerID', 'combinedID', 'donor', 'donor_type', 'region', 'region_finest', 'age', 'gender', 'facility', 'cell_or_nuclei', 'modality', 'kit_10x', 'flushed', 'scrublet_score', 'scrublet_leiden', 'cluster_scrublet_score', 'doublet_pval', 'doublet_bh_pval', 'n_genes', 'n_counts', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt', 'total_counts_ribo', 'pct_counts_ribo', 'batch_key', '_scvi_batch', '_scvi_labels', 'leiden_scVI', 'clus20', 'doublet_cls'
    var: 'gene_name'
    obsm: 'X_scVI', 'X_umap', '_scvi_extra_continuous', 'latent_gene_encoding'
    layers: 'counts'

In [31]:
adata_que.layers["counts"].data[:10]

array([4., 2., 2., 1., 3., 3., 1., 6., 1., 1.], dtype=float32)

# Create SCANVI model and train it on fully labelled reference dataset

In [9]:
sca.models.SCVI.setup_anndata(adata_ref, 
                              layer="counts", 
                              batch_key='batch_key', 
                              # categorical_covariate_keys=['donor','cell_or_nuclei','kit_10x'],
                              # scArches currently does not support models with extra categorical covariates.
                              continuous_covariate_keys=['total_counts','pct_counts_mt','pct_counts_ribo'],
                              labels_key=cell_type_key)

In [10]:
vae = sca.models.SCVI(
    adata_ref,
    n_hidden = 128, n_latent = 50, n_layers = 3, dispersion = 'gene-batch',
    encode_covariates=True,
    deeply_inject_covariates=False,
    use_layer_norm="both",
    use_batch_norm="none",
)

In [11]:
vae.view_anndata_setup(adata_ref)

In [12]:
vae.train()

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 13/13: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [15:00<00:00, 69.09s/it, loss=640, v_num=1]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=13` reached.


Epoch 13/13: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [15:00<00:00, 69.28s/it, loss=640, v_num=1]


Create the SCANVI model instance with ZINB loss as default. Insert “gene_likelihood=’nb’,” to change the reconstruction loss to NB loss.

In [13]:
scanvae = sca.models.SCANVI.from_scvi_model(vae, unlabeled_category = "Unknown")

In [14]:
print("Labelled Indices: ", len(scanvae._labeled_indices))
print("Unlabelled Indices: ", len(scanvae._unlabeled_indices))

Labelled Indices:  629041
Unlabelled Indices:  0


In [15]:
scanvae.train(max_epochs=20)

[34mINFO    [0m Training for [1;36m20[0m epochs.                                                                                   


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 20/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [52:51<00:00, 158.59s/it, loss=714, v_num=1]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 20/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [52:51<00:00, 158.59s/it, loss=714, v_num=1]


In [16]:
# save models
ref_path = '/nfs/team205/kk18/Analysis/scArches/models/global_ref_scanvae/'
scanvae.save(ref_path, overwrite=True)
vae.save('/nfs/team205/kk18/Analysis/scArches/models/global_ref_vae/', overwrite=True)

# Create anndata file of latent representation and compute UMAP

In [None]:
'''
adata_ref.obsm["X_scVI"] = vae.get_latent_representation()
adata_ref.obsm["X_scANVI"] = scanvae.get_latent_representation()
adata_ref.write('./anndata/RNA_adult-8reg-ref_pp.h5ad')
'''

In [25]:
reference_latent = sc.AnnData(scanvae.get_latent_representation())
reference_latent.obs["cell_type"] = adata_ref.obs[cell_type_key].tolist()
reference_latent.obs["batch"] = adata_ref.obs["batch_key"].tolist()

In [27]:
# One can also compute the accuracy of the learned classifier
reference_latent.obs['predictions'] = scanvae.predict()
print("Acc: {}".format(np.mean(reference_latent.obs.predictions == reference_latent.obs.cell_type)))

Acc: 0.9930068151360563


In [28]:
reference_latent.obs_names = scanvae.adata.obs_names.copy()
# save latent anndata
reference_latent.write('/nfs/team205/kk18/Analysis/scArches/latent_anndata/reference_latent.h5ad')

In [None]:
'''
sc.pp.neighbors(reference_latent, n_neighbors=8)
sc.tl.leiden(reference_latent)
sc.tl.umap(reference_latent)
sc.pl.umap(reference_latent,
           color=["batch",'cell_type'],
           frameon=False,
           wspace=0.6,
           )
'''

# Perform surgery on reference model and train on query dataset without cell type labels

In [32]:
adata_que = adata_que.copy()

model = sca.models.SCANVI.load_query_data(
    adata_que,
    ref_path,
    freeze_dropout = True,
)
model._unlabeled_indices = np.arange(adata_que.n_obs)
model._labeled_indices = []
print("Labelled Indices: ", len(model._labeled_indices))
print("Unlabelled Indices: ", len(model._unlabeled_indices))

[34mINFO    [0m File [35m/nfs/team205/kk18/Analysis/scArches/models/global_ref_scanvae/[0m[95mmodel.pt[0m already downloaded            


  f"Missing labels key {self._original_attr_key}. Filling in with unlabeled category {self._unlabeled_category}."


Labelled Indices:  0
Unlabelled Indices:  75255


In [33]:
model.view_anndata_setup(adata_que)

In [34]:
model.train(
    max_epochs=100,
    plan_kwargs=dict(weight_decay=0.0),
    check_val_every_n_epoch=10,
)

[34mINFO    [0m Training for [1;36m100[0m epochs.                                                                                  


INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 100/100: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [21:02<00:00, 12.80s/it, loss=1.14e+03, v_num=1]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 100/100: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [21:02<00:00, 12.63s/it, loss=1.14e+03, v_num=1]


In [35]:
# save model
surgery_path = '/nfs/team205/kk18/Analysis/scArches/models/surgery_model'
model.save(surgery_path, overwrite=True)

In [None]:
'''
# save adata
adata_que.obsm["X_scANVI_surgery"] = model.get_latent_representation()
adata_que.obs['scANVI_predictions'] = model.predict()
adata_que.write('./anndata/multiome_RNA_adult_new-SAN-AVN.h5ad')
'''

In [36]:
query_latent = sc.AnnData(model.get_latent_representation())
query_latent.obs['cell_type'] = adata_que.obs[cell_type_key].tolist()
query_latent.obs['batch'] = adata_que.obs["batch_key"].tolist()
query_latent.obs['predictions'] = model.predict()
query_latent.obs_names = model.adata.obs_names.copy()

# save latent anndata
query_latent.write('/nfs/team205/kk18/Analysis/scArches/latent_anndata/query_latent.h5ad')

# Get latent representation of reference + query dataset and compute UMAP

In [None]:
# add scANVI prediction outs to reference adata
# adata_ref.obs['scANVI_predictions'] = scanvae.predict()

In [37]:
adata_full = adata_ref.concatenate(adata_que,
                                  index_unique = None,
                                  batch_key= 'original_or_new',
                                  batch_categories=['original','new'])

full_latent = sc.AnnData(model.get_latent_representation(adata=adata_full))
full_latent.obs['cell_type'] = adata_full.obs[cell_type_key].tolist()
full_latent.obs['batch'] = adata_full.obs["batch_key"].tolist()
full_latent.obs['predictions'] = model.predict(adata=adata_full)
full_latent.obs_names = adata_full.obs_names.copy()
full_latent.obs['original_or_new'] = adata_full.obs['original_or_new'].tolist()

# save
full_latent.write('/nfs/team205/kk18/Analysis/scArches/latent_anndata/full_latent.h5ad')

  [AnnData(sparse.csr_matrix(a.shape), obs=a.obs) for a in all_adatas],


[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             
