In [None]:
# Import dependencies
%matplotlib inline
import os
import numpy as np
import pandas as pd
import scanpy as sc
import scanpy.external as sce
import seaborn as sns
import anndata
import matplotlib.pyplot as plt
import yaml
import scvi
import ray
import hyperopt
from ray import tune
from scvi import autotune

# Print date and time:
import datetime
e = datetime.datetime.now()
print ("Current date and time = %s" % e)

# Set other settings
sc.settings.verbosity = 3 # verbosity: errors (0), warnings (1), info (2), hints (3)
sc.logging.print_versions()
sc.set_figure_params(dpi=150, fontsize=10, dpi_save=600)

In [None]:
# set a working directory
wdir = '/project/tendonhca/ccohen/chromium/analysis/20240221_achilles_python'
os.chdir( wdir )

# create an output directory with today's date and time
year = e.strftime("%Y")
month = e.strftime("%m")
day = e.strftime("%d")
hour = e.strftime('%H')
minute = e.strftime('%M')
dmyt = f'{year}{month}{day}_{hour}-{minute}'
directory = f'{dmyt}_integration-scvi.dir'

# folder structures
RESULTS_FOLDERNAME = f'{directory}/results/'
FIGURES_FOLDERNAME = f'{directory}/figures/'

if not os.path.exists(RESULTS_FOLDERNAME):
    os.makedirs(RESULTS_FOLDERNAME)
if not os.path.exists(FIGURES_FOLDERNAME):
    os.makedirs(FIGURES_FOLDERNAME)
    
# Set folder for saving figures into
sc.settings.figdir = FIGURES_FOLDERNAME

print(directory)

In [None]:
# Read in the yml file
ini = yaml.safe_load(open('integration-scvi.yaml'))
print(yaml.safe_dump(ini))

Read in the concatenated object.
In the concat_norm script, normalisation and dim reduction was performed but this is not actually needed here because we will start again from the raw counts. 
The only question is whether to work on the whole object or to subset to hvg (and if so how many)

In [None]:
wdir

In [None]:
# path = os.path.join(wdir, 'concat_norm/results/merged_normalised.h5ad')
# For testing use the subsetted object with only 3 samples in it
path = os.path.join(wdir, ini['datadir'], 'results/Achilles_subset.h5ad')
path

In [None]:
print('Reading adata object')

In [None]:
# This will be the unintegrated reference data
# NB for some integration methods, here the data is subsetted to only hvg (see Alina's tutorial)
adata_ref = sc.read_h5ad(path)
adata_ref

In [None]:
print('Adata object read successfully')

In [None]:
# scvi works with raw counts
adata_ref.X = adata_ref.layers['counts'].copy()

In [None]:
# make a new object to perform the integration
adata_scvi = adata_ref.copy()


In [None]:
# take a snapshot
adata_scvi.raw = adata_scvi

In [None]:
# subset to hvg TODO Add parameter for this to be optional
adata_scvi = adata_scvi[:, adata_scvi.var.highly_variable].copy()
adata_scvi


Optimise the scVI model using ray

In [None]:
# set up the object and view the available paramaters that can be tuned

model_cls = scvi.model.SCVI
model_cls.setup_anndata(adata_scvi, layer="counts", 
                        batch_key='patient.seqbatch')

scvi_tuner = autotune.ModelTuner(model_cls)
scvi_tuner.info()

In [None]:
# specify which variables will be tested
search_space = {
    "n_latent": tune.choice([10, 30, 50]),
    "n_hidden": tune.choice([60, 128, 256]),
    "n_layers": tune.choice([1, 2, 3]),
    "lr": tune.loguniform(1e-4, 1e-2),
    "gene_likelihood": tune.choice(["nb", "zinb"])
}

In [None]:
ray.init(log_to_driver=False)

In [None]:
print("Performing parameter tuning")

In [None]:
# run the optimisation
results = scvi_tuner.fit(
    adata_scvi,
    metric="validation_loss",
    search_space=search_space,
    searcher='hyperopt',
    num_samples=100,
    max_epochs=30 #,
    #resources={"gpu": 1},
)

In [None]:
print("Parameter tuning complete")
print("Results of parameter tuning")

In [None]:
print(results.model_kwargs)
print(results.train_kwargs)

In [None]:
# find the best parameters
# this loop goes through all the parameters and finds the index with the best result (lowest validation loss)
best_vl = 10000
best_i = 0
for i, res in enumerate(results.results):
    vl = res.metrics['validation_loss']

    if vl < best_vl:
        best_vl = vl
        best_i = i



In [None]:
print("Index of optimal parameters")
best_i

In [None]:
print("Optimal parameters")
results.results[best_i]

In [None]:
ray.shutdown()

In [None]:
print ("script completed")