In [None]:
import math
import numpy as np
import torch
import torch.utils.checkpoint
import torchvision
import matplotlib.pyplot as plt

from torchvision import transforms
from tqdm.auto import tqdm

from rectified_flow.rectified_flow import RectifiedFlow
from rectified_flow.utils import match_dim_with_data

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [None]:
batch_size = 512

transform_list = [
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
]
train_dataset = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transforms.Compose(transform_list)
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,          
    pin_memory=True,
    persistent_workers=False,
)

batch = next(iter(train_dataloader))
print(batch[0].shape)  # torch.Size([256, 1, 28, 28])

In [None]:
model_type = "unet"
from rectified_flow.models.enhanced_mlp import VarMLP
from rectified_flow.models.utils import EMAModel
from rectified_flow.models.unet import SongUNet, SongUNetConfig

flow_model = SongUNet.from_pretrained("/scratch/10992/liaorunlong93/random_flow_toys/checkpoints/flow_mnist_unet_unconditional", use_ema=True).to(device)

In [None]:
from rectified_flow.samplers import EulerSampler
from rectified_flow.utils import plot_cifar_results

rf = RectifiedFlow(
    data_shape=(1, 28, 28),
    velocity_field=flow_model,
    device=device,
)

sampler = EulerSampler(rectified_flow=rf, num_samples=50, num_steps=200)
x_1 = sampler.sample_loop().trajectories[-1]

plot_cifar_results(x_1)

## Exact Jacobian calculation with autograd

In [None]:
from tqdm import tqdm
from torch.autograd.functional import jacobian

def divergence_exact_manual(v_func, x_t, t):
    device, dtype = x_t.device, x_t.dtype
    B = x_t.shape[0]
    div = torch.empty(B, device=device, dtype=dtype)

    for b in tqdm(range(B), desc="Computing exact divergence"):
        xb = x_t[b:b+1].detach().requires_grad_(True)
        tb = t[b:b+1]

        def f(inp):
            return v_func(inp, tb)

        J = jacobian(f, xb, vectorize=False, create_graph=False)
        D = xb.numel()
        div[b] = J.reshape(D, D).diagonal().sum().to(dtype)
        del J
        
    return div

### Divergence Estimation via the Hutchinson Trick

For a vector field $v(x,t)$ with Jacobian $J(x) = \dfrac{\partial v}{\partial x}$, the divergence is
$$
\operatorname{div} v(x,t) = \operatorname{tr}\, J(x).
$$
Computing $\operatorname{tr} J$ via an explicit Jacobian is expensive. The Hutchinson identity gives an unbiased estimator:
$$
\operatorname{tr} J = \mathbb{E}_{\varepsilon}\!\left[\varepsilon^\top J \varepsilon\right],
$$
where $\varepsilon$ can be Rademacher ($\pm1$ w.p. 1/2) or standard Gaussian. A Monte Carlo estimate is
$$
\widehat{\operatorname{div}}(x,t) = \frac{1}{M}\sum_{i=1}^{M} \varepsilon_i^\top (J\,\varepsilon_i),
$$
which only requires Jacobian–vector products (JVP) or vector–Jacobian products (VJP), avoiding materializing the full Jacobian.

**Implementation notes**
- `eps_dist="rademacher"` often yields slightly lower variance; `"normal"` works too.
- Use forward-mode `jvp` to get $J\varepsilon$, or a single backward (VJP) to get $J^\top \varepsilon$ and then dot with $\varepsilon$.
- Compute per-sample inner products in the batch and average across $M$ probes; `M=1–4` is typically sufficient.
- When using fp16/bf16, accumulate in fp32 for numerical stability.


In [None]:
from torch.func import jvp as jvp_func

def divergence_hutchinson(
    v_func,
    x_t: torch.Tensor,
    t: torch.Tensor,
    n_samples: int = 1,
    eps_dist: str = "rademacher",
    method: str = "jvp",
    generator: torch.Generator | None = None,
) -> torch.Tensor:
    device, dtype = x_t.device, x_t.dtype
    B = x_t.shape[0]
    D = x_t.numel() // B

    x_flat = x_t.detach().view(B, D)

    acc_dtype = torch.float32 if dtype in (torch.float16, torch.bfloat16) else dtype
    acc = torch.zeros(B, device=device, dtype=acc_dtype)

    def _sample_eps_flat():
        if eps_dist == "rademacher":
            r = torch.randint(0, 2, (B, D), device=device, generator=generator)
            eps = (r * 2 - 1).to(acc_dtype)
        elif eps_dist == "normal":
            eps = torch.randn((B, D), device=device, dtype=acc_dtype, generator=generator)
        else:
            raise ValueError("eps_dist must be 'rademacher' or 'normal'")
        return eps.to(dtype)

    if method == "jvp":
        def f_flat(inp_flat):
            x4 = inp_flat.view(B, *x_t.shape[1:])
            y4 = v_func(x4, t)
            return y4.view(B, D)

        for _ in range(n_samples):
            eps_flat = _sample_eps_flat()
            _, jvp_out = jvp_func(f_flat, (x_flat,), (eps_flat,), strict=False)  # <-- fix
            acc += (eps_flat * jvp_out).sum(dim=1).to(acc_dtype)

        return (acc / n_samples).to(dtype)

    elif method == "vjp":
        for _ in range(n_samples):
            eps = _sample_eps_flat().view_as(x_t).to(dtype)
            x_req = x_t.detach().clone().requires_grad_(True)
            v = v_func(x_req, t)
            s_per_sample = (v * eps).view(B, -1).sum(dim=1)
            gx = torch.autograd.grad(s_per_sample.sum(), x_req, create_graph=False, retain_graph=False)[0]
            acc += (gx * eps).view(B, -1).sum(dim=1).to(acc_dtype)

        return (acc / n_samples).to(dtype)

    else:
        raise ValueError("method must be 'jvp' or 'vjp'")

In [None]:
x_1 = batch[0].to(device)[:5]
t = torch.rand(x_1.shape[0], device=device)
x_t = t[:, None, None, None] * x_1 + (1 - t)[:, None, None, None] * torch.randn_like(x_1)

div_exact = divergence_exact_manual(flow_model, x_t, t)
div_approx = divergence_hutchinson(flow_model, x_t, t, n_samples=10, method="vjp")

print("Divergence shape:", div_exact.shape)
print("div exact:", div_exact)
print("div approx:", div_approx)
print("Max absolute error:", (div_exact - div_approx).abs().max().item())