## fine-tuning a flow model with a physics constraint.

train a basic flow matching model on 2D Gaussian data, then fine-tune it so generated samples have a specific mean. In HEP, some simulator that generates events given some theory parameters theta, but no explicit likelihood p(d | theta), train fm model as conditional density estimator for posterior p(theta | obs data), then use rsample to fine-tune flow. 


In [None]:
import torch
import torch.nn as nn
import nami

class ToyField(nn.Module):
    def __init__(self, dim=2, hidden=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim + 1, hidden), nn.SiLU(),
            nn.Linear(hidden, hidden), nn.SiLU(),
            nn.Linear(hidden, dim),
        )

    @property
    def event_ndim(self):
        return 1

    def forward(self, x, t, c=None):
        t_exp = t.unsqueeze(-1).expand(*x.shape[:-1], 1)
        return self.net(torch.cat([x, t_exp], dim=-1))

dim = 2
field = ToyField(dim)
base = nami.StandardNormal(event_shape=(dim,))
solver = nami.RK4(steps=32)
optimizer = torch.optim.Adam(field.parameters(), lr=1e-3)

# 2D Gaussian centered at (3, 3)
data_mean = torch.tensor([3.0, 3.0])

print("P1: standard flow matching training")
for step in range(500):
    x_target = data_mean + 0.5 * torch.randn(256, dim)
    x_source = torch.randn_like(x_target)
    loss = nami.fm_loss(field, x_target, x_source)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (step + 1) % 100 == 0:
        print(f"  Step {step+1}, FM loss: {loss.item():.4f}")

fm = nami.FlowMatching(field, base, solver, event_ndim=1)
process = fm(None)
with torch.no_grad():
    pre_samples = process.sample((500,))
print(f"\nBefore fine-tuning: sample mean = {pre_samples.mean(0).tolist()}")

# fine-tune
target_mean = torch.tensor([5.0, 5.0])
ft_optimiser = torch.optim.Adam(field.parameters(), lr=5e-4)

print("\nP2: Fine-tuning (via rsample)")
for step in range(200):
    process = fm(None)
    
    # rsample lets gradients flow back into field
    samples = process.rsample((64,))
    # push the generated distributions mean toward target
    physics_loss = (samples.mean(0) - target_mean).pow(2).sum()
    
    ft_optimiser.zero_grad()
    physics_loss.backward()
    ft_optimiser.step()
    
    if (step + 1) % 50 == 0:
        with torch.no_grad():
            check = process.sample((500,))
        print(f"  Step {step+1}, physics loss: {physics_loss.item():.4f}, "
              f"sample mean: {check.mean(0).tolist()}")

with torch.no_grad():
    post_samples = process.sample((500,))
print(f"\nAfter fine-tuning: sample mean = {post_samples.mean(0).tolist()}")
print(f"Target mean:                     {target_mean.tolist()}")

P1: standard flow matching training
  Step 100, FM loss: 1.5353
  Step 200, FM loss: 0.9963
  Step 300, FM loss: 0.9572
  Step 400, FM loss: 1.0478
  Step 500, FM loss: 0.9572

Before fine-tuning: sample mean = [3.038644313812256, 3.0558784008026123]

P2: Fine-tuning (via rsample)
  Step 50, physics loss: 0.0030, sample mean: [4.858842849731445, 4.845598220825195]
  Step 100, physics loss: 0.0294, sample mean: [5.056225299835205, 5.076822757720947]
  Step 150, physics loss: 0.0496, sample mean: [4.948419570922852, 4.893041610717773]
  Step 200, physics loss: 0.0748, sample mean: [5.01827335357666, 5.036735534667969]

After fine-tuning: sample mean = [5.059838771820068, 5.076112747192383]
Target mean:                     [5.0, 5.0]
