# MIOFlow — New High-Level API (Google Colab)

This notebook demonstrates the sklearn-style `MIOFlow` class using the **Embryoid Body** scRNA-seq dataset.

## Workflow
1. Install dependencies and download data
2. Load pre-processed Embryoid Body AnnData
3. Train a **GAGA autoencoder** — learns a geometry-preserving latent embedding from PCA, regularised by PHATE distances
4. Pass the trained GAGA model to `MIOFlow` and call `fit()` — trains a Neural ODE in GAGA latent space using optimal transport
5. Inspect losses and trajectories
6. `decode_to_gene_space()` — maps trajectories back through the GAGA decoder and PCA components to recover gene-level predictions

## 0. Setup — Install packages and enable widgets

In [None]:
!pip install mioflow gdown

In [None]:
# This is necessary to facilitate the visualization of algorithms in Google Colab
from google.colab import output
output.enable_custom_widget_manager()

In [None]:
import os
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import gdown
from sklearn.preprocessing import StandardScaler
from sklearn.metrics.pairwise import pairwise_distances

from MIOFlow.gaga import Autoencoder, train_gaga_two_phase, dataloader_from_pc
from MIOFlow.mioflow import MIOFlow

# This use_cuda flag verifies that we can use GPUs for training.
# CUDA speeds up training significantly on supported hardware.
use_cuda = torch.cuda.is_available()
print(f'Using CUDA: {use_cuda}')

## 1. Load Data

Download the pre-processed Embryoid Body AnnData from Google Drive.

> **Note:** Replace `FILE_ID` below with the Google Drive file ID of your `preprocessed_eb_adata.h5ad`.
> You can find the file ID in the shareable link:
> `https://drive.google.com/file/d/FILE_ID/view`

In [None]:
file_id = "YOUR_FILE_ID_HERE"  # <-- Replace with the Google Drive file ID of preprocessed_eb_adata.h5ad
url = f"https://drive.google.com/uc?id={file_id}"
output_path = "preprocessed_eb_adata.h5ad"
gdown.download(url, output_path, quiet=False)

In [None]:
adata = sc.read_h5ad(output_path)
adata

In [None]:
fig, ax = plt.subplots(figsize=(6, 4))
sns.scatterplot(
    x=adata.obsm['X_phate'][:, 0],
    y=adata.obsm['X_phate'][:, 1],
    hue=adata.obs['time_bin'],
    palette='tab10',
    s=3,
    ax=ax,
)
ax.set_xlabel('PHATE 1')
ax.set_ylabel('PHATE 2')
ax.set_title('PHATE embedding coloured by time bin')
plt.tight_layout()
plt.show()

## 2. Train GAGA Autoencoder

GAGA trains a two-phase autoencoder on PCA embeddings, regularised by PHATE distances:

- **Phase 1** — encoder learns a geometry-preserving latent space (distance preservation loss, decoder frozen)
- **Phase 2** — decoder learns to reconstruct PCA coordinates (reconstruction loss, encoder frozen)

The resulting latent space is where the MIOFlow ODE will be trained.

In [None]:
# ── Scale PCA inputs (same normalisation used during GAGA training) ──────────
X_pca_raw  = adata.obsm['X_pca'].astype(np.float32)
scaler_pca = StandardScaler().fit(X_pca_raw)
X_pca      = scaler_pca.transform(X_pca_raw)

# ── PHATE-based pairwise distances for geometric regularisation ──────────────
scaler_phate    = StandardScaler().fit(adata.obsm['X_phate'])
X_phate_scaled  = scaler_phate.transform(adata.obsm['X_phate'])
phate_distances = pairwise_distances(X_phate_scaled, metric='euclidean').astype(np.float32)

# ── Build model + dataloader ─────────────────────────────────────────────────
input_dim  = X_pca.shape[1]   # 50 PCA components
latent_dim = 2                 # match PHATE dimensionality

gaga_model  = Autoencoder(input_dim, latent_dim, hidden_dims=[128, 64])
gaga_loader = dataloader_from_pc(X_pca, phate_distances, batch_size=1024)
print(f'GAGA architecture: {input_dim} → {latent_dim}')

# ── Two-phase training ───────────────────────────────────────────────────────
gaga_history = train_gaga_two_phase(
    gaga_model,
    gaga_loader,
    encoder_epochs=100,       # Phase 1: distance preservation
    decoder_epochs=100,       # Phase 2: reconstruction
    learning_rate=1e-3,
    dist_weight_phase1=1.0,
    recon_weight_phase2=1.0,
)

In [None]:
# ── Quick sanity-check: compare GAGA latent space with original PHATE ────────
gaga_model.eval()
with torch.no_grad():
    gaga_embeddings = gaga_model.encode(torch.tensor(X_pca)).numpy()

adata.obsm['X_gaga'] = gaga_embeddings

fig, axes = plt.subplots(1, 2, figsize=(12, 5))
for ax, key, title in zip(
    axes,
    ['X_phate', 'X_gaga'],
    ['Original PHATE', 'GAGA Latent Space'],
):
    sc_plot = ax.scatter(
        adata.obsm[key][:, 0], adata.obsm[key][:, 1],
        c=adata.obs['time_bin'], cmap='viridis', s=2, alpha=0.6,
    )
    plt.colorbar(sc_plot, ax=ax, label='time_bin')
    ax.set_title(title)
    ax.set_xlabel('Dim 1')
    ax.set_ylabel('Dim 2')

plt.suptitle('GAGA vs PHATE embedding', fontsize=14)
plt.tight_layout()
plt.show()

## 3. Configure and Initialise MIOFlow

Pass the trained `gaga_model` and `scaler_pca` (the scaler fitted on `X_pca` before GAGA training).  
`MIOFlow` will use the encoder to embed cells into latent space for ODE training, and the decoder in `decode_to_gene_space()`.

In [None]:
mf = MIOFlow(
    adata,
    gaga_model=gaga_model,
    gaga_input_scaler=scaler_pca,    # fitted on X_pca — used for encode() and inverse_transform()
    obs_time_key='time_bin',
    debug_level='info',
    hidden_dim=64,
    use_cuda=use_cuda,
    momentum_beta=0.0,
    # Training
    n_epochs=300,
    # Loss
    use_density_loss=True,
    lambda_ot=1.0,
    lambda_density=0.1,
    lambda_energy=0.01,
    energy_time_steps=20,
    # Data / output
    sample_size=100,
    n_trajectories=100,
    n_bins=100,
    exp_dir='.',
)

## 4. Fit — `~5 minutes`

`fit()` trains the Neural ODE in the GAGA latent space end-to-end using optimal transport loss.

In [None]:
mf.fit()
print(mf)

## 5. Training Losses

After fitting, `mf.losses` contains per-epoch records of `{'Total', 'OT', 'Density', 'Energy'}` losses.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

if mf.losses and mf.losses.get('total_loss'):
    epochs = mf.losses['epoch']

    # Left: total loss
    ax = axes[0]
    ax.plot(epochs, mf.losses['total_loss'], label='Total', color='tab:blue')
    ax.set_title('Global training loss — Total')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_yscale('log')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Right: individual loss components
    ax = axes[1]
    ax.plot(epochs, mf.losses['ot_loss'],      label='OT',      color='tab:orange')
    ax.plot(epochs, mf.losses['density_loss'], label='Density', color='tab:green')
    ax.plot(epochs, mf.losses['energy_loss'],  label='Energy',  color='tab:red')
    ax.set_title('Global training loss — Components')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_yscale('log')
    ax.legend()
    ax.grid(True, alpha=0.3)
else:
    axes[0].set_visible(False)
    axes[1].set_visible(False)

plt.suptitle('MIOFlow Training Losses', fontsize=14)
plt.tight_layout()
plt.show()

## 6. Trajectories in GAGA Latent Space

`mf.trajectories` has shape **`(n_bins, n_trajectories, latent_dim)`**.  
To iterate over individual trajectories use `mf.trajectories[:, i, :]`.

In [None]:
print('Trajectory tensor shape (n_bins, n_trajectories, latent_dim):', mf.trajectories.shape)

# Denormalise back to GAGA latent scale
traj_pts = mf.trajectories * mf.std_vals + mf.mean_vals  # (n_bins, n_traj, 2)

# Original data in GAGA latent scale
all_normed = np.vstack([X for X, _ in mf.dataset.time_series_data])
all_times  = np.concatenate([np.full(len(X), t) for X, t in mf.dataset.time_series_data])
true_data  = all_normed * mf.std_vals + mf.mean_vals

fig, ax = plt.subplots(figsize=(10, 8))
sc_plot = ax.scatter(
    true_data[:, 0], true_data[:, 1],
    c=all_times, cmap='viridis', s=1, alpha=0.6,
)
plt.colorbar(sc_plot, ax=ax, label='Time bin')

for i in range(mf.trajectories.shape[1]):
    traj = traj_pts[:, i, :]  # (n_bins, 2)
    ax.plot(traj[:, 0], traj[:, 1], alpha=0.5, linewidth=0.5, color='black')
    ax.annotate(
        '', xy=(traj[-1, 0], traj[-1, 1]), xytext=(traj[-2, 0], traj[-2, 1]),
        arrowprops=dict(arrowstyle='->', color='black', lw=0.5, mutation_scale=10),
    )

ax.set_xlabel('Latent Dim 1')
ax.set_ylabel('Latent Dim 2')
ax.set_title('MIOFlow trajectories in GAGA latent space')
plt.tight_layout()
plt.show()

## 7. Decode to Gene Space

`decode_to_gene_space()` maps trajectories through the chain:

**GAGA latent → GAGA decoder → PCA space (inverse scaler) → gene space**

Returns shape **`(n_bins, n_trajectories, n_genes)`**.

In [None]:
trajectories_gene_space = mf.decode_to_gene_space()
print('Gene-space trajectory shape (n_bins, n_trajectories, n_genes):',
      trajectories_gene_space.shape)

## 8. Gene Trends — Top Highly-Variable Genes

We plot the mean expression (± std) over all trajectories at each time bin.

In [None]:
sc.pp.highly_variable_genes(adata, n_top_genes=25)
example_genes     = adata.var_names[adata.var['highly_variable']]
example_gene_mask = adata.var_names.isin(example_genes)
print(example_genes)

In [None]:
# decoded shape: (n_bins, n_trajectories, n_selected_genes)
decoded_example_gene = trajectories_gene_space[:, :, example_gene_mask]

# Mean / std over trajectories (axis=1) → (n_bins, n_selected_genes)
decoded_mean = decoded_example_gene.mean(axis=1)
decoded_std  = decoded_example_gene.std(axis=1)

# Normalised time axis
x_time = np.linspace(0, 1, decoded_mean.shape[0])

# Reconstruct integer time labels the same way MIOFlow does internally
obs_time_labels = pd.factorize(adata.obs[mf.obs_time_key])[0]
obs_time_norm   = (obs_time_labels - obs_time_labels.min()) / (obs_time_labels.max() - obs_time_labels.min())

# Original data aggregated by time
adata_sel = adata[:, example_gene_mask]
data_df   = pd.DataFrame(adata_sel.X.toarray(), columns=example_genes)
data_df['t'] = obs_time_norm
data_mean = data_df.groupby('t').mean()

# Grid plot
n_genes = decoded_mean.shape[1]
n_cols  = int(np.ceil(np.sqrt(n_genes)))
n_rows  = int(np.ceil(n_genes / n_cols))
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows))
axes = axes.flatten()

for i, gene in enumerate(example_genes):
    ax = axes[i]
    ax.plot(x_time, decoded_mean[:, i], label='Decoded (mean ± std)')
    ax.fill_between(x_time,
                    decoded_mean[:, i] - decoded_std[:, i],
                    decoded_mean[:, i] + decoded_std[:, i],
                    alpha=0.2)
    ax.scatter(data_mean.index, data_mean[gene], s=10, label='Original (mean)')
    ax.set_title(gene, fontsize=8)
    ax.legend(fontsize=7)

for i in range(n_genes, len(axes)):
    axes[i].set_visible(False)

plt.suptitle('Gene Trends Along the Trajectory', fontsize=18)
plt.tight_layout()
plt.subplots_adjust(top=0.94)
plt.show()

## 9. Single Gene of Interest

Change `interest_gene` below to investigate any gene in the dataset.

In [None]:
interest_gene = 'CXCL3'

gene_mask = adata.var_names.isin([interest_gene])
gene_traj = trajectories_gene_space[:, :, gene_mask]  # (n_bins, n_traj, 1)
gene_mean = gene_traj.mean(axis=1).flatten()           # (n_bins,)
gene_std  = gene_traj.std(axis=1).flatten()

x_time_gene = np.linspace(0, 1, len(gene_mean))

orig_df = pd.DataFrame({
    'expression': adata[:, gene_mask].X.toarray().flatten(),
    'time':       obs_time_norm,
})
orig_mean = orig_df.groupby('time')['expression'].mean()

fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(x_time_gene, gene_mean, linewidth=2, label='Decoded (mean)')
ax.fill_between(x_time_gene, gene_mean - gene_std, gene_mean + gene_std, alpha=0.2)
ax.scatter(orig_mean.index, orig_mean.values, s=20, label='Original (mean)', zorder=3)
ax.set_xlabel('Normalised time')
ax.set_ylabel('Expression')
ax.set_title(f'Gene Expression Trajectory: {interest_gene}')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 10. Highlight a Specific Trajectory

Select the trajectory whose endpoint is closest to a target coordinate in latent space.

In [None]:
# Set your target endpoint in GAGA latent space
target_x, target_y = -1.2, 0

# Find closest trajectory endpoint
endpoints = traj_pts[-1, :, :]  # (n_traj, 2) — last time step
dists     = np.sqrt((endpoints[:, 0] - target_x)**2 + (endpoints[:, 1] - target_y)**2)
highlight = int(dists.argmin())
print(f'Selected trajectory {highlight} at endpoint '
      f'({endpoints[highlight, 0]:.4f}, {endpoints[highlight, 1]:.4f})')

# Plot
fig, ax = plt.subplots(figsize=(10, 8))
ax.scatter(true_data[:, 0], true_data[:, 1],
           c=all_times, cmap='viridis', s=1, alpha=0.5)

for i in range(mf.n_trajectories):
    traj  = traj_pts[:, i, :]
    color = 'red'  if i == highlight else 'black'
    lw    = 1.0    if i == highlight else 0.4
    alpha = 1.0    if i == highlight else 0.4
    ax.plot(traj[:, 0], traj[:, 1], color=color, lw=lw, alpha=alpha)
    ax.annotate('', xy=(traj[-1, 0], traj[-1, 1]), xytext=(traj[-2, 0], traj[-2, 1]),
                arrowprops=dict(arrowstyle='->', color=color, lw=lw, mutation_scale=10))

ax.plot(target_x, target_y, 'r*', markersize=12, label='Target point')
ax.set_title(f'Highlighted trajectory {highlight}')
ax.legend()
plt.tight_layout()
plt.show()

In [None]:
# Gene trends for the single highlighted trajectory
decoded_highlight = decoded_example_gene[:, highlight, :]  # (n_bins, n_selected_genes)

fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 4 * n_rows))
axes = axes.flatten()

for i, gene in enumerate(example_genes):
    ax = axes[i]
    ax.plot(x_time, decoded_highlight[:, i], label=f'Traj {highlight}')
    ax.scatter(data_mean.index, data_mean[gene], s=10, label='Original (mean)')
    ax.set_title(gene, fontsize=8)
    ax.legend(fontsize=7)

for i in range(n_genes, len(axes)):
    axes[i].set_visible(False)

plt.suptitle(f'Gene Trends — Trajectory {highlight}', fontsize=18)
plt.tight_layout()
plt.subplots_adjust(top=0.94)
plt.show()