## TODO

Try flow MNIST with UNet MLP, UNet / DiT

In [None]:
import math
import numpy as np
import torch
import torch.utils.checkpoint
import torchvision

from torchvision import transforms
from tqdm.auto import tqdm

from rectified_flow.rectified_flow import RectifiedFlow

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

## Load dataset, model, optimizer

In [6]:
batch_size = 256

transform_list = [
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
    transforms.RandomHorizontalFlip(),
]
train_dataset = torchvision.datasets.MNIST(
    root="./data", train=True, download=True, transform=transforms.Compose(transform_list)
)

train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, num_workers=4
)

batch = next(iter(train_dataloader))
print(batch[0].shape)  # torch.Size([256, 1, 28, 28])

torch.Size([256, 1, 28, 28])


In [5]:
model_type = "mlp"    # "mlp" or "unet" or "dit"
from rectified_flow.models.enhanced_mlp import VarMLP
from rectified_flow.models.utils import EMAModel

if model_type == "mlp":
    from rectified_flow.models.enhanced_mlp import MLPVelocity
    flow_model = MLPVelocity(
        dim=28*28,
    )
elif model_type == "unet":
    from rectified_flow.models.unet import SongUNet, SongUNetConfig
    config = SongUNetConfig(
        img_resolution = 28,
        in_channels = 1,                  # Number of color channels at input.
        out_channels = 1,                 # Number of color channels at output.
        label_dim = 0,                    # Number of class labels, 0 = unconditional.
        augment_dim = 0,                   # Augmentation label dimensionality, 0 = no augmentation.

        model_channels = 64,               # Base multiplier for the number of channels.
        channel_mult = [2, 2],               # Channel multipliers for each resolution.
        channel_mult_emb = 2,                # Multiplier for the dimensionality of the embedding vector.
        num_blocks = 3,                      # Number of residual blocks per resolution.
        attn_resolutions = [16],             # Resolutions at which to apply attention.
        dropout = 0.13,                      # Dropout probability of intermediate activations.
        label_dropout = 0.0,                 # Dropout probability of class labels for classifier-free guidance.
        embedding_type = "positional",        # Timestep embedding type: 'positional' or 'fourier'.
        channel_mult_time = 1,                 # Timestep embedding size: 1 for DDPM++, 2 for NCSN++.
        encoder_type = "standard",              # Encoder architecture: 'standard' or 'residual'.
        decoder_type = "standard",              # Decoder architecture: 'standard' or 'residual'.
        resample_filter = [1, 1]
    )
    flow_model = SongUNet(config)
elif model_type == "dit":
    raise NotImplementedError
else:
    raise ValueError(f"Unknown model type: {model_type}")

logvar = VarMLP().to(device)
flow_model = flow_model.to(device)

ema_flow = EMAModel(flow_model, ema_halflife_kimg=1.0, ema_rampup_ratio=0.05)
ema_logvar = EMAModel(logvar, ema_halflife_kimg=1.0, ema_rampup_ratio=0.05)

optimizer = torch.optim.AdamW(
    list(flow_model.parameters()) + list(logvar.parameters()),
    lr=1e-4, weight_decay=0.01,
    betas=(0.9, 0.95)
)

compiled_flow = torch.compile(flow_model, mode="reduce-overhead", fullgraph=False, dynamic=False)
compiled_logvar = torch.compile(logvar, mode="reduce-overhead", fullgraph=False, dynamic=False)

compiled_flow.train(); compiled_logvar.train()
optimizer.zero_grad()

## Train Unconditional Generation

In [None]:
rf_train = RectifiedFlow(
    data_shape=(1, 28, 28),
    velocity_field=compiled_flow,
    train_time_distribution="uniform",
)

