In [None]:
import torchvision
from torchvision import transforms

import torch.optim as optim
from torch.utils.data import DataLoader

from score_models.models.unet import UNet
from score_models.trainer import trainer
from score_models.train_steps import TrainStepDenoisingScoreMatching
from score_models.utils.noise import get_sigmas


%load_ext autoreload
%autoreload 2

In [None]:
# Define transformations to be applied to the data
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert PIL image or numpy array to tensor
    transforms.Lambda(lambda x: 2 * x - 1)  # Scale between -1 and 1
])

# Download and load the MNIST training dataset
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)


In [None]:
# Define parameters for the DataLoader
batch_size = 32
shuffle = True

# Create a DataLoader for the MNIST training dataset
images = [image for image, _ in train_dataset]
train_loader = DataLoader(dataset=images, batch_size=batch_size, shuffle=shuffle)

In [None]:
input_size = 2
hidden_size = 32

L = 10
lr = 5e-4

sigma_min = 0.01
sigma_max = 1.0

device = "cuda"

# define score model and optimizer
score_model = UNet(L=L, n_channels=1, n_classes=1).to(device)
optimizer = optim.Adam(score_model.parameters(), lr=lr)

# define train step (i.e., criterion)
sigmas = get_sigmas(L=L, sigma_min=sigma_min, sigma_max=sigma_max)
train_step = TrainStepDenoisingScoreMatching(score_model=score_model, sigmas=sigmas)


In [None]:
score_model = trainer(
    train_step=train_step,
    model=score_model,
    train_loader=train_loader,
    optimizer=optimizer,
    device=device,
    num_steps=10_000
)