In [None]:
# Simple Flow matching

In [None]:
import torch
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.ToTensor()
])

mnist = datasets.MNIST(root=".", download=True, train=True, transform=transform)

# pick the first digit 1
for img, label in mnist:
    if label == 1:
        x_1 = img[0]  # (1, 28, 28) tensor
        break


In [None]:
import torch
from torch import nn, Tensor
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons

In [None]:
# class Flow(nn.Module):
#     def __init__(self, dim: int = 2, h: int = 64):
#         super().__init__()
#         self.net = nn.Sequential(
#             nn.Linear(dim + 1, h), nn.ELU(),
#             nn.Linear(h, h), nn.ELU(),
#             nn.Linear(h, h), nn.ELU(),
#             nn.Linear(h, dim))

#     def forward(self, x_t: Tensor, t: Tensor) -> Tensor:
#         return self.net(torch.cat((t, x_t), -1))
    
#     def step(self, x_t: Tensor, t_start: Tensor, t_end: Tensor) -> Tensor:
#         t_start = t_start.view(1, 1).expand(x_t.shape[0], 1)
#         # For simplicity, using midpoint ODE solver in this example
#         return x_t + (t_end - t_start) * self(x_t + self(x_t, t_start) * (t_end - t_start) / 2,
#                                               t_start + (t_end - t_start) / 2)

# # training
# flow = Flow()
# optimizer = torch.optim.Adam(flow.parameters(), 1e-2)
# loss_fn = nn.MSELoss()

In [None]:
import torch
from torch import nn, Tensor
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Flow(nn.Module):
    def __init__(self, dim: int = 28*28, h: int = 1024):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim + 1, h),
            nn.ELU(),
            nn.Linear(h, h),
            nn.ELU(),
            nn.Linear(h, h),
            nn.ELU(),
            nn.Linear(h, dim),
        )

    def forward(self, x_t: Tensor, t: Tensor) -> Tensor:
        # x_t: (B, dim), t: (B, 1)
        return self.net(torch.cat([t, x_t], dim=-1))

    @torch.no_grad()
    def step(self, x_t: Tensor, t_start: float, t_end: float) -> Tensor:
        """
        Single ODE step using a simple midpoint rule, mirroring your 2D example.
        """
        B = x_t.shape[0]
        t_start_batch = torch.full((B, 1), t_start, device=x_t.device)
        t_end_batch = torch.full((B, 1), t_end, device=x_t.device)

        dt = t_end_batch - t_start_batch
        k1 = self(x_t, t_start_batch)
        x_mid = x_t + 0.5 * dt * k1
        t_mid = t_start_batch + 0.5 * dt
        k2 = self(x_mid, t_mid)
        return x_t + dt * k2


In [None]:
x_1.shape

In [None]:
# from tqdm import tqdm

# #for _ in range(10):
# for _ in tqdm(range(1000), desc="Training"):
#     #x_1 = Tensor(make_moons(100, noise=0.05)[0])
#     #x_1 = Tensor(x[0])
#     x_0 = torch.randn_like(x_1)
#     t = torch.rand(len(x_1), 1)
#     x_t = (1 - t) * x_0 + t * x_1
#     dx_t = x_1 - x_0
#     optimizer.zero_grad()
#     loss_fn(flow(x_t, t), dx_t).backward()
#     optimizer.step()

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),          # (1, 28, 28), in [0, 1]
    transforms.Lambda(lambda x: x.view(-1))  # flatten to (784,)
])

train_dataset = datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=transform,
)

batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [None]:
flow = Flow(dim=28*28, h=1024).to(device)
optimizer = torch.optim.Adam(flow.parameters(), lr=1e-4)
loss_fn = nn.MSELoss()

num_epochs = 5  # increase for better samples

flow.train()
for epoch in range(num_epochs):
    for x_1, _ in train_loader:
        x_1 = x_1.to(device)              # (B, 784), values in [0, 1]
        x_0 = torch.randn_like(x_1)       # Gaussian noise (B, 784)

        t = torch.rand(x_1.size(0), 1, device=device)  # (B, 1), uniform in [0, 1]
        x_t = (1.0 - t) * x_0 + t * x_1                # linear interpolation
        dx_t = x_1 - x_0                               # target velocity

        optimizer.zero_grad()
        v_pred = flow(x_t, t)
        loss = loss_fn(v_pred, dx_t)
        loss.backward()
        optimizer.step()

    print(f"epoch {epoch+1}: loss={loss.item():.4f}")

In [None]:
@torch.no_grad()
def sample(flow: Flow, n_samples: int = 64, n_steps: int = 20):
    flow.eval()
    x = torch.randn(n_samples, 28*28, device=device)
    time_grid = torch.linspace(0.0, 1.0, n_steps + 1, device=device)

    for i in range(n_steps):
        t_start = float(time_grid[i].item())
        t_end = float(time_grid[i + 1].item())
        x = flow.step(x, t_start, t_end)

    # map to [0,1] for visualization and reshape to images
    x_img = x.clamp(0.0, 1.0).view(-1, 1, 28, 28)
    return x_img

samples = sample(flow, n_samples=64, n_steps=30).cpu()

grid = make_grid(samples, nrow=8, padding=2, normalize=False)
plt.figure(figsize=(6, 6))
plt.imshow(grid.permute(1, 2, 0), cmap="gray")
plt.axis("off")
plt.show()

In [None]:
# sampling
x = torch.randn(300, 2)
n_steps = 8
fig, axes = plt.subplots(1, n_steps + 1, figsize=(30, 4), sharex=True, sharey=True)
time_steps = torch.linspace(0, 1.0, n_steps + 1)

axes[0].scatter(x.detach()[:, 0], x.detach()[:, 1], s=10)
axes[0].set_title(f't = {time_steps[0]:.2f}')
axes[0].set_xlim(-3.0, 3.0)
axes[0].set_ylim(-3.0, 3.0)

for i in range(n_steps):
    x = flow.step(x, time_steps[i], time_steps[i + 1])
    axes[i + 1].scatter(x.detach()[:, 0], x.detach()[:, 1], s=10)
    axes[i + 1].set_title(f't = {time_steps[i + 1]:.2f}')

plt.tight_layout()
plt.show()

In [None]:
x = make_moons(256, noise=0.05)[0]
plt.scatter(x[:, 0], x[:, 1])

In [None]:

print(x.shape)  # torch.Size([1, 28, 28])

In [None]:
plt.imshow(x[0],cmap='gray' )

In [None]:
x[0].shape