# Stochastic fm loss

test `nami.stochastic_fm_loss` with the class-based gamma schedules and verify that `ZeroGamma()` recovers the deterministic `nami.fm_loss`.

In [1]:
import torch
from torch import nn
import nami

In [2]:
torch.manual_seed(7)
class TinyField(nn.Module):
    def __init__(self, dim: int, hidden: int = 16):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim + 1, hidden),
            nn.SiLU(),
            nn.Linear(hidden, dim),
        )
    @property
    def event_ndim(self) -> int:
        return 1
    def forward(self, x: torch.Tensor, t: torch.Tensor, c=None) -> torch.Tensor:
        _ = c
        t_exp = t.unsqueeze(-1).expand(*x.shape[:-1], 1)
        return self.net(torch.cat([x, t_exp], dim=-1))
field = TinyField(dim=4)

In [3]:
batch, dim = 8, 4
x_target = torch.randn(batch, dim)
x_source = torch.randn(batch, dim)
t = torch.rand(batch)
z = torch.randn_like(x_target)

In [4]:
loss_det = nami.fm_loss(
    field,
    x_target,
    x_source,
    t=t,
    reduction="mean",
)

loss_stoch = nami.stochastic_fm_loss(
    field,
    x_target,
    x_source,
    t=t,
    gamma=nami.BrownianGamma(),
    z=z,
    reduction="mean",
)

print(f"deterministic fm_loss: {loss_det.item():.6f}")
print(f"stochastic_fm_loss (Brownian): {loss_stoch.item():.6f}")

deterministic fm_loss: 2.340302
stochastic_fm_loss (Brownian): 4.018364


In [5]:
loss_det_none = nami.fm_loss(
    field,
    x_target,
    x_source,
    t=t,
    reduction="none",
)

loss_zero_none = nami.stochastic_fm_loss(
    field,
    x_target,
    x_source,
    t=t,
    gamma=nami.ZeroGamma(),
    z=z,
    reduction="none",
)

print("ZeroGamma matches deterministic fm_loss:", torch.allclose(loss_det_none, loss_zero_none, atol=1e-6))

ZeroGamma matches deterministic fm_loss: True


In [6]:
optimizer = torch.optim.Adam(field.parameters(), lr=1e-3)

for step in range(3):
    loss = nami.stochastic_fm_loss(
        field,
        x_target,
        x_source,
        gamma=nami.ScaledBrownianGamma(scale=1.5),
        reduction="mean",
    )
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(f"step={step} loss={loss.item():.6f}")

step=0 loss=6.202804
step=1 loss=5.009123
step=2 loss=3.986613
