In [None]:
!pip -q install torchdyn
!pip -q install torchcfm

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.1/58.1 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m81.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m79.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m48.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m39.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import os
import copy
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image
from tqdm import trange
import wandb

from torchdyn.core import NeuralODE
from torchcfm.models.unet.unet import UNetModelWrapper

In [None]:
def warmup_lr(step, WARMUP=5000):
    """Simple warmup schedule from 0 to initial LR over first WARMUP steps."""
    return min(step, WARMUP) / WARMUP

def infiniteloop(dataloader):
    """Creates an infinite iterator over a given dataloader."""
    while True:
        for x, _ in dataloader:
            yield x

@torch.no_grad()
def generate_samples_euler(model,
                           savedir="./results_mnist/",
                           step_=0,
                           total_steps=100,
                           net_="normal",
                           plot=False):
    """
    Generate and save 32 samples using a simple Euler iteration from t=0 to t=1.
    Article reference:
      x ∼ N(0, I)
      d ← 1/M
      t ← 0
      for n in [0..M−1]:
         x ← x + sθ(x, t, d)*d
         t ← t + d
      return x
    """
    model.eval()

    # 1) Sample random Gaussian as our initial x
    B = 32
    x = torch.randn(B, 1, 28, 28, device=device)

    # 2) Set up step size d = 1 / total_steps
    dt = 1.0 / total_steps
    t = 0.0

    # 3) Simple Euler loop
    for _ in range(total_steps):
        # Compute the velocity/score sθ(x, t, d).
        s = model(t, x, d=dt)

        # Update x ← x + sθ(x, t, d)*d
        x = x + s * dt

        # Update t ← t + d
        t += dt

    # 4) Post-process: clamp & shift from [-1..1] to [0..1]
    x_gen = x.clamp(-1, 1)
    x_gen = x_gen / 2 + 0.5

    # 5) Save out
    os.makedirs(savedir, exist_ok=True)
    img_path = f"{savedir}/{net_}_generated_EULER_step_{step_}_ts_{total_steps}.png"
    save_image(x_gen, img_path, nrow=8)

    if plot:
        import matplotlib.pyplot as plt
        from torchvision.utils import make_grid
        grid = make_grid(x_gen, nrow=8)
        plt.figure(figsize=(8, 8))
        plt.imshow(grid.permute(1,2,0).cpu().numpy(), cmap="gray")
        plt.title(f"MNIST Euler Gen: {net_} | step={step_}")
        plt.axis("off")
        plt.show()

    model.train()

def ema(source, target, decay=0.9999):
    """EMA update of target's params from source."""
    source_dict = source.state_dict()
    target_dict = target.state_dict()
    for k in source_dict.keys():
        target_dict[k].data.copy_(
            target_dict[k].data * decay + source_dict[k].data * (1 - decay)
        )


In [None]:
############################
# 2) MNIST Data Loader
############################
def get_mnist_dataloader(batch_size=64):
    transform = transforms.Compose([
        # NO RESIZE -> keep 28×28
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # Single-channel normalization
    ])

    mnist_train = datasets.MNIST(
        root="./data_mnist",
        train=True,
        transform=transform,
        download=True
    )
    loader = DataLoader(
        mnist_train,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        drop_last=True,
        pin_memory=True
    )
    return loader

In [None]:
###########################
# 3) Define U-Net for 1×28×28
############################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import torch.nn as nn
from torchcfm.models.unet.unet import UNetModelWrapper
from torchcfm.models.unet.nn import timestep_embedding


class UNetModelWrapperWithD(UNetModelWrapper):
    """
    A subclass of `UNetModelWrapper` that can optionally take an extra param 'd' and
    incorporate it into the time embedding. If d=None, it behaves exactly like the original model.
    """

    def forward(self, t, x, d=None, y=None, *args, **kwargs):
        """
        :param t: timesteps, shape [B], or [B, 1].
        :param x: the input image tensor, shape [B, C, H, W].
        :param d: optional extra scalar(s), shape [B], [B,1], etc. If None, we skip it.
        :param y: optional class labels (if model is class-conditional).
        :return: an output image of shape [B, C, H, W].
        """

        #1) Flatten out the timesteps if needed 
        timesteps = t
        while timesteps.dim() > 1:
            timesteps = timesteps[:, 0]
        if timesteps.dim() == 0:
            timesteps = timesteps.repeat(x.shape[0])

        # 2) Timestep embedding
        emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

        # 3) If class-conditional, add label embedding 
        if self.num_classes is not None:
            # parent class requires y if the model is class-cond
            assert (y is not None), "You must pass 'y' if the model is class-conditional"
            emb = emb + self.label_emb(y)

        #4) If d is provided, embed it and add to emb
        if d is not None:
            d_ = d
            while d_.dim() > 1:
                d_ = d_[:, 0]
            if d_.dim() == 0:
                d_ = d_.repeat(x.shape[0])

            d_emb = self.time_embed(timestep_embedding(d_, self.model_channels))
            emb = emb + d_emb 


        h = x.type(self.dtype)
        hs = []
        for module in self.input_blocks:
            h = module(h, emb)
            hs.append(h)

        h = self.middle_block(h, emb)

        for module in self.output_blocks:
            h = torch.cat([h, hs.pop()], dim=1)
            h = module(h, emb)

        h = h.type(x.dtype)
        out = self.out(h)
        return out



net_model = UNetModelWrapperWithD(
    dim=(1, 28, 28),      # 1 channel, 28×28
    num_res_blocks=2,
    num_channels=128,      
    channel_mult=[1, 2],  # just one downsampling: 28->14
    num_heads=4,
    num_head_channels=64,
    attention_resolutions="14",
    dropout=0.05,
).to(device)


ema_model = copy.deepcopy(net_model)

# Print model size
model_size = sum(p.numel() for p in ema_model.parameters())
print("Model params: %.2f M" % (model_size / 1024 / 1024))


Model params: 15.19 M


In [None]:
############################
# 4) Flow Matching Trainer
############################
import torch.nn as nn

def pick_discrete_steps(batch_size, device):
    """
    Returns a tensor of shape [batch_size] containing
    discrete step sizes from the set {1/2, 1/4, 1/8, 1/16, 1/32, 1/64, 1/128}.
    Each is sampled with equal probability.
    """
    # 1) define the set: [2,4,8,16,32,64,128]
    #    then invert => [1/2,1/4,...,1/128].
    powers_of_two = torch.tensor([2,4,8,16,32,64,128], device=device, dtype=torch.float)
    possible_steps = 1.0 / powers_of_two  # shape [7]

    # 2) pick random indices in [0..6] for each batch element
    #    uniform discrete distribution
    random_inds = torch.randint(0, len(possible_steps), (batch_size,), device=device)
    # 3) gather the step sizes
    selected_steps = possible_steps[random_inds]  # shape [batch_size]

    return selected_steps


def train_flowmatching_shortcut(
    net_model,
    ema_model,
    TOTAL_STEPS=50000,
    BATCH_SIZE=64,
    LR=2e-4,
    GRAD_CLIP=1.0,
    SAVE_STEP=1000,
    WARMUP=5000,
    # fraction of each batch used for standard flow matching (d=0)
    FRACTION_K=0.75
):
    loader = get_mnist_dataloader(batch_size=BATCH_SIZE)
    data_iter = infiniteloop(loader)

    # Basic setup: optimizer & scheduler
    optimizer = torch.optim.AdamW(net_model.parameters(), lr=LR, weight_decay=0.1)
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=lambda s: warmup_lr(s, WARMUP=WARMUP)
    )

    # MSE for velocity fields
    loss_fn = nn.MSELoss()

    # Close previous W&B run if present
    if wandb.run is not None:
        wandb.finish()
    wandb.init(project="FM-Shortcut", config={
        "batch_size": BATCH_SIZE,
        "lr": LR,
        "frac_k": FRACTION_K,
        "total_steps": TOTAL_STEPS,
    })

    # Main training loop
    pbar = trange(TOTAL_STEPS, desc="ShortcutTrainer", dynamic_ncols=True)
    for step in pbar:
        optimizer.zero_grad()

        # ====== 1) Load data batch ======
        x_real = next(data_iter).to(device)   # shape: [B,1,28,28] for MNIST
        B_ = x_real.size(0)

        # ====== 2) Sample x_0 ~ N(0,I), sample random t in [0..1] ======
        x_noise = torch.randn_like(x_real)
        t_sample = torch.rand(B_, device=device)
        t_4d = t_sample.view(B_,1,1,1)
        # interpolation
        x_t = (1 - t_4d)*x_noise + t_4d*x_real

        # ====== 3) Pick discrete step sizes {1/2,1/4,1/8,...} + set fraction K to zero ======
        d_vals = pick_discrete_steps(B_, device=device)
        K = int(B_ * FRACTION_K)
        if K>0:
            d_vals[:K] = 0.0  # the first K items do standard FM

        # ====== 4) Flow Matching Loss for the first K items (d=0) ======
        fm_loss = 0.0
        if K>0:
            pred_vel_fm = net_model(t_sample[:K], x_t[:K], d=d_vals[:K])
            target_vel_fm = (x_real[:K] - x_noise[:K])  # x_1 - x_0
            fm_loss = loss_fn(pred_vel_fm, target_vel_fm)

        # ====== 5) Self-Consistency for items K..B-1 (where d>0) ======
        consistency_loss = 0.0
        if K < B_:
            idx = torch.arange(K, B_, device=device)
            t_cons = t_sample[idx]
            x_t_cons = x_t[idx]
            d_cons = d_vals[idx]

            # Step 1: s_t
            s_t = net_model(t_cons, x_t_cons, d=d_cons)

            # x_{t+d} = x_t + s_t*d
            x_next = x_t_cons + s_t * d_cons.view(-1,1,1,1)
            # clamp time at <=1
            t_next = torch.clamp(t_cons + d_cons, max=1.0)

            # Step 2: s_{t+d}
            s_t_next = net_model(t_next, x_next, d=d_cons)

            # Step 3: average => s_target
            s_target = 0.5*(s_t + s_t_next)

            # Step 4: predict net_model(t, x_t, 2*d)
            pred_s = net_model(t_cons, x_t_cons, d=2.0*d_cons)

            consistency_loss = loss_fn(pred_s, s_target)

        total_loss = fm_loss + consistency_loss
        total_loss.backward()

        # grad clip
        nn.utils.clip_grad_norm_(net_model.parameters(), GRAD_CLIP)
        optimizer.step()
        scheduler.step()

        # EMA update
        ema(net_model, ema_model)

        # Logging
        wandb.log({
            "step": step,
            "fm_loss": float(fm_loss),
            "consistency_loss": float(consistency_loss),
            "total_loss": float(total_loss),
        })
        pbar.set_postfix({
            "fm": f"{fm_loss:.4f}",
            "cons": f"{consistency_loss:.4f}",
            "loss": f"{total_loss:.4f}"
        })

        # Periodic save & sample
        if SAVE_STEP > 0 and (step % SAVE_STEP == 0) and (step>0):
            os.makedirs("./checkpoints_mnist/", exist_ok=True)

            generate_samples(
                net_model,
                step_=step,
                savedir="./checkpoints_mnist/img/normal/",
                total_steps=4,
                net_="normal"
            )
            generate_samples(
                ema_model,
                step_=step,
                savedir="./checkpoints_mnist/img/ema/",
                total_steps=4,
                net_="ema"
            )

            ckpt_path = f"./checkpoints_mnist/fm_mnist_step_{step}.pth"
            torch.save({
                "model": net_model.state_dict(),
                "ema_model": ema_model.state_dict(),
                "sched": scheduler.state_dict(),
                "optim": optimizer.state_dict(),
                "step": step,
            }, ckpt_path)
            wandb.save(ckpt_path)

    print("Training completed!")
    wandb.finish()


In [None]:
############################
# 5) Run Training
############################
if __name__ == "__main__":
    train_flowmatching_shortcut(
        net_model = net_model,
        ema_model = ema_model,
        TOTAL_STEPS=60001,
        BATCH_SIZE=64,
        LR=1e-4,
        GRAD_CLIP=1.0,
        SAVE_STEP=1000,
        WARMUP=3000
    )

ShortcutTrainer: 100%|██████████| 60001/60001 [2:57:46<00:00,  5.63it/s, fm=0.1476, cons=0.0008, loss=0.1484]

Training completed!





0,1
consistency_loss,▁▃▂▂▂▃▁▂▂▁█▃▂▁▃▂▃▂▂▂▂▂▂▁▁▃▁▂▄▂▂▂▁▃▂▁▂▂▂▂
fm_loss,▃▅▁▅▆▇▄▄▆▅▇▅▇▇▅▄▅▆▆▃▆▇▅▇▅▄▅▄▅█▃▆▅▅▆▅▄▃▁█
step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
total_loss,▃▃▇▃▃▆▄▅▅▆▃▃▁█▆▆▂▃▅▃▃▂▅▄▃▅▄▅▅▇▁▂▄▄▆▂▄▆▇▂

0,1
consistency_loss,0.00081
fm_loss,0.14763
step,60000.0
total_loss,0.14844
