## STREAM trajectory inference
This script notebook tests the STREAM trajectory inference method on a few CyTOF and scRNASeq datasets.  
stream github: https://github.com/pinellolab/STREAM with tutorials
stream paper: https://www.nature.com/articles/s41467-019-09670-4

The working directory is Ahmad_workdir

In [None]:
## imports

## core
from datetime import datetime
import time
import os
import ipdb
import sys
# os.chdir("/home/rstudio/data/Ahmad_workdir")

## aux
import scanpy as sc
from sklearn.decomposition import PCA
import numpy as np
import pandas as pd
import importlib
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import rcParams
%matplotlib inline

### stream
import stream as st
st.set_figure_params(dpi=80,style='white',figsize=[5.4,4.8],
                     rc={'image.cmap': 'viridis'})


In [None]:
# have to use the deprecated version to avoid conflicts
!pip install networkx==2.1

In [None]:
import networkx as nx

In [None]:
## all datasets have been exported from tviblindi objects in R

# read mass data 
data_full = pd.read_csv("./data/tv1_data_full.csv", delimiter=",")
data_labels = pd.read_csv("./data/tv1_labels_full.csv", delimiter=",")
dimred_vae = pd.read_csv("./data/tv1_dr_vae_full.csv", delimiter=",")
dimred_tsne = pd.read_csv("./data/tv1_dr_tsne_full.csv", delimiter=",")
dimred_umap = pd.read_csv("./data/tv1_dr_umap_full.csv", delimiter=",")

# read artificial data
arti_data = pd.read_csv("./data/arti_data.csv", delimiter=",")
arti_labels = pd.read_csv("./data/arti_labels.csv", delimiter=",")
arti_dimred_vae = pd.read_csv("./data/arti_dr_vae.csv", delimiter=",")
arti_dimred_umap = pd.read_csv("./data/arti_dr_umap.csv", delimiter=",")

# artificial data - not upsampled
arti_data_o = pd.read_csv("./data/arti_data_o.csv", delimiter=",")
arti_labels_o = pd.read_csv("./data/arti_labels_o.csv", delimiter=",")
arti_dimred_vae_o = pd.read_csv("./data/arti_dr_vae_o.csv", delimiter=",")
arti_dimred_umap_o = pd.read_csv("./data/arti_dr_umap_o.csv", delimiter=",")

# 3rd real world dataset

data_fullD = pd.read_csv('./data/tvD_data.csv')
data_labelsD = pd.read_csv('./data/tvD_labels.csv')
dimred_vaeD = pd.read_csv('./data/tvD_dr_vae.csv')

In [None]:
# 2nd real world dataset
data_full2 = pd.read_csv('./data/dataL.csv')

data_labels2 = pd.read_csv('./data/dataL_labels.csv')
data_labels2 = data_labels2.gate.to_frame()
dimred_vae2 = data_full2[['vavevictis_1', 'vaevictis_2']]

dimred_vae2.index = data_full2.index
data_labels2.index = data_full2.index

data_full2 = data_full2.drop(['gate','vavevictis_1','vaevictis_2'], axis=1)


In [None]:
# This function generates a pdf file with trajectory skeleton plots in the specified workdir
def test_stream(data, dimred = None, labels = None, origin = None):
    # preprocessing
    ad = sc.AnnData(data)
    ad.obs_names_make_unique()
    ad.var_names_make_unique()
    ad.obs['label'] = labels
    ad.uns['workdir'] = './stream_results'

    # custom layout
    ad.obsm['X_dr'] = dimred.to_numpy()

    # prepare elastic graph
    st.seed_elastic_principal_graph(ad,n_clusters=15,n_neighbors=30)
    
    # compute elastic graph
    now = datetime.now().strftime('%m-%d-%H%M%S')
    st.elastic_principal_graph(ad,epg_alpha=0.02,epg_mu=0.1,epg_lambda=0.02,
                              fig_name = 'ElPiGraph_analysis'+now+'.pdf')
    
    ## adjusting trajectories (optional)

    # finetune branching event
    st.optimize_branching(ad,incr_n_nodes=30)

    # prune trivial branches
    st.prune_elastic_principal_graph(ad,epg_collapse_mode='EdgesNumber',epg_collapse_par=2)

    # shift branching node
    st.shift_branching(ad,epg_shift_mode='NodeDensity',epg_shift_radius=0.1,epg_shift_max=3)
    
    # extend leaf branches to reach further cells 
    st.extend_elastic_principal_graph(ad, epg_ext_mode='WeigthedCentroid',epg_ext_par=0.8)
    
    return ad

In [None]:
def evaluate_speed(data, labels, dimred, method=['via'],
                   origin = None, ori_id = 0, sizes=[1000, 'all'], repeat = 1,
                   quality='coarse', via_allcurves = False, ncomps = -1):
                     
    """Function used to test one or multiple trajectory inference methods, 
    returning running times, post-analysis objects, and plotting the result.

    Parameters
    ----------
    data : pandas dataframe
        Flow, CyTOF or scRNASeq data in numeric dataframe of shape n cells x m markers.
        The easiest way to create a pandas dataframe is saving the data to .csv then using pandas.read_csv().
    labels : pandas series
        Annotations for each cell of data. 
        Can also be created by saving a label list to .csv then using pandas.read_csv().
    dimred : pandas dataframe
        Low-dimensional embeddings of data.
    method : list of str, default ['via']
        Names of the trajectory inference methods to evaluate. 
        Currently available: 'via', 'palantir', 'paga', 'stream'
    origin : str, optional (default None)
        Label of the cell population to be used as origin to compute trajectories. 
    ori_id : int, optional (default 0)
        Numeric index of the cell to be used as origin to compute trajectories. 
        Specify either origin or ori_id but not both.
    sizes : list, optional (default [1000, 'all'])
        Size of the dataset(s) to feed to the method. 
        Can accept multiple values to test on multiple datasets. 
        Datasets for a value smaller then the original size of data are prepared via uniform random sampling. 
        Origin cell is then added to the dataset.
        If 'all', takes the entire dataset without any sampling. 
        Do not set any value to a multiple of 50,000 (causes a very specific issue in one method)
    repeat : int, optional (default 1)
        Amount of times to repeat the entire process. 
        Set to 5-10 for a more accurate measure of speed (will take longer).
    quality : str, optional (default coarse)
        Specific parameter for via. 
        Controls the quality of the clustering, pseudotime values and lineage probabilities for via.
        Set to 'coarse' or 'fine'.
    via_allcurves : bool, optional (default False)
        Specific parameter for via.
        If True, will plot all computed trajectories on the resulting graph (may clutter).
    ncomps : int, optional (default -1)
        Number of pca dimensions to use for neighbor graphs, diffusion maps, etc.
        Usually set to 30-100.

    Returns
    -------
    tuple
        a tuple with two values:
        times: dict
            Each key corresponds to a method. 
            The associated value is a dict where each key corresponds to a dataset size.
            The values are the average running times for each dataset tested.
        objects: dict
            Each key corresponds to a method. 
            The associated value is a dict where each key corresponds to a dataset size.
            The values are the objects returned by the test functions associated with the tested methods. 
    """
    
    objects = {}
    n = data.shape[0]
    times = dict.fromkeys(sizes)
    
    labels.index = data.index
    dimred.index = data.index

    # sample datasets and find and append origin
    for s in sizes:
        if ori_id is None:
            ori_id = np.where(labels==origin)[0][0]
        if s=='all':
            ss = np.arange(n)
        else:
            ss = np.random.choice(np.arange(n), s, replace=False)
        if ori_id not in ss: 
            ss = np.append(ss, ori_id)
            ori_id_new = ss.shape[0] - 1
        else:
            ori_id_new = np.where(ss==ori_id)[0][0]

        sdata = data.iloc[ss]
        sdimred = dimred.iloc[ss]
        slabels = (labels.iloc[ss]).values.tolist()
        if isinstance(slabels[0], list):
            slabels = [l[0] for l in slabels]
        
        origin_new = ori_id_new

        ###
        # test methods
        ###
        
        st = 0
        for i in range(repeat):
            print("testing {} events, iter {}".format(s, i))
            sys.stdout.flush()
            
            if 'via' in method:
                if not isinstance(sdimred, np.ndarray):
                    sdimred = sdimred.to_numpy()
                print('starting via')
                start = time.perf_counter()
                res = test_via(sdata, slabels, origin_new, quality, ncomps)
                if method in objects:
                    objects[method][s] = res, sdimred
                else:
                    objects[method] = dict.fromkeys(sizes)
                    objects[method][s] = res, sdimred
                    
                end = time.perf_counter()
                draw_trajectory_gams_f(via_coarse=res[0],via_fine=res[0], 
                     embedding=sdimred, draw_all_curves=via_allcurves,
                     scatter_size=10, scatter_alpha=1,)
                
                now = datetime.now().strftime('%m-%d-%H%M%S')
                plt.savefig('./plots/'+method+'_'+str(s)+'_traj_'+'coarse'+'_'+now)
                plt.close('all')    
                
                if quality=='fine':
                    # Finer adjustments to the plot can be made using the arguments to this function
                    draw_trajectory_gams_f(via_coarse=res[0],via_fine=res[1], 
                        embedding=sdimred, draw_all_curves=via_allcurves,
                        scatter_size=10, scatter_alpha=1,)
                    plt.savefig('./plots/'+method+'_'+str(s)+'_traj_'+'fine'+'_'+now)
                    plt.close('all')    

                
                via.via_streamplot(via_coarse=res[0], embedding=sdimred)
                plt.savefig('./plots/'+method+'_'+str(s)+'_streamplot_'+'_'+now)
                plt.close('all')    
                
            elif 'palantir' in method:
                print('starting palantir')
                start = time.perf_counter()
                res = test_palantir(sdata, sdimred, origin_new, n_pca=ncomps)
                if method in objects:
                    objects[method][s] = res
                else:
                    objects[method] = dict.fromkeys(sizes)
                    objects[method][s] = res
                    
                end = time.perf_counter()  
                pr_res, p_dimred = res
                # palantir.plot.plot_palantir_results_f(pr_res, p_dimred)
                # Finer adjustments to the plot can be made using the arguments to this function
                plot_palantir_results_f(pr_res, p_dimred, point_size = 0.3, labels = slabels, origin=origin_new)
                now = datetime.now().strftime('%m-%d-%H%M%S')
                plt.savefig('./plots/'+method+'_'+str(s)+'_pseudotime_'+now)
                plt.close('all')
                
            elif 'paga' in method:
                print('starting paga')
                start = time.perf_counter()
                res = test_paga(sdata, sdimred, labels = labels.iloc[ss], origin=origin_new, ncomps = ncomps)
                dt, oid = res
                if method in objects:
                    objects[method][s] = res
                else:
                    objects[method] = dict.fromkeys(sizes)
                    objects[method][s] = res
                    
                end = time.perf_counter()  
                
                # Finer adjustments to the plot can be made using the arguments to this function
                ax = sc.pl.paga(dt, layout='rt', fontsize=8, fontoutline=1, root=oid, edge_width_scale=0.4, threshold=0.1, show=False)
                handles = create_handles(dt)
                lgd = ax.legend(handles=handles, loc=8, prop={'size': 8}, ncol=len(handles)//5, bbox_to_anchor=(0.5,-0.4))
                now = datetime.now().strftime('%m-%d-%H%M%S')
                plt.savefig('./plots/'+method+'_'+str(s)+'_graph_'+now, bbox_extra_artists=(lgd,), bbox_inches='tight')
                plt.close('all')
                
                # sc.pl.draw_graph(res, color='dpt_pseudotime', layout='rt')
                # plt.savefig('./plots/'+method+'_'+str(s)+'_pseudotime_'+now)
                # plt.close('all')
                
            elif method == 'stream':
                print('starting stream')
                start = time.perf_counter()
                res = test_stream(sdata, sdimred, slabels, origin=origin_new)
                if method in objects:
                    objects[method][s] = res
                else:
                    objects[method] = dict.fromkeys(sizes)
                    objects[method][s] = res
                
                end = time.perf_counter()  
            
            st += (end-start)
            print("time for iter: {} seconds".format(end-start))
            sys.stdout.flush()

        times[s] = st/repeat
        
    return times, objects

In [None]:
# CyTOF dataset
sz = [10000, 50000, 100000, 200000, 500000, 'all']
tt = evaluate_speed(data_full, data_labels, dimred_vae, method = 'stream', sizes=sz, ori_id=358425)

In [None]:
# create and display the plots outside the testing function to take advantage of matplotlib inline mode
# plot graph on layout
dt = tt[1]['stream'][50000] # change size key here to view a different dataset sampling
st.plot_dimension_reduction(dt,color=['label'],n_components=2,show_graph=True,show_text=True)
st.plot_branches(dt,show_text=True)

In [None]:
# plot graph on flattened layout
st.plot_flat_tree(dt,color=['label','branch_id_alias','S7_pseudotime'],
                  dist_scale=0.5,show_graph=True,show_text=True)

In [None]:
# plot stream tree at cell level
# choose root based on previous plots
st.plot_stream_sc(dt,root='S7',color=['label'],
                  dist_scale=0.3,show_graph=True,show_text=True)

In [None]:
# plot stream tree
st.plot_stream(dt,root='S7',color=['label'])

In [None]:
# Artificial datasets
sz = ['all']
tt2 = evaluate_speed(arti_data, arti_labels, arti_dimred_vae, method = 'stream', sizes=sz, origin='M4')
tt3 = evaluate_speed(arti_data_o, arti_labels_o, arti_dimred_vae_o, method = 'stream', sizes=sz, origin='M4')

In [None]:
dt2 = tt2[1]['stream']['all']
dt3 = tt3[1]['stream']['all']
st.plot_dimension_reduction(dt2,color=['label'],n_components=2,show_graph=True,show_text=True)
st.plot_branches(dt2,show_text=True)
st.plot_flat_tree(dt2,color=['label','branch_id_alias','S2_pseudotime'],
                  dist_scale=0.5,show_graph=True,show_text=True)
st.plot_stream_sc(dt2,root='S2',color=['label'],
                  dist_scale=0.3,show_graph=True,show_text=True)
st.plot_stream(dt2,root='S2',color=['label'])

In [None]:
st.plot_dimension_reduction(dt3,color=['label'],n_components=2,show_graph=True,show_text=True)
st.plot_branches(dt3,show_text=True)
st.plot_flat_tree(dt3,color=['label','branch_id_alias','S4_pseudotime'],
                  dist_scale=0.5,show_graph=True,show_text=True)
st.plot_stream_sc(dt3,root='S4',color=['label'],
                  dist_scale=0.3,show_graph=True,show_text=True)
st.plot_stream(dt3,root='S4',color=['label'])

In [None]:
# scRNASeq dataset #1
sz=['all']
tt4 = evaluate_speed(data_full2, data_labels2, dimred_vae2, method = 'stream', sizes=sz, ori_id=4095)

In [None]:
dt4 = tt4[1]['stream']['all']

st.plot_dimension_reduction(dt4,color=['label'],n_components=2,show_graph=True,show_text=True)
st.plot_branches(dt4,show_text=True)
st.plot_flat_tree(dt4,color=['label','branch_id_alias','S2_pseudotime'],
                  dist_scale=0.5,show_graph=True,show_text=True)
st.plot_stream_sc(dt4,root='S2',color=['label'],
                  dist_scale=0.3,show_graph=True,show_text=True)
st.plot_stream(dt4,root='S2',color=['label'])

In [None]:
# scRNASeq dataset #2
sz = ['all']
tt = evaluate_speed(data_fullD, data_labelsD, dimred_vaeD, method = 'stream', sizes=sz, ori_id=29863)

In [None]:
dtD = tt[1]['stream']['all']

st.plot_dimension_reduction(dtD,color=['label'],n_components=2,show_graph=True,show_text=True)
st.plot_branches(dtD,show_text=True)
st.plot_flat_tree(dtD,color=['label','branch_id_alias','S4_pseudotime'],
                  dist_scale=0.5,show_graph=True,show_text=True)
st.plot_stream_sc(dtD,root='S4',color=['label'],
                  dist_scale=0.3,show_graph=True,show_text=True)
st.plot_stream(dtD,root='S4',color=['label'])