In [None]:
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import matplotlib.pyplot as plt
from torch import nn
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

preprocess = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)
batch_size = 128
dataset = MNIST(root='', train=True, download=True, transform=preprocess)
loader_train = DataLoader(dataset, batch_size, shuffle=True)
plt.imshow(dataset[0][0].squeeze(), cmap='gray')
plt.colorbar()
plt.show()

In [None]:
from unet import UNet

model = UNet(in_channels=1, out_channels=1, block_out_channels=[64, 128],)

sample_image = dataset[8][0].unsqueeze(0)
print("Input shape:", sample_image.shape)

print("Output shape:", model(sample_image, torch.ones(1)).shape)
plt.imshow(model(sample_image, torch.ones(1)).squeeze().detach().numpy())
plt.title("Model output")
plt.show()

In [None]:
from ddpm import DDPM
        
model = DDPM(in_channels=1, out_channels=1, block_out_channels=[64, 128], device=device)
model.to(device)
sample_image = dataset[9][0].unsqueeze(0).to(device)
noised_image, noise = model(sample_image, 1)
fig, axes = plt.subplots(1, 10, figsize=(20, 2))
for i,t in enumerate([99, 199, 299, 399, 499, 599, 699, 799, 899, 999]):
    noised_image = model(sample_image, t)
    axes[i].imshow(noised_image[0].squeeze().detach().cpu().numpy(), cmap='gray')
    axes[i].axis('off')

output = model.sampling(1000)
plt.figure()
plt.imshow(output.squeeze().detach().cpu().numpy(), cmap='gray')
plt.colorbar()
plt.title("Sampled image")
plt.show()

In [None]:
from random import randint
from tqdm import tqdm

def training(model, epochs, criterion, optimizer, loader_train, T = 1000):
    '''Training loop for the model.
    model: model to be trained.
    lr: learning rate.
    epochs: number of epochs.
    criterion: loss function.
    optimizer: optimizer to be used.
    loader_train: training data loader.'''
    model.train()
    for epoch in range(epochs):
        for i, (images, _) in tqdm(enumerate(loader_train)):
            images = images.to(device)
            t = randint(0, T-1)
            x_noised, noise = model.forward(images, t)
            optimizer.zero_grad()
            
            noise_pred = model.model(x_noised, t)
            loss = criterion(noise_pred, noise)
            loss.backward()
            optimizer.step()
        print(f"Epoch: {epoch}, Loss: {loss.item()}")
        if epoch % 10 == 0:
            img = model.sampling(T)
            plt.imshow(img.squeeze().detach().cpu().numpy(), cmap='gray')
            plt.colorbar()
            plt.title("Sampled image")
            plt.show()
    return model

lr = 1e-4
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
epochs = 200
model = training(model, epochs, criterion, optimizer, loader_train)


In [None]:
torch.save(model.state_dict(), 'model.pth')