# Training

Code based on: https://github.com/atong01/conditional-flow-matching/tree/main


## Required packages:

In [None]:
!pip install torchdyn;
!pip install torchcfm;

## Importing Libaries and initiating device

In [1]:
# Libs:
import copy
import os
import math

from tqdm import trange
from absl import app, flags

import torch
import torch.nn as nn
from torchdyn.core import NeuralODE
from torchvision import datasets, transforms
from torchvision.utils import make_grid, save_image

from torchcfm.conditional_flow_matching import ( ConditionalFlowMatcher, ExactOptimalTransportConditionalFlowMatcher, VariancePreservingConditionalFlowMatcher,)
from torchcfm.models.unet.unet import UNetModelWrapper

# Device initiation:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()
  register_backend(TensorflowBackend())


## Help functions

In [2]:
def generate_samples(model, parallel, savedir, step, net_="normal"):
    """Save 64 generated images (8 x 8) for sanity check along training.
    Parameters
    ----------
    model:
        represents the neural network that we want to generate samples from
    parallel: bool
        represents the parallel training flag. Torchdyn only runs on 1 GPU, we need to send the models from several GPUs to 1 GPU.
    savedir: str
        represents the path where we want to save the generated images
    step: int
        represents the current step of training
    """
    model.eval()

    model_ = copy.deepcopy(model)
    if parallel:
        # Send the models from GPU to CPU for inference with NeuralODE from Torchdyn
        model_ = model_.module.to(device)

    node_ = NeuralODE(model_, solver="euler", sensitivity="adjoint")
    with torch.no_grad():
        traj = node_.trajectory(
            torch.randn(64, 3, 32, 32).to(device),
            t_span=torch.linspace(0, 1, 100).to(device),
        )
        traj = traj[-1, :].view([-1, 3, 32, 32]).clip(-1, 1)
        traj = traj / 2 + 0.5
    save_image(traj, savedir + f"{net_}_generated_FM_images_step_{step}.png", nrow=8)

    model.train()


def ema(source, target, decay):
    source_dict = source.state_dict()
    target_dict = target.state_dict()
    for key in source_dict.keys():
        target_dict[key].data.copy_(
            target_dict[key].data * decay + source_dict[key].data * (1 - decay)
        )

def infiniteloop(dataloader):
    while True:
        for x, y in iter(dataloader):
            yield x

def warmup_lr(step):
    return min(step, warmup) / warmup

In [3]:
# Model
model_name = "otcfm" # Flow matching model type, otcfm or vpcfm
output_dir = "./results/" # Output_directory

# UNet
num_channel = 64 # Base channel of UNet

# Training
lr = 2e-4 # Target learning rate
grad_clip = 1.0 # Gradient norm clipping
total_steps = 400000 # Total training steps, Lipman et al uses 400k and double batch size, we used: 40k and batch size: 64
warmup = 1000 # Learning rate warmup
batch_size = 32 # batch size, Lipman et al uses 128
num_workers = 4 # Workers of Dataloader
ema_decay = 0.9999 # Ema decay rate
parallel = False # Multi gpu training

# Evaluation
save_step = 100 # Frequency of saving checkpoints, 0 to disable during training"

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

## Training loop

In [4]:
def train():
    eval_loss = []
    print(
        "lr, total_steps, ema decay, save_step:",
        lr,
        total_steps,
        ema_decay,
        save_step,
    )
    print(warmup_lr)

    # DATASETS
    dataset = datasets.CIFAR10(
        root="./data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [   transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]   ),
    )
    # DATALOADER
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        drop_last=True,
    )

    datalooper = infiniteloop(dataloader)

    # UNET MODEL
    net_model = UNetModelWrapper(
        dim=(3, 32, 32),
        num_res_blocks=2,
        num_channels=num_channel,
        channel_mult=[1, 2, 2, 2],
        num_heads=4,
        num_head_channels=64,
        attention_resolutions="16",
        dropout=0.1,
    ).to( device)
    # EMA model
    ema_model = copy.deepcopy(net_model)
    optim = torch.optim.Adam(net_model.parameters(), lr=lr)
    sched = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=warmup_lr)
    if parallel:
        print(
            "Warning: parallel training is performing slightly worse than single GPU training due to statistics computation in dataparallel. We recommend to train over a single GPU, which requires around 8 Gb of GPU memory."
        )
        net_model = torch.nn.DataParallel(net_model)
        ema_model = torch.nn.DataParallel(ema_model)

    # Print model size
    model_size = 0
    for param in net_model.parameters():
        model_size += param.data.nelement()
    print("Model params: %.2f M" % (model_size / 1024 / 1024))

    ######## Choose model ########
    sigma = 0.0
    if model_name == "otcfm":
        FM = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
    elif model_name == "vpcfm":
        FM = VariancePreservingConditionalFlowMatcher(sigma=sigma)
    else:
        raise NotImplementedError( f"Unknown model {model_name}, must be one of ['otcfm','vpcfm']")

    savedir = output_dir + model_name + "/"
    os.makedirs(savedir, exist_ok=True)

    with trange(total_steps, dynamic_ncols=True) as pbar:
        for step in pbar:
            optim.zero_grad()
            x1 = next(datalooper).to(device)
            x0 = torch.randn_like(x1)
            t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)
            vt = net_model(t, xt)
            loss = torch.mean((vt - ut) ** 2)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(net_model.parameters(), grad_clip)
            optim.step()
            sched.step()
            eval_loss.append(loss)
            ema(net_model, ema_model, ema_decay)

            # Sample and Saving the weights
            if save_step > 0 and step % save_step == 0:
                print(f', \Epoch: {step}, Train loss: {loss:.4f}')
                generate_samples(net_model, parallel, savedir, step, net_="normal")
                generate_samples(ema_model, parallel, savedir, step, net_="ema")
                torch.save(
                    {
                        "net_model": net_model.state_dict(),
                        "ema_model": ema_model.state_dict(),
                        "sched": sched.state_dict(),
                        "optim": optim.state_dict(),
                        "step": step,
                    },
                    savedir + f"cifar10_weights_step_{step}.pt",
                )

In [None]:
train()

lr, total_steps, ema decay, save_step: 0.0002 1000 0.9999 100
<function warmup_lr at 0x795dc5ee5e10>
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 80376058.36it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data




Model params: 8.54 M


  0%|          | 0/1000 [00:00<?, ?it/s]

, \Epoch: 0, Train loss: 1.2072
