In [25]:
from src.datasets.mnist import FlowMatchingMNIST
from torch.utils.data import DataLoader
from torchvision.transforms import CenterCrop, Compose, Lambda, Resize, ToTensor, Pad
from src.flow_matching import OTConditionalFlow
from torch.distributions.uniform import Uniform
from src.models.unet import Unet, SinusoidalPositionEmbeddings

import torch
import torch.nn as nn
import numpy as np

In [2]:
# Setup dataloader
transform = Compose(
    [
        Resize(28),
        ToTensor(),
        Lambda(lambda t: (t * 2) - 1),
    ]
)
dataset = FlowMatchingMNIST("./data", train=True, transform=transform)
dataloader = DataLoader(dataset=dataset, batch_size=8, shuffle=True, pin_memory=True)

In [14]:
len(dataloader)

7500

In [3]:
batch = next(iter(dataloader))
batch["image"].shape

torch.Size([8, 1, 28, 28])

In [4]:
x_1 = batch["image"]
batch_size = x_1.shape[0]
x_0 = torch.randn_like(x_1)

In [13]:
type(m)

torch.distributions.uniform.Uniform

In [5]:
m = Uniform(0, 1)
t = m.sample(sample_shape=(batch_size,))
t.shape

torch.Size([8])

In [6]:
target_flow = OTConditionalFlow(sigma_min=0.1)

In [7]:
x_t = target_flow.sample_p_t(x_0=x_0, x_1=x_1, t=t)
x_t.shape

torch.Size([8, 1, 28, 28])

In [46]:
class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(28), nn.Linear(28, 1), nn.GELU()
        )

        self.net = nn.Sequential(
            nn.Conv2d(1, 4, 3, padding="same"),
            nn.SELU(),
            nn.GroupNorm(1, 4),
            nn.Conv2d(4, 1, 3, padding="same"),
        )

    def forward(self, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        time_embed = self.time_mlp(t)

        return self.net(x_t + time_embed[:, None, None])

In [47]:
model = Unet(channels=1, dim_mults=(1, 2, 4), dim=28)
model = SimpleModel()

In [48]:
predicted_cond_vector_field = model(x_t, t)
predicted_cond_vector_field.shape

torch.Size([8, 1, 28, 28])

In [12]:
target_cond_vector_field = target_flow.get_conditional_vector_field(x_0=x_0, x_1=x_1, t=t)
target_cond_vector_field.shape

torch.Size([8, 1, 28, 28])