# Infer trajectory with STREAM

This notebook will guide you in the analysis of a single cell dataset to obtain a trajectory with the framework `stream`. 
All functions starting by `st.` like `st.read` are function from STREAM, some other functions like : `prep_data_dynverse` were created to ease the wrapping of STREAM results in R through the dynverse framework. If you are not running this notebook as part of the trajecotry pipeline, the execution of this functions are optionnal. 

To easily navigate through the notebook, you can find here some useful keyboard shortcuts : 
 - Ctrl + Enter : Execute the cell
 - Maj + Enter : Execute cell and select the one below
 - Alt +Enter : Execute cell and insert a cell below
 - Esc : Command mode
 - Enter : Edition mode
 
Some useful command in Command mode : 
  - A : Insert a cell above
  - B : insert a cell below
  
Some useful commands in Edition mode : 
 - Ctrl + / : Comment
 
To access the help of a function just add `?` at the end of the function name and execute the line

## Import useful functions

In [None]:
import stream as st
import pandas as pd
import matplotlib as mplt
mplt.rcParams.update({'figure.max_open_warning': 0})
import numpy as np
import networkx as nx
import os
%matplotlib inline

In [None]:
def prep_data_dynverse(adata,root, filename = 'stream_to_dynverse', path = None):
    
    filename+="_"+root
    if(path is None):
        path = adata.uns['workdir']
    cwd = os.getcwd()
    os.chdir(path)
    
    flat_tree = adata.uns['flat_tree']
 
    # Raw count
#     adata.raw.X.tofile(filename+'_raw_counts.csv',sep=',')
    pd.DataFrame(data=adata.raw.X, index=adata.raw.obs_names, columns=adata.raw.var_names).to_csv(filename+'_raw_counts.csv',index = True, index_label = 'cell_id', sep = ",", doublequote = False)
    # Log count 
#     adata.X.tofile(filename+'_counts.csv',sep=',')
    pd.DataFrame(data=adata.X, index=adata.obs_names, columns=adata.var_names).to_csv(filename+'_counts.csv',index = True, index_label = 'cell_id', sep = ",", doublequote = False)

    # Progression
    data = pd.DataFrame({'cell_id':adata.obs_names, 'edge_id':adata.obs['branch_id'],
                 'edge_id_alias':adata.obs['branch_id_alias'], 'branch_lam':adata.obs['branch_lam']})

    branch_len = list()
    for i in data['edge_id']:
        branch_len.append(flat_tree[i[0]][i[1]]['len'])

    data['branch_len'] = branch_len

    data['percentage'] = data.apply(lambda row: row.branch_lam / row.branch_len, axis=1)

    from_to = data['edge_id_alias'].astype(str).str.replace('[\(\) ]','').str.split(',',expand= True)
    data['from'] = from_to[0]
    data['to'] = from_to[1]
    
    flat_tree = adata.uns['flat_tree']
    # Millestone network
    node_label = nx.get_node_attributes(flat_tree,'label')
    node_index = {v: k for k, v in node_label.items()}
    node_index

    length , directed , from_ , to_ = list(),list(),list(),list()
    for edge in data['edge_id_alias'].unique():
        length.append(flat_tree[node_index[edge[0]]][node_index[edge[1]]]['len'])
        directed.append(False)
        from_.append(edge[0])
        to_.append(edge[1])

    millestone_network = pd.DataFrame({'from':from_, 'to':to_, 'length':length, 'directed':directed})
    millestone_network.to_csv(filename+'_millestone_network.csv',index = False, sep = ",", doublequote = False)
    
    # saving progression
    data = data.drop(columns=['edge_id', 'edge_id_alias', 'branch_lam', 'branch_len'])
    data.to_csv(filename+'_progression.csv',index = False, sep = ",", doublequote = False)
    
    # Dimension reduction
    pd.DataFrame(adata.obsm['X_dr'],index=adata.obs_names).to_csv(filename+'_dim_red.csv',index = True, index_label = 'cell_id', sep = ",", doublequote = False)
    # Grouping 
    adata.obs.label.to_csv(filename+'_grouping.csv',index = True, index_label = 'cell_id', sep = ",", doublequote = False)
    
    os.chdir(cwd)
    return None 

# STREAM parameters

In [None]:
st.set_figure_params(dpi=80,style='white',figsize=[5.4,4.8],
                     rc={'image.cmap': 'viridis'})

## Reading the data

First we need to create an adata object by calling `st.read` with the following paramaters : 
 - file_name : name of the tsv file containing the count
 - workdir : folder where to save all the output (created automatically)

In [None]:
adata=st.read(file_name='../data_for_trajectories.tsv',workdir='./stream_result')
adata.raw = adata
adata.var_names_make_unique()

Then we can add metadata in order to identify the cells (like different subtypes of cells). 
 - `st.add_cell_labels(adata,file_name='./cell_label_subset.tsv.gz')`
 - `st.add_cell_colors(adata,file_name='./cell_label_color_subset.tsv.gz')`

If you do not have this information you need to simply run : 
 - `st.add_cell_labels(adata)`, which will identify all the cells as unknown
 - `st.add_cell_colors(adata)`, which will add random colors to the cells


In [None]:
st.add_cell_labels(adata,file_name='../data_for_trajectories_cell_identities.tsv')
st.add_cell_colors(adata,file_name='../data_for_trajectories_cell_colors.tsv')

## Preprocessing

When dealing with raw count you to have preprocess your data with the following steps : 
 1. Normalize gene expression based on library size : `st.normalize(adata)`
 2. Logarithmize gene expression : `st.log_transform(adata)`
 3. Remove mitochondrial genes : `st.remove_mt_genes(adata)`
 4. Filter out cells based on different metrics : `sc.filter_cells(adata)`
  
  Available options with their default value : 
  - min_n_features = 10 ; Minimum number of genes expressed
  - min_pct_features = None ; Minimum percentage of genes expressed
  - min_n_counts = None ; Minimum number of read count for one cell
  - expr_cutoff = 1 ; If greater than expr_cutoff,the gene is considered 'expressed'
 5. Filter out genes based on different metrics : `st.filter_features(adata)`
 
  Available options with their default value : 
  - min_n_cells = 5 ; Minimum number of cells expressing one gene
  - min_pct_cells = None ; Minimum percentage of cells expressing one gene
  - min_n_counts = None ; Minimum number of read count for one gene
  - expr_cutoff = 1 ; If greater than expr_cutoff,the gene is considered 'expressed'


**Quality metrics**

In [None]:
st.cal_qc(adata,assay='rna')

**Filtering**

In [None]:
st.normalize(adata,method='lib_size')
st.log_transform(adata)
st.remove_mt_genes(adata)
st.filter_cells(adata,min_n_features = 1000)
st.filter_features(adata,min_n_cells = 5)

Now we need to select the variable genes in our dataset by running `st.select_variable_genes(adata)`. The options of interest are : 
 - loess_frac = 0.1
 - percentile = 95 ; Specify the percentile to select genes
 - save_fig = False
 - fig_name = 'std_vs_means.pdf'
 
Check if the blue curve fits the points well. If not, please adjust the parameter loess_frac until the blue curve fits well.

In [None]:
st.select_variable_genes(adata, loess_frac=0.01)

## Dimension reduction

Several dimension reduction are available in STREAM : 
 - Spectral embedding algorithm (`se`),
 - Modified locally linear embedding algorithm (`mlle`),
 - Uniform Manifold Approximation and Projection (`umap`) and 
 - Principal component analysis (`pca`)
 
By default STREAM use the Spectral embedding space to find trajectories. 
 
To run the dimension reduction run : `st.dimension_reduction(adata)`
 
Options of interest : 
 - n_neighbors = 50 ; The number of neighbor cells used for manifold learning
 - n_components = 3 ; Number of components to keep
 - method = 'se' ; Method used for dimension reduction

To ensure reproductibility use eigen_solver = 'arpack'

In [None]:
st.dimension_reduction(adata,n_components=2, eigen_solver='arpack')

Afterwards it's possible to visualize the reduction with `st.plot_dimension_reduction(adata)`

With all the visualizing functions you have a `save_fig` parameters to allow you to save your figure in the workdir defined in `adata.uns['workdir']`

In [None]:
st.plot_dimension_reduction(adata)

## Trajectory inference

### Initial graph

First we need to seed the intial principal graph with the function `st.seed_elastic_principal_graph`. Some of the available options are : 
 - clustering = 'kmeans' ; clustering method used to infer the initial nodes, Choose from : 'ap','kmeans','sc'
 - n_clusters = 10 ; Number of clusters
 - n_neighbors = 50 ; The number of neighbor cells used for spectral clustering

In [None]:
st.seed_elastic_principal_graph(adata,n_clusters=10)

**Plotting**

Once the initial strucuture of the graph is computed we can visualize it with or without the cells : 
 - `st.plot_branches(adata)`
 - `st.plot_dimension_reduction(adata,n_components=2,show_graph=True,show_text=False)`
 
This functions can be used anytime to visualize the graph

In [None]:
st.plot_dimension_reduction(adata,n_components=2,show_graph=True,show_text=False)
st.plot_branches(adata)

**Principal Graph**

Now we can estimate the graph with : `st.elastic_principal_graph`

`epg_alpha`, `epg_mu`, `epg_lambda` are the three most influential parameters for learning elastic principal graph : 
- `epg_alpha`: penalizes spurious branching events. **The larger, the fewer branches the function will learn.** (by default, `epg_alpha=0.02`)
- `epg_mu`: penalizes the deviation from harmonic embedding, where harmonicity assumes that each node is the mean of its neighbor nodes. **The larger, the more edges the function will use to fit into points(cells)** (by default, `epg_mu=0.1`)
- `epg_lambda`: penalizes the total length of edges. **The larger, the 'shorter' curves the function will use to fit into points(cells) and the fewer points(cells) the curves will reach.** (by default, `epg_lambda=0.02`)

In case you have noisy points in your data you can use the `epg_trimmingradius` parameters, by default it is set to `Inf`, but a value of 0.1 can be a good starting point to get rid of this noisy points. 


In [None]:
st.elastic_principal_graph(adata)

In [None]:
st.plot_dimension_reduction(adata,n_components=2,show_graph=True,show_text=False)
st.plot_branches(adata)

### Graph optimization (optionnal)

**Branching optimization**

The most influential parameters are `epg_alpha`, `epg_mu`, `epg_lambda` and `epg_trimmingradius`. They have the same meanings as in `st.elastic_principal_graph`

In [None]:
st.optimize_branching(adata,epg_alpha=0.02,epg_mu=0.1,epg_lambda=0.02)

In [None]:
st.plot_dimension_reduction(adata,n_components=2,show_graph=True,show_text=False)
st.plot_branches(adata)

**Prune branches**

Prune the learnt elastic principal graph by filtering out 'trivial' branches with the function `st.prune_elastic_principal_graph`. 
Different method are available to prune the branches by specifying `epg_collapse_mode` parameter : 
 - 'PointNumber': branches with less than `epg_collapse_par` points (points projected on the extreme points are not considered) are removed
 - 'PointNumber_Extrema', branches with less than `epg_collapse_par` (points projected on the extreme points are not considered) are removed
 - 'PointNumber_Leaves', branches with less than `epg_collapse_par` points (points projected on non-leaf extreme points are not considered) are removed
 - 'EdgesNumber', branches with less than `epg_collapse_par` edges are removed
 - 'EdgesLength', branches shorter than `epg_collapse_par` are removed 
 
To control the different method you can specify `epg_collapse_par` (by default `epg_collapse_par = 5`)

In [None]:
st.prune_elastic_principal_graph(adata)

**Shift branching**

Move branching node to the area with higher density.

In [None]:
st.shift_branching(adata)

**Extend leaf branch**

Extend leaf branch to reach further cells by running `st.extend_elastic_principal_graph`

In [None]:
st.extend_elastic_principal_graph(adata)

In [None]:
st.plot_dimension_reduction(adata,n_components=2,show_graph=True,show_text=False)
st.plot_branches(adata)

## Vizualization

In some visualization you need to choose a root to your graph, It will not affect the results only the representation

**Flat tree**

In [None]:
st.plot_flat_tree(adata,color=['label','branch_id_alias','S0_pseudotime'],
                  dist_scale=0.5,show_graph=True,show_text=True)

**Subway plot**

In [None]:
st.plot_stream_sc(adata,root='S0',
                  dist_scale=0.3,show_graph=True,show_text=True)

**Stream plot**

In [None]:
st.plot_stream(adata,root='S0')

**Visualize genes**

You can visualize gene expression along the different branches with two different function `st.plot_stream_sc` and `st.plot_stream`. To do so, you need to provide a list of genes (`[...]`) through the `color` parameters

In [None]:
st.plot_stream_sc(adata,root='S0',color=['Gata1','Car2','Epx']) 

In [None]:
st.plot_stream(adata,root='S0',color=['Gata1','Car2','Epx'])

## Marker gene detection

**Marker gene detection part is a bit time-consuming, so please make sure the struture learned from previous steps is reasonble before running any maker gene detection steps**

**Also it's not always necessary to execute all three marker gene detection parts. Running one of them might be adequate already.**

### Detect marker genes for each leaf branch


In [None]:
st.detect_leaf_markers(adata,root='S0')

In [None]:
adata.uns['leaf_genes'].keys()

In [None]:
adata.uns['leaf_genes'][('S0','S1')]

### Detect transition gene for each branch

In [None]:
st.detect_transistion_markers(adata,root='S0')

In [None]:
st.plot_transition_genes(adata)

### Detect differentially expressed genes between pairs of branches

In [None]:
st.detect_de_markers(adata,root='S0')

In [None]:
adata.uns['de_genes_greater'].keys()

In [None]:
adata.uns['de_genes_less'].keys()

## Saving

In [None]:
st.write(adata,file_name='stream_result.pkl')

## Export to dynverse

In [None]:
prep_data_dynverse(adata,root = "S0", filename = 'stream_to_dynverse')