In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('./../..')

import os

import torch
from torch.utils.data import Subset
from torchvision import datasets
from torchvision.transforms import ToTensor

from image_gen import GenerativeModel
from image_gen.samplers import ExponentialIntegrator
from image_gen.diffusion import VarianceExploding, VariancePreserving, SubVariancePreserving
from image_gen.noise import LinearNoiseSchedule

from image_gen.visualization import display_images

In [None]:
epochs = 20
digit = 3

seed = 42

In [None]:
# Load the dataset
data = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor()
)

# Select a subset to speed up the training process
indices_digit = torch.where(data.targets == digit)[0]
data = Subset(data, indices_digit)

# **VE**

In [None]:
model = GenerativeModel(
    sampler=ExponentialIntegrator,
    diffusion=VarianceExploding
)

In [None]:
filename = f'mnist_{digit}_ve_{epochs}e.pth'

if os.path.isfile(filename):
    model.load(filename)
else:
    model.train(data, epochs=epochs)
    # Tip: Save the models for them to be accessible through the dashboard
    model.save(filename)

In [None]:
n_images = 16
samples = model.generate(n_images, seed=seed)
display_images(samples)

# **VP**

In [None]:
model = GenerativeModel(
    sampler=ExponentialIntegrator,
    diffusion=VariancePreserving,
    noise_schedule=LinearNoiseSchedule
)

In [None]:
filename = f'mnist_{digit}_vp-lin_{epochs}e.pth'

if os.path.isfile(filename):
    model.load(filename)
else:
    model.train(data, epochs=epochs)
    # Tip: Save the models for them to be accessible through the dashboard
    model.save(filename)

In [None]:
n_images = 16
samples = model.generate(n_images, seed=seed)
display_images(samples)

# **Sub-VP**

In [None]:
model = GenerativeModel(
    sampler=ExponentialIntegrator,
    diffusion=SubVariancePreserving,
    noise_schedule=LinearNoiseSchedule
)

In [None]:
filename = f'mnist_{digit}_svp-lin_{epochs}e.pth'

if os.path.isfile(filename):
    model.load(filename)
else:
    model.train(data, epochs=epochs)
    # Tip: Save the models for them to be accessible through the dashboard
    model.save(filename)

In [None]:
n_images = 16
samples = model.generate(n_images, seed=seed)
display_images(samples)