# Continuous Normalizing Flow (Diffsol)

This tutorial is inspired by `examples/integration/cnf/mnist_cnf.py`. To keep execution
fast on CPU, we train on random noise instead of full MNIST, but the structure is the
same: integrate a 1-D CNF, compute log-probabilities, and update parameters using
diffsol's reverse-mode gradients.


In [None]:
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from diffsol_pytorch import DiffsolModule, reverse_mode
from helpers import describe_device, gpu_section_mode, preferred_device, save_cached_json, seed_everything

In [None]:
seed_everything(0)
device_target = preferred_device()
print(f"Using device: {describe_device(device_target)}")


In [None]:
if device_target.type != 'cuda':
    print('CUDA not available; skipping GPU-only experiments. This demo uses CPU-friendly noise.')
else:
    print(f'Running CNF demo on {torch.cuda.get_device_name(0)}')


In [None]:
CNF_CODE = '''

in = [a, b, c]
a { 0.1 }
b { 0.1 }
c { 0.0 }
u {
    z = 0.0,
}
F {
    a * z * z * z + b * z + c,
}
'''

times = torch.linspace(0.0, 1.0, 41, dtype=torch.float64)
times_list = times.tolist()
module = DiffsolModule(CNF_CODE)


In [None]:
class DiffsolCNF(nn.Module):
    def __init__(self):
        super().__init__()
        self.params = nn.Parameter(torch.randn(3, dtype=torch.float64) * 0.01)

    def forward(self, z0: torch.Tensor) -> torch.Tensor:
        outputs = []
        for rate in z0.detach().cpu().tolist():
            _, _, flat = module.solve_dense(self.params.detach().tolist(), times_list)
            outputs.append(flat[-1])
        return torch.tensor(outputs, dtype=z0.dtype, device=z0.device)

    def log_prob(self, z0: torch.Tensor) -> torch.Tensor:
        zT = self.forward(z0)
        return -0.5 * (zT**2).mean()

    def backward(self, grad_scalar: float):
        grad_out = [grad_scalar] * len(times_list)
        grads = reverse_mode(
            CNF_CODE,
            self.params.detach().tolist(),
            times_list,
            grad_out,
        )
        self.params.grad = torch.tensor(grads[:3], dtype=self.params.dtype)


In [None]:
model = DiffsolCNF().to(device_target)
optimizer = optim.Adam([model.params], lr=1e-2)
loss_history = []
for step in range(50):
    noise = torch.randn(64, 1, dtype=torch.float64, device=device_target)
    loss = -model.log_prob(noise)
    optimizer.zero_grad()
    model.backward(float(loss.item()))
    optimizer.step()
    loss_history.append(loss.item())
loss_history[:5]

In [None]:
mode, cached_metrics = gpu_section_mode("CNF GPU benchmark", cache_key="cnf_gpu_metrics.json")
if mode == "run":
    batch = torch.randn(256, 1, dtype=torch.float64, device=device_target)
    probe = DiffsolCNF().to(device_target)
    if device_target.type == "cuda":
        torch.cuda.synchronize()
    start = time.perf_counter()
    with torch.inference_mode():
        forward = probe(batch)
        _ = probe.log_prob(batch)
        if device_target.type == "cuda":
            torch.cuda.synchronize()
    metrics = {
        "device": describe_device(device_target),
        "batch_size": int(batch.shape[0]),
        "variance": float(forward.var().item()),
    }
    save_cached_json("cnf_gpu_metrics.json", metrics)
elif mode == "cache":
    metrics = cached_metrics
else:
    metrics = {
        "device": "cpu",
        "note": "CNF GPU benchmark skipped; run on a CUDA kernel to refresh cache.",
    }
metrics

In [None]:
import matplotlib.pyplot as plt
plt.plot(loss_history)
plt.xlabel('Iteration')
plt.ylabel('Negative log-prob')
plt.title('Diffsol CNF training (noise demo)')
plt.show()
