In [1]:
# Import dependencies
%matplotlib inline
import os
import numpy as np
import pandas as pd
import scanpy as sc
import scanpy.external as sce
import seaborn as sns
import anndata
import matplotlib.pyplot as plt
import yaml
import scvi
import ray
import hyperopt
from ray import tune
from scvi import autotune

# Print date and time:
import datetime
e = datetime.datetime.now()
print ("Current date and time = %s" % e)

# Set other settings
sc.settings.verbosity = 3 # verbosity: errors (0), warnings (1), info (2), hints (3)
sc.logging.print_versions()
sc.set_figure_params(dpi=150, fontsize=10, dpi_save=600)

  self.seed = seed
  self.dl_pin_memory_gpu_training = (
  jax.tree_util.register_keypaths(data_clz, keypaths)
  jax.tree_util.register_keypaths(data_clz, keypaths)


Current date and time = 2025-01-23 13:57:35.875294
-----
anndata     0.9.1
scanpy      1.9.3
-----
PIL                         9.4.0
absl                        NA
aiohttp                     3.9.5
aiosignal                   1.3.1
annotated_types             0.5.0
anyio                       NA
arrow                       1.2.3
asttokens                   NA
async_timeout               4.0.3
attr                        23.1.0
attrs                       23.1.0
babel                       2.14.0
backcall                    0.2.0
backoff                     2.2.1
brotli                      NA
bs4                         4.12.2
certifi                     2024.07.04
cffi                        1.15.1
charset_normalizer          3.2.0
chex                        0.1.83
click                       8.1.5
cloudpickle                 3.0.0
colorama                    0.4.6
comm                        0.1.3
contextlib2                 NA
croniter                    NA
cycler                  

In [2]:
# set a working directory
wdir = '/media/apc1/ccohen/chromium/analysis/20240711_Achilles/'
os.chdir( wdir )

# create an output directory with today's date and time
year = e.strftime("%Y")
month = e.strftime("%m")
day = e.strftime("%d")
hour = e.strftime('%H')
minute = e.strftime('%M')
dmyt = f'{year}{month}{day}_{hour}-{minute}'
directory = f'{dmyt}_ray_autotune.dir'

# folder structures
RESULTS_FOLDERNAME = f'{directory}/results/'
FIGURES_FOLDERNAME = f'{directory}/figures/'

if not os.path.exists(RESULTS_FOLDERNAME):
    os.makedirs(RESULTS_FOLDERNAME)
if not os.path.exists(FIGURES_FOLDERNAME):
    os.makedirs(FIGURES_FOLDERNAME)
    
# Set folder for saving figures into
sc.settings.figdir = FIGURES_FOLDERNAME

print(directory)

20250123_13-57_ray_autotune.dir


In [3]:
# Read in the yml file
ini = yaml.safe_load(open('integration-scvi.yaml'))
print(yaml.safe_dump(ini))

datadir: data/integrated_objects/20250123_13-52_convert-objects.dir
neighbours:
  n_pcs: 30
variable_genes:
  batch: patient.seqbatch
  flavor: seurat
  hvg_subset: true
  n_genes: 5000



Read in the concatenated object.
In the concat_norm script, normalisation and dim reduction was performed but this is not actually needed here because we will start again from the raw counts. 
The only question is whether to work on the whole object or to subset to hvg (and if so how many)

In [4]:
wdir

'/media/apc1/ccohen/chromium/analysis/20240711_Achilles/'

In [5]:
path = os.path.join(wdir, ini['datadir'], 'Achilles_integrated_annotated.h5ad')
path

'/media/apc1/ccohen/chromium/analysis/20240711_Achilles/data/integrated_objects/20250123_13-52_convert-objects.dir/Achilles_integrated_annotated.h5ad'

In [6]:
# This will be the unintegrated reference data
# NB for some integration methods, here the data is subsetted to only hvg (see Alina's tutorial)
adata_ref = sc.read_h5ad(path)
adata_ref

AnnData object with n_obs × n_vars = 66892 × 61544
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sample', 'sum', 'detected', 'subsets_mito_sum', 'subsets_mito_detected', 'subsets_mito_percent', 'total', 'log10GenesPerUMI', 'patient', 'age', 'sex', 'ethnicity', 'surgical_procedure', 'disease_status', 'anatomical_site', 'affected_side', 'time_to_freezing', 'sequencing_date', 'microanatomical_site', 'seurat_clusters', 'decontX_contamination', 'decontX_clusters', 'sizeFactor', 'scDblFinder.cluster', 'scDblFinder.class', 'scDblFinder.score', 'scDblFinder.weighted', 'scDblFinder.difficulty', 'scDblFinder.cxds_score', 'scDblFinder.mostLikelyOrigin', 'scDblFinder.originAmbiguous', 'RNA_snn_res.0.1', 'RNA_snn_res.0.2', 'RNA_snn_res.0.3', 'nCount_decontXcounts', 'nFeature_decontXcounts', 'nCount_soupX', 'nFeature_soupX', 'soupX_fraction', 'patient.seqbatch', 'soupX_snn_res.0.1', 'soupX_snn_res.0.2', 'soupX_snn_res.0.3', 'soupX_snn_res.0.4', 'soupX_snn_res.0.5', 'soupX_snn_res.0.6', 'soup

In [7]:
# read in the feature metadata
path = os.path.join(wdir, ini['datadir'], 'gene_metafeatures.txt')
path
feature_metadata = pd.read_csv(path)
# add this to the adata.var
adata_ref.var = feature_metadata
adata_ref.var

Unnamed: 0,highly_variable
TTN,False
SLC30A5,True
ACTA1,True
MT-ND1,True
ENSG00000280441,False
...,...
ENSG00000277856,False
ENSG00000275987,False
ENSG00000268674,False
ENSG00000277475,False


In [8]:
# scvi works with raw counts
adata_ref.X = adata_ref.layers['counts'].copy()

In [9]:
# make a new object to perform the integration
adata_scvi = adata_ref.copy()


In [10]:
# take a snapshot
adata_scvi.raw = adata_scvi

In [11]:
# subset to hvg if required
if ini['variable_genes']['hvg_subset'] == True: 
    adata_scvi = adata_scvi[:, adata_scvi.var.highly_variable].copy()
    
adata_scvi

AnnData object with n_obs × n_vars = 66892 × 5000
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'sample', 'sum', 'detected', 'subsets_mito_sum', 'subsets_mito_detected', 'subsets_mito_percent', 'total', 'log10GenesPerUMI', 'patient', 'age', 'sex', 'ethnicity', 'surgical_procedure', 'disease_status', 'anatomical_site', 'affected_side', 'time_to_freezing', 'sequencing_date', 'microanatomical_site', 'seurat_clusters', 'decontX_contamination', 'decontX_clusters', 'sizeFactor', 'scDblFinder.cluster', 'scDblFinder.class', 'scDblFinder.score', 'scDblFinder.weighted', 'scDblFinder.difficulty', 'scDblFinder.cxds_score', 'scDblFinder.mostLikelyOrigin', 'scDblFinder.originAmbiguous', 'RNA_snn_res.0.1', 'RNA_snn_res.0.2', 'RNA_snn_res.0.3', 'nCount_decontXcounts', 'nFeature_decontXcounts', 'nCount_soupX', 'nFeature_soupX', 'soupX_fraction', 'patient.seqbatch', 'soupX_snn_res.0.1', 'soupX_snn_res.0.2', 'soupX_snn_res.0.3', 'soupX_snn_res.0.4', 'soupX_snn_res.0.5', 'soupX_snn_res.0.6', 'soupX

Optimise the scVI model using ray

In [12]:
# set up the object and view the available paramaters that can be tuned

model_cls = scvi.model.SCVI
model_cls.setup_anndata(adata_scvi, layer="counts", 
                        batch_key='patient.seqbatch')

scvi_tuner = autotune.ModelTuner(model_cls)
scvi_tuner.info()

  _verify_and_correct_data_format(adata, self.attr_name, self.attr_key)


In [13]:
# specify which variables will be tested
search_space = {
    "n_latent": tune.choice([10, 30, 50]),
    "n_hidden": tune.choice([60, 128, 256]),
    "n_layers": tune.choice([1, 2, 3]),
    "lr": tune.loguniform(1e-4, 1e-2),
    "gene_likelihood": tune.choice(["nb", "zinb"])
}

In [14]:
ray.init(log_to_driver=False)

2025-01-23 13:57:42,856	INFO worker.py:1633 -- Started a local Ray instance. View the dashboard at [1m[32m127.0.0.1:8265 [39m[22m


0,1
Python version:,3.9.16
Ray version:,2.7.0
Dashboard:,http://127.0.0.1:8265


In [15]:
# run the optimisation

results = scvi_tuner.fit(
    adata_scvi,
    metric="validation_loss",
    search_space=search_space,
    searcher='hyperopt',
    num_samples=100,
    max_epochs=30,
    resources={"gpu": 1}
)

0,1
Current time:,2025-01-23 16:16:21
Running for:,02:18:37.09
Memory:,80.5/503.5 GiB

Trial name,status,loc,n_latent,n_hidden,n_layers,lr,gene_likelihood,validation_loss
_trainable_d8e34b9e,TERMINATED,163.1.64.158:809328,50,60,1,0.000510498,nb,889.784
_trainable_73c5c4b0,TERMINATED,163.1.64.158:809791,50,128,3,0.000488935,nb,888.101
_trainable_1a08d563,TERMINATED,163.1.64.158:809328,50,60,3,0.000501782,zinb,938.433
_trainable_a20696d3,TERMINATED,163.1.64.158:809791,30,128,3,0.00375528,zinb,884.327
_trainable_e54aaaaa,TERMINATED,163.1.64.158:809328,10,128,2,0.00261157,zinb,885.054
_trainable_60da39f9,TERMINATED,163.1.64.158:809791,30,256,1,0.000891235,zinb,870.987
_trainable_af21c593,TERMINATED,163.1.64.158:809328,10,128,1,0.000186963,zinb,1021.22
_trainable_68477c48,TERMINATED,163.1.64.158:809328,10,256,1,0.0012693,nb,900.826
_trainable_6183d4a6,TERMINATED,163.1.64.158:809328,30,128,2,0.000492737,nb,998.117
_trainable_250a2b46,TERMINATED,163.1.64.158:809328,30,60,3,0.000423125,zinb,1035.76


2025-01-23 13:57:44,300	INFO tune.py:645 -- [output] This uses the legacy output and progress reporter, as Jupyter notebooks are not supported by the new engine, yet. For more information, please see https://github.com/ray-project/ray/issues/36949
2025-01-23 16:16:21,479	INFO tune.py:1143 -- Total run time: 8317.18 seconds (8317.07 seconds for the tuning loop).


We are looking for the parameters that give the lowest validation loss

In [16]:
print(results.model_kwargs)
print(results.train_kwargs)

{'n_latent': 50, 'n_hidden': 128, 'n_layers': 1, 'gene_likelihood': 'zinb'}
{'plan_kwargs': {'lr': 0.002132748861038831}}


In [17]:
df = results.results.get_dataframe()
df

Unnamed: 0,validation_loss,timestamp,done,training_iteration,trial_id,date,time_this_iter_s,time_total_s,pid,hostname,node_ip,time_since_restore,iterations_since_restore,checkpoint_dir_name,config/n_latent,config/n_hidden,config/n_layers,config/lr,config/gene_likelihood,logdir
0,889.783997,1737641110,False,30,d8e34b9e,2025-01-23_14-05-10,14.466383,441.649517,809328,BRC-89SJ904,163.1.64.158,441.649517,30,,50,60,1,0.000510,nb,d8e34b9e
1,888.100708,1737641130,False,30,73c5c4b0,2025-01-23_14-05-30,15.379817,458.020491,809791,BRC-89SJ904,163.1.64.158,458.020491,30,,50,128,3,0.000489,nb,73c5c4b0
2,938.432556,1737641172,True,4,1a08d563,2025-01-23_14-06-12,15.433875,61.977791,809328,BRC-89SJ904,163.1.64.158,61.977791,4,,50,60,3,0.000502,zinb,1a08d563
3,884.327271,1737641603,False,30,a20696d3,2025-01-23_14-13-23,16.262115,472.973655,809791,BRC-89SJ904,163.1.64.158,472.973655,30,,30,128,3,0.003755,zinb,a20696d3
4,885.053833,1737641629,False,30,e54aaaaa,2025-01-23_14-13-49,15.380125,456.563732,809328,BRC-89SJ904,163.1.64.158,456.563732,30,,10,128,2,0.002612,zinb,e54aaaaa
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,950.593506,1737648489,True,1,e5b78a35,2025-01-23_16-08-09,15.194398,15.194398,809791,BRC-89SJ904,163.1.64.158,15.194398,1,,10,256,1,0.000938,nb,e5b78a35
96,906.245850,1737648521,True,2,eea76160,2025-01-23_16-08-41,15.624465,31.455614,809791,BRC-89SJ904,163.1.64.158,31.455614,2,,30,128,3,0.006971,zinb,eea76160
97,999.835876,1737648536,True,1,55f17855,2025-01-23_16-08-56,15.110570,15.110570,809791,BRC-89SJ904,163.1.64.158,15.110570,1,,50,60,1,0.001084,nb,55f17855
98,870.976440,1737648981,False,30,e43a5e95,2025-01-23_16-16-21,14.689486,445.105014,809791,BRC-89SJ904,163.1.64.158,445.105014,30,,50,128,1,0.008077,zinb,e43a5e95


In [18]:
df2 = df.sort_values(by = 'validation_loss').reset_index()
df2

Unnamed: 0,index,validation_loss,timestamp,done,training_iteration,trial_id,date,time_this_iter_s,time_total_s,pid,...,node_ip,time_since_restore,iterations_since_restore,checkpoint_dir_name,config/n_latent,config/n_hidden,config/n_layers,config/lr,config/gene_likelihood,logdir
0,68,863.351196,1737647517,False,30,318dbbde,2025-01-23_15-51-57,16.073066,483.769895,809328,...,163.1.64.158,483.769895,30,,50,128,1,0.002133,zinb,318dbbde
1,65,866.652649,1737646791,False,30,f6a485f9,2025-01-23_15-39-51,15.338087,453.649527,809791,...,163.1.64.158,453.649527,30,,50,128,1,0.002194,zinb,f6a485f9
2,37,867.100098,1737644794,False,30,189047d3,2025-01-23_15-06-34,14.658142,440.764652,809791,...,163.1.64.158,440.764652,30,,50,128,1,0.003477,nb,189047d3
3,80,867.416748,1737648197,False,30,eb6dd5db,2025-01-23_16-03-17,14.866749,448.965191,809791,...,163.1.64.158,448.965191,30,,50,256,1,0.000740,zinb,eb6dd5db
4,51,867.745605,1737645948,False,30,02a80d67,2025-01-23_15-25-48,15.512162,461.485268,809328,...,163.1.64.158,461.485268,30,,50,128,1,0.002895,zinb,02a80d67
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,35,1038.391479,1737644113,True,1,40d6b492,2025-01-23_14-55-13,15.311136,15.311136,809791,...,163.1.64.158,15.311136,1,,50,60,2,0.000325,zinb,40d6b492
96,59,1047.908447,1737646025,True,1,aadde92a,2025-01-23_15-27-05,15.503129,15.503129,809328,...,163.1.64.158,15.503129,1,,50,128,1,0.000257,nb,aadde92a
97,31,1087.877808,1737643956,True,1,e837b0c5,2025-01-23_14-52-36,15.338459,15.338459,809328,...,163.1.64.158,15.338459,1,,50,60,1,0.000363,nb,e837b0c5
98,57,1138.627319,1737645979,True,1,d1189219,2025-01-23_15-26-19,15.498403,15.498403,809328,...,163.1.64.158,15.498403,1,,50,60,1,0.000151,zinb,d1189219


In [19]:
print("Index of optimal parameters")
row_number = df2['index'][0]
row_number

Index of optimal parameters


68

In [20]:
print("Optimal parameters")
df.iloc[row_number]

Optimal parameters


validation_loss                      863.351196
timestamp                            1737647517
done                                      False
training_iteration                           30
trial_id                               318dbbde
date                        2025-01-23_15-51-57
time_this_iter_s                      16.073066
time_total_s                         483.769895
pid                                      809328
hostname                            BRC-89SJ904
node_ip                            163.1.64.158
time_since_restore                   483.769895
iterations_since_restore                     30
checkpoint_dir_name                        None
config/n_latent                              50
config/n_hidden                             128
config/n_layers                               1
config/lr                              0.002133
config/gene_likelihood                     zinb
logdir                                 318dbbde
Name: 68, dtype: object

In [21]:
ray.shutdown()