In [None]:
!python --version

In [None]:
!pwd

In [None]:
!cd /workspace

In [8]:
!pip install scanpy scikit-misc scvi-tools mplscience  

In [9]:
import os
import gc
import warnings
warnings.filterwarnings("ignore")

# single cell pipelines
import scanpy as sc
import numpy as np
import pandas as pd

import seaborn as sns
import matplotlib.pyplot as plt
import mplscience

# scvi-tools
import scvi
import torch

print(scvi.__version__)
sc.logging.print_header()

np.random.seed(777)
torch.manual_seed(777)
scvi.settings.seed = 777

In [None]:
batch_key = 'sample_id'

# 1. autotune

In [None]:
# !pip install scanpy scvi-tools scikit-misc ray[tune] hyperopt celltypist mplscience  

In [None]:
from scvi import autotune
from ray import tune
import ray

In [None]:
adata = sc.read_h5ad('')

In [None]:
sc.pp.highly_variable_genes(adata, n_top_genes = 3000, subset = True, layer = 'counts', flavor = 'seurat_v3', batch_key = batch_key) 

In [None]:
model_cls = scvi.model.SCVI
model_cls.setup_anndata(adata, layer = 'counts', batch_key = batch_key, continuous_covariate_keys=['pct_counts_mt', 'pct_counts_ribo'])
gc.collect()

In [None]:
search_space = {
    "model_params": {
        "n_hidden" : tune.choice([92, 128, 192, 256]),
        "n_latent" : tune.choice([10, 20, 30, 40, 50, 60]),
        "n_layers" : tune.choice([1,2,3]),
        "gene_likelihood" : tune.choice(["nb", "zinb"])
    },
    "train_params" : {
        "max_epochs" : 100, 
        "plan_kwargs": {
            "lr": tune.loguniform(1e-4, 1e-2)
        },
        "datasplitter_kwargs": {
            "drop_last": True,
        }
    }
}

In [None]:
ray.init(log_to_driver=False)

In [None]:
results = autotune.run_autotune(model_cls, data = adata, mode = "min",  metrics = "validation_loss", search_space = search_space, num_samples = 100)

In [None]:
best_vl, best_i = 10000, 0
for i, res in enumerate(results.result_grid):
    vl = res.metrics['validation_loss']
    if vl < best_vl:
        best_vl = vl
        best_i = i
# get hyperparameter        
results.result_grid[best_i]

# 2. scVI
- change params

In [None]:
scvi.model.SCVI.setup_anndata(adata, layer = 'counts', batch_key='sample_id', continuous_covariate_keys=['pct_counts_mt', 'pct_counts_ribo'])
scvi_model = scvi.model.SCVI(adata, gene_likelihood='zinb', n_hidden=192, n_latent=30, n_layers=2, dropout_rate=0.5)  

# scvi_model = scvi.model.SCVI(adata, n_latent=30, n_layers=2)  
gc.collect()

In [None]:
scvi_model.train(early_stopping=True, datasplitter_kwargs={"drop_last": True}, plan_kwargs={"lr": 0.0027},)

In [None]:
with mplscience.style_context():
    y = scvi_model.history['reconstruction_loss_validation']['reconstruction_loss_validation'].min()
    plt.plot(scvi_model.history['reconstruction_loss_validation']['reconstruction_loss_validation'], label = 'validation')
    plt.plot(scvi_model.history['reconstruction_loss_train']['reconstruction_loss_train'], label = 'train')

    plt.axhline(y, c = 'k')
    plt.legend()
    plt.show()

In [None]:
# read again
adata = sc.read_h5ad('')
adata.obsm['X_scVI'] = scvi_model.get_latent_representation()
adata.raw = adata

In [None]:
gc.collect()
sc.pp.neighbors(adata, use_rep="X_scVI")
sc.tl.leiden(adata)
sc.tl.umap(adata)

In [None]:
adata.write('')
scvi_model.save('model_scvi')

# 3. scANVI

In [None]:
# scvi_model = scvi.model.SCVI.load('model_path', adata)

In [None]:
scanvi_model = scvi.model.SCANVI.from_scvi_model(scvi_model= scvi_model, adata = adata, labels_key = 'cell_type', unlabeled_category='unlabelled')

In [None]:
scanvi_model.train(max_epochs = 30, early_stopping = True)

# scanvi_model.train(
#     max_epochs=30,
#     early_stopping=True,
#     early_stopping_monitor="validation_loss",
#     early_stopping_patience=10,
#     plan_kwargs={
#         "lr": 1e-3,           
#         "weight_decay": 1e-4, 
#         "n_epochs_kl_warmup": 5  
#     }
# )

In [None]:
# read again
adata = sc.read_h5ad('')
adata.obsm["X_scANVI"] = scanvi_model.get_latent_representation()

In [None]:
scanvi_model.save('model_scanvi')