# MIOFlow — Embryoid Body Data: Preprocessing & Trajectory Inference

This notebook walks through the **complete pipeline** for the Embryoid Body (EB) scRNA-seq dataset:

1. Download the raw 10X data from [Mendeley Datasets](https://data.mendeley.com/datasets/v6n743h5ng/)
2. Preprocess the data (QC filtering, normalisation, sqrt transform)
3. Embed into PCA and PHATE latent spaces
4. Train MIOFlow to infer developmental trajectories
5. Decode trajectories back to gene space to obtain gene expression trends

> **Dataset reference:** Moon et al. (2019), *Visualizing structure and transitions in high-dimensional biological data*, Nature Biotechnology. Data available at [Mendeley Datasets (v6n743h5ng)](https://data.mendeley.com/datasets/v6n743h5ng/).
>
> **MIOFlow reference:** Huguet et al. (2022), *Manifold Interpolating Optimal-Transport Flows for Trajectory Inference*, NeurIPS. [arXiv:2206.14928](https://arxiv.org/abs/2206.14928)

## 0. Installation

Run the cell below once to install all required packages.

In [None]:
!pip install mioflow phate

In [None]:
# Enable interactive widgets in Google Colab
try:
    from google.colab import output
    output.enable_custom_widget_manager()
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

print(f'Running in Google Colab: {IN_COLAB}')

## 1. Import Libraries

In [None]:
import os
import urllib.request
import zipfile
import shutil

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import scanpy as sc
import phate
import torch

from MIOFlow.mioflow import MIOFlow
from MIOFlow.plots import plot_losses
from MIOFlow.utils import set_seeds

set_seeds(0)

use_cuda = torch.cuda.is_available()
print(f'Using CUDA: {use_cuda}')

## 2. Download the Embryoid Body Data

The raw scRNA-seq data is publicly available on [Mendeley Datasets (v6n743h5ng)](https://data.mendeley.com/datasets/v6n743h5ng/).

The dataset contains five 10X Genomics samples collected at different time points:

| Folder   | Time point   |
|----------|--------------|
| `T0_1A`  | Day 00–03    |
| `T2_3B`  | Day 06–09    |
| `T4_5C`  | Day 12–15    |
| `T6_7D`  | Day 18–21    |
| `T8_9E`  | Day 24–27    |

Each folder contains `barcodes.tsv`, `genes.tsv`, and `matrix.mtx` files as produced by CellRanger.

In [None]:
RAW_DATA_DIR = 'data/raw/scRNAseq'
EXPECTED_DIRS = ['T0_1A', 'T2_3B', 'T4_5C', 'T6_7D', 'T8_9E']

def data_already_downloaded(base_dir, expected_dirs):
    return all(os.path.isdir(os.path.join(base_dir, d)) for d in expected_dirs)

if data_already_downloaded(RAW_DATA_DIR, EXPECTED_DIRS):
    print('Data already present — skipping download.')
else:
    os.makedirs(RAW_DATA_DIR, exist_ok=True)

    url = 'https://data.mendeley.com/public-api/zip/v6n743h5ng/download/1'
    zip_file = os.path.join(RAW_DATA_DIR, 'v6n743h5ng-1.zip')

    print(f'Downloading from {url} ...')
    req = urllib.request.Request(url, headers={
        'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
    })
    with urllib.request.urlopen(req) as response, open(zip_file, 'wb') as out_file:
        shutil.copyfileobj(response, out_file)

    # Extract outer zip to a temp directory
    temp_extract = os.path.join(RAW_DATA_DIR, 'temp_extract')
    os.makedirs(temp_extract, exist_ok=True)

    print('Extracting ...')
    with zipfile.ZipFile(zip_file, 'r') as zip_ref:
        zip_ref.extractall(temp_extract)

    # Extract inner scRNAseq.zip if present
    scrna_zip = os.path.join(temp_extract, 'scRNAseq.zip')
    if os.path.exists(scrna_zip):
        with zipfile.ZipFile(scrna_zip, 'r') as zip_ref:
            zip_ref.extractall(temp_extract)

    # Move sample directories to RAW_DATA_DIR
    scrna_folder = os.path.join(temp_extract, 'scRNAseq')
    src_root = scrna_folder if os.path.exists(scrna_folder) else temp_extract
    for item in EXPECTED_DIRS:
        src = os.path.join(src_root, item)
        dst = os.path.join(RAW_DATA_DIR, item)
        if os.path.exists(src):
            shutil.move(src, dst)

    shutil.rmtree(temp_extract)
    os.remove(zip_file)
    print(f'Done. Data extracted to {RAW_DATA_DIR}')

print('Data directories found:')
for d in EXPECTED_DIRS:
    path = os.path.join(RAW_DATA_DIR, d)
    status = 'OK' if os.path.isdir(path) else 'MISSING'
    print(f'  {path}  [{status}]')

## 3. Preprocessing

We follow the standard scRNA-seq preprocessing pipeline:

1. Load 10X data with `scanpy` and concatenate all time points (with QC metrics)
2. Filter cells by library size (remove top and bottom 20% per sample)
3. Remove genes expressed in fewer than 10 cells
4. Library-size normalise
5. Remove dead cells (high mitochondrial RNA expression)
6. Square-root transform

### 3.1 Load 10X Data

`sc.read_10x_mtx` reads each CellRanger output directory directly into an AnnData object. We label each sample with its time point, concatenate all samples, then compute QC metrics — including the fraction of mitochondrial gene counts needed for dead-cell removal later.

In [None]:
samples = ['T0_1A', 'T2_3B', 'T4_5C', 'T6_7D', 'T8_9E']
labels  = ['Day 00-03', 'Day 06-09', 'Day 12-15', 'Day 18-21', 'Day 24-27']

adatas = []
for sample, label in zip(samples, labels):
    adata_sample = sc.read_10x_mtx(
        os.path.join(RAW_DATA_DIR, sample),
        var_names='gene_symbols',
        make_unique=True,
        cache=True,
    )
    adata_sample.obs['time_label'] = label
    adatas.append(adata_sample)

adata = sc.concat(adatas, merge='same')
adata.obs_names_make_unique()

# Compute QC metrics (library size, mitochondrial gene fraction)
adata.var['mt'] = adata.var_names.str.startswith('MT-')
sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)

print('Cells per sample:')
for label in labels:
    n = (adata.obs['time_label'] == label).sum()
    print(f'  {label}: {n}')
print(f'\nTotal: {adata.n_obs} cells × {adata.n_vars} genes')

### 3.2 Library Size Filtering

We remove cells in the top and bottom 20% of library sizes **within each sample**. This removes empty droplets and potential doublets while accounting for the fact that library size correlates with sample.

In [None]:
min_percentile = 20
max_percentile = 80

fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()

for idx, label in enumerate(labels):
    sample_counts = adata.obs.loc[adata.obs['time_label'] == label, 'total_counts']
    t_min = np.percentile(sample_counts, min_percentile)
    t_max = np.percentile(sample_counts, max_percentile)

    axes[idx].hist(sample_counts, bins=50, alpha=0.7, edgecolor='black', log=True)
    axes[idx].axvline(t_min, color='red',  linestyle='--', linewidth=2,
                      label=f'{min_percentile}th: {t_min:.0f}')
    axes[idx].axvline(t_max, color='blue', linestyle='--', linewidth=2,
                      label=f'{max_percentile}th: {t_max:.0f}')
    axes[idx].set_xlabel('Library Size (Total Counts)')
    axes[idx].set_ylabel('Number of Cells')
    axes[idx].set_title(label)
    axes[idx].legend(fontsize=8)
    axes[idx].grid(alpha=0.3)

axes[-1].axis('off')
plt.tight_layout()
plt.show()

In [None]:
cells_to_keep = []

for label in labels:
    sample_counts = adata.obs.loc[adata.obs['time_label'] == label, 'total_counts']
    t_min = np.percentile(sample_counts, min_percentile)
    t_max = np.percentile(sample_counts, max_percentile)
    keep  = sample_counts[(sample_counts >= t_min) & (sample_counts <= t_max)].index
    cells_to_keep.extend(keep.tolist())
    print(f'{label}: keeping {len(keep)}/{len(sample_counts)} cells '
          f'(range: {t_min:.0f}–{t_max:.0f})')

adata = adata[cells_to_keep, :].copy()
print(f'\nCells after library-size filtering: {adata.n_obs}')

### 3.3 Remove Rare Genes

Genes expressed in 10 or fewer cells are unlikely to be biologically informative and are removed.

In [None]:
sc.pp.filter_genes(adata, min_cells=10)
print(f'Genes after rare-gene filtering: {adata.n_vars}')

### 3.4 Library-Size Normalisation

Divide each cell by its total count and rescale by the median library size to make cells comparable.

In [None]:
sc.pp.normalize_total(adata, target_sum=np.median(adata.obs['total_counts']))
print('Library-size normalisation complete.')

### 3.5 Remove Dead Cells

Dead cells show elevated mitochondrial RNA expression. We remove cells in the top 90th percentile of `pct_counts_mt`, which was already computed in step 3.1.

In [None]:
mito_pct       = adata.obs['pct_counts_mt']
mito_percentile = 90
mito_threshold  = np.percentile(mito_pct, mito_percentile)

print(f'Mitochondrial genes found: {adata.var["mt"].sum()}')

fig, ax = plt.subplots(figsize=(8, 5))
ax.hist(mito_pct, bins=50, alpha=0.7, edgecolor='black')
ax.axvline(mito_threshold, color='red', linestyle='--', linewidth=2,
           label=f'{mito_percentile}th percentile: {mito_threshold:.1f}%')
ax.set_xlabel('Mitochondrial Gene %')
ax.set_ylabel('Number of Cells')
ax.set_title('Mitochondrial Content Distribution')
ax.legend()
ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
adata = adata[adata.obs['pct_counts_mt'] < mito_threshold].copy()
print(f'Cells after dead-cell removal: {adata.n_obs}')

### 3.6 Square-Root Transform

We use the square-root transform instead of log1p to stabilise variance. It has the same compressive shape as log but is stable at zero.

In [None]:
import scipy.sparse as sp
adata.X = np.sqrt(adata.X.toarray() if sp.issparse(adata.X) else adata.X)
print(f'Preprocessing complete. Final matrix: {adata.n_obs} cells × {adata.n_vars} genes')

## 4. AnnData Summary

All preprocessing is complete. The AnnData object is ready for dimensionality reduction.

In [None]:
print(adata)

## 5. Dimensionality Reduction

### 5.1 PCA

We first reduce to 50 principal components to denoise and speed up the subsequent PHATE embedding.

In [None]:
sc.pp.pca(adata, n_comps=50)
print(f'PCA embedding shape: {adata.obsm["X_pca"].shape}')

### 5.2 PHATE

PHATE (Potential of Heat-diffusion for Affinity-based Trajectory Embedding) preserves both local and global manifold structure, making it well suited for trajectory inference. We run it on the PCA embedding.

This step can take a few minutes.

In [None]:
phate_op = phate.PHATE(n_components=2, n_jobs=-2)
X_phate = phate_op.fit_transform(adata.obsm['X_pca'])

adata.obsm['X_phate'] = X_phate
print(f'PHATE embedding shape: {adata.obsm["X_phate"].shape}')

In [None]:
sc.pl.embedding(
    adata, basis='phate', color='time_label',
    cmap='Spectral', title='Embryoid Body — PHATE embedding',
)

## 6. Prepare Data for MIOFlow

MIOFlow expects a DataFrame with:
- One column per embedding dimension (`d1`, `d2`, …)
- A `samples` column with an **integer** time-bin label for each cell

In [None]:
# Create integer time bins: Day 00-03 → 0, Day 06-09 → 1, …, Day 24-27 → 4
adata.obs['discrete_time'], _ = pd.factorize(adata.obs['time_label'])

# Build the input DataFrame (PHATE dims + samples column)
n_phate = adata.obsm['X_phate'].shape[1]
mioflow_df = pd.DataFrame(
    adata.obsm['X_phate'],
    columns=[f'd{i}' for i in range(1, n_phate + 1)],
)
mioflow_df['samples'] = adata.obs['discrete_time'].values

print(mioflow_df.head())
print(f'\nTime bins: {sorted(mioflow_df["samples"].unique())}')

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
sns.scatterplot(
    data=mioflow_df, x='d1', y='d2',
    hue='samples', palette='viridis',
    s=3, ax=ax,
)
ax.set_title('PHATE embedding coloured by discrete time bin')
plt.tight_layout()
plt.show()

## 7. Configure MIOFlow

The table below describes the most important hyperparameters.

| Parameter | Default | Description |
|---|---|---|
| `n_epochs` | 40 | Number of global training epochs |
| `use_density_loss` | `True` | Add a kNN-based density regulariser |
| `lambda_density` | 20 | Weight for the density loss |
| `sample_size` | 100 | Cells sampled per time step per batch |
| `n_trajectories` | 100 | Number of trajectories to integrate |
| `n_bins` | 100 | Number of time bins for ODE integration |

In [None]:
# Model architecture
MODEL_CONFIG = {
    'layers': [16, 32, 16],
    'activation': 'CELU',
    'use_cuda': use_cuda,
}

# Training
TRAINING_CONFIG = {
    'n_epochs': 40,
}

# Loss
OPTIMIZATION_CONFIG = {
    'use_density_loss': True,
    'lambda_density': 20,
}

# Data sampling
DATA_CONFIG = {
    'sample_size': 100,
}

# Output
OUTPUT_CONFIG = {
    'exp_dir': '.',
    'n_trajectories': 100,
    'n_bins': 100,
}

## 8. Initialise MIOFlow

Pass the AnnData object, the input DataFrame, and the configuration dictionaries to `MIOFlow`. The `obsm_key` parameter tells MIOFlow which embedding to use for the trajectory inference.

In [None]:
mioflow_operator = MIOFlow(
    adata,
    input_df=mioflow_df,
    obsm_key='X_phate',
    debug_level='info',
    model_config=MODEL_CONFIG,
    **TRAINING_CONFIG,
    **OPTIMIZATION_CONFIG,
    **DATA_CONFIG,
    **OUTPUT_CONFIG,
)

## 9. Train — `~5 minutes`

`fit()` trains the Neural ODE end-to-end using the optimal-transport loss. Progress is printed each epoch.

In [None]:
mioflow = mioflow_operator.fit()

## 10. Training Losses

The losses should decrease during training, indicating that the algorithm is learning to fit the data.

In [None]:
plot_losses(
    mioflow.local_losses,
    mioflow.batch_losses,
    mioflow.globe_losses,
)

## 11. Visualise Trajectories

`mioflow.trajectories` has shape `(n_bins, n_trajectories, n_dims)` in normalised PHATE space. We denormalise before plotting.

In [None]:
print('Trajectory shape (n_bins, n_trajectories, n_dims):', mioflow.trajectories.shape)

In [None]:
# Denormalise trajectories and original data back to PHATE scale
traj_pts  = mioflow.trajectories * mioflow.std_vals + mioflow.mean_vals
dim_cols  = [c for c in mioflow.df.columns if c != 'samples']
true_data = mioflow.df[dim_cols].values * mioflow.std_vals + mioflow.mean_vals

fig, ax = plt.subplots(figsize=(10, 8))
sc = ax.scatter(
    true_data[:, 0], true_data[:, 1],
    c=mioflow.df['samples'].values, cmap='viridis', s=1, alpha=0.5,
)
plt.colorbar(sc, ax=ax, label='time bin')

for traj in np.transpose(traj_pts, axes=(1, 0, 2)):  # iterate over trajectories
    ax.plot(traj[:, 0], traj[:, 1], alpha=0.5, linewidth=0.5, color='black')
    ax.annotate(
        '', xy=(traj[-1, 0], traj[-1, 1]), xytext=(traj[-2, 0], traj[-2, 1]),
        arrowprops=dict(arrowstyle='->', color='black', lw=0.5, mutation_scale=10),
    )

ax.set_title('MIOFlow trajectories on PHATE embedding')
ax.set_xlabel('PHATE 1')
ax.set_ylabel('PHATE 2')
plt.tight_layout()
plt.show()

## 12. Decode Trajectories to Gene Space

`decode_to_gene_space()` inverts the PCA projection to recover trajectories in the original gene-expression space.

The result has shape `(n_bins, n_trajectories, n_genes)`.

In [None]:
trajectories_gene_space = mioflow.decode_to_gene_space()
print('Gene-space trajectory shape (n_bins, n_trajectories, n_genes):', trajectories_gene_space.shape)

## 13. Gene Expression Trends

### 13.1 Top Highly-Variable Genes

We select the 25 most highly variable genes and plot their mean expression (± std) over all trajectories.

In [None]:
sc.pp.highly_variable_genes(adata, n_top_genes=25)
example_genes    = adata.var_names[adata.var['highly_variable']]
example_gene_mask = adata.var_names.isin(example_genes)
print('Top 25 highly variable genes:\n', list(example_genes))

In [None]:
adata_hvg         = adata[:, example_gene_mask]
decoded_hvg       = trajectories_gene_space[:, :, example_gene_mask]  # (n_bins, n_traj, n_hvg)
decoded_hvg_mean  = decoded_hvg.mean(axis=1)   # mean over trajectories → (n_bins, n_hvg)
decoded_hvg_std   = decoded_hvg.std(axis=1)    # (n_bins, n_hvg)

x_time      = np.linspace(0, 1, decoded_hvg_mean.shape[0])
obs_time    = adata_hvg.obs['discrete_time']
obs_time_n  = (obs_time - obs_time.min()) / (obs_time.max() - obs_time.min())

import scipy.sparse as sp
X_hvg = adata_hvg.X.toarray() if sp.issparse(adata_hvg.X) else adata_hvg.X
data_df           = pd.DataFrame(X_hvg, columns=example_genes)
data_df['x_time'] = obs_time_n.values
data_mean         = data_df.groupby('x_time').mean()

n_genes = decoded_hvg_mean.shape[1]
n_cols  = 5
n_rows  = int(np.ceil(n_genes / n_cols))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows))
axes = axes.flatten()

for i, gene in enumerate(example_genes):
    ax = axes[i]
    ax.plot(x_time, decoded_hvg_mean[:, i], label='MIOFlow (mean)', color='tab:blue')
    ax.fill_between(
        x_time,
        decoded_hvg_mean[:, i] - decoded_hvg_std[:, i],
        decoded_hvg_mean[:, i] + decoded_hvg_std[:, i],
        alpha=0.2, color='tab:blue',
    )
    if gene in data_mean.columns:
        ax.scatter(data_mean.index, data_mean[gene], label='Observed (mean)', s=20, color='tab:orange')
    ax.set_title(gene, fontsize=9)
    ax.set_xlabel('Normalised time')
    ax.legend(fontsize=7)

for i in range(n_genes, len(axes)):
    axes[i].set_visible(False)

plt.suptitle('Gene Expression Trends Along MIOFlow Trajectories', fontsize=16, y=1.01)
plt.tight_layout()
plt.show()

### 13.2 Single Gene of Interest

Replace `interest_gene` with any gene name you want to investigate.

In [None]:
interest_gene = 'CXCL3 (ENSG00000163734)'  # change this to any gene of interest

gene_mask = adata.var_names.isin([interest_gene])
if gene_mask.sum() == 0:
    print(f"Gene '{interest_gene}' not found. Available genes (first 10): {list(adata.var_names[:10])}")
else:
    gene_data    = adata[:, gene_mask]
    decoded_gene = trajectories_gene_space[:, :, gene_mask]   # (n_bins, n_traj, 1)
    decoded_mean = decoded_gene.mean(axis=1).flatten()         # mean over trajectories → (n_bins,)
    decoded_std  = decoded_gene.std(axis=1).flatten()

    x_time     = np.linspace(0, 1, len(decoded_mean))
    obs_time   = gene_data.obs['discrete_time']
    obs_time_n = (obs_time - obs_time.min()) / (obs_time.max() - obs_time.min())

    X_gene    = gene_data.X.toarray() if sp.issparse(gene_data.X) else gene_data.X
    orig_df   = pd.DataFrame({'expr': X_gene.flatten(), 'time': obs_time_n.values})
    orig_mean = orig_df.groupby('time')['expr'].mean()

    fig, ax = plt.subplots(figsize=(10, 5))
    ax.plot(x_time, decoded_mean, label='MIOFlow (mean)', linewidth=2)
    ax.fill_between(x_time, decoded_mean - decoded_std, decoded_mean + decoded_std, alpha=0.2)
    ax.scatter(orig_mean.index, orig_mean.values, label='Observed (mean)', s=30)
    ax.set_xlabel('Normalised time')
    ax.set_ylabel('Expression')
    ax.set_title(f'Gene Expression Trajectory: {interest_gene}')
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

## 14. Analyse a Specific Trajectory

Instead of averaging over all trajectories, you can focus on a single trajectory by selecting the one whose endpoint is closest to a target location in PHATE space.

Update `target_x` and `target_y` to the endpoint coordinates of the trajectory you are interested in (read them from the plot in section 11).

In [None]:
target_x = -0.006  # PHATE 1 coordinate of the desired trajectory endpoint
target_y =  0.020  # PHATE 2 coordinate of the desired trajectory endpoint

distances = [
    np.sqrt((traj[-1, 0] - target_x) ** 2 + (traj[-1, 1] - target_y) ** 2)
    for traj in np.transpose(traj_pts, axes=(1, 0, 2))
]
highlight_idx = int(np.argmin(distances))
print(f'Selected trajectory #{highlight_idx} (distance to target: {distances[highlight_idx]:.4f})')

fig, ax = plt.subplots(figsize=(10, 8))
ax.scatter(true_data[:, 0], true_data[:, 1], c=mioflow.df['samples'].values, cmap='viridis', s=1, alpha=0.5)

for i, traj in enumerate(np.transpose(traj_pts, axes=(1, 0, 2))):
    colour    = 'red' if i == highlight_idx else 'black'
    linewidth = 1.5  if i == highlight_idx else 0.4
    alpha     = 1.0  if i == highlight_idx else 0.3
    ax.plot(traj[:, 0], traj[:, 1], alpha=alpha, linewidth=linewidth, color=colour)
    ax.annotate('', xy=(traj[-1, 0], traj[-1, 1]), xytext=(traj[-2, 0], traj[-2, 1]),
                arrowprops=dict(arrowstyle='->', color=colour, lw=linewidth, mutation_scale=10))

ax.plot(target_x, target_y, 'r*', markersize=12, label='target')
ax.set_title(f'Selected trajectory #{highlight_idx}')
ax.legend()
plt.tight_layout()
plt.show()

In [None]:
# Gene trends for the selected trajectory only (no averaging across trajectories)
# trajectories_gene_space shape: (n_bins, n_traj, n_genes)
decoded_selected = trajectories_gene_space[:, highlight_idx, :][:, example_gene_mask]  # (n_bins, n_hvg)

x_time = np.linspace(0, 1, decoded_selected.shape[0])

fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows))
axes = axes.flatten()

for i, gene in enumerate(example_genes):
    ax = axes[i]
    ax.plot(x_time, decoded_selected[:, i], label=f'Traj #{highlight_idx}', color='tab:red')
    if gene in data_mean.columns:
        ax.scatter(data_mean.index, data_mean[gene], label='Observed (mean)', s=20, color='tab:orange')
    ax.set_title(gene, fontsize=9)
    ax.set_xlabel('Normalised time')
    ax.legend(fontsize=7)

for i in range(n_genes, len(axes)):
    axes[i].set_visible(False)

plt.suptitle(f'Gene Trends — Trajectory #{highlight_idx}', fontsize=16, y=1.01)
plt.tight_layout()
plt.show()