In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def imshow(img):
    img = img / 2 + 0.5     
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

def imshow_grayscale(img):
    img = (img + 1) * 0.5
    npimg = img.numpy()
    plt.imshow(npimg, cmap='gray')
    plt.show()

In [None]:
# hyperparameters
batch_size = 128
n_channels = 1
num_timesteps = 200

device = "mps"

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

])

transformMNIST = transforms.Compose([
    transforms.ToTensor(),
    # transforms.Normalize(mean=[0.5], std=[0.5]),
    transforms.Lambda(lambda t: (t * 2) - 1)
])

# cifar10_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

# automobile_subset = [i for i in range(len(cifar10_train)) if cifar10_train[i][1] == 1]
# automobile_subset = torch.utils.data.Subset(cifar10_train, automobile_subset)

mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transformMNIST)

# pad all images to 32x32

mnist_train.data = F.pad(mnist_train.data, (2, 2, 2, 2), value=0)

trainloader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True)

    
# trainloader = torch.utils.data.DataLoader(automobile_subset, batch_size=batch_size, shuffle=True)

In [None]:

beta_schedule = torch.linspace(1e-4, 0.02, num_timesteps)
alpha_schedule = [1 - beta for beta in beta_schedule]

def get_alphas(num_timesteps):
    alphas = [1] + [alpha_schedule[t] for t in range(num_timesteps)]
    return torch.cumprod(torch.tensor(alphas), 0).to(device)

alphas_cumprod = get_alphas(num_timesteps)

# returns tuple (noised image, noise)
def get_noised_image_at(t, image):
    alpha = alphas_cumprod[t].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).to(device)
    noise = torch.randn(batch_size, n_channels, 32, 32, device=device)
    noise_factor = torch.sqrt(1 - alpha)
    return ((torch.sqrt(alpha) * image) + (noise * noise_factor), noise)

In [None]:
def get_sinusoidal_embed(x, size):
    half_size = size // 2
    emb = torch.log(torch.Tensor([10000.0]).to(device)) / (half_size - 1)
    emb = torch.exp(-emb * torch.arange(half_size, device=device))
    emb = x.unsqueeze(-1) * emb.unsqueeze(0)
    emb = torch.cat((torch.sin(emb), torch.cos(emb)), dim=-1)
    return emb.to(device)

In [None]:
class WideResidualBlock(nn.Module):
    def __init__(self, channels_in, channels_out):
        super(WideResidualBlock, self).__init__()
        # two convolutional layers and a skip connection for residual
        self.conv1 = nn.Conv2d(channels_in, channels_out, 3, padding=1).to(device) 
        self.conv2 = nn.Conv2d(channels_out, channels_out, 3, padding=1).to(device)
        # groupnorm for stability (DDPM authors use this)
        self.gn1 = nn.GroupNorm(4, channels_out).to(device)
        self.gn2 = nn.GroupNorm(4, channels_out).to(device)
        self.skip_connection = nn.Conv2d(channels_in, channels_out, 1).to(device)

    def forward(self, x):
        out = F.relu(self.gn1(self.conv1(x)))
        out = F.relu(self.gn2(self.conv2(out)))
        return self.skip_connection(x) + out
        
class DownsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownsampleBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        # two residual blocks and a downsampling operation
        self.wide_residual_block1 = WideResidualBlock(self.in_channels, self.out_channels).to(device)
        self.wide_residual_block2 = WideResidualBlock(self.out_channels, self.out_channels).to(device)
        self.max_pool = nn.MaxPool2d(2, 2).to(device)
        self.time_mlp = nn.Sequential(
            nn.Linear(self.in_channels, self.in_channels * 3),
            nn.ReLU(),
            nn.Linear(self.in_channels * 3, self.in_channels)
        ).to(device)

    def forward(self, x, t):
        x = x + self.time_mlp(get_sinusoidal_embed(t, self.in_channels))[:, :, None, None] if self.in_channels % 2 == 0 else x
        x = self.wide_residual_block1(x)
        x = self.wide_residual_block2(x)
        x = self.max_pool(x)
        return x

class UpsampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpsampleBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        # two residual blocks and a upsampling operation
        self.upsample = nn.ConvTranspose2d(in_channels, out_channels, 3, stride=2, padding=1, output_padding=1).to(device)
        self.wide_residual_block1 = WideResidualBlock(self.out_channels, self.out_channels).to(device)
        self.wide_residual_block2 = WideResidualBlock(self.out_channels, self.out_channels).to(device)

        self.time_mlp = nn.Sequential(
            nn.Linear(self.out_channels, self.out_channels * 3),
            nn.ReLU(),
            nn.Linear(self.out_channels * 3, self.out_channels),
            nn.ReLU()
        ).to(device)

    def forward(self, x, skip_conn, t):
        # print(x.shape, self.upsample(x).shape, skip_conn.shape)
        x = self.upsample(x) + skip_conn
        x = x + self.time_mlp(get_sinusoidal_embed(t, self.out_channels))[:, :, None, None]
        x = self.wide_residual_block1(x)
        x = self.wide_residual_block2(x)
        return x

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.downsamplers = [DownsampleBlock(n_channels, 64), DownsampleBlock(64, 128), DownsampleBlock(128, 256), DownsampleBlock(256, 512)]
        self.upsamplers = [UpsampleBlock(512, 256), UpsampleBlock(256, 128), UpsampleBlock(128, 64)]
        scaling_factor = 2
        self.latent_resnet1 = WideResidualBlock(512, 512 * scaling_factor).to(device)
        self.latent_resnet2 = WideResidualBlock(512*2, 512).to(device)
        
        self.final_upsampler = nn.ConvTranspose2d(64, n_channels, 3, stride=2, padding=1, output_padding=1)
        self.final_conv = nn.Sequential(
            nn.Conv2d(n_channels, n_channels, 3, padding=1),
            # nn.ReLU(),
            nn.Conv2d(n_channels, n_channels, 3, padding=1),
            # nn.ReLU(),
            nn.Conv2d(n_channels, n_channels, 3, padding=1)
        )

    def forward(self, x, t):
        residuals = []
        for i in range(len(self.downsamplers)):
            residuals.append(x)
            x = self.downsamplers[i](x, t)
            # print(x.shape)
        
        x = self.latent_resnet2(self.latent_resnet1(x))

        for i in range(len(self.upsamplers)):
            x = self.upsamplers[i](x, residuals.pop(), t)

        x = self.final_upsampler(x) + residuals.pop()
        return self.final_conv(x)

    

In [None]:
unet = UNet()
unet.to(device)
unet.train()

num_epochs = 5
loss = nn.SmoothL1Loss()
optimizer = torch.optim.Adam(unet.parameters(), lr=0.1e-3)

for i in range(num_epochs):
    for j, (batch, _) in enumerate(trainloader):
        optimizer.zero_grad()
        images = batch.to(device)
        if(images.shape[0] != batch_size):
            continue
        # adding noise
        time_vals = torch.randint(1, num_timesteps, (batch_size,), device=device)
        batch_, noise = get_noised_image_at(time_vals, images)
        # forward pass, get predicted noise
        predicted_noise = unet(batch_, time_vals)
        loss_val = loss(predicted_noise, noise)
        if j %  20 == 0:
            print(loss_val.item())
        loss_val.backward()
        optimizer.step()

In [None]:
alphas = [1] + [alpha_schedule[t] for t in range(num_timesteps)]

with torch.no_grad():
    unet.eval()
    x_0 = torch.randn(1, n_channels, 32, 32, device=device)
    for t in range(num_timesteps, 0, -1):
        predicted_noise = unet(x_0, torch.tensor([t], device=device))
        factor = (1 - alphas[t]) / torch.sqrt(1-alphas_cumprod[t])
        # print(torch.sqrt(1-alphas_cumprod[t]), 1-alphas[t], factor)
        x_0 = (1/np.sqrt(alphas[t])) * (x_0 - factor * predicted_noise)
        noise = torch.randn_like(x_0) * np.sqrt(1 - alphas[t])
        x_0 += noise

        if t % 10 == 0 or t == 1:
            imshow_grayscale(x_0.squeeze(0).T.cpu())

In [None]:
imshow_grayscale(x_0.squeeze(0).T.cpu())


In [None]:
test_image = mnist_train[0][0].unsqueeze(0)
imshow_grayscale(test_image.squeeze(0).T.cpu())
# imshow_grayscale(get_noised_image_at(200, test_image.to(device))[0][0].squeeze(0).T.cpu())
