This implementation would not be possible without the article provided by Hugging face and paper written by Ho et al., 2017 as I referenced them quite often in order to try implement and fix bugs located in the code. 


In [None]:
from torch.utils.data import DataLoader
from torchvision import transforms
import math
import numpy as np
import time
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torch.distributions as tdist

In [None]:
# Parameter code block
image_dimensions = 32 # get 32 x 32, 64 x 64 or 128 x 128 images
batch_size_param = 32 # We will start with batch size 32 and iteratively reduce depending what the limit of NCC and Google colab is
n_channels = 3 # RGB image has 3 channels and greyscale has 1
dataset_list = ['stl10', 'cifar10', 'ffhq']
dataset = dataset_list[0] # for now we will attempt to use the stl10 dataset however we may use the cifar10 dataset if this does not work or upgrade to the ffhq if we have the time
dd = None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # If GPU is available we will use that otherwise cpu (realistically we can only use a gpu as cpu will train too long)
IMG_SIZE = 64
BATCH_SIZE = 32


In [None]:
"""Dataloader implementation based on code provided for the assingment by Amir and Chris"""

from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset

data_transforms = [transforms.Resize((IMG_SIZE,IMG_SIZE)),
#                    transforms.Grayscale(num_output_channels=1), -> if we want to use grayscale images on higher dimensionality data and the training is taking too long we will enable this option
                   transforms.RandomHorizontalFlip(),
                   transforms.ToTensor(),
                   transforms.Lambda(lambda t: (t * 2) - 1)]

data_transform = transforms.Compose(data_transforms)

if dataset == 'cifar10':
    train = torchvision.datasets.CIFAR10('drive/My Drive/training/cifar10', download=True, transform=data_transform)

    test = torchvision.datasets.CIFAR10('drive/My Drive/training/cifar10', train=False, download=True, transform=data_transform)
    transformed_dataset = torch.utils.data.ConcatDataset([train, test])


# stl10 has larger images which are much slower to train on. You should develop your method with CIFAR-10 before experimenting with STL-10
if dataset == 'stl10':
    train = torchvision.datasets.STL10('drive/My Drive/training/stl10', transform=data_transform, split='train')
    test = torchvision.datasets.STL10('drive/My Drive/training/stl10', transform=data_transform, split='test')
    transformed_dataset = torch.utils.data.ConcatDataset([train, test])

if dataset == "ffhq":

    # Set batch size

    IMG_SIZE = 128
    BATCH_SIZE = 16

    #         transforms.Resize((IMG_SIZE, IMG_SIZE)),
    #         transforms.RandomHorizontalFlip(),
    #         transforms.ToTensor(), # Scales data into [0,1] 
    #         transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1] 
    #     ]


    # Define the transform to resize the images to 128x128 and convert them to tensors
    transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Lambda(lambda t: (t * 2) - 1)
        ])

    # Load the dataset using the ImageFolder dataset
    dataset = datasets.ImageFolder(root="./images/thumbnails128x128", transform=transform)

    # Create a DataLoader for the dataset with the specified batch size
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

    # Create empty lists for the images and labels
    images_list = []
    labels_list = []

    # Loop over the batches and append the images and labels to the lists
    for images, labels in dataloader:
        images_list.append(images)
        labels_list.append(labels)

    # Concatenate the lists of images and labels into PyTorch tensors
    images_tensor = torch.cat(images_list, dim=0)
    labels_tensor = torch.cat(labels_list, dim=0)

    # Create a TensorDataset with the concatenated tensors
    data= TensorDataset(images_tensor, labels_tensor)
    dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

In [None]:
""""Below mathematical python implementation based on DDPM paper and help by referencing huggingface implentation to understand how to perform certain calculation in python e.g. cumulative production"""

def forward_diffusion(x_0, t, device='cpu'):
    
    noise = torch.randn_like(x_0)
    sqrt_cumulative_alphas_t = get_index(cumulative_sqrt_alphas, t, x_0.shape)
    sqrt_one_minus_cumulative_alphas_t = get_index(sqrt_one_minus_cumulative_alphas, t, x_0.shape)
    
    return sqrt_cumulative_alphas_t.to(device) * x_0.to(device) \
    + sqrt_one_minus_cumulative_alphas_t.to(device) * noise.to(device), noise.to(device)
    


def get_index(vals, t, x_shape):
    """ 
    Returns a specific index t of a passed list of values vals
    while considering the batch dimension.
    """
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def linear_beta_schedule(timesteps, start=0.0001, end=0.02):
    return torch.linspace(start, end, timesteps)


def cosine_beta_schedule(t, s=0.00001):
    """
    cosine beta scheduler obtained from: https://huggingface.co/blog/annotated-diffusion
    Initially proposed by: https://arxiv.org/abs/2102.09672
    
    """
    n_steps = t + 1
    x = torch.linspace(0, t, n_steps) # Creates a tesnor of evenly spaced values from start to end
    t_over_T_plus_s = x/t + s # t/T + s
    one_plus_s = 1 + s # 1 + s
    pi_over_2 = torch.pi * 0.5
    
    f_t = torch.cos((t_over_T_plus_s/one_plus_s) * pi_over_2)**2
    cumulative_product = f_t/f_t[0]
    betas = 1 - (cumulative_product[1:])/cumulative_product[:-1] # which is equal (cumulative_product)/cumulative_product[:-2]
    return torch.clip(betas, 0.0001, 0.9999)

# Define beta schedule
T = 100
betas = linear_beta_schedule(timesteps=T)

# Pre-calculate different terms for closed form
alphas = 1. - betas
cumulative_alphas = torch.cumprod(alphas, axis=0)
cumulative_alphas_prev = F.pad(cumulative_alphas[:-1], (1, 0), value=1.0)
sqrt_one_over_alphas = torch.sqrt(1.0 / alphas)
cumulative_sqrt_alphas = torch.sqrt(cumulative_alphas)
sqrt_one_minus_cumulative_alphas = torch.sqrt(1. - cumulative_alphas)
posterior_variance = betas * (1. - cumulative_alphas_prev) / (1. - cumulative_alphas)

In [None]:


def show_tensor_image(image):
    
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
        transforms.Lambda(lambda t: t * 255.), 
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])
    
    if len(image.shape) == 4:
        
        image = image[0, :, :, :]
        
    plt.imshow(reverse_transforms(image))


if dd is None:
    data = transformed_dataset
    dataloader = DataLoader(data, batch_size =BATCH_SIZE, shuffle=True, drop_last=True)


In [None]:
# Simulate forward diffusion
image = next(iter(dataloader))[0]

plt.figure(figsize=(94,94))
plt.axis('off')
num_images = 10
stepsize = int(T/num_images)

for idx in range(0, T, stepsize):
    t = torch.Tensor([idx]).type(torch.int64)
    plt.subplot(1, num_images+1, int((idx/stepsize) + 1))
    image, noise = forward_diffusion(image, t)
    show_tensor_image(image)

In [None]:
"""The below code is adapted from the code provided by huggingface @ https://huggingface.co/blog/annotated-diffusion"""


class AttentionBlock(nn.Module):
    def __init__(self, in_channels):
        super(AttentionBlock, self).__init__()
        
        self.q_conv = nn.Conv2d(in_channels, in_channels , kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
        
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, x):
        # Compute query, key, and value
        proj_query = self.q_conv(x)
        proj_key = proj_query
        proj_value = x
        
        # Compute attention map
        energy = torch.matmul(proj_query, proj_key.transpose(2, 3))
        attention = self.softmax(energy)
        
        # Compute attended feature maps
        out = torch.matmul(attention, proj_value)
        out = out + x
        
        # Apply gamma scaling
        out = self.gamma * out
        
        return out

    

class SampleBlock(nn.Module):
    
    
    def __init__(self, in_ch, out_ch, time_embedding_dim, up=False):
        super().__init__()
        self.time_embedding =  nn.Linear(time_embedding_dim, out_ch)
        if up:
            self.conv1 = nn.Conv2d(2*in_ch, out_ch, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_ch, out_ch, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            self.transform = nn.Conv2d(out_ch, out_ch, 4, 2, 1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bnorm1 = nn.BatchNorm2d(out_ch)
        self.bnorm2 = nn.BatchNorm2d(out_ch)
        self.relu  = nn.ReLU()
        
        self.attention = AttentionBlock(out_ch)
        
    def forward(self, x, t, ):
        
        # First Convolution
        h = self.conv1(x)
        h = self.relu(h)
        h = self.bnorm1(h)
        
        # Create time embedding
        time_embedding = self.time_embedding(t)
        time_embedding = self.relu(time_embedding)
        
        time_embedding = time_embedding[(..., ) + (None, ) * 2]
        
        # Add time embedding to tensor
        h = h + time_embedding
        
        # Second Convolution
        h = self.conv2(h)
        h = self.relu(h)
        h = self.bnorm2(h)
        
        # Apply attention mechanisms
        h = self.attention(h)
        
        
        # Downsample or Upsample
        
        return self.transform(h)

class SinusoidalPositionEmbeddings(nn.Module):

    def __init__(self, dim):

        super().__init__()
        self.dim = dim

    def forward(self, time):

        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)

        return embeddings
        
class U_Net(nn.Module):


    def __init__(self):

        super().__init__()
        image_channels = 3
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 1
        time_embedding_dim = 64


        self.time_embedding = nn.Sequential(
        SinusoidalPositionEmbeddings(time_embedding_dim),
        nn.Linear(time_embedding_dim, time_embedding_dim),
        nn.ReLU())

        # Input convolution to project image to initial image channels
        self.input_conv = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        self.downs = nn.ModuleList([SampleBlock(down_channels[i], down_channels[i+1], time_embedding_dim)\
                                   for i in range(len(down_channels) - 1)])

        self.ups = nn.ModuleList([SampleBlock(up_channels[i], up_channels[i+1], time_embedding_dim, up=True)\
                                   for i in range(len(down_channels) - 1)])

        self.output_conv = nn.Conv2d(up_channels[-1], 3, out_dim)


    def forward(self, x, timestep):

        t = self.time_embedding(timestep)

        x = self.input_conv(x)

        residual_inputs = []

        for down in self.downs:

            x = down(x, t)
            residual_inputs.append(x)

        for up in self.ups:

            residual_x = residual_inputs.pop()

            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t)

        return self.output_conv(x)

        
model = U_Net()
print("Num params: ", sum(p.numel() for p in model.parameters()))
model

In [None]:
"""Due to the instability of the NCC I would save the model weights of my model at the end of every training itteration and 
load them in for the next training session. This was done for the purpose of making sure I could train the model for an ample amount
of time"""

model.load_state_dict(torch.load("model_weights_log.pth"))

In [None]:
import torch
import torch.nn.functional as F
import torch.distributions as tdist

def negative_log_likelihood_loss(x_true, x_pred):
    """
    Compute the negative log-likelihood loss function for a diffusion model.

    Arguments:
        x_true: Tensor of shape (batch_size, channels, height, width) containing the true images.
        x_pred: Tensor of shape (batch_size, channels, height, width) containing the predicted images.

    Returns:
        loss: Scalar Tensor containing the negative log-likelihood loss.
    """
    batch_size, channels, height, width = x_true.shape

    # Reshape input images to (batch_size, channels * height * width)
    x_true_flat = x_true.reshape(batch_size, -1)
    x_pred_flat = x_pred.reshape(batch_size, -1)

    # Compute the negative log-likelihood of the true images given a normal distribution with
    # mean equal to the predicted images and standard deviation of 1.0
    likelihood = tdist.Normal(loc=x_pred_flat, scale=1.0)
    log_probs = likelihood.log_prob(x_true_flat)
    loss = -torch.mean(log_probs)

    return loss


In [None]:
def get_loss(model, x_0, t):
    x_noisy, noise = forward_diffusion(x_0, t, device)
    noise_pred = model(x_noisy, t)
#     return F.l1_loss(noise, noise_pred)
    return negative_log_likelihood_loss(noise,noise_pred)

In [None]:
"""The below functions are based on adapting the implementations provided by Hugging face @ https://huggingface.co/blog/annotated-diffusion """


@torch.no_grad()
def sample_timestep(x, t):
    """
    Calls the model to predict the noise in the image and returns 
    the denoised image. 
    Applies noise to this image, if we are not in the last step yet.
    """
    betas_t = get_index(betas, t, x.shape)
    sqrt_one_minus_cumulative_alphas_t = get_index(
        sqrt_one_minus_cumulative_alphas, t, x.shape
    )
    sqrt_one_over_alphas_t = get_index(sqrt_one_over_alphas, t, x.shape)
    
    # Call model (current image - noise prediction)
    model_mean = sqrt_one_over_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_cumulative_alphas_t
    )
    posterior_variance_t = get_index(posterior_variance, t, x.shape)
    
    if t == 0:
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

@torch.no_grad()
def sample_plot_image():
    # Sample noise
    img_size = IMG_SIZE
    img = torch.randn((1, 3, img_size, img_size), device=device)
    plt.figure(figsize=(15,15))
    plt.axis('off')
    num_images = 10
    stepsize = int(T/num_images)

    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(img, t)
        if i % stepsize == 0:
            plt.subplot(1, num_images, int(i/stepsize+1))
            show_tensor_image(img.detach().cpu())
    plt.show()            

In [None]:
print(device)

In [None]:
model.to(device)

In [None]:
from torch.optim import Adam
losses = []
optimizer = Adam(model.parameters(), lr=0.0001)

for epoch in range(100):
    for step, batch in enumerate(dataloader):
        # batch[0] (len 128) -> image tensors batch[1] (len 128)
        clean_images = batch[0].to(device)
        noise = torch.randn(clean_images.shape).to(clean_images.device)
        bs = clean_images.shape[0]
        t = torch.randint(0, T, (BATCH_SIZE,), device=clean_images.device).long()

        noisy_images, noise = forward_diffusion(batch[0], t, device)

        noise_pred = model(noisy_images, t)

        loss = F.mse_loss(noise_pred, noise)
        loss.backward(loss)
        losses.append(loss.item())

        optimizer.step()
        optimizer.zero_grad()

        if ((epoch + 1) % 5 == 0) and (step == 0):
            loss_last_epoch = sum(losses[-len(dataloader) :]) / len(dataloader)
            sample_plot_image()
            print(f"Epoch:{epoch+1}, loss: {loss_last_epoch}")
            
    

In [None]:
sample_plot_image()

I started with MAE then tried MSE and both those yielded ok but not great results. I have moved on to try use log likelihood as an error function hopefully this will improve my model


In [None]:
# from torch.optim import Adam

# device = "cuda" if torch.cuda.is_available() else "cpu"
# model.to(device)
# optimizer = Adam(model.parameters(), lr=0.0001)
# epochs = 100 # Try more!

# for epoch in range(epochs):
#     for step, batch in enumerate(dataloader):
#         optimizer.zero_grad()
#         clean_images = batch
#         t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
#         loss = get_loss(model, batch[0], t)
#         loss.backward()
#         optimizer.step()

#         if (step == 0) and (epochs % 10 == 0):
#             print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
#             sample_plot_image()

In [None]:
# sample_plot_image()

In [None]:
# from torch.optim import Adam

# device = "cuda" if torch.cuda.is_available() else "cpu"
# model.to(device)
# optimizer = Adam(model.parameters(), lr=0.0001)
# epochs = 100 # Try more!

# for epoch in range(epochs):
#     for step, batch in enumerate(dataloader):
#         optimizer.zero_grad()

#         t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
#         loss = get_loss(model, batch[0], t)
#         loss.backward()
#         optimizer.step()

#         if (step == 0) and (epochs % 10 == 0):
#             print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
#             sample_plot_image()

In [None]:
torch.save(model.state_dict(), "model_weights_log.pth")
# net.load_state_dict(torch.load("model_weights.pth"))