In [1]:
%cd /home/luke-padmore/Source/flow-matching-mnist

/home/luke-padmore/Source/flow-matching-mnist


In [2]:
import torchvision
import mlflow
import math
import torch
import matplotlib.pyplot as plt 
import torch.nn as nn 
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from models.unet import UNet, CondUNet
from functools import partial
import os, sys
from models.ode_solvers import euler_solver, rk2_solver, make_vf_uncond,create_samples
from pathlib import Path 

# For system metrics
os.environ["MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING"] = "true"
PROJECT_ROOT = Path.cwd().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

batch_size = 64
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Pad(2,padding_mode='constant'),
    transforms.Normalize((0.5,), (0.5,))]
)
trainset = torchvision.datasets.MNIST(root = '/home/luke-padmore/Source/flow-matching-mnist/data',
                                      train=True,
                                      download=True,
                                      transform=transform)
trainloader = DataLoader(trainset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=4)
valset = torchvision.datasets.MNIST(root = '/home/luke-padmore/Source/flow-matching-mnist/data',
                                      train=False,
                                      download=True,
                                      transform=transform)
valloader = DataLoader(valset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=4)

In [3]:
run_id = "4435f756670844fbbc1a89bff37cba34"
model = mlflow.pytorch.load_model(f"runs:/{run_id}/UNet")
print(model)

  return FileStore(store_uri, store_uri)
  from .autonotebook import tqdm as notebook_tqdm
Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]


Downloading artifacts: 100%|██████████| 6/6 [00:00<00:00, 792.08it/s] 


UNet(
  (encoder): Encoder(
    (initial_conv): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_blocks): ModuleList(
      (0): ConvDownblock(
        (conv): Sequential(
          (0): GroupNorm(8, 40, eps=1e-05, affine=True)
          (1): Conv2d(40, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (2): SiLU()
          (3): GroupNorm(8, 32, eps=1e-05, affine=True)
          (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (5): SiLU()
        )
        (down): Sequential(
          (0): GroupNorm(8, 32, eps=1e-05, affine=True)
          (1): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (2): SiLU()
        )
        (time_emb_mlp): Linear(in_features=8, out_features=8, bias=True)
      )
      (1): ConvDownblock(
        (conv): Sequential(
          (0): GroupNorm(8, 40, eps=1e-05, affine=True)
          (1): Conv2d(40, 64, kernel_size=(3, 3), 

In [4]:
from models.ode_solvers import get_ode_solver_from_name
from utils.mlflow_tracking_utils import get_run_param, parse_int_list
ode_solver = get_ode_solver_from_name(get_run_param(run_id,"ode_solver"))
image_shape = parse_int_list(get_run_param(run_id,"image_shape"))
ode_steps = int(get_run_param(run_id,"ode_steps"))
batch_size = int(get_run_param(run_id,"batch_size"))
 
f = make_vf_uncond(model)

sample_fn = partial(
    create_samples,
    image_shape= image_shape,
    ode_solver = ode_solver,
    f=f,
    n_steps=ode_steps,
    seed=None,
    device=device,
)
sample_fn

functools.partial(<function create_samples at 0x70278121aac0>, image_shape=(1, 32, 32), ode_solver=<function euler_solver at 0x70278121a700>, f=<function make_vf_uncond.<locals>.f at 0x7027810931a0>, n_steps=50, seed=None, device=device(type='cuda', index=0))

In [5]:
from evalutate_fid_uncond import run_eval
run_eval(run_id, n_samples=64000)   

Downloading artifacts:   0%|          | 0/1 [00:00<?, ?it/s]
Downloading artifacts: 100%|██████████| 6/6 [00:00<00:00, 972.22it/s]  
2026/01/11 01:02:06 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
  return FileStore(store_uri)
2026/01/11 01:07:27 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2026/01/11 01:07:27 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!
