In [None]:
import anndata
import numpy as np
import scvelo as scv
import scanpy as sc
import sys
import torch
import os.path
import deepvelo as dv
import pickle as pickle
import matplotlib.pyplot as plt
import pandas as pd
import unitvelo as utv
from os.path import exists
import time
method = 'DeepVelo'

In [None]:
datasets = ['MouseErythroid', 'Pancreas_with_cc', 'DentateGyrus' , 'MouseBoneMarrow', 'HumanBoneMarrow', 'HumanDevelopingBrain']
data_dir = '/nfs/team283/aa16/data/fate_benchmarking/benchmarking_datasets/'
save_dir = '/nfs/team283/aa16/data/fate_benchmarking/benchmarking_results_revision/'

In [None]:
for dataset in datasets:
    print(dataset)
    adata = sc.read_h5ad(data_dir + dataset + '/' + dataset + '_anndata.h5ad')
    start = time.time()
    scv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=3000)
    scv.pp.moments(adata, n_pcs=30, n_neighbors=30)
    trainer = dv.train(adata, dv.Constants.default_configs)
    scv.pp.neighbors(adata)
    scv.tl.velocity_graph(adata, vkey = 'velocity')
    scv.tl.velocity_embedding(adata, vkey = 'velocity')
    end = time.time()
    fix, ax = plt.subplots(1, 1, figsize = (8, 6))
    scv.pl.velocity_embedding_stream(adata, basis='umap', save = False, vkey='velocity',
                                     show = False, ax = ax)
    plt.savefig(save_dir + 'UMAPs/' + dataset + '_UMAP_DeepVelo.svg')
    # Calculate performance metrics:
    file = open(data_dir + dataset + '/' + dataset + '_groundTruth.pickle' ,'rb')
    ground_truth = pickle.load(file)
    metrics = utv.evaluate(adata, ground_truth, 'clusters', 'velocity')
    if exists(save_dir + dataset + '_CBDC_scores.csv'):
        tab = pd.read_csv(save_dir + dataset + '_CBDC_scores.csv', index_col = 0)
    else:
        tab = pd.DataFrame(columns = list(metrics['Cross-Boundary Direction Correctness (A->B)'].keys()) + ['Mean', 'Time'],
                 index = [method])
    cb_score = [np.mean(metrics['Cross-Boundary Direction Correctness (A->B)'][x])
                for x in metrics['Cross-Boundary Direction Correctness (A->B)'].keys()]
    tab.loc[method,:] = cb_score + [np.mean(cb_score), end-start]
    tab.to_csv(save_dir + dataset + '_CBDC_scores.csv')
    fix, ax = plt.subplots(1, 1, figsize = (8, 6))
    scv.pl.velocity_embedding_stream(adata, basis='umap', save = False, vkey='velocity',
                                     show = False, ax = ax)
    plt.savefig(save_dir + 'UMAPs/' + dataset + '_UMAP_' + method + '.svg')
    adata.write_h5ad('/nfs/team283/aa16/data/fate_benchmarking/' + method + dataset + 'AnnDataForCellRank.h5ad')

In [None]:
    fix, ax = plt.subplots(1, 1, figsize = (8, 6))
    scv.pl.velocity_embedding_stream(adata, basis='umap', save = False, vkey='velocity',
                                     show = False, ax = ax)
    plt.savefig(save_dir + 'UMAPs/' + dataset + '_UMAP_' + method + '.svg')