In [None]:
!pip install --user scprep statsmodels magic-impute phate

In [None]:
import scprep

In [None]:
scprep.io.download.download_google_drive('1VuddKA7Xa23NAFKvRUmtuND16eEMHL9z', '/content/r_packages.tar.gz')
!tar xzf /content/r_packages.tar.gz && rm /content/r_packages.tar.gz
!apt-get install -yqq libgsl-dev=2.4+dfsg-6
!pip install --user --upgrade rpy2
!echo "R_LIBS_USER='/content/.r_packages'" >> ~/.Renviron
!R -e "BiocManager::install('splatter')"

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import graphtools as gt

import phate

import scipy.sparse
import sklearn.cluster

from scipy.spatial.distance import pdist, squareform

## Generate Synthetic Data with Known Ground Truth

### Introduction to Splatter

Splatter, [Zappia L, et al. Genome Biology. 2017](https://doi.org/10.1186/s13059-017-1305-0), is a tool for simulating single-cell RNA-sequencing data with a known topology. Splatter has two main modes: "`paths`" and "`groups`." Because we're interested in generating data with a known trajectory structure, we are going to use the `paths` mode.

Splatter has a bunch of other parameters you can choose from, and if you have time, feel free to play around with the various options.



#### Generating simulated data

Splatter is only implemented in R. Fortunately, we provide a wrapper in the `scprep.run` module. You can call Splatter from Python using [`scprep.run.SplatSimulate()`](https://scprep.readthedocs.io/en/stable/reference.html#scprep.run.SplatSimulate).

In [None]:
# Run Splatter

# You can easily change the number of paths and cells here
n_paths = 3
cells_per_path = 1000

# Splatter returns a `dict` objbect that contains a bunch of useful information
results = scprep.run.SplatSimulate(
                        method='paths', 
                        batch_cells=cells_per_path * n_paths, 
                        group_prob=np.tile(1/n_paths, n_paths), 
                        n_genes=5000,
                        de_down_prob=0.5,
                        path_from=[0,0,0],
                        mean_shape=1,
                        seed=0)

Let's check what we can find in this `results` dictionary. Note, you can find explanations of what each key refers to by checking the `returns` section of the `SplatSimulate` docstring or consulting https://scprep.readthedocs.io/en/stable/reference.html#scprep.run.SplatSimulate.

In [None]:
results.keys()

There's a lot of organizational stuff happening in the next code block. The key idea is that we want to grab three pieces of information from `results`:
* `counts` - simulated expression counts
* `group` - the group or path the cell belongs to
* `step` - how far along the path each cell is

We're putting `counts` into one DataFrame and `group` & `step` in a separate DataFrame. Then we do some manipulation of the `step` vector so it increases monotonically (as opposed to restarting at 0 at the beginning of each path). Finally we change the name of the `Index` of `data` and `metadata` and preprocess the data.

In [None]:
# Put counts data in DataFrame
data = pd.DataFrame(results['counts'])

# Put metadata in a DataFrame
metadata = pd.DataFrame([results['step'], results['group']], index=['step', 'group']).transpose()
metadata = metadata.astype({'step': int}) # These are strings by default, but we want ints

# Adjusting the 'step' vector to monotonically increase
p1_mask = metadata['group'] == 1
metadata.loc[p1_mask, 'step'] = 100 - metadata.loc[p1_mask, 'step']
p23_mask = np.isin(metadata['group'], [2,3])
metadata.loc[p23_mask, 'step'] = 100 + metadata.loc[p23_mask, 'step']

# Change DataFrame index names to be a little nicer
new_index = pd.Index(['cell_{}'.format(i) for i in range(metadata.shape[0])])
data.index = new_index
metadata.index = new_index

# Library-size normalize and sqrt transform
data = scprep.normalize.library_size_normalize(data)
data = scprep.transform.sqrt(data)

#### Visualizing the data

Splatter is designed to generate data that can be easily visualized using PCA. If you have time, try to visualize the data using other tools from our visualization module.

In [None]:
data_pca = scprep.reduce.pca(data)

Plot the data using `scprep.plot.scatter2d()`.

In [None]:
scprep.plot.scatter2d(data_pca, color='skyblue',
                      ticks=None, title='PCA', 
                      label_prefix='PC',
                      figsize=(5,5.2))

#### Plot the `step` vector

In [None]:
scprep.plot.scatter2d(data_pca, c=metadata['step'],
                      ticks=None, title='PCA', 
                      label_prefix='PC',
                      figsize=(5.5,5))

#### Plot the `group` vector

In [None]:
scprep.plot.scatter2d(data_pca, c=metadata['group'],
                      legend_title='Branch',
                      ticks=None, title='PCA', 
                      label_prefix='PC',
                      figsize=(5,5))

#### Assign each cell to a day using a binomal distribution

To make it easier to understand the benefit of pseudotime analysis, we're going to assign each cell to a sample labelled `Day1`, `Day2`, or `Day3` to mimic a traditional time course experiment. Note that we're not using the `counts` data to do this, but rather we're using the _latent variable_ `step` to assign samples. This means the we're using the same latent variable to generate counts and independently generate the `sample` labels.

In [None]:
# This random distribution will have increased probability of generating a `2` 
# as the `step` gets larger.
normalized_step = metadata['step'] / metadata['step'].max()
sample = np.random.binomial(n = 2,
                            p = normalized_step, 
                            size=metadata.shape[0]) + 1


sample_cmap = {1 : 'lightgreen',
       2: 'skyblue',
       3: 'darkblue'}

In [None]:
scprep.plot.scatter2d(data_pca,  c=sample, cmap=sample_cmap,
                      legend_title='Day',
                      label_prefix='PC',
                      ticks=None, title='Day of sample collection', 
                      legend=True, figsize=(5,5))

## 2. Learning Diffusion Pseudotime (DPT)

In this section, we're going to run diffusion pseudotime (DPT) from scratch. Yes, there is a nice wrapper we will use later. For now, we want to show you _how_ DPT is calculated.

#### The math
Remember from lecture that the following two equations define DPT:
$$ \textrm{DPT}(x,y) = ||\mathbf{M}(x,:) - \mathbf{M}(y,:)||, ~\mathbf{M} = \sum\limits_{t=1}^{\infty}\mathbf{\tilde{T}}$$
$$ \mathbf{M} = \sum\limits_{t=1}^{\infty}\mathbf{\tilde{T}} = (I-\mathbf{\tilde{T}})^{-1}-I ~\textrm{where}~ \mathbf{\tilde{T}} = \mathbf{T} - \psi_0\psi_0^T$$

Here, $\mathbf{T}$ is the diffusion operator (defined below), $I$ is the identity matrix, and $\psi_0$ is the first eigenvector of $\mathbf{T}$. You can actually derive all the steps of DPT from these two equations by reading backwards.

The first step is to calculate $\mathbf{\tilde{T}} = \mathbf{T} - \psi_0\psi_0^T$. From this we can reason that we need to learn a graph, calculate a diffusion operator, get the first eigenvector of that diffusion operator, and subtract that from $\mathbf{T}$.

Next, we need to calculate $\mathbf{M}$ using these equations $ \mathbf{M} = \sum\limits_{t=1}^{\infty}\mathbf{\tilde{T}} = (I-\mathbf{\tilde{T}})^{-1}-I$. There's a lot going on there, but remember that if $a = b = c$ then $a=c$ so we really only care that $ \mathbf{M} = (I-\mathbf{\tilde{T}})^{-1}-I$.

Once we have $\mathbf{M}$, we can easily get the DPT between any two points $x,y$ by subtracting the relevant rows of $\mathbf{M}$.

#### The algorithm

Having gone through all that, we can break down DPT into the following steps:
1. Learn a graph from the data using local Euclidean distances
2. Subtract the first diffusion operator eigenspace from the transition probabilities
3. Calculate the limit of all random walks on the graph
4. Calculate pairwise distances in the random walk space



### Step 1 - Learn a graph from the counts data 

Here, we're using `graphtools` to learn a graph from data. If you remember from the manifold learning section, the key to learning a graph from data is to calculate _local Euclidean distances_ and then connect cells that are in the same _neighborhood_. We can do this in one line by calling `gt.Graph()`.  We have loaded `graphtools` during the initial `import` statements in this notebook.

The full graphtools API can be found here: https://graphtools.readthedocs.io/en/stable/reference.html#module-graphtools.api

In [None]:
G = gt.Graph(data, n_pca=100, use_pygsp=True)

### Step 2 - Subtract the first diffusion operator eigenspace from the transition probabilities

Now, we know that even though eigenvectors are important, they are incredibly abstract and confusing. The more important concept here is to understand the **diffusion operator**. The diffusion operator is so useful that it is calculated whenever you build a graph with `graphtools`.

The diffusion operator is usually defines as "the row-stochastic form of the kernel matrix." The key terms here are:
* kernel matrix, $\mathbf{K}$ - the matrix of size $n \times n$ where $n$ is the number of vertices in the graph. Each entry $\mathbf{K}_{i,j}$ contains the weight of the edge connecting vertex $i$ and vertex $j$.
* row-stochastic - this is a fancy term to describe the fact that each row of a matrix sums to $1$.

A row-stochastic kernel matrix has edge weights scaled such that they could represent the **transition probability** of moving from one vertex to another. This is the heart of the "random-walk" along a graph. If you Google "[graph random walk](https://www.google.com/search?q=graph+random+walk&oq=graph+random+walk)", you'll find a ton a resources on the topic.

Let's compare the row-sums of `G.K`, the kernel matrix, and `G.diff_op`, the diffusion operator.

In [None]:
# ====================
# Which axis do we need to sum across to get the row-sum?
axis =
# ====================
print(G.K.sum(axis=axis))
print(G.diff_op.sum(axis=axis))

By default, `G.diff_op` is a sparse matrix, because so many entries are `0`. You can confirm this using `type(G.diff_op)` which should return `scipy.sparse.csr.csr_matrix`. This means we can calculate the first eigenvector efficiently using `scipy.sparse.linalg.eigs()`.

To turn an eigenvector into an eigenspace, we take the [_inner product_](http://mathworld.wolfram.com/InnerProduct.html) using the `@` operator. We can then subtract this matrix from `G.diff_op`.

In [None]:
# If you look at the top, we already imported the `sparse` module from `scipy`
W, V = scipy.sparse.linalg.eigs(G.diff_op, k=1)

# W is the eigenVALUE matrix and V is the eigenVECTOR matrix
# We only care about eigenVECTORS here
V_0 = V @ V.T

# Now we can subtract this from G.diff_op
T_tilde = G.diff_op.toarray() - (V[:,0] @ V[:,0].T)

### Step 3 - Calculate $\mathbf{M}$, the limit of all random walk distances on the graph

Here, we're calculating $ \mathbf{M} = \sum\limits_{t=1}^{\infty}\mathbf{\tilde{T}} = (I-\mathbf{\tilde{T}})^{-1}-I$. This is a fairly simple calculation, especially now that we have $\tilde{T}$ calculated. The only think you might not be familiar with here is that $^{-1}$ operation which is a [_matrix inverse_](http://mathworld.wolfram.com/MatrixInverse.html). All you need to know for now is that you can calculate this with `np.linalg.inv()`.

In [None]:
# Create the identity matrix the same size as T_tilde
I = np.eye(T_tilde.shape[1])

# Do the matrix inversion
M = np.linalg.inv(I - T_tilde) - I
# Discard imaginary component of M
M = np.real(M)

### Step 4 - Calculate DPT between all pairs of points

This should be the simplest step as we've already gone over how to calculate pairwise distances. Instead of calculating Euclidean distances on the raw data, we're doing it on $\mathbf{M}$.

In [None]:
# Calc DPT
DPT = pd.DataFrame(squareform(pdist(M)), index=data.index, columns=data.index)

### Putting it all in a single function

To hammer home how simple these steps are, let's put all of them in one convenience function.

In [None]:
def calculateDiffusionPseudotime(data):
    '''Calculates DPT between all points in the data'''

    G = gt.Graph(data, n_pca=100, use_pygsp=True)
    
    # Calculate eigenvectors of the diffusion operator
    # May be faster way to do this
    W, V = scipy.sparse.linalg.eigs(G.diff_op, k=1)
    
    # Remove first eigenspace
    T_tilde = G.diff_op.toarray() - (V[:,0] @ V[:,0].T)
    
    # Calculate M
    I = np.eye(T_tilde.shape[1])
    M = np.linalg.inv(I - T_tilde) - I
    M = np.real(M)
    
    # Calc DPT
    DPT = squareform(pdist(M))
    
    DPT = DPT
    return DPT

### Picking a root cell

If you remember, diffusion pseudotime requires the user to input a **root cell**. There is some guidance in the Diffusion Pseudotime paper on how to pick a root cell, but for the most part it requires some prior assumptions about the data. Here, we can return to our simulated data and figure out some way to pick the root.

In [None]:
scprep.plot.scatter2d(data_pca, c=sample, cmap=sample_cmap, 
                      title='Day of Sample Collection', 
                      legend=True, label_prefix='PC', legend_title='Day',
                      figsize=(5,5))

#### Using PC1

On the first branch of the data, it looks like PC1 does a pretty good job of tracking with the developmental progression. Let's take the cell with the highest PC1 value for the root cell.

In [None]:
# Selecting the root cell
root_cell = data_pca['PC1'].idxmin()

# Plotting the rest of the cells
ax = scprep.plot.scatter2d(data_pca, c=sample, cmap=sample_cmap,
                           title='Day of Sample Collection',
                           legend=True, label_prefix='PC',
                           legend_title='Day',
                           ticks=None,
                           figsize=(5,5))

# Plotting the root cell
ax.scatter(data_pca.loc[root_cell]['PC1'], data_pca.loc[root_cell]['PC2'],
                         c='red', s=100, zorder=2)

#### Plotting DPT from the root cell to  all other cells

Because $ \textrm{DPT}(x,y) = ||\mathbf{M}(x,:) - \mathbf{M}(y,:)||$, we can now take rows of the `DPT` matrix we calculated earlier to get a pseudotemporal ordering of cells.

In [None]:
# Plotting the rest of the cells
ax = scprep.plot.scatter2d(data_pca, c=DPT.loc[root_cell], cmap='viridis',
                           title='DPT from root cell',
                           legend=True, label_prefix='PC',
                           ticks=False,
                           figsize=(5.8,5))

# Plotting the root cell
ax.scatter(data_pca.loc[root_cell]['PC1'], data_pca.loc[root_cell]['PC2'],
                         c='red', s=100, zorder=2)

#### Comparing DPT to latent `step` variable

Finally, we can compare DPT to the latent variable `step` that was used to generate the data to determine how accurately the method captures ground truth.

**Note:** You will need to add code for the x and y axis.

In [None]:
ax = scprep.plot.scatter(
            # =========
            # plot diffusion pseudotime on the x-axis and the metadata 'step'
            # variable on the y-axis
            x = 
            y = 
            # ========= 
            c=metadata['group'],
            xlabel='Diffusion pseudotime', ylabel='Latent step variable',
            legend_title='Branch', figsize=(5,5))

What conclusions do you draw from this?

### Examining gene trends with DPT

In a later section, we'll go over how to select the genes that are the most correlated with pseudotime, but for now, let's cheat. The `results` returned from `SplatSimulate` contains arrays that tell us how differentially expressed on each branch. They are `de_fac_1`, `de_fac_2`, `de_fac_3`.

We're going to use `np.argsort()` to get the gene indices that have the highest level of differential expression on average across all genes.

**Comprehension question:** what does `[::-1]` do?

In [None]:
# Change this to switch which branch we want to look at
branch_of_interest = 1

# This line sorts all simulated genes
de_factors = pd.DataFrame([results['de_fac_1'] - 1, results['de_fac_2'] - 1, results['de_fac_3'] - 1]).T

most_de_genes = np.argsort(de_factors.abs().mean(axis=1).values)[::-1]
least_de_genes = np.argsort(de_factors.abs().mean(axis=1).values)

In [None]:
from statsmodels.nonparametric.smoothers_lowess import lowess
import magic

#### Running MAGIC

In [None]:
data_magic = magic.MAGIC(t=3).fit_transform(data)

In [None]:
fig, axes = plt.subplots(5,3, figsize=(12,12))

for i, ax_row in enumerate(axes):
    gene = data.iloc[:, most_de_genes[i]]    
    gene_magic = data_magic.iloc[:, most_de_genes[i]]

    ax_left, ax_mid, ax_right = ax_row
    
    scprep.plot.jitter(sample, gene, c=sample, cmap=sample_cmap,
                       title='Raw - sample',
                       legend=False,
                       ax=ax_left)
    # Plotting raw data
    scprep.plot.scatter(DPT.loc[root_cell], gene, 
                         c=sample, cmap=sample_cmap,
                        title='Raw - pseudotime',
                        legend=False,
                        ax=ax_mid)
    
    # Plot a loess curve
    smooth_expression = lowess(gene, DPT.loc[root_cell], return_sorted=False)
    scprep.plot.scatter(DPT.loc[root_cell], smooth_expression, 
                         c='red',
                        ax=ax_mid)
                        
   # Plotting MAGIC data
    scprep.plot.scatter(DPT.loc[root_cell], gene_magic, 
                         c=sample, cmap=sample_cmap,
                        title='MAGIC - pseudotime',
                        legend=False,
                        ax=ax_right)
    
    # Plot a loess curve
    smooth_expression = lowess(gene_magic, DPT.loc[root_cell], return_sorted=False)
    scprep.plot.scatter(DPT.loc[root_cell], smooth_expression, 
                         c='red',
                        ax=ax_right)
    ax_right.set_ylim(0, np.max(gene_magic) * 1.1)


fig.tight_layout()

### Discussion

The above code block plots the 5 most differentially expressed genes on average across all branches.  

In pairs:
1. Duplicate the code block
2. Change the code so the *least* differentially expressed genes are ploted (Hint: what end of the `most_de_genes` will these genes be on?
3. Generate the plot.

What do you notice? Do you see anything subtle happening with the MAGIC values on the least differentially expressed genes? Why do you think this is happening?


If you have time, try adding more rows in the call to `plt.subplots()`. This will take more time to run, but will show you more information.

#### Discussion - Key


<div id='calculating-branches'></div>

## Calculating branches using correlation anti-correlation

The next major step in DPT is identifying branches using the correlation/anti-correlation between forward and reverse pseudotime. We already have forward pseudotime. To get reverse pseudotime we perform the following steps:
1. Identify the "end cell" farthest (in DPT) from the root cell
2. Calculate the distance to the "end cell" for each cell in the dataset

In [None]:
# The cell with the highest DPT from the root cell is the end_cell
end_cell = DPT.loc[root_cell].idxmax()

In [None]:
# Plot forward pseudotime
ax = scprep.plot.scatter2d(data_pca, c = DPT[root_cell], 
                      ticks=None, title='Forward Pseudotime', 
                      discrete=False, legend=True,label_prefix='PC',
                      figsize=(4.6,4))

# Plot the root cell
ax.scatter(data_pca.loc[root_cell][0], data_pca.loc[root_cell][1], c ='red', s=100, zorder=10)

# Plot the end cell
ax.scatter(data_pca.loc[end_cell][0], data_pca.loc[end_cell][1], color='#5ac8fa', s=100, zorder=10)

Now we can just take the `end_cell` row of the DPT matrix to get reverse pseudotime.

In [None]:
#Here we take the `end_cell` row of `DPT`
ax = scprep.plot.scatter2d(data_pca, c = DPT[end_cell], 
                      ticks=None, title='Reverse Pseudotime', 
                      discrete=False, legend=True, label_prefix='PC',
                      figsize=(4.6,4))
# Plot the root cell
ax.scatter(data_pca.loc[root_cell][0], data_pca.loc[root_cell][1], c ='red', s=100, zorder=10)

# Plot the end cell
ax.scatter(data_pca.loc[end_cell][0], data_pca.loc[end_cell][1], color='#5ac8fa', s=100, zorder=10)

#### Plotting forward vs reverse pseudotime

In [None]:
ax = scprep.plot.scatter(DPT[root_cell], DPT[end_cell], 
                         c=sample, cmap=sample_cmap,
                        xlabel='Forward psuedotime',
                        ylabel='Reverse pseudotime')

# resize the axis to avoid distortion between axes
ax.set_aspect(1)

#### Clustering in the forward vs. reverse pseudotime space

Now we have a few options for separating the branches. We could manually bin the cells into groups based on the above plot. However, we can also try to use a clustering algorithm to do this for us.

In [None]:
# Pick your favorite clustering algorithm
clusters = sklearn.cluster.SpectralClustering(n_clusters=3).fit_predict(DPT[[root_cell,end_cell]])

In [None]:
# Plot clusters
ax = scprep.plot.scatter(DPT[root_cell], DPT[end_cell],                          
                         c=clusters,
                         xlabel='Forward psuedotime',
                         ylabel='Reverse pseudotime')
ax.set_aspect(1)

In [None]:
# Plotting the rest of the cells
scprep.plot.scatter2d(data_pca, c=clusters, 
                      title='DPT Branch Assignments',
                      legend=True, label_prefix='PC',
                      ticks=False,
                      figsize=(5,5))

### Discussion

In pairs, go back and change:
1. The selection for the root cell (instead of the lowest PC1 value, use the 5th lowest or 10th lowest)
2. The clustering algorithm for branch identification

Answer the following questions:
1. How robust is DPT to selection of root and end cell? How could you quantify this?
2. Is there a better clustering algorithm for branch selection than KMeans?


## Bonus: Trajectory inference on more complicated trees.

Above we worked on a very simple dataset with only two branches. Below, we've set up parameters to generate a tree with 9 branches. Try running the code an determine:

1. How would you select the root cell?
2. How well does the correlation/anti-correlation work to identify different branches?


#### Generating data with Splatter

In [None]:
# Splatter parameters
# cells_per_path * n_groupsnp.tile(1/n_groups, n_groups)
n_groups = 9
cells_per_path = 300
params = {'method':'paths', 'batch_cells':cells_per_path * n_groups, 'seed':42, 
         'group_prob':np.tile(1/n_groups, n_groups), 'path_from':[0,1,1,2,2,3,3,4,4], 'n_genes':5000,
         'de_fac_loc':0.2}

# Run Splatter
results = scprep.run.SplatSimulate(**params)

# Assign data to data frame
data = pd.DataFrame(results['true_counts'])

metadata = pd.DataFrame([results['step'], results['group']], index=['step', 'group']).T
metadata = metadata.astype({'step': int})



# Reindex
new_index = pd.Index(['cell_{}'.format(i) for i in range(metadata.shape[0])])
data.index = new_index
metadata.index = new_index

data_ln = scprep.normalize.library_size_normalize(data)

#### Visualizing with PCA

Given what you know about PCA, why does it perform poorly here?

In [None]:
data_pca = scprep.reduce.pca(np.sqrt(data_ln))

In [None]:
scprep.plot.scatter2d(data_pca, c=metadata['group'], legend_anchor=(1,1))

#### Visualizing with PHATE

In [None]:
phate_op = phate.PHATE()
data_phate = phate_op.fit_transform(data_pca)
data_phate = pd.DataFrame(data_phate, index=data.index)

In [None]:
data_phate = pd.DataFrame(data_phate, index=data.index)

In [None]:
scprep.plot.scatter2d(data_phate, c=metadata['group'], legend_anchor=(1,1))

#### Identifying the root cell

Here we're cheating and using the ground truth to identify this cell. What would you do in a biological dataset?

In [None]:
root_idx = metadata['step'].loc[metadata['group'] == 1].idxmin()

#### Calculating DPT



In [None]:
def calculateDiffusionPseudotime(data):
    '''Calculates DPT between all points in the data'''

    G = gt.Graph(data, n_pca=100, use_pygsp=True)
    
    # Calculate eigenvectors of the diffusion operator
    # May be faster way to do this
    W, V = scipy.sparse.linalg.eigs(G.diff_op, k=1)
    
    # Remove first eigenspace
    T_tilde = G.diff_op.toarray() - (V[:,0] @ V[:,0].T)
    
    # Calculate M
    I = np.eye(T_tilde.shape[1])
    M = np.linalg.inv(I - T_tilde) - I
    M = np.real(M)
    
    # Calc DPT
    DPT = squareform(pdist(M))
    
    DPT = DPT
    return DPT

In [None]:
DPT = calculateDiffusionPseudotime(data)

DPT = pd.DataFrame(DPT, index=data.index, columns=data.index)

In [None]:
ax = scprep.plot.scatter2d(data_phate, c=DPT[root_idx], legend_anchor=(1,1))
ax.scatter(data_phate.loc[root_idx, 0], data_phate.loc[root_idx, 1], color='red', s=100, zorder=10)

## Calculating reverse pseudotime and branches

Here, you should adapt the code from the **"Calculating branches using correlation anti-correlation"** to calculate reverse pseudotime.

In [None]:
# ========
end_cell = 



#Here we take the `end_cell` row of `DPT`
ax = scprep.plot.scatter2d(data_pca, c = ????, 
                      ticks=None, title='Reverse Pseudotime', 
                      discrete=False, legend=True, label_prefix='PC',
                      figsize=(4.6,4))
# Plot the root cell
ax.scatter(??, ??, c ='red', s=100, zorder=10)

# Plot the end cell
ax.scatter(??, ??, color='#5ac8fa', s=100, zorder=10)


# ========

In [None]:
# ========
# Plotting forward vs reverse pseudotime

ax = scprep.plot.scatter(??, ??, 
                         c=sample, cmap=sample_cmap,
                        xlabel='Forward psuedotime',
                        ylabel='Reverse pseudotime')

ax.set_aspect(1)
# ========

In [None]:
scprep.plot.jitter(metadata['group'], DPT[root_idx],
                   c=metadata['group'], legend=False)

### Discussion

1. What do you notice about the distribution of DPT values in this tree?
2. Try adding more or fewer branches. How does this change the resolution of DPT?