### **MNIST - Super-Resolution**

My own implementation of the MNIST paper for super-resolution.

Based on https://arxiv.org/abs/2011.13456

and https://yang-song.net/blog/2021/score/

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

import matplotlib.pyplot as plt
import tqdm
import math

In [None]:
transform_hr = transforms.Compose( # 28x28
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)) # Channels [-1,1]
    ]
)

transform_lr = transforms.Compose( # 14x14
    [
        transforms.ToTensor(),
        transforms.Resize((14,14), antialias=True),
        transforms.Normalize((0.5,), (0.5,)) # Channels [-1,1]
    ]
)

class MNISTSuperResDataset(Dataset):
  def __init__(self, mode='train', transform_lr=None, transform_hr=None):
    self.mnist_data = MNIST(root='./data', train=(mode == 'train'), download=True)
    self.transform_lr = transform_lr
    self.transform_hr = transform_hr

  def __len__(self):
    return len(self.mnist_data)

  def __getitem__(self, idx):
    image, label = self.mnist_data[idx]

    img_lr = self.transform_lr(image)
    img_hr = self.transform_hr(image)

    return img_lr, img_hr

In [None]:
train_dataset = MNISTSuperResDataset(mode='train', transform_lr=transform_lr, transform_hr=transform_hr)
test_dataset = MNISTSuperResDataset(mode='test', transform_lr=transform_lr, transform_hr=transform_hr)

batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
T = 1000 # N timesteps
betas = torch.linspace(1e-4, 0.02, T) # Variance scheduler
alphas_cumprod = torch.cumprod(1 - betas, dim=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)

def forward_pass(x_0, t, noise):
  x = sqrt_alphas_cumprod[t] * x_0 + sqrt_one_minus_alphas_cumprod[t] * noise
  return x

In [None]:
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time): # Construct the embedding of the current timestep.
        device = time.device
        half_dim = self.dim // 2
        embeddings = torch.log(torch.tensor(10000.0, device=device)) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

In [None]:
class DiffusionSupResModel(nn.Module):
  def __init__(self):
    super(DiffusionSupResModel, self).__init__()

    down_channels = (64, 128, 256, 512, 1024)
    up_channels = (1024, 512, 256, 128, 64)
    time_emb_dim = 32

    self.time_mlp = nn.Sequential(
        SinusoidalPositionEmbeddings(time_emb_dim),
        nn.Linear(time_emb_dim, time_emb_dim),
        nn.ReLU()
    )

    self.class_emb = nn.Embedding(num_classes, class_emb_dim)
    self.class_mlp = nn.Linear(class_emb_dim, time_emb_dim)

    self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

    # Downsampling path
    self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], time_emb_dim) for i in range(len(down_channels)-1)])

    # Upsampling path
    self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], time_emb_dim, up=True) for i in range(len(up_channels)-1)])

    self.output = nn.Conv2d(up_channels[-1], out_dim, 1)

def forward(self, x, timestep, y):
    t = self.time_mlp(timestep)

    class_emb = self.class_emb(y)
    class_emb = self.class_mlp(class_emb)
    t = t + class_emb

    x = self.conv0(x)

    residual_inputs = []
    for down in self.downs:
        x = down(x, t)
        residual_inputs.append(x)
    for up in self.ups:
        residual_x = residual_inputs.pop()
        x = torch.cat((x, residual_x), dim=1) # Skip connection
        x = up(x, t)

    return self.output(x)


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = MNISTSRModel().to(device)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

Using device: cpu


In [None]:
num_epochs = 10

for epochs in range(num_epochs):
  model.train()
  running_loss = 0.0

  for lr_images, hr_images in train_loader:
    lr_images = lr_images.to(device)
    hr_images = hr_images.to(device)

    optimizer.zero_grad()

    outputs = model(lr_images)

    loss = criterion(outputs, hr_images)
    loss.backward()
    optimizer.step()

    running_loss += loss.item() * lr_images.size(0)

  epoch_loss = running_loss / len(train_dataset)
  print(f"Epoch {epochs+1}/{num_epochs}, Loss: {epoch_loss}")

print("Finished training!")

KeyboardInterrupt: 

In [None]:
model.eval()

with torch.no_grad():
  dataiter = iter(test_loader)
  lr_images, hr_images = next(dataiter)
