### Application of GPcounts on the single-cell RNA-seq data to identify gene-specific branching locations for individual genes. 

This notebook demonstrates how to build a GPcounts model and plot the posterior model fit and posterior branching times. We have used the single-cell RNA-seq of haematopoietic stem cells (HSCs) from mouse <a href="https://pubmed.ncbi.nlm.nih.gov/26627738/" target="_blank" text_decoration=none>(Paul et al., 2015)</a>. The data contain cells that are differentiated into myeloid and erythroid precursor cell types.

In [None]:
%matplotlib inline
import pandas as pd
import numpy as np
from matplotlib import pyplot as plt
import tensorflow as tf
# from IPython.display import display

### Data description

Slingshot algorithm <a href="https://pubmed.ncbi.nlm.nih.gov/29914354/" target="_blank">(Street et al., 2018)</a> has been used to get trajectory-specific pseudotimes as well as  assignment of cells to different branches. Slingshot infers two lineages for this dataset. The data were derived from the <a href="https://statomics.github.io/tradeSeq/articles/tradeSeq.html" target="_blank">tradeSeq vignette</a>.

The __geneExpression.csv__ file contains the expression profiles of HSCs from mouse.  

The __Slingshot.csv__ file contains lineage-specific cell assignments as well as pseudotimes.

In [None]:
data = pd.read_csv('../data/MouseHSC/geneExpression.csv', index_col=[0]).T
slingShot = pd.read_csv('../data/MouseHSC/Slingshot.csv', index_col=[0])

In [None]:
data.head()

In [None]:
slingShot.head()

Slingshot assigns a weight to each cell indicating the assignment of cells to different branches. We have used assignment probability 0.80 to determine the cells belong to a specific branch (1 or 2) .

In [None]:
cell_label = np.ones(slingShot.shape[0]) * 2
for i in range(0, slingShot.shape[0]):
    if slingShot.values[i, 0] > 0.8:
        cell_label[i] = 1

### Fit GPcounts model for branching

We have shown examples using both the Negative binomial and the Gaussian likelihood

### Negative bionomial likelihood

In [None]:
from GPcounts.GPcounts_Module import Fit_GPcounts
def Fit_GPcounts_for_branching(geneName, likelihood='Negative_binomial', bins_num=50, ns=5):
    X = slingShot[1::ns][['pseudotime']] 
    Y = data[1::ns][[geneName]].T
    gp_counts = Fit_GPcounts(X,Y)
    d = gp_counts.Infer_branching_location(cell_label[1::ns], bins_num=bins_num, lik_name=likelihood)
    del gp_counts
    return d
d = Fit_GPcounts_for_branching('Mpo', bins_num=25)

#### plot the posterior model fit and posterior branching times

In [None]:
from helper import plotBranching, plotGene
fig, ax = plotBranching(d)
plotGene(ax[0], X=slingShot[['pseudotime']], Y=data[['Mpo']].T, label=cell_label, size=10, alpha=.6)

### Gaussian likelihood

In [None]:
geneName = 'Mpo'
d = Fit_GPcounts_for_branching(geneName, 'Gaussian', bins_num=25)

#### plot the posterior model fit and posterior branching times

In [None]:
ig, ax = plotBranching(d)
plotGene(ax[0], X=slingShot[['pseudotime']], Y=np.log(data[[geneName]].T + 1), label=cell_label, size=10, alpha=.6)

### Paper results
Uncommenting the following code will reproduce the branching location inference examples demonstrated in the main paper as well as in the supplementary document. It will take more time to run as both the number of genes and the number of bins or test points are larger.

In [None]:
# bins_num = 50
# geneList = ['Mpo', 'Ly6e', 'Car2', 'Car1', 'Ctsg', 'Prtn3', 'Irf8', 'Erp29', 'Apoe']
# d_gaussian = list()
# d_nb = list()
# for g in geneList:
#     print(g)
#     d_nb.append(Fit_GPcounts_for_branching(g, likelihood='Negative_binomial', bins_num=bins_num))
#     d_gaussian.append(Fit_GPcounts_for_branching(g, likelihood='Gaussian', bins_num=bins_num))

In [None]:
# for i in range(0, len(geneList)):
#     _, ax = plotBranching(d_nb[i])
#     plotGene(ax[0], X=slingShot[['pseudotime']], Y=data[[geneList[i]]].T, label=cell_label, size=10, alpha=.6)
#     _, ax = plotBranching(d_gaussian[i])
#     plotGene(ax[0], X=slingShot[['pseudotime']], Y=np.log(data[[geneList[i]]].T + 1), label=cell_label, size=10, alpha=.6)