In [2]:
%load_ext autoreload
%autoreload 2
import os

import matplotlib.pyplot as plt
import torch
import torchsde
from torchdyn.core import NeuralODE
from torchvision import datasets, transforms
from torchvision.transforms import ToPILImage
from torchvision.utils import make_grid
from tqdm import tqdm

from torchcfm.conditional_flow_matching import *
from torchcfm.models.unet import UNetModel

savedir = "models/mnist"
os.makedirs(savedir, exist_ok=True)

In [3]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

batch_size = 128
n_epochs = 5

trainset = datasets.MNIST(
    "../data",
    train=True,
    download=True,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]),
)

train_loader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, drop_last=True
)

In [4]:
sigma = 0.1

model = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)
score_model = UNetModel(dim=(1, 28, 28), num_channels=32, num_res_blocks=1).to(device)

optimizer = torch.optim.Adam(list(model.parameters()) + list(score_model.parameters()))
FM = ConditionalFlowMatcher(sigma=sigma)
node = NeuralODE(model, solver="euler", sensitivity="adjoint", atol=1e-4, rtol=1e-4)

In [5]:
for epoch in range(n_epochs):
    for i, data in tqdm(enumerate(train_loader)):
        optimizer.zero_grad()
        x1 = data[0].to(device)
        x0 = torch.randn_like(x1)
        t, xt, ut, eps = FM.sample_location_and_conditional_flow(x0, x1, return_noise=True)
        lambda_t = FM.compute_lambda(t)
        vt = model(t, xt)
        st = score_model(t, xt)

        flow_loss = torch.mean((vt - ut) ** 2)
        score_loss = torch.mean((lambda_t * st - eps) ** 2)

        loss = flow_loss + score_loss
        loss.backward()
        optimizer.step()

468it [02:35,  3.01it/s]
468it [02:41,  2.90it/s]
468it [02:42,  2.88it/s]
468it [02:43,  2.86it/s]
468it [02:43,  2.87it/s]


In [11]:
torch.save(score_model,'score_model_10epochs.pt')
torch.save(model, 'flow_model_10epochs.pt')

In [17]:
number_of_images = 100
initial_condition = torch.randn(number_of_images, 1, 28, 28).to(device)

In [18]:
#Noise = 0

for i, num_steps_solver in enumerate([1,10,100,1000]):

  os.system(f'mkdir cfm_{str(num_steps_solver)}_3')
  with torch.no_grad():
    traj = node.trajectory(
          initial_condition,
          t_span=torch.linspace(0, 1, num_steps_solver + 1).to(device),
      )
  for j in range(number_of_images):
    img = ToPILImage()(traj[-1, j].clip(-1, 1))
    filename = f'cfm_{str(num_steps_solver)}_3/{j}.png'
    img.save(filename, 'PNG')

OutOfMemoryError: CUDA out of memory. Tried to allocate 4.58 GiB. GPU 0 has a total capacty of 5.80 GiB of which 2.19 GiB is free. Including non-PyTorch memory, this process has 3.59 GiB memory in use. Of the allocated memory 3.01 GiB is allocated by PyTorch, and 420.22 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
class SDE(torch.nn.Module):
    noise_type = "diagonal"
    sde_type = "ito"

    def __init__(self, ode_drift, score, epsilon = 1, input_size=(1, 28, 28), reverse=False):
        super().__init__()
        self.drift = ode_drift
        self.score = score
        self.reverse = reverse
        self.epsilon = torch.tensor(epsilon)
    # Drift
    def f(self, t, y):
        y = y.view(-1, 1, 28, 28)
        if self.reverse:
            t = 1 - t
            return -self.drift(t, y) + self.score(t, y)
        return self.drift(t, y).flatten(start_dim=1) + self.epsilon * self.score(t, y).flatten(start_dim=1)

    # Diffusion
    def g(self, t, y):
        y = y.view(-1, 1, 28, 28)
        return torch.sqrt(self.epsilon) * (torch.ones_like(t) * torch.ones_like(y)).flatten(start_dim=1) * sigma

In [None]:
noise_coeff = 0.25

sde = SDE(model, score_model,epsilon = noise_coeff)
for i, num_steps_solver in enumerate([1,10,100,1000]):

  os.system(f'mkdir sde_{str(num_steps_solver)}_{str(noise_coeff)}')
  with torch.no_grad():
      sde_traj = torchsde.sdeint(
          sde,
          # x0.view(x0.size(0), -1),
          torch.flatten(initial_condition, start_dim = 1),
          ts=torch.linspace(0, 1, num_steps_solver + 1).to(device),
          dt=1/(num_steps_solver),
      )
  for j in range(number_of_images):
    img = ToPILImage()(sde_traj[-1, j].view([ 1, 28, 28]).clip(-1, 1))
    filename = f'sde_{str(num_steps_solver)}_{str(noise_coeff)}/{j}.png'
    img.save(filename)

In [None]:
noise_coeff = 0.5

generated_images_SDE_2 = torch.zeros([5, 100, 1, 28, 28])
sde = SDE(model, score_model,epsilon = noise_coeff)
for i, num_steps_solver in enumerate([1,10,100,1000]):

  os.system(f'mkdir sde_{str(num_steps_solver)}_{str(noise_coeff)}')
  with torch.no_grad():
      sde_traj = torchsde.sdeint(
          sde,
          # x0.view(x0.size(0), -1),
          torch.flatten(initial_condition, start_dim = 1),
          ts=torch.linspace(0, 1, num_steps_solver + 1).to(device),
          dt=1/(num_steps_solver),
      )
  for j in range(number_of_images):
    img = ToPILImage()(sde_traj[-1, j].view([ 1, 28, 28]).clip(-1, 1))
    filename = f'sde_{str(num_steps_solver)}_{str(noise_coeff)}/{j}.png'
    img.save(filename)

In [None]:
noise_coeff = 1

generated_images_SDE_3 = torch.zeros([5, 100, 1, 28, 28])
sde = SDE(model, score_model,epsilon = noise_coeff)
for i, num_steps_solver in enumerate([1,10,100,1000]):

  os.system(f'mkdir sde_{str(num_steps_solver)}_{str(noise_coeff)}')
  with torch.no_grad():
      sde_traj = torchsde.sdeint(
          sde,
          # x0.view(x0.size(0), -1),
          torch.flatten(initial_condition, start_dim = 1),
          ts=torch.linspace(0, 1, num_steps_solver + 1).to(device),
          dt=1/(num_steps_solver),
      )
  for j in range(number_of_images):
    img = ToPILImage()(sde_traj[-1, j].view([ 1, 28, 28]).clip(-1, 1))
    filename = f'sde_{str(num_steps_solver)}_{str(noise_coeff)}/{j}.png'
    img.save(filename)

In [None]:
noise_coeff = 2

sde = SDE(model, score_model,epsilon = noise_coeff)
for i, num_steps_solver in enumerate([1,10,100,1000]):

  os.system(f'mkdir sde_{str(num_steps_solver)}_{str(noise_coeff)}')
  with torch.no_grad():
      sde_traj = torchsde.sdeint(
          sde,
          # x0.view(x0.size(0), -1),
          torch.flatten(initial_condition, start_dim = 1),
          ts=torch.linspace(0, 1, num_steps_solver + 1).to(device),
          dt=1/(num_steps_solver),
      )
  for j in range(number_of_images):
    img = ToPILImage()(sde_traj[-1, j].view([ 1, 28, 28]).clip(-1, 1))
    filename = f'sde_{str(num_steps_solver)}_{str(noise_coeff)}/{j}.png'
    img.save(filename)

In [None]:
noise_coeff = 4

sde = SDE(model, score_model,epsilon = noise_coeff)
for i, num_steps_solver in enumerate([1,10,100,1000]):

  os.system(f'mkdir sde_{str(num_steps_solver)}_{str(noise_coeff)}')
  with torch.no_grad():
      sde_traj = torchsde.sdeint(
          sde,
          # x0.view(x0.size(0), -1),
          torch.flatten(initial_condition, start_dim = 1),
          ts=torch.linspace(0, 1, num_steps_solver + 1).to(device),
          dt=1/(num_steps_solver),
      )
  for j in range(number_of_images):
    img = ToPILImage()(sde_traj[-1, j].view([ 1, 28, 28]).clip(-1, 1))
    filename = f'sde_{str(num_steps_solver)}_{str(noise_coeff)}/{j}.png'
    img.save(filename)

In [None]:
####################################################
###Compute FID scores for the different noise levels
####################################################
import numpy as np
from PIL import Image

In [None]:
#Making a grid view 10x10 for all SDE samples for different steps and noise
for s in [1,10,100,1000]:
    for e in [0.25,0.5,1,2,4]:

        grid = np.zeros((28*10,28*10))
        for i in range(10):
            for j in range(10):
                grid[i*28:(i+1)*28,j*28:(j+1)*28] = np.asarray(Image.open(f'sde_{s}_{e}/{i*10+j}.png'))

        grid = Image.fromarray(grid.astype(np.uint8))
        grid.save(f'sde_{s}_{e}.png', 'PNG')

In [None]:
#Making a grid view 10x10 for all CFM samples for different steps
for s in [1,10,100,1000]:

    grid = np.zeros((28*10,28*10))
    for i in range(10):
        for j in range(10):
            grid[i*28:(i+1)*28,j*28:(j+1)*28] = np.asarray(Image.open(f'cfm_{s}/{i*10+j}.png'))

    grid = Image.fromarray(grid.astype(np.uint8))
    grid.save(f'cfm_{s}.png', 'PNG')

In [9]:
#FID for all SDE samples with different steps and noise. data is all mnist(60k) dataset
for s in [1,10,100,1000]:
    for e in [0.25,0.5,1,2,4]:
        print(f"SDE Steps={s} Noise={e} :")
        os.system(f"python3 -m pytorch_fid sde_{s}_{e} data")


SDE Steps=1 Noise=0.25 :


100%|██████████| 2/2 [00:00<00:00,  3.99it/s]
100%|██████████| 1200/1200 [02:19<00:00,  8.59it/s]


FID:  229.08294272178728
SDE Steps=1 Noise=0.5 :


100%|██████████| 2/2 [00:00<00:00,  4.08it/s]
100%|██████████| 1200/1200 [02:22<00:00,  8.39it/s]


FID:  232.59970415772216
SDE Steps=1 Noise=1 :


100%|██████████| 2/2 [00:00<00:00,  4.03it/s]
100%|██████████| 1200/1200 [02:25<00:00,  8.26it/s]


FID:  231.355337364466
SDE Steps=1 Noise=2 :


100%|██████████| 2/2 [00:00<00:00,  3.83it/s]
100%|██████████| 1200/1200 [02:25<00:00,  8.23it/s]


FID:  239.6984008614937
SDE Steps=1 Noise=4 :


100%|██████████| 2/2 [00:00<00:00,  4.14it/s]
100%|██████████| 1200/1200 [02:25<00:00,  8.23it/s]


FID:  249.34611093272713
SDE Steps=10 Noise=0.25 :


100%|██████████| 2/2 [00:00<00:00,  4.15it/s]
100%|██████████| 1200/1200 [02:26<00:00,  8.18it/s]


FID:  217.9432515094471
SDE Steps=10 Noise=0.5 :


100%|██████████| 2/2 [00:00<00:00,  4.01it/s]
100%|██████████| 1200/1200 [02:26<00:00,  8.19it/s]


FID:  220.78699357544397
SDE Steps=10 Noise=1 :


100%|██████████| 2/2 [00:00<00:00,  3.84it/s]
100%|██████████| 1200/1200 [02:27<00:00,  8.16it/s]


FID:  220.18953653128762
SDE Steps=10 Noise=2 :


100%|██████████| 2/2 [00:00<00:00,  4.05it/s]
100%|██████████| 1200/1200 [02:25<00:00,  8.22it/s]


FID:  219.00477173587524
SDE Steps=10 Noise=4 :


100%|██████████| 2/2 [00:00<00:00,  4.05it/s]
100%|██████████| 1200/1200 [02:25<00:00,  8.24it/s]


FID:  220.01367149273113
SDE Steps=100 Noise=0.25 :


100%|██████████| 2/2 [00:00<00:00,  3.96it/s]
100%|██████████| 1200/1200 [02:25<00:00,  8.24it/s]


FID:  217.72834674076523
SDE Steps=100 Noise=0.5 :


100%|██████████| 2/2 [00:00<00:00,  3.93it/s]
100%|██████████| 1200/1200 [02:27<00:00,  8.13it/s]


FID:  217.25170917289498
SDE Steps=100 Noise=1 :


100%|██████████| 2/2 [00:00<00:00,  3.87it/s]
100%|██████████| 1200/1200 [02:26<00:00,  8.20it/s]


FID:  216.10583703657002
SDE Steps=100 Noise=2 :


100%|██████████| 2/2 [00:00<00:00,  4.02it/s]
100%|██████████| 1200/1200 [02:26<00:00,  8.18it/s]


FID:  217.68637068665063
SDE Steps=100 Noise=4 :


100%|██████████| 2/2 [00:00<00:00,  3.83it/s]
100%|██████████| 1200/1200 [02:26<00:00,  8.21it/s]


FID:  224.02308072973932
SDE Steps=1000 Noise=0.25 :


100%|██████████| 2/2 [00:00<00:00,  4.01it/s]
100%|██████████| 1200/1200 [02:26<00:00,  8.19it/s]


FID:  214.9567448903651
SDE Steps=1000 Noise=0.5 :


100%|██████████| 2/2 [00:00<00:00,  3.93it/s]
100%|██████████| 1200/1200 [02:26<00:00,  8.22it/s]


FID:  216.31224074639096
SDE Steps=1000 Noise=1 :


100%|██████████| 2/2 [00:00<00:00,  3.81it/s]
100%|██████████| 1200/1200 [02:25<00:00,  8.24it/s]


FID:  217.28380054001673
SDE Steps=1000 Noise=2 :


100%|██████████| 2/2 [00:00<00:00,  3.99it/s]
100%|██████████| 1200/1200 [02:23<00:00,  8.34it/s]


FID:  219.0015017768356
SDE Steps=1000 Noise=4 :


100%|██████████| 2/2 [00:00<00:00,  3.96it/s]
100%|██████████| 1200/1200 [02:25<00:00,  8.26it/s]


FID:  226.0063324130836


In [10]:
#FID for all CFM samples with different steps. data is all mnist(60k) dataset
for s in [1,10,100,1000]:
        print(f"CFM Steps={s} :")
        os.system(f"python3 -m pytorch_fid cfm_{s} data")

CFM Steps=1 :


100%|██████████| 2/2 [00:00<00:00,  4.12it/s]
100%|██████████| 1200/1200 [02:19<00:00,  8.58it/s]


FID:  330.5893410417374
CFM Steps=10 :


100%|██████████| 2/2 [00:00<00:00,  4.04it/s]
100%|██████████| 1200/1200 [02:21<00:00,  8.46it/s]


FID:  210.1260097360542
CFM Steps=100 :


100%|██████████| 2/2 [00:00<00:00,  4.02it/s]
100%|██████████| 1200/1200 [02:24<00:00,  8.31it/s]


FID:  215.87934161050342
CFM Steps=1000 :


100%|██████████| 2/2 [00:00<00:00,  3.99it/s]
100%|██████████| 1200/1200 [02:25<00:00,  8.28it/s]


FID:  215.90738216737705


In [14]:
#FID for all CFM samples with different steps. data is all mnist(60k) dataset
for s in [1,10,100,1000]:
        print(f"CFM Steps={s} :")
        os.system(f"python3 -m pytorch_fid cfm_{s}_2 data")

CFM Steps=1 :


100%|██████████| 2/2 [00:00<00:00,  4.03it/s]
100%|██████████| 1200/1200 [02:20<00:00,  8.53it/s]


FID:  318.3557859058728
CFM Steps=10 :


100%|██████████| 2/2 [00:00<00:00,  3.87it/s]
100%|██████████| 1200/1200 [02:23<00:00,  8.34it/s]


FID:  198.68773875984942
CFM Steps=100 :


100%|██████████| 2/2 [00:00<00:00,  3.89it/s]
100%|██████████| 1200/1200 [02:24<00:00,  8.29it/s]


FID:  202.0917301702733
CFM Steps=1000 :


100%|██████████| 2/2 [00:00<00:00,  4.09it/s]
100%|██████████| 1200/1200 [02:25<00:00,  8.25it/s]


FID:  201.47718097833814
