In [12]:
%cd /home/luke-padmore/Source/flow-matching-mnist
%autoreload 2
from models.config import UNetConfig
from models.unet import UNet

/home/luke-padmore/Source/flow-matching-mnist


## Testing it instantiates correctly

In [8]:
cfg = UNetConfig(
    channels=(1, 64, 128, 256),  
    d_trunk=32,
    d_concat=8,
    group_norm_size=8,
    d_time=128,
    max_time_period=10000.0,
    activation="silu",
    upsample_mode="nearest",
)

model = UNet.from_config(cfg)
print(model)

UNet(
  (encoder): Encoder(
    (initial_conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_blocks): ModuleList(
      (0): ConvDownblock(
        (conv): Sequential(
          (0): GroupNorm(8, 72, eps=1e-05, affine=True)
          (1): Conv2d(72, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (2): SiLU()
          (3): GroupNorm(8, 64, eps=1e-05, affine=True)
          (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): SiLU()
        )
        (down): Sequential(
          (0): GroupNorm(8, 64, eps=1e-05, affine=True)
          (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (2): SiLU()
        )
        (time_emb_mlp): Linear(in_features=32, out_features=8, bias=True)
      )
      (1): ConvDownblock(
        (conv): Sequential(
          (0): GroupNorm(8, 72, eps=1e-05, affine=True)
          (1): Conv2d(72, 128, kernel_size=(3, 3)

## Testing forward pass

In [3]:
import torchvision
import torch
import matplotlib.pyplot as plt 
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from models.unet import UNet
import sys
from pathlib import Path 
PROJECT_ROOT = Path.cwd().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

batch_size = 64
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Pad(2,padding_mode='constant'),
    transforms.Normalize((0.5,), (0.5,))]
)
trainset = torchvision.datasets.MNIST(root = '/home/luke-padmore/Source/flow-matching-mnist/data',
                                      train=True,
                                      download=True,
                                      transform=transform)
trainloader = DataLoader(trainset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=4)

testset = torchvision.datasets.MNIST(root = '/home/luke-padmore/Source/flow-matching-mnist/data',
                                      train=False,
                                      download=True,
                                      transform=transform)
testloader = DataLoader(trainset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=4)

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataiter = iter(trainloader)
images, labels = next(dataiter)

cfg = UNetConfig(
    channels=(1, 64, 128, 256),  
    d_trunk=32,
    d_concat=8,
    group_norm_size=8,
    d_time=128,
    max_time_period=10000.0,
    activation="silu",
    upsample_mode="nearest",
)

model = UNet.from_config(cfg).to(device)
images = images.to(device)
t = torch.rand((images.shape[0]),device=device).view(-1,1)
out = model(images,t)
print(out.shape)

torch.Size([64, 1, 32, 32])


Working :P

## Testing sampling for an optuna trial

In [26]:
from utils.optuna_models import HPTYaml
from models.config import UNetConfig, OptimConfig
import optuna

In [27]:
yaml_str = """
study:
  study_name: "test"
  direction: "minimize"
  storage: "sqlite:///optuna.db"
  n_trials: 1

dataloader:
  batch_size: 64
  data_path: /home/luke-padmore/Source/flow-matching-mnist/data
  num_workers: 4
  shuffle: true
  transform: "default"

unet:
  fixed:
    max_time_period: 1000.0
  choices:
    channels:
      type: categorical
      choices:
        - [1, 64, 128]
        - [1, 64, 128, 256]
    d_time:
      type: categorical
      choices: [64, 128, 256]
    activation:
      type: categorical
      choices: [silu, relu, gelu]
    upsample_mode:
      type: categorical
      choices: [nearest, bilinear, convtranspose]

optim:
  fixed:
    name: adamw
  choices:
    lr:
      type: float
      low: 1e-5
      high: 5e-4
      log: true
    weight_decay:
      type: float
      low: 1e-6
      high: 1e-2
      log: true
"""


In [28]:
import yaml

data = yaml.safe_load(yaml_str)
hpt = HPTYaml.model_validate(data)
study = optuna.create_study(
    direction=hpt.opt_study_cfg.direction,
    storage=hpt.opt_study_cfg.storage,
    study_name=hpt.opt_study_cfg.study_name,
    load_if_exists=True,
)

trial = study.ask()
unet_cfg, optim_cfg = hpt.sample(trial)
print(unet_cfg, optim_cfg)

[I 2026-01-29 23:32:42,452] Using an existing study with name 'test' instead of creating a new one.


UNetConfig(channels=[1, 64, 128, 256], d_trunk=32, d_concat=8, group_norm_size=8, d_time=64, max_time_period=1000.0, activation='gelu', upsample_mode='bilinear') OptimConfig(name='adamw', lr=0.00018585739010922318, weight_decay=0.00011993545279058418)


