# CellRank 2: Unified Fate Mapping in Multiview Single-Cell Data

## Workshop Tutorial — Informatics Club

---

### What is CellRank 2?

**CellRank 2** is a framework for studying **cellular fate decisions** from single-cell data. It models cell-state transitions as a **Markov chain** on a cell-cell graph and identifies:

- **Initial states** (cells at the beginning of a process)
- **Terminal states / macrostates** (fate endpoints)
- **Fate probabilities** (likelihood each cell reaches each terminal state)
- **Driver genes** correlated with specific fate decisions

### Key innovation of CellRank 2 over CellRank 1

CellRank 1 relied exclusively on **RNA velocity** to provide directionality. CellRank 2 introduces a **modular kernel framework** that can incorporate diverse sources of directional information:

| Kernel | Input | Use Case |
|--------|-------|----------|
| `VelocityKernel` | RNA velocity vectors | Standard scRNA-seq with splicing info |
| `PseudotimeKernel` | Pseudotime ordering | When a pseudotime is available |
| `RealTimeKernel` | Experimental time labels + optimal transport | Time-series experiments |
| `CytoTRACEKernel` | CytoTRACE scores (gene counts as proxy for potency) | When no velocity/time is available |

Kernels can be **combined** via weighted sums or products to integrate multiple signals.

---

## Part 0: Mathematical Background

### The Markov Chain Model

CellRank models cell-state dynamics as a **Markov chain** on a KNN graph of cells. Each cell is a node, and directed, weighted edges represent transition probabilities.

The **transition matrix** $T \in \mathbb{R}^{n \times n}$ is row-stochastic:

$$T_{ij} = P(X_{t+1} = j \mid X_t = i), \quad \sum_j T_{ij} = 1$$

### How Kernels Build the Transition Matrix

#### Velocity Kernel
For each cell $i$ with velocity vector $v_i$, the transition probability to neighbor $j$ is based on the **cosine similarity** between $v_i$ and the displacement vector $(x_j - x_i)$:

$$\tilde{T}_{ij} \propto \exp\left(\frac{\cos(v_i, x_j - x_i)}{\sigma}\right)$$

where $\sigma$ controls the softmax sharpness.

#### Pseudotime Kernel
Uses a **pseudotime** $\tau_i$ assigned to each cell. The kernel biases transitions toward increasing pseudotime:

$$\tilde{T}_{ij} \propto \begin{cases} \text{high} & \text{if } \tau_j > \tau_i \\ \text{low} & \text{if } \tau_j < \tau_i \end{cases}$$

Specifically, it computes a soft threshold based on the pseudotime difference between neighbors.

#### RealTimeKernel (Optimal Transport)
Given cells at discrete time points $t_1, t_2, \ldots$, **optimal transport** (via the `moscot` package) computes a coupling matrix $\pi$ that maps cells at $t_k$ to cells at $t_{k+1}$:

$$\pi^* = \arg\min_\pi \sum_{i,j} c(x_i, x_j) \pi_{ij} + \varepsilon H(\pi)$$

where $c(x_i, x_j)$ is a cost (e.g., squared Euclidean distance in gene expression space), and $H(\pi)$ is the entropic regularization term. The coupling $\pi$ is then converted to a transition matrix on the full KNN graph.

#### CytoTRACE Kernel
Uses **CytoTRACE** scores as a proxy for developmental potential (based on the number of expressed genes). This kernel biases transitions from high-potency to low-potency cells, acting like a pseudotime kernel but without needing a pre-computed pseudotime.

### Kernel Combination
Multiple kernels can be combined:

$$T_{\text{combined}} = \alpha \cdot T_{\text{kernel}_1} + (1 - \alpha) \cdot T_{\text{kernel}_2}$$

### GPCCA for Macrostate Identification

CellRank 2 uses **Generalized Perron Cluster Cluster Analysis (GPCCA)** on the transition matrix to identify **macrostates** — groups of cells that are metastable (cells within a macrostate tend to stay there).

This works by computing the **Schur decomposition** of $T$:

$$T = Q R Q^{-1}$$

The leading Schur vectors reveal the coarse-grained structure. GPCCA then soft-assigns cells to macrostates. **Terminal states** are macrostates with high self-transition probability (absorbing).

### Absorption Probabilities (Fate Probabilities)

Once terminal states are identified, CellRank computes **absorption probabilities**: for each non-terminal cell, the probability of eventually reaching each terminal state. This is solved from the linear system:

$$\mathbf{a}_m = (I - T_{\text{transient}})^{-1} T_{\text{transient} \to m}$$

where $\mathbf{a}_m$ is the vector of absorption probabilities into terminal state $m$.

---

## Part 1: Setup and Data Loading

### Environment Setup

```bash
# Create and activate the conda environment:
conda env create -f environment.yml
conda activate cellrank2_workshop
```

In [None]:
import cellrank as cr
import scanpy as sc
import scvelo as scv
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

scv.settings.verbosity = 3
cr.settings.verbosity = 2
sc.settings.set_figure_params(frameon=False, dpi=100)

### Load the Pancreatic Endocrinogenesis Dataset

This dataset from **Bastidas-Ponce et al. (2019)** profiles mouse pancreatic development using scRNA-seq.

- **E15.5 pancreas** — endocrine progenitor cells differentiating into:
  - **Alpha cells** (glucagon-producing)
  - **Beta cells** (insulin-producing)
  - **Delta cells** (somatostatin-producing)
  - **Epsilon cells** (ghrelin-producing)

This is a well-characterized system ideal for benchmarking trajectory inference methods.

In [None]:
adata = cr.datasets.pancreas()
adata

In [None]:
# Examine the cell type annotations
print("Cell types:", adata.obs["clusters"].unique().tolist())
print("\nCell type counts:")
print(adata.obs["clusters"].value_counts())

In [None]:
scv.pl.proportions(adata, groupby="clusters")

---

## Part 2: Preprocessing and RNA Velocity

### Preprocessing

Standard scRNA-seq preprocessing: filter genes, normalize, log-transform, find highly variable genes, and compute PCA + UMAP.

In [None]:
# Standard scVelo preprocessing pipeline
scv.pp.filter_and_normalize(adata, min_shared_counts=20, n_top_genes=2000)
sc.tl.pca(adata)
sc.pp.neighbors(adata, n_pcs=30, n_neighbors=30)
scv.pp.moments(adata, n_pcs=None, n_neighbors=None)

In [None]:
# Visualize the UMAP embedding colored by cell type
scv.pl.scatter(
    adata,
    basis="umap",
    color="clusters",
    legend_loc="right margin",
    title="Pancreatic Endocrinogenesis",
)

### RNA Velocity with the Dynamical Model

RNA velocity estimates the **rate of change** of gene expression by modeling splicing kinetics:

$$\frac{du}{dt} = \alpha(t) - \beta u(t)$$
$$\frac{ds}{dt} = \beta u(t) - \gamma s(t)$$

where:
- $u$ = unspliced mRNA abundance
- $s$ = spliced mRNA abundance  
- $\alpha$ = transcription rate
- $\beta$ = splicing rate
- $\gamma$ = degradation rate

The **velocity** is defined as:

$$v = \frac{ds}{dt} = \beta u - \gamma s$$

scVelo's **dynamical model** fits a full ODE for each gene, recovering time-varying kinetic parameters.

In [None]:
# Fit the dynamical model (this takes a few minutes)
scv.tl.recover_dynamics(adata, n_jobs=8)

In [None]:
# Compute velocity and the velocity graph
scv.tl.velocity(adata, mode="dynamical")
scv.tl.velocity_graph(adata)

In [None]:
# Visualize the velocity streamlines on UMAP
scv.pl.velocity_embedding_stream(
    adata,
    basis="umap",
    color="clusters",
    legend_loc="right margin",
    title="RNA Velocity Streamlines",
)

In [None]:
# Also compute latent time from the dynamical model
scv.tl.latent_time(adata)
scv.pl.scatter(adata, color="latent_time", color_map="gnuplot", size=80)

---

## Part 3: CellRank 2 — Kernel-Based Transition Matrix

### The VelocityKernel

The `VelocityKernel` converts RNA velocity vectors into a cell-cell **transition matrix**.

For each cell $i$ with velocity $v_i$, the transition probability to neighbor $j$ depends on how well the velocity "points toward" $j$:

$$\tilde{p}_{ij} \propto \exp\left(\frac{\cos\angle(v_i,\, x_j - x_i)}{\sigma}\right)$$

The resulting matrix is row-normalized to produce valid transition probabilities.

In [None]:
from cellrank.kernels import VelocityKernel

vk = VelocityKernel(adata)
vk.compute_transition_matrix()
print(vk)

### The ConnectivityKernel

The `ConnectivityKernel` is a **symmetric** kernel based on the KNN graph similarities, with no directional information. It acts as a **diffusion** component and is useful for smoothing the velocity signal.

$$T^{\text{conn}}_{ij} = \frac{w_{ij}}{\sum_k w_{ik}}$$

where $w_{ij}$ are the KNN graph edge weights.

In [None]:
from cellrank.kernels import ConnectivityKernel

ck = ConnectivityKernel(adata)
ck.compute_transition_matrix()
print(ck)

### Combining Kernels

Combine velocity-based directionality with connectivity-based smoothing:

$$T_{\text{combined}} = 0.8 \cdot T_{\text{velocity}} + 0.2 \cdot T_{\text{connectivity}}$$

The 80/20 weighting prioritizes velocity while allowing some diffusion to smooth noise.

In [None]:
combined_kernel = 0.8 * vk + 0.2 * ck
print(combined_kernel)

### Visualize Transition Matrix as Random Walk Simulation

We can visualize the transition matrix by projecting it as a **streamplot**, similar to RNA velocity but now incorporating the full Markov chain.

In [None]:
combined_kernel.plot_projection(basis="umap", color="clusters")

---

## Part 4: Estimators — Identifying Terminal & Initial States

### GPCCA Estimator

The **GPCCA** (Generalized Perron Cluster Cluster Analysis) estimator identifies **macrostates** from the transition matrix.

Steps:
1. Compute the **Schur decomposition** of $T$ to find the leading eigenvalues/vectors
2. Use the **eigenvalue gap** to determine the number of macrostates
3. Soft-assign cells to macrostates
4. Classify macrostates as **terminal** (absorbing) or **initial** based on coarse-grained self-transition probabilities

In [None]:
from cellrank.estimators import GPCCA

g = GPCCA(combined_kernel)
print(g)

### Compute Schur Decomposition

The Schur decomposition reveals the coarse-grained structure. We look at the **eigenvalue spectrum** to determine the number of macrostates. Eigenvalues close to 1 indicate metastable states.

In [None]:
g.compute_schur(n_components=20)
g.plot_spectrum()
plt.show()

The **eigengap** (largest gap in the real part of eigenvalues) suggests the optimal number of macrostates. Look for a clear drop in the eigenvalue spectrum.

In [None]:
g.plot_spectrum(real_only=True)

### Compute Macrostates

Based on the eigenvalue spectrum, select the number of macrostates. For this pancreas dataset, we expect ~6 macrostates (Ductal, Ngn3 low EP, Ngn3 high EP, Alpha, Beta, Delta, Epsilon).

In [None]:
g.compute_macrostates(n_states=6, cluster_key="clusters")
g.plot_macrostates(which="all", basis="umap", legend_loc="right margin", s=100)

In [None]:
# View the coarse-grained transition matrix between macrostates
g.plot_coarse_T()

### Classify Terminal States

Terminal states are macrostates that cells tend to "absorb into" — once a cell reaches a terminal state, it is unlikely to leave. CellRank identifies these by analyzing the coarse-grained dynamics.

In [None]:
g.predict_terminal_states()
g.plot_macrostates(which="terminal", basis="umap", legend_loc="right margin", s=100)

In [None]:
# Also identify initial states
g.predict_initial_states(allow_overlap=True)
g.plot_macrostates(which="initial", basis="umap", legend_loc="right margin", s=100)

---

## Part 5: Fate Probabilities

### Computing Absorption Probabilities

For each cell, compute the probability of being absorbed into each terminal state. These are the **fate probabilities** — the key output of CellRank.

Mathematically, for terminal state $m$, the absorption probability vector $\mathbf{a}_m$ satisfies:

$$(I - T_{\text{transient}}) \mathbf{a}_m = T_{\text{transient} \to m} \cdot \mathbf{1}_m$$

In [None]:
g.compute_fate_probabilities()
print(g.fate_probabilities)

In [None]:
# Plot fate probabilities on UMAP — each panel is one terminal state
g.plot_fate_probabilities(basis="umap", same_plot=False)

In [None]:
# Aggregate fate probabilities by cluster and visualize as a bar chart
cr.pl.circular_projection(adata, keys=["clusters"], legend_loc="right")

---

## Part 6: Identifying Driver Genes

### Gene Expression Trends Along Lineages

CellRank can identify **driver genes** — genes whose expression is significantly correlated with fate probabilities toward a given terminal state. These are genes that drive or mark a particular cell fate decision.

The approach:
1. Use fate probabilities as a continuous "lineage coordinate"
2. Fit gene expression as a function of this coordinate using GAMs (Generalized Additive Models)
3. Rank genes by the significance of their association

In [None]:
# Compute driver genes for each terminal state
# This correlates gene expression with absorption probabilities
terminal_states = g.terminal_states.cat.categories.tolist()
print("Terminal states:", terminal_states)

In [None]:
# Get driver genes for each lineage
drivers = g.compute_lineage_drivers(
    lineages=terminal_states,
    return_drivers=True,
)
drivers.head(10)

In [None]:
# Visualize top driver genes for each fate
# Show the top 3 driver genes per lineage
for lineage in terminal_states:
    col = f"{lineage}_corr"
    if col in drivers.columns:
        top_genes = drivers[col].sort_values(ascending=False).head(3).index.tolist()
        print(f"\nTop drivers for {lineage}: {top_genes}")
        scv.pl.scatter(adata, color=top_genes, basis="umap", ncols=3)

---

## Part 7: Gene Expression Trends

### Modeling Gene Expression Along Fate Trajectories

CellRank can fit **gene expression trends** along lineages using GAMs. This allows visualizing how a gene's expression changes as cells commit to a particular fate.

In [None]:
# Set up the gene expression model
from cellrank.models import GAM

model = GAM(adata)

In [None]:
# Plot gene expression trends for key marker genes along lineages
# Using known pancreas markers:
# Ins1/Ins2 -> Beta cells
# Gcg -> Alpha cells  
# Sst -> Delta cells

cr.pl.gene_trends(
    adata,
    model=model,
    genes=["Ins1", "Gcg", "Sst"],
    time_key="latent_time",
    same_plot=True,
    ncols=3,
    hide_cells=False,
)

In [None]:
# Heatmap of gene expression trends
cr.pl.heatmap(
    adata,
    model=model,
    genes=drivers.head(50).index.tolist()[:20],
    time_key="latent_time",
    show_fate_probabilities=True,
)

---

## Part 8: Alternative Kernels

### 8A: PseudotimeKernel

When RNA velocity is not available or unreliable, the **PseudotimeKernel** can use any pseudotime ordering to provide directionality.

It uses a soft-assignment that biases transitions toward cells with higher pseudotime values. The key parameter is `threshold_scheme` which controls how strictly the pseudotime ordering is enforced.

In [None]:
from cellrank.kernels import PseudotimeKernel

# Use scVelo's latent_time as pseudotime
pk = PseudotimeKernel(adata, time_key="latent_time")
pk.compute_transition_matrix(threshold_scheme="soft")
print(pk)

In [None]:
# Visualize the pseudotime kernel's transition matrix
pk.plot_projection(basis="umap", color="clusters")

In [None]:
# Run the full CellRank pipeline with the PseudotimeKernel
g_pt = GPCCA(pk)
g_pt.compute_schur(n_components=20)
g_pt.compute_macrostates(n_states=6, cluster_key="clusters")
g_pt.predict_terminal_states()
g_pt.compute_fate_probabilities()
g_pt.plot_fate_probabilities(basis="umap", same_plot=False)

### 8B: CytoTRACEKernel

The **CytoTRACEKernel** uses CytoTRACE scores as a proxy for developmental potential. CytoTRACE is based on the observation that the **number of expressed genes** decreases during differentiation.

This kernel is useful when:
- No spliced/unspliced information is available
- No pseudotime has been computed
- No experimental time labels exist

In [None]:
from cellrank.kernels import CytoTRACEKernel

ctk = CytoTRACEKernel(adata)
ctk.compute_cytotrace()
ctk.compute_transition_matrix()
print(ctk)

In [None]:
# Plot CytoTRACE scores — higher = more progenitor-like
scv.pl.scatter(adata, color="ct_score", color_map="gnuplot", size=80, title="CytoTRACE Score")

In [None]:
# Visualize CytoTRACEKernel transitions
ctk.plot_projection(basis="umap", color="clusters")

In [None]:
# Run CellRank pipeline with CytoTRACEKernel
g_ct = GPCCA(ctk)
g_ct.compute_schur(n_components=20)
g_ct.compute_macrostates(n_states=6, cluster_key="clusters")
g_ct.predict_terminal_states()
g_ct.compute_fate_probabilities()
g_ct.plot_fate_probabilities(basis="umap", same_plot=False)

---

## Part 9: RealTimeKernel with Optimal Transport

### When to Use

The **RealTimeKernel** is designed for **time-series single-cell experiments** where cells are profiled at discrete time points. Since scRNA-seq is destructive, we cannot directly track individual cells. Instead, **optimal transport** (OT) is used to probabilistically map cells between time points.

### Optimal Transport (OT) Formulation

Given cells at time $t_k$ and $t_{k+1}$, OT finds a **coupling matrix** $\pi \in \mathbb{R}^{n_k \times n_{k+1}}$ minimizing:

$$\pi^* = \arg\min_{\pi \in \Pi(\mu_k, \mu_{k+1})} \sum_{i,j} c(x_i, x_j) \pi_{ij} + \varepsilon \sum_{i,j} \pi_{ij} \log \pi_{ij}$$

where:
- $c(x_i, x_j)$ = cost function (typically squared Euclidean distance in PCA space)
- $\varepsilon$ = entropic regularization parameter
- $\Pi(\mu_k, \mu_{k+1})$ = set of valid couplings with marginals matching the empirical distributions

The `moscot` package handles the OT computation. CellRank's `RealTimeKernel` then converts the coupling into a transition matrix on the full KNN graph.

### Demo with Reprogramming Data

We use a reprogramming dataset where cells are collected at multiple time points during conversion from fibroblasts to induced endoderm progenitors (iEPs).

In [None]:
import warnings
warnings.filterwarnings("ignore")

# Load a time-series dataset
adata_reprog = cr.datasets.reprogramming_morris()
adata_reprog

In [None]:
print("Time points:", sorted(adata_reprog.obs["day"].unique().tolist()))
print("\nCells per time point:")
print(adata_reprog.obs["day"].value_counts().sort_index())

In [None]:
sc.pl.umap(adata_reprog, color=["day", "cell_type"], ncols=2)

In [None]:
from cellrank.kernels import RealTimeKernel

rtk = RealTimeKernel.from_moscot(adata_reprog, time_key="day")
rtk.compute_transition_matrix()
print(rtk)

In [None]:
rtk.plot_projection(basis="umap", color="cell_type")

In [None]:
# Run CellRank estimator on the RealTimeKernel
g_rt = GPCCA(rtk)
g_rt.compute_schur(n_components=20)
g_rt.plot_spectrum(real_only=True)

In [None]:
g_rt.compute_macrostates(n_states=4, cluster_key="cell_type")
g_rt.predict_terminal_states()
g_rt.plot_macrostates(which="terminal", basis="umap", legend_loc="right margin", s=100)

In [None]:
g_rt.compute_fate_probabilities()
g_rt.plot_fate_probabilities(basis="umap", same_plot=False)

---

## Part 10: Kernel Comparison & Combination

### Comparing Kernels Side-by-Side

Different kernels capture different aspects of cellular dynamics. Comparing them helps assess which source of information is most appropriate for your dataset.

In [None]:
# Compare VelocityKernel and PseudotimeKernel on the pancreas data
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

vk.plot_projection(basis="umap", color="clusters", ax=axes[0], title="VelocityKernel")
pk.plot_projection(basis="umap", color="clusters", ax=axes[1], title="PseudotimeKernel")

plt.tight_layout()
plt.show()

### Combining Multiple Kernels

Kernels can be combined using **weighted sums** or **products**:

$$T_{\text{combined}} = \alpha_1 T_1 + \alpha_2 T_2 + \ldots$$

This is useful when you have multiple sources of directional information that complement each other.

In [None]:
# Example: combine VelocityKernel and PseudotimeKernel
combined_vk_pk = 0.5 * vk + 0.5 * pk
combined_vk_pk.plot_projection(basis="umap", color="clusters", title="50% Velocity + 50% Pseudotime")

---

## Part 11: Summary & Key Takeaways

### CellRank 2 Workflow Summary

```
1. Preprocess scRNA-seq data (scanpy/scVelo)
         ↓
2. Choose & compute kernel(s):
   • VelocityKernel   (RNA velocity)
   • PseudotimeKernel (pseudotime ordering)
   • RealTimeKernel   (time-series + OT via moscot)
   • CytoTRACEKernel  (gene count-based potency)
         ↓
3. Optionally combine kernels (weighted sum/product)
         ↓
4. GPCCA estimator:
   • Schur decomposition → eigenvalue spectrum
   • Macrostate identification
   • Terminal / initial state classification
         ↓
5. Compute fate probabilities (absorption probabilities)
         ↓
6. Downstream analysis:
   • Driver gene identification
   • Gene expression trends (GAMs)
   • Fate probability visualization
```

### When to Use Which Kernel?

| Scenario | Recommended Kernel |
|----------|-------------------|
| Standard scRNA-seq with splicing info | `VelocityKernel` + `ConnectivityKernel` |
| Pseudotime available, velocity unreliable | `PseudotimeKernel` |
| Time-series experiment | `RealTimeKernel` (via moscot) |
| No velocity, no time, no pseudotime | `CytoTRACEKernel` |
| Multiple signals available | Combine kernels with weighted sum |

### Key Concepts

- **Markov chain on KNN graph**: cells = states, transitions = weighted directed edges
- **Kernels**: modular components that encode directional information into a transition matrix
- **GPCCA**: spectral method for coarse-graining the Markov chain into macrostates
- **Absorption probabilities**: quantify fate commitment for every cell
- **Driver genes**: genes correlated with fate probability, identify molecular programs

### References

- Lange et al. (2024) *CellRank 2: unified fate mapping in multiview single-cell data.* Nature Methods.
- Lange et al. (2022) *CellRank for directed single-cell fate mapping.* Nature Methods.
- Weiler et al. (2024) *CellRank 2 Protocol.* Nature Protocols.
- Bergen et al. (2020) *Generalizing RNA velocity to transient cell states through dynamical modeling.* Nature Biotechnology.
- Klein et al. (2023) *moscot: Multi-omic single-cell optimal transport.* 

In [None]:
import session_info
session_info.show()