In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# ----------------------------
# 1) Pretrained diffusion backbone (score/v-pred function)
# ----------------------------
class PretrainedDiffusion(nn.Module):
    def __init__(self, unet, sigma_min=0.001, sigma_max=1.0):
        super().__init__()
        self.unet = unet
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max

    @torch.no_grad()
    def score(self, x_t, t):
        """
        Return diffusion score ∇x log p_t(x).
        If your backbone predicts v (SD v-pred), convert to score.
        This example uses a placeholder mapping; replace with your backbone’s exact formula.
        Shapes:
          - x_t: (B, C, H, W)
          - t:   (B,)
        """
        v = self.unet(x_t, t)  # (B, C, H, W)
        # Per-sample sigma -> broadcast to image shape
        sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t           # (B,)
        sigma = sigma.view(-1, 1, 1, 1)                                           # (B, 1, 1, 1)
        score = -v / (sigma**2 + 1e-8)
        return score

# ----------------------------
# 2) Flow matching vector field model
# ----------------------------
class FlowField(nn.Module):
    def __init__(self, net):
        super().__init__()
        self.net = net  # any U-Net/Transformer that outputs a vector field

    def forward(self, x_t, t):
        """
        Return vector field v_theta(x_t, t). Used by ODE: dx/dt = v_theta(x_t, t).
        Shapes:
          - x_t: (B, C, H, W)
          - t:   (B,)
        """
        return self.net(x_t, t)

# ----------------------------
# 3) Noise schedule and sampling utilities
# ----------------------------
def sample_t(batch_size, device):
    # Uniform time in [0, 1]; can use importance schedules too.
    return torch.rand(batch_size, device=device)

def add_noise(x0, t, sigma_min=0.001, sigma_max=1.0):
    """
    Construct x_t by adding noise according to a monotone schedule sigma(t).
    Ensures sigma has broadcastable shape (B,1,1,1).
    """
    sigma_scalar = sigma_min * (sigma_max / sigma_min) ** t               # (B,)
    sigma = sigma_scalar.view(-1, 1, 1, 1)                                # (B,1,1,1)
    noise = torch.randn_like(x0)
    x_t = x0 + sigma * noise
    return x_t, sigma, noise

# ----------------------------
# 4) Loss: align flow with diffusion score
# ----------------------------
def flow_matching_loss(flow_v, diffusion_score):
    """
    A simple squared error alignment: v_theta ≈ score.
    """
    return ((flow_v - diffusion_score) ** 2).mean()

# ----------------------------
# 5) Training loop
# ----------------------------
def train_flow_matching(
    diffusion_backbone, flow_model, dataloader, device="cuda",
    lr=1e-4, epochs=1, sigma_min=0.001, sigma_max=1.0
):
    opt = optim.AdamW(flow_model.parameters(), lr=lr)
    flow_model.train()
    diffusion_backbone.eval()

    for epoch in range(epochs):
        for x0 in dataloader:
            # If dataloader yields tuples, unwrap
            if isinstance(x0, (tuple, list)):
                x0 = x0[0]
            x0 = x0.to(device)

            # Sample time and construct x_t
            t = sample_t(x0.size(0), device)                           # (B,)
            x_t, sigma, noise = add_noise(x0, t, sigma_min, sigma_max) # sigma: (B,1,1,1)

            # Diffusion score (frozen backbone)
            with torch.no_grad():
                score_t = diffusion_backbone.score(x_t, t)             # (B,C,H,W)

            # Flow field prediction
            flow_v = flow_model(x_t, t)                                 # (B,C,H,W)

            # Alignment loss
            loss = flow_matching_loss(flow_v, score_t)

            opt.zero_grad()
            loss.backward()
            opt.step()

        print(f"Epoch {epoch+1}: loss={loss.item():.4f}")

# ----------------------------
# 6) ODE sampling (Euler as simplest)
# ----------------------------
@torch.no_grad()
def sample_with_flow(flow_model, batch_size, shape, steps=32, device="cuda"):
    """
    Deterministic ODE sampling from t=1 -> 0: dx/dt = v_theta(x, t).
    Use more stable solvers (Heun/RK) for better quality.
    """
    x = torch.randn((batch_size, *shape), device=device)  # start from noise
    ts = torch.linspace(1.0, 0.0, steps+1, device=device)

    for i in range(steps):
        t_curr = ts[i].expand(batch_size)                 # (B,)
        v = flow_model(x, t_curr)                         # (B,C,H,W)
        dt = ts[i+1] - ts[i]
        x = x + v * dt

    return x

# ----------------------------
# 7) Wiring things together
# ----------------------------
if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Dummy nets; replace with actual UNet/Transformer implementations.
    class DummyUNet(nn.Module):
        def __init__(self, channels=3, hidden=64):
            super().__init__()
            self.net = nn.Sequential(
                nn.Conv2d(channels, hidden, 3, padding=1),
                nn.GELU(),
                nn.Conv2d(hidden, channels, 3, padding=1),
            )
        def forward(self, x, t):
            # Optionally condition on t via FiLM/embeddings; here we ignore t.
            return self.net(x)

    pretrained_unet = DummyUNet().to(device)
    diffusion_backbone = PretrainedDiffusion(pretrained_unet).to(device)

    flow_net = DummyUNet().to(device)
    flow_model = FlowField(flow_net).to(device)

    # Example dataset
    images = torch.randn(256, 3, 64, 64)  # replace with real images
    loader = DataLoader(images, batch_size=32, shuffle=True)

    # Train
    train_flow_matching(diffusion_backbone, flow_model, loader, device=device, epochs=2)

    # Sample
    samples = sample_with_flow(flow_model, batch_size=8, shape=(3, 64, 64), steps=64, device=device)
    print("Generated samples:", samples.shape)


Epoch 1: loss=120564872.0000
Epoch 2: loss=1202458496.0000
Generated samples: torch.Size([8, 3, 64, 64])
