In [1]:
from scvi.module.base import PyroBaseModuleClass
from scvi.train import PyroTrainingPlan
from typing import Optional, Union
import pyro

max_epochs = 4000
start_lr = 0.01
final_lr = 0.001
lrd = (final_lr/start_lr)**(1/max_epochs)
clipped_adam = pyro.optim.ClippedAdam({"lr": start_lr, "lrd": lrd,  "clip_norm": 10.0})

class PyroTrainingPlan_ClippedAdamDecayingRate(PyroTrainingPlan):
    """
    Lightning module task to train Pyro scvi-tools modules.
    Parameters
    ----------
    pyro_module
        An instance of :class:`~scvi.module.base.PyroBaseModuleClass`. This object
        should have callable `model` and `guide` attributes or methods.
    loss_fn
        A Pyro loss. Should be a subclass of :class:`~pyro.infer.ELBO`.
        If `None`, defaults to :class:`~pyro.infer.Trace_ELBO`.
    optim
        A Pyro optimizer instance, e.g., :class:`~pyro.optim.Adam`. If `None`,
        defaults to :class:`pyro.optim.Adam` optimizer with a learning rate of `1e-3`.
    optim_kwargs
        Keyword arguments for **default** optimiser :class:`pyro.optim.Adam`.
    n_aggressive_epochs
        Number of epochs in aggressive optimisation of amortised variables.
    n_aggressive_steps
        Number of steps to spend optimising amortised variables before one step optimising global variables.
    n_steps_kl_warmup
        Number of training steps (minibatches) to scale weight on KL divergences from 0 to 1.
        Only activated when `n_epochs_kl_warmup` is set to None.
    n_epochs_kl_warmup
        Number of epochs to scale weight on KL divergences from 0 to 1.
        Overrides `n_steps_kl_warmup` when both are not `None`.
    """

    def __init__(
        self,
        pyro_module: PyroBaseModuleClass,
        loss_fn: Optional[pyro.infer.ELBO] = None,
        optim: Optional[pyro.optim.PyroOptim] = clipped_adam,
        optim_kwargs: Optional[dict] = None,
    ):
        super().__init__(
            pyro_module=pyro_module,
            loss_fn=loss_fn,
            optim=optim,
            optim_kwargs=optim_kwargs
        )

        self.svi = pyro.infer.SVI(
            model=pyro_module.model,
            guide=pyro_module.guide,
            optim=self.optim,
            loss=self.loss_fn,
        )

Global seed set to 0


In [2]:
import os
os.chdir('..')
os.chdir('..')
import scvelo as scv
import scanpy as sc
import scvi
scvi.settings.seed = 1
import cell2fate as c2f
import pickle as pickle
from eval_utils import cross_boundary_correctness
from datetime import datetime
import pandas as pd
import numpy as np
from os.path import exists
import matplotlib.pyplot as plt
import torch
import unitvelo as utv
method = 'Cell2fateDynamicalModel_pyroVelocityTrainingPlan'
data_dir = '/nfs/team283/aa16/data/fate_benchmarking/benchmarking_datasets/'
save_dir = '/nfs/team283/aa16/data/fate_benchmarking/benchmarking_results/'
datasets = ['Pancreas_with_cc',  'DentateGyrus' , 'MouseBoneMarrow', 'MouseErythroid', 'HumanBoneMarrow']
n_genes_list = np.array((2000, 3000))
n_counts_list = np.array((10, 20))

for i in range(len(n_genes_list)):
    for j in range(len(n_counts_list)):
        for k in (2,0,1,3,4):
            print(i)
            print(j)
            print(k)
            dataset = datasets[k]
            n_genes = n_genes_list[i]
            min_counts = n_counts_list[j]
            model_index = str(i) + '-' + str(j) + '-' + str(k)
            save_name = method + '_'
            if exists(save_dir + save_name + '_CBDC_fullBenchmark.csv'):
                tab = pd.read_csv(save_dir + save_name + '_CBDC_fullBenchmark.csv', index_col = 0)
                if model_index in tab.index:
                    continue
            adata = sc.read_h5ad(data_dir + dataset + '/' + dataset + '_anndata.h5ad')
            adata = c2f.utils.get_training_data(adata, cells_per_cluster = 10**6, cluster_column = 'clusters',
                                            remove_clusters = [], min_shared_counts = min_counts, n_var_genes= n_genes)
            c2f.Cell2fate_DynamicalModel.setup_anndata(adata, spliced_label='spliced', unspliced_label='unspliced')    
            n_modules = c2f.utils.get_max_modules(adata)
            mod = c2f.Cell2fate_DynamicalModel(adata,
                                               n_modules = n_modules)
            mod.train(max_epochs = 4000, **{'training_plan' : PyroTrainingPlan_ClippedAdamDecayingRate},
                     early_stopping = True, early_stopping_min_delta = 10**(-4),
                     early_stopping_monitor = 'elbo_train', early_stopping_patience = 45)
            sample_kwarg = {"num_samples": 10, "batch_size" : adata.n_obs,
                 "use_gpu" : True, 'return_samples': False}
            adata = mod.export_posterior(adata, sample_kwargs = sample_kwarg)
            mod.compute_and_plot_total_velocity_scvelo(adata, save = False, delete = False)
            # Calculate performance metrics:
            file = open(data_dir + dataset + '/' + dataset + '_groundTruth.pickle' ,'rb')
            ground_truth = pickle.load(file)
            metrics = utv.evaluate(adata, ground_truth, 'clusters', 'velocity')
            cb_score = [np.mean(metrics['Cross-Boundary Direction Correctness (A->B)'][x])
                        for x in metrics['Cross-Boundary Direction Correctness (A->B)'].keys()]
            if exists(save_dir + save_name + '_CBDC_fullBenchmark.csv'):
                tab = pd.read_csv(save_dir + save_name + '_CBDC_fullBenchmark.csv', index_col = 0)
            else:
                c_names = ['CBDC']
                tab = pd.DataFrame(columns = c_names)
            tab.loc[model_index, 'CBDC'] = np.mean(cb_score)
            tab.to_csv(save_dir + save_name + '_CBDC_fullBenchmark.csv')
tab = pd.read_csv(save_dir + save_name + '_CBDC_fullBenchmark.csv', index_col = 0)
tab.loc['AVERAGE', 'CBDC'] = np.mean(tab['CBDC'])
tab.to_csv(save_dir + save_name + '_CBDC_fullBenchmark.csv')

Global seed set to 1


(Running UniTVelo 0.2.5.2)
2023-11-26 12:22:30


2023-11-26 12:22:32.400998: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-26 12:22:32.401373: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-26 12:22:32.558529: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


0
0
2
0
0
0
0
0
1
0
0
3
0
0
4
0
1
2
0
1
0
0
1
1
0
1
3
0
1
4
1
0
2
1
0
0
1
0
1
1
0
3
1
0
4
1
1
2
1
1
0
1
1
1
1
1
3
1
1
4
