# Single-cell Time Series Interpolation

This notebook is taken with some edits from the repo https://github.com/atong01/conditional-flow-matching/tree/main

Note that to run this notebook you well need the `ebdata_v3.h5ad` data object, which is accessible at https://data.mendeley.com/datasets/hhny5ff7yj/1

### Part 1:

In order to run flow matching, you need a choice of conditional probability path and conditional velocity that solve the continuity equation: $$\partial_t p_t = - \nabla \cdot (v_t p_t)$$

The most common choice is conditioning on input points $x_0$ and $x_1$, of the form: $$p_t(x | x_0, x_1) = \mathcal{N}(x; (1-t)x_0 + tx_1, \sigma^2)$$ $$v_t(x | x_0, x_1) = x_1 - x_0$$

First of all, before implementing flow matching, prove that these choices actually do satisfy the continuity equation.

Hint: divide by $p_t$ and work with $\log p_t$ instead

### Part 2:

We will take the skeleton of the flow matching implementation given at https://github.com/atong01/conditional-flow-matching/tree/main and fill in the details given the above parameterization.  

Fill in the functions with 'pass' instead of 'return' below, and then try using FlowMatcher (which pairs points with the product coupling) and OTFlowMatcher (which pairs points using the coupling from minibatch OT).  Observe the difference in the trajectories on the PHATE embedding, which is a 2D representation of single cell data that tries to preserve geometry better than a UMAP.  

Why aren't the paths in the OT flow matcher exactly straight?  What happens as you vary the batch size?

In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
    
import matplotlib.pyplot as plt
import numpy as np
import scanpy as sc
import torch
import torchsde
from torchdyn.core import NeuralODE
from tqdm import tqdm

from torchcfm.conditional_flow_matching import *
from torchcfm.models import MLP
from torchcfm.utils import plot_trajectories, torch_wrapper

In [None]:
adata = sc.read_h5ad("data/ebdata_v3.h5ad")
adata

In [None]:
sc.pl.scatter(adata, basis="phate", color="sample_labels")

In [None]:
n_times = len(adata.obs["sample_labels"].unique())
# Standardize coordinates
coords = adata.obsm["X_phate"]
coords = (coords - coords.mean(axis=0)) / coords.std(axis=0)
adata.obsm["X_phate_standardized"] = coords
X = [
    adata.obsm["X_phate_standardized"][adata.obs["sample_labels"].cat.codes == t]
    for t in range(n_times)
]

In [None]:
import math
import warnings
from typing import Union

import torch

from torchcfm.optimal_transport import OTPlanSampler

In [None]:
class FlowMatcher:

    def __init__(self, sigma: Union[float, int] = 0.0):
        
        self.sigma = sigma

    def compute_mu_t(self, x0, x1, t):
        """
        Compute the mean of the probability path N(t * x1 + (1 - t) * x0, sigma**2)
        """
        pass

    def sample_xt(self, x0, x1, t):
        """
        Draw a sample from the probability path N(t * x1 + (1 - t) * x0, sigma**2)
        """
        pass

    def compute_conditional_flow(self, x0, x1, t, xt):
        """
        Compute the conditional vector field ut(x1|x0) = x1 - x0
        """
        pass

    def sample_location_and_conditional_flow(self, x0, x1, t=None):
        """
        Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma**2))
        and the conditional vector field ut(x1|x0) = x1 - x0
        """
        
        t = torch.rand(x0.shape[0]).type_as(x0)

        xt = self.sample_xt(x0, x1, t)
        ut = self.compute_conditional_flow(x0, x1, t, xt)

        return t, xt, ut

In [None]:
class OTFlowMatcher(FlowMatcher):

    def __init__(self, sigma: Union[float, int] = 0.0):

        super().__init__(sigma)
        self.ot_sampler = OTPlanSampler(method="exact")

    def sample_location_and_conditional_flow(self, x0, x1):
        r"""
        Compute the sample xt (drawn from N(t * x1 + (1 - t) * x0, sigma))
        and the conditional vector field ut(x1|x0) = x1 - x0
        with respect to the minibatch OT plan $\Pi$.
        """
        x0, x1 = self.ot_sampler.sample_plan(x0, x1)
        return super().sample_location_and_conditional_flow(x0, x1, t)

In [None]:
##############################################################################################

In [None]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
batch_size = 256
sigma = 0.1
dim = 2
ot_cfm_model = MLP(dim=dim, time_varying=True, w=64).to(device)
ot_cfm_optimizer = torch.optim.Adam(ot_cfm_model.parameters(), 1e-4)

FM = None
# FM = FlowMatcher(sigma=sigma)
# FM = OTFlowMatcher(sigma=sigma)

In [None]:
label_dict = {'Day 00-03': 0, 'Day 06-09': 1, 'Day 12-15': 2, 'Day 18-21': 3, 'Day 24-27': 4}
labels = [label_dict[x] for x in adata.obs["sample_labels"].tolist()]

In [None]:
def plot_trajectories(traj):
    n = 2000
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    ax.scatter(
        adata.obsm["X_phate_standardized"][:, 0],
        adata.obsm["X_phate_standardized"][:, 1],
        c=labels,
    )
    ax.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.4, alpha=0.1, c="olive")

    for i in range(15):
        ax.plot(traj[:, i, 0], traj[:, i, 1], alpha=0.9, c="red")


def get_batch(FM, X, batch_size, n_times):
    """Construct a batch with point sfrom each timepoint pair"""
    ts = []
    xts = []
    uts = []
    noises = []
    for t_start in range(n_times - 1):
        x0 = (
            torch.from_numpy(X[t_start][np.random.randint(X[t_start].shape[0], size=batch_size)])
            .float()
            .to(device)
        )
        x1 = (
            torch.from_numpy(
                X[t_start + 1][np.random.randint(X[t_start + 1].shape[0], size=batch_size)]
            )
            .float()
            .to(device)
        )
        t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
        
        ts.append(t + t_start)
        xts.append(xt)
        uts.append(ut)
    t = torch.cat(ts)
    xt = torch.cat(xts)
    ut = torch.cat(uts)
    return t, xt, ut

## OT-CFM

In [None]:
for i in tqdm(range(10000)):
    ot_cfm_optimizer.zero_grad()
    t, xt, ut = get_batch(FM, X, batch_size, n_times)
    vt = ot_cfm_model(torch.cat([xt, t[:, None]], dim=-1))
    loss = torch.mean((vt - ut) ** 2)
    loss.backward()
    ot_cfm_optimizer.step()

In [None]:
node = NeuralODE(torch_wrapper(ot_cfm_model), solver="dopri5", sensitivity="adjoint")
with torch.no_grad():
    traj = node.trajectory(
        torch.from_numpy(X[0][:1000]).float().to(device),
        t_span=torch.linspace(0, n_times - 1, 400),
    ).cpu()

In [None]:
plot_trajectories(traj.cpu().numpy())