In [None]:
import torch, torchvision, torch.nn as nn, torch.nn.functional as F
import numpy as np
from tqdm.notebook import tqdm, trange
import matplotlib.pyplot as plt
from IPython import display
import copy

#########
# data: MNIST, standardize to ~zero mean/unit variance, and pad to 32x32 pixels
#########
tform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Pad(2),
    torchvision.transforms.Normalize(0.1003, 0.2756)])
dataset = torchvision.datasets.MNIST(root='.', transform=tform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)

#########
# model: UNet with noise conditioning
#########
class NoiseConditionalBlock(nn.Module):
    def __init__(self, fts, embed_dim=128):
        super().__init__()
        self.normactconv1 = nn.Sequential(nn.GroupNorm(8, fts), nn.ELU(), nn.Conv2d(fts, fts, kernel_size=3, padding=1))
        self.normactconv2 = nn.Sequential(nn.GroupNorm(8, fts), nn.ELU(), nn.Conv2d(fts, fts, kernel_size=3, padding=1))
        self.affine = nn.Linear(embed_dim, 2*fts) # FiLM-like affine normalization layer https://arxiv.org/pdf/1709.07871.pdf

    def forward(self, x, emb):
        residual = x
        x = self.normactconv1(x)
        scale, shift = torch.chunk(self.affine(emb)[:,:,None,None], 2, 1)
        x = x * (1+scale) + shift # noise-dependent rescaling of fts maps
        x = self.normactconv2(x)
        return x + residual

class NoiseConditionalUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, embed_dim=128, features=[32,64,128,256,256]):
        super().__init__()
        self.emb_weight  = nn.Parameter(2 * np.pi * np.sqrt(embed_dim) * torch.rand(embed_dim), requires_grad=False)
        self.in_conv     = nn.Conv2d(in_channels, features[0], kernel_size=3, padding=1)
        self.out_conv    = nn.Conv2d(features[0], out_channels, kernel_size=3, padding=1)

        self.down_blocks = nn.ModuleList([NoiseConditionalBlock(fts, embed_dim) for fts in features])
        self.up_blocks   = nn.ModuleList([NoiseConditionalBlock(fts, embed_dim) for fts in features])
        self.downsamples = nn.ModuleList([nn.Conv2d(f_in, f_out, kernel_size=2, stride=2) for f_in, f_out in zip(features[:-1],features[1:])])
        self.upsamples   = nn.ModuleList([nn.ConvTranspose2d(f_out, f_in, kernel_size=2, stride=2) for f_in, f_out in zip(features[:-1],features[1:])])

    def forward(self, x, sigmas):
         # convert noise level to tensor if necessary
        if type(sigmas) is not torch.Tensor:
          sigmas = torch.tensor(sigmas).float().unsqueeze(0).to(x.device)

        # map noise-level to a higher-dim embedding via a random matrix multiply
        sigma_emb = ((sigmas/(1+sigmas**2).sqrt())[:,None] * self.emb_weight).sin()

        # rescale the inputs
        x = x/(1+sigmas.view(-1,1,1,1)**2).sqrt()

        x  = self.in_conv(x)
        x0 = self.down_blocks[0](x, sigma_emb)
        x1 = self.down_blocks[1](self.downsamples[0](x0), sigma_emb)
        x2 = self.down_blocks[2](self.downsamples[1](x1), sigma_emb)
        x3 = self.down_blocks[3](self.downsamples[2](x2), sigma_emb)
        x4 = self.down_blocks[4](self.downsamples[3](x3), sigma_emb)
        x  = self.up_blocks[3](x3 + self.upsamples[3](x4), sigma_emb)
        x  = self.up_blocks[2](x2 + self.upsamples[2](x), sigma_emb)
        x  = self.up_blocks[1](x1 + self.upsamples[1](x), sigma_emb)
        x  = self.up_blocks[0](x0 + self.upsamples[0](x), sigma_emb)
        x  = self.out_conv(x)
        return x

#########
# generation: we diffuse with an exponentially decaying noise level (details in readme.md)
#########
# in short, we start with random noise then iterate:
# 1. subtract some fraction of predicted noise (alpha sigma eps_hat)
# 2. reinject some new amount of noise (beta sigma z)
# 3. this reduces the noise level by a factor of sqrt((1-alpha)^2 + beta^2) (see https://arxiv.org/abs/2007.13640)
# We stop once the noise level has decayed below a threshold sigma_min
@torch.no_grad()
def generate_samples(model, sigma=30.0, sigma_min=0.03, alpha=0.1, beta=.40, shape=(64,1,32,32), device='cuda'):
    x = sigma * torch.randn(shape, device=device)
    xs = [x.cpu()]
    while sigma > sigma_min:
        x = x - alpha * sigma * model(x,sigma) + beta * sigma * torch.randn_like(x)
        sigma = sigma * np.sqrt((1-alpha)**2 + beta**2) # noise decays exponentially
        xs.append(x.cpu()) # save all intermediate generations

    return torch.stack(xs) # (nsteps, batch, channels, height, width)

#########
# visualization utils
#########
def show_grid(x, title=''):
    plt.figure(figsize=(12,3), dpi=100)
    img = torchvision.utils.make_grid(x, nrow=16).permute(1,2,0).cpu()
    plt.imshow((img-img.min()) / (img.max()-img.min()))
    plt.title(title); plt.xticks([]); plt.yticks([])
    plt.show()

#########
# training: we just train the UNet to denoise images with multiple levels of added noise
#########
losses = []
sigma_min, sigma_max = 0.03, 100.0  # lowest and highest noise-to-signal ratio we train the network on
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = NoiseConditionalUNet(1,1).to(device)
model_ema = copy.deepcopy(model) # keep an exponentially weighting moving average
optim = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(50):
    for x, y in (pbar:=tqdm(dataloader, desc=f'epoch {epoch}/{50}')):
        optim.zero_grad()
        x = x.to(device)

        # noisify input
        #   noise levels are chosen from log-uniform distribution between sigma_min and sigma_max
        #   noise levels ranges from "looks uncorrupted" to "looks like pure noise"
        #   this distribution can be modified for improved generation performance
        sigmas = sigma_max * (sigma_min/sigma_max) ** torch.rand(x.shape[0],device=device)
        epsilon = torch.randn_like(x).to(device)
        x_noisy = x + sigmas.view(-1,1,1,1) * epsilon

        # noise prediction loss: mean squared error between noise and network output
        loss = ((epsilon - model(x_noisy, sigmas))**2).mean()
        loss.backward()
        optim.step()

        # exponential moving average of the weights
        # can be surprisingly helpful for stabilizing generations
        # Even when the loss curves look normal, generations seemingly randomly will look terrible.
        for ema_v, model_v in zip(model_ema.state_dict().values(), model.state_dict().values()):
            ema_v.copy_(0.999 * ema_v + 0.001 * model_v)

        # logging
        losses.append(loss.item())
        pbar.set_postfix({'loss': np.mean(losses[-100:])}) # average of last 100 losses

    # generate and visualize images every epoch
    display.clear_output()
    xs = generate_samples(model.eval(), sigma=sigma_max, sigma_min=sigma_min, alpha=0.1, beta = 0.35)
    show_grid(xs[-1], f'epoch={epoch} \n generated images')

    xs = generate_samples(model_ema.eval(), sigma=sigma_max, sigma_min=sigma_min, alpha=0.1, beta = 0.35)
    show_grid(xs[-1], 'generated images from exponential moving average network')

    show_grid(xs[::16,0]/xs[::16,0].std(dim=(1,2,3),keepdim=True), f'one generation trajectory (total sampling steps = {len(xs)})')
    show_grid(x[:32], 'real images')
    show_grid(x_noisy[:32] / (1+sigmas[:32].view(-1,1,1,1)**2).sqrt(), 'noisy images we train on')
    plt.plot(losses)
    plt.yscale('log'); plt.xlabel('iter'); plt.title('training loss')
    plt.show()