## TODO

Try flow MNIST with UNet MLP, UNet / DiT

In [None]:
import math
import numpy as np
import torch
import torch.utils.checkpoint
import torchvision

from torchvision import transforms
from tqdm.auto import tqdm

from rectified_flow.rectified_flow import RectifiedFlow
from rectified_flow.utils import match_dim_with_data

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

## Load dataset, model, optimizer

In [None]:
batch_size = 512

transform_list = [
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
]
train_dataset = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transforms.Compose(transform_list)
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,          
    pin_memory=True,
    persistent_workers=False,
)

batch = next(iter(train_dataloader))
print(batch[0].shape)  # torch.Size([256, 1, 28, 28])

In [None]:
model_type = "unet"    # "mlp" or "unet" or "dit"
from rectified_flow.models.enhanced_mlp import VarMLP
from rectified_flow.models.utils import EMAModel

if model_type == "mlp":
    from rectified_flow.models.enhanced_mlp import MLPVelocity
    flow_model = MLPVelocity(
        dim=28*28,
    )
    data_shape = (28*28,)
elif model_type == "unet":
    from rectified_flow.models.unet import SongUNet, SongUNetConfig
    config = SongUNetConfig(
        img_resolution = 28,
        in_channels = 1,                  # Number of color channels at input.
        out_channels = 1,                 # Number of color channels at output.
        label_dim = 0,                    # Number of class labels, 0 = unconditional.
        augment_dim = 0,                   # Augmentation label dimensionality, 0 = no augmentation.

        model_channels = 64,               # Base multiplier for the number of channels.
        channel_mult = [2, 2],               # Channel multipliers for each resolution.
        channel_mult_emb = 2,                # Multiplier for the dimensionality of the embedding vector.
        num_blocks = 3,                      # Number of residual blocks per resolution.
        attn_resolutions = [16],             # Resolutions at which to apply attention.
        dropout = 0.13,                      # Dropout probability of intermediate activations.
        label_dropout = 0.0,                 # Dropout probability of class labels for classifier-free guidance.
        embedding_type = "positional",        # Timestep embedding type: 'positional' or 'fourier'.
        channel_mult_time = 1,                 # Timestep embedding size: 1 for DDPM++, 2 for NCSN++.
        encoder_type = "standard",              # Encoder architecture: 'standard' or 'residual'.
        decoder_type = "standard",              # Decoder architecture: 'standard' or 'residual'.
        resample_filter = [1, 1]
    )
    flow_model = SongUNet(config)
    data_shape = (1, 28, 28)
elif model_type == "dit":
    raise NotImplementedError
else:
    raise ValueError(f"Unknown model type: {model_type}")

logvar = VarMLP().to(device)
flow_model = flow_model.to(device)

print(f"Number of parameters in flow model: {sum(p.numel() for p in flow_model.parameters() if p.requires_grad):,}")

ema_flow = EMAModel(flow_model, ema_halflife_kimg=1.0, ema_rampup_ratio=0.05)

optimizer = torch.optim.AdamW(
    list(flow_model.parameters()) + list(logvar.parameters()),
    lr=1e-4, weight_decay=0.0,
    betas=(0.9, 0.95)
)

compiled_flow = torch.compile(flow_model, mode="reduce-overhead", fullgraph=False, dynamic=False)
compiled_logvar = torch.compile(logvar, mode="reduce-overhead", fullgraph=False, dynamic=False)

compiled_flow.train(); compiled_logvar.train()
optimizer.zero_grad()

## Train Unconditional Generation

In [None]:
rf_train = RectifiedFlow(
    data_shape=data_shape,
    velocity_field=compiled_flow,
    train_time_distribution="uniform",
    device=device,
)

epoch = 2000
cur_nimg = 0

try:
    from tqdm.auto import tqdm
except Exception:
    from tqdm import tqdm

def safe_tqdm_write(msg: str):
    try:
        tqdm.write(msg)
    except Exception:
        print(msg)

zero_to_none = True
grad_clip_norm = None

global_step = 0
running_loss = None
smooth_alpha = 0.99

# test model
with torch.no_grad():
    batch = next(iter(train_dataloader))
    x_1, c = batch
    x_1 = x_1.to(device, non_blocking=True).reshape(x_1.shape[0], *data_shape)
    x_0 = torch.randn_like(x_1)
    t = rf_train.sample_train_time(x_1.shape[0]).to(device, non_blocking=True)
    v_pred = flow_model(x_1, t)
    print(x_1.shape, x_0.shape, v_pred.shape)
    log_var = logvar(t)[:, None, None, None]  # [B] or [B,1]
    denom = torch.exp(log_var)
    print(log_var.shape, denom.shape)
    sq_err = (v_pred - (x_1 - x_0)).pow(2).sum(dim=1)
    loss = (sq_err / denom + log_var).mean()
    safe_tqdm_write(f"Initial test loss: {loss.item():.4f}")

In [None]:
for ep in tqdm(range(epoch), desc="Epochs", position=0):
    pbar = tqdm(train_dataloader, desc=f"ep {ep+1}/{epoch}", leave=False, position=1)
    for step, batch in enumerate(pbar):
        optimizer.zero_grad(set_to_none=zero_to_none)

        x_1, c = batch
        x_1 = x_1.to(device, non_blocking=True).reshape(x_1.shape[0], *data_shape)
        x_0 = torch.randn_like(x_1)
        t = rf_train.sample_train_time(x_1.shape[0]).to(device, non_blocking=True)

        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            x_t, dot_x_t = rf_train.get_interpolation(x_0=x_0, x_1=x_1, t=t)
            v_pred = rf_train.get_velocity(x_t, t)
            log_var = compiled_logvar(t)  # [B] or [B,1]
            log_var = log_var[:, None, None, None]  # match data shape

        mse_loss = torch.nn.functional.mse_loss(v_pred.detach(), dot_x_t.detach())
        denom = torch.exp(log_var)
        sq_err = (v_pred - dot_x_t).pow(2).sum(dim=1)
        loss = (sq_err / denom + log_var).mean()

        loss.backward()

        if grad_clip_norm is not None:
            torch.nn.utils.clip_grad_norm_(rf_train.parameters(), grad_clip_norm)

        optimizer.step()

        global_step += 1
        cur_nimg += x_1.shape[0]
        ema_flow.update(cur_nimg=cur_nimg, batch_size=x_1.shape[0])

        loss_val = float(loss.detach().item())
        running_loss = loss_val if running_loss is None else smooth_alpha * running_loss + (1 - smooth_alpha) * loss_val

        pbar.set_postfix({
            "mse_loss": f"{mse_loss:.4f}",
            "logvar_loss": f"{loss_val:.4f}",
            "logvar_loss_ema": f"{running_loss:.4f}",
            "steps": global_step
        })

    safe_tqdm_write(f"[epoch {ep+1}/{epoch}] steps={global_step:,} mse_loss={mse_loss:.4f} logvar_loss={loss_val:.4f} logvar_loss_ema={running_loss:.4f}")

    if ep % 50 == 0:
        flow_model.save_pretrained(f"./checkpoints/flow_mnist")
        ema_flow.save_pretrained(f"./checkpoints/flow_mnist")

print(f"Training done. Total steps: {global_step:,}, last loss: {loss_val:.4f}, EMA: {running_loss:.4f}")

In [None]:
from rectified_flow.samplers import EulerSampler, SDESampler

model_inference = flow_model.eval()

rf_inference = RectifiedFlow(
    data_shape=data_shape,
    velocity_field=model_inference,
    device=device,
)

euler_sampler = EulerSampler(rf_inference, num_steps=100)
sde_sampler = SDESampler(rf_inference, num_steps=200, noise_scale=5, noise_decay_rate=1.0)

In [None]:
x_0 = torch.randn(20, *data_shape).to(device)

x_1_euler = euler_sampler.sample_loop(x_0=x_0).trajectories[-1]
x_1_sde = sde_sampler.sample_loop(x_0=x_0).trajectories[-1]

print(x_1.shape)  # torch.Size([20, 1, 28, 28])

In [None]:
from rectified_flow.utils import plot_cifar_results

plot_cifar_results(x_1_euler)
plot_cifar_results(x_1_sde)