In [1]:
import scvelo as scv
import scanpy as sc
import cell2fate as c2f
import pickle as pickle
from eval_utils import cross_boundary_correctness
from datetime import datetime
import pandas as pd
import numpy as np
from os.path import exists
import matplotlib.pyplot as plt
import torch

Global seed set to 0


In [2]:
datasets = ['Pancreas_with_cc', 'DentateGyrus', 'MouseErythroid', 'MouseBoneMarrow', 'HumanBoneMarrow']
data_dir = '/nfs/team283/aa16/data/fate_benchmarking/benchmarking_datasets/'
save_dir = '/nfs/team283/aa16/data/fate_benchmarking/benchmarking_results/'
modules_per_dataset = [15, 15, 5, 5, 10]

In [None]:
for i in range(len(datasets)):
    dataset = datasets[i]
    adata = sc.read_h5ad(data_dir + dataset + '/' + dataset + '_anndata.h5ad')
    adata.layers['unspliced_raw'] = adata.layers['unspliced']
    adata.layers['spliced_raw'] = adata.layers['spliced']
    scv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=3000)
    c2f.Cell2fate_DynamicalModel.setup_anndata(adata, spliced_label='spliced_raw', unspliced_label='unspliced_raw')
    n_modules = int(np.round(2.0*len(np.unique(adata.obs['clusters']))))
    mod = c2f.Cell2fate_DynamicalModel(adata, n_modules = modules_per_dataset[i], stochastic_v_ag_hyp_prior={"alpha": 6.0, "beta": 3.0})
    mod.train()
    adata = mod.export_posterior(adata)
    fig, ax = plt.subplots(1,2, figsize = (15, 5))
    sc.pl.umap(adata, color = ['Time (hours)'], legend_loc = 'right margin',
                    size = 200, color_map = 'inferno', ncols = 2, show = False, ax = ax[0])
    sc.pl.umap(adata, color = ['Time Uncertainty (sd)'], legend_loc = 'right margin',
                    size = 200, color_map = 'inferno', ncols = 2, show = False, ax = ax[1])
    plt.savefig(save_dir + 'c2f_plots/' + dataset + '_UMAP_Time2.png')
    scv.pp.neighbors(adata)
    scv.tl.velocity_graph(adata, vkey='velocity', xkey = 'spliced mean')
    scv.tl.velocity_embedding(adata, vkey='velocity')
    fix, ax = plt.subplots(1, 1, figsize = (8, 6))
    scv.pl.velocity_embedding_stream(adata, basis='umap', save = False, vkey='velocity',
                                     show = False, ax = ax)
    plt.savefig(save_dir + 'UMAPs/' + dataset + '_UMAP_cell2fate2.png')
    # Calculate performance:
    file = open(data_dir + dataset + '/' + dataset + '_groundTruth.pickle' ,'rb')
    ground_truth = pickle.load(file)
    score = cross_boundary_correctness(adata = adata, k_cluster = 'clusters',
                               k_velocity = 'velocity', cluster_edges = ground_truth)
    tab = pd.read_csv(save_dir + dataset + '_scores.csv', index_col = 0)
    tab.loc['cell2fate2',:] = np.array(list(score[0].values()) + [score[1]])
    tab.to_csv(save_dir + dataset + '_scores.csv')  

Filtered out 20801 genes that are detected 20 counts (shared).
Normalized count data: X, spliced, unspliced.
Extracted 3000 highly variable genes.
Logarithmized X.


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 250/500:  50%|████████████████████████████████████████████████████████████████████████████▋                                                                             | 249/500 [04:44<04:45,  1.14s/it, v_num=1, elbo_train=1.15e+7]