In [None]:
from torch.utils.data import DataLoader, Dataset
import torch

class SquareImageDataset(Dataset):
    def __init__(self):
        img = torch.zeros(1, 32, 32, dtype=torch.float32)
        img[:, 8:24, 8:24] = 1.0
        self.img = img
    def __len__(self):
        return 1
    def __getitem__(self, idx):
        return self.img, 0

train_dataset = SquareImageDataset()
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(train_dataset, batch_size=1)


In [None]:
from einops import repeat
import torch
from torch import Tensor, nn
import lovely_tensors as lt
lt.monkey_patch()

class FlowMatchingModel(nn.Module):
    "Flow-matching CNN written with a single nn.Sequential block."
    def __init__(self, img_size: int = 32, kernel_size: int = 3, hidden_dim: int = 64):
        super().__init__()
        self.img_size = img_size
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=hidden_dim, kernel_size=kernel_size, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=kernel_size, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=kernel_size, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=kernel_size, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=kernel_size, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_dim, out_channels=1, kernel_size=kernel_size, padding=1),
            nn.ReLU()
        )

    def forward(self, x_t: Tensor, t: Tensor) -> Tensor:
        if t.ndim == 0:
            t = t.view(1).repeat(x_t.size(0))
        if t.ndim == 1:
            t = t.view(x_t.size(0), 1, 1, 1)
        elif t.ndim != 4:
            raise ValueError("t must be of shape (), (B,) or (B,1,H,W)")
        t = t.expand(-1, 1, x_t.size(2), x_t.size(3))
        x = torch.cat([x_t, t], dim=1)
        return self.net(x)


In [None]:
from torch.optim import AdamW
from flow_matching.path.scheduler import CondOTScheduler
from flow_matching.path import AffineProbPath
from flow_matching.solver import ODESolver
from flow_matching.utils import ModelWrapper

path = AffineProbPath(scheduler=CondOTScheduler())

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
fm = FlowMatchingModel(img_size=32, kernel_size=3).to(device)
optim = AdamW(params=fm.parameters(), lr=5e-3)
loss_fn = nn.MSELoss()

num_epochs = 200
for epoch in range(num_epochs):
    fm.train()
    for batch in train_loader:
        x_1 = batch[0].to(device)
        x_0 = torch.randn_like(x_1)
        B, C, H, W = x_0.shape
        t = torch.rand(B, 1, 1, 1, device=device)
        path_sample = path.sample(t=t.squeeze(), x_0=x_0, x_1=x_1)
        x_t = path_sample.x_t
        dx_t = path_sample.dx_t
        pred = fm(x_t, t)
        loss = loss_fn(pred, dx_t)
        optim.zero_grad()
        loss.backward()
        optim.step()
    if epoch % 20 == 0:
        print(f'Epoch {epoch}: Loss = {loss.item():.5f}')


In [None]:
class WrappedModel(ModelWrapper):
    def forward(self, x: torch.Tensor, t: torch.Tensor, **extras):
        return self.model(x, t)

wrapped_vf = WrappedModel(fm)

number_of_steps = 500
step_size = 1 / (number_of_steps * 3)
T = torch.linspace(0, 1, number_of_steps).to(device)

x_init = torch.randn((1, 1, 32, 32), device=device)
solver = ODESolver(velocity_model=wrapped_vf)
sol = solver.sample(time_grid=T, x_init=x_init, method='midpoint', step_size=step_size, return_intermediates=True)


In [None]:
import matplotlib.pyplot as plt
import numpy as np

def visualize_image_evolution(sol, number_of_steps, n_plots=10, title_prefix='Step'):
    images = sol[:number_of_steps, 0]
    plot_indices = np.linspace(0, number_of_steps-1, n_plots, dtype=int)
    plt.figure(figsize=(2.5*n_plots, 3))
    for j, i in enumerate(plot_indices):
        img = images[i].detach().cpu().squeeze().numpy()
        plt.subplot(1, n_plots, j+1)
        plt.imshow(img, cmap='gray')
        plt.axis('off')
        plt.title(f'{title_prefix} {i}')
    plt.tight_layout()
    plt.show()

visualize_image_evolution(sol, number_of_steps, n_plots=10)
