Test environment config.

In [None]:
import sys
print(sys.executable)
import torch
print(torch.__file__) 
print(torch.cuda.is_available())
from torch.utils import collect_env
print(collect_env.main())


Check if the environment has access to the NVIDIA A100 GPU.

In [None]:
!nvidia-smi 

# Diffusion Model - Lion Optimiser

A simple implementation of the diffusion model in PyTorch without text decoder and encoder for a full text-to-image generation pipeline.

In [None]:
import torch
import torchvision 
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchvision import datasets, transforms 
# from torchvision.transforms import Compose, ToTensor, Lambda, Resize, CenterCrop, RandomHorizontalFlip, ToPILImage
from torch.utils.data import DataLoader
import numpy as np
from torch import nn
import math

In [None]:
# Generates 150 samples of 25 columns x 10 rows of images
def show(dataset, num_sample=150, cols=25, rows=10):
    plt.figure(figsize=(15, 15))
    for i, img in enumerate(dataset):
        if i == num_sample:
            break
        plt.subplot(num_sample // rows + 1, cols, i + 1)
        plt.axis('off')
        plt.imshow(img[0])

# Download the dataset 
# *WARNING:* This will take a while to download (depending on connection speed) on the first ever run of this notebook
data = torchvision.datasets.CelebA(root='', split="train", download=True)

# Show the first 150 samples
show(data)

## Step 1 - Forward Diffusion Process

### The linear schedule used in the forward diffusion process to calculate the alphas, betas, diffusion and posterior.

In [None]:
# A linear schedule as proposed in https://arxiv.org/pdf/2102.09672.pdf
def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02

    return torch.linspace(beta_start, beta_end, timesteps)

# A cosine schedule as proposed in https://arxiv.org/abs/2102.09672.pdf
def cosine_beta_schedule(timesteps, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])

    return torch.clip(betas, 0.0001, 0.9999)

# A quadratic schedule
def quadratic_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02

    return torch.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2

# A sigmoid schedule
def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)

    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start

# Returns a specific index t of a passed list of values vals while considering the batch dimension.
def get_index_from_list(vals, t, x_shape):
    batch_size = t.shape[0]
    out = vals.gather(-1, t.cpu())

    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

# Returns the diffusion model's forward diffusion sample, taking an image x_0 and a timestep t as input and returning the noisy version.
def forward_diffusion_sample(x_0, t, device="cpu"):

    noise = torch.randn_like(x_0)

    sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape)

    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x_0.shape
    )
    
    # mean + variance
    return sqrt_alphas_cumprod_t.to(device) * x_0.to(device) \
    + sqrt_one_minus_alphas_cumprod_t.to(device) * noise.to(device), noise.to(device)


# Define beta schedule
T = 300
betas = linear_beta_schedule(timesteps=T)

# define alphas
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# Calculate for diffusion q(x_t | x_{t-1})
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# Calculate for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

### Image Preprocessing Helper Functions

In [None]:
# Parameters for the dataset with image size of 64x64, 128x128, 256x256 
# These will be used to resize the images and test the models on different image sizes
IMG_SIZE = 64
IMG_SIZE_128 = 128
IMG_SIZE_256 = 256

# Batch size for training and testing with 128 images per batch and 256 images per batch.
BATCH_SIZE = 128
BATCH_SIZE_256 = 256

# The tensor transformer for the dataset
def load_transformed_dataset():
    transform = transforms.Compose([ 
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Lambda(lambda t: t * 2 - 1)
    ])

    data_transform = transform

    train = datasets.CelebA(root='', split="train", download=True, transform=data_transform)

    test = datasets.CelebA(root='', split="test", download=True, transform=data_transform)

    return torch.utils.data.ConcatDataset([train, test])


# Load the transformer dataset
data = load_transformed_dataset()

# Appends the data into a dataloader with a batch size of 128 or 256 depending on investigation
dataloader = DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

# The reverse transformer for the dataset to show the images back to their original form
def reverse_tensor_img(image):
    reverse_transform = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)),
        transforms.Lambda(lambda t: t*255),
        transforms.Lambda(lambda t: t.cpu().numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])

    # Take first image of batch
    if len(image.shape) == 4:
        image = image[0, :, :, :]

    plt.imshow(reverse_transform(image))

Testing if the forward diffusion process is working correctly.

In [None]:
# Convert reverse_tensor_img to a cuda tensor function
# reverse_tensor_img = torch.jit.script(reverse_tensor_img)

# Load a single image from the dataloader
image = next(iter(dataloader))[0]

# Add image dimensions for the graph, the amount of image steps and the step sizes
plt.figure(figsize=(18, 18))
plt.axis('off')
num_images = 20
stepsize = int(T/num_images)

# Plot the image with the step size and show the image
for idx in range(0, T, stepsize):
    t = torch.Tensor([idx]).type(torch.int64)
    plt.subplot(1, num_images+1, (idx//stepsize) + 1)
    image, noise = forward_diffusion_sample(image, t)
    plt.axis('off')
    reverse_tensor_img(image)

## Step 2 - Backward Diffusion Process (U-Net)

In [None]:
# The convolutional block for the model
# The block consists of two convolutional layers with each one having its own batch normalization and a relu activation function
# The block also has a time embedding layer that is used to add the time embedding to the convolutional layers
# The block also has skip connections using the time embedding layer and the convolutional layers to add the time embedding to the skip connections
class Block(nn.Module):
    def __init__(self, in_channel, out_channel, time_emb_dim, up=False):
        super().__init__()
        # Time embedding layer
        self.time_mlp =  nn.Linear(time_emb_dim, out_channel)

        # First convolutional layers
        # If up is true then add a convolutional transpose layer to upsample the channels
        if up:
            self.conv1 = nn.Conv2d(2*in_channel, out_channel, 3, padding=1)
            self.transform = nn.ConvTranspose2d(out_channel, out_channel, 4, 2, 1)
        else:
            self.conv1 = nn.Conv2d(in_channel, out_channel, 3, padding=1)
            self.transform = nn.Conv2d(out_channel, out_channel, 4, 2, 1)

        # Second convolutional layer
        self.conv2 = nn.Conv2d(out_channel, out_channel, 3, padding=1)

        # Batch normalization layers for both convolutional layers
        self.bnorm1 = nn.BatchNorm2d(out_channel)
        self.bnorm2 = nn.BatchNorm2d(out_channel)

        # Relu activation function
        self.relu  = nn.ReLU()
        
    def forward(self, x, t, ):
        # First Conv
        h = self.bnorm1(self.relu(self.conv1(x)))
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(..., ) + (None, ) * 2]
        # Add time channel
        h = h + time_emb
        # Second Conv
        h = self.bnorm2(self.relu(self.conv2(h)))
        # Down or Upsample
        return self.transform(h)

# A sinusoidal time embedding layer as described in the paper https://arxiv.org/pdf/1706.03762.pdf
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings

# A UNet architecture for the image denoising task with time embedding in each layer
class SimpleUnet(nn.Module):
    def __init__(self):
        super().__init__()
        image_channels = 3 # RGB: 3 channels for RED, GREEN, BLUE
        down_channels = (64, 128, 256, 512, 1024) # Number of channels in each downsample layer
        up_channels = (1024, 512, 256, 128, 64) # Number of channels in each upsample layer
        out_dim = 1 # 1x1 final of output channels
        time_emb_dim = 32 # Dimension of time embedding

        # Time embedding
        self.time_mlp = nn.Sequential(
                SinusoidalPositionEmbeddings(time_emb_dim),
                nn.Linear(time_emb_dim, time_emb_dim),
                nn.ReLU()
            )
        
        # Initial projection
        self.conv0 = nn.Conv2d(image_channels, down_channels[0], 3, padding=1)

        # Downsample
        self.downs = nn.ModuleList([Block(down_channels[i], down_channels[i+1], \
                                    time_emb_dim) \
                    for i in range(len(down_channels)-1)])
        # Upsample
        self.ups = nn.ModuleList([Block(up_channels[i], up_channels[i+1], \
                                        time_emb_dim, up=True) \
                    for i in range(len(up_channels)-1)])

        # Final output 1x1 conv
        self.output = nn.Conv2d(up_channels[-1], 3, out_dim)

    def forward(self, x, timestep):
        # Embedd time
        t = self.time_mlp(timestep)
        # Initial conv
        x = self.conv0(x)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            # Add residual x as additional channels
            x = torch.cat((x, residual_x), dim=1)           
            x = up(x, t)
        return self.output(x)

model = SimpleUnet()
print("Num params: ", sum(p.numel() for p in model.parameters()))
model
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device=device)

## Step 3 - Training, Lose, Sampling

In [None]:
# A function to get the loss for the model given the input image and the timestep in PyTorch
def get_loss(model, x_0, t, type="l1"):
    if type == "l1":
        x_noisy, noise = forward_diffusion_sample(x_0, t, device)
        noise_pred = model(x_noisy, t)
        return F.l1_loss(noise, noise_pred)
    elif type == "l2":
        x_noisy, noise = forward_diffusion_sample(x_0, t, device)
        noise_pred = model(x_noisy, t)
        return F.mse_loss(noise, noise_pred)
    else:
        raise NotImplementedError()

### Sampling

In [None]:
# A sample that calls the model to predict the noise in the image and returns the denoised image.
# Applies noise to this image, if we are not in the last step yet.
@torch.no_grad()
def sample_timestep(x, t):
    # Get noise from betas, timestep and image shape
    betas_t = get_index_from_list(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = get_index_from_list(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = get_index_from_list(sqrt_recip_alphas, t, x.shape)
    
    # Call model (current image - noise prediction)
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )
    posterior_variance_t = get_index_from_list(posterior_variance, t, x.shape)
    
    if t == 0:
        return model_mean
    else:
        noise = torch.randn_like(x)
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

# A function to plot the denoised image at each timestep showing a 10 step diffusion
@torch.no_grad()
def sample_plot_image():
    # Sample noise
    img_size = IMG_SIZE
    img = torch.randn((1, 3, img_size, img_size), device=device)
    plt.figure(figsize=(15,15))
    plt.axis('off')
    num_images = 10
    stepsize = int(T/num_images)

    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(img, t)
        if i % stepsize == 0:
            plt.subplot(1, num_images, int(i/stepsize+1))
            reverse_tensor_img(img.detach().cpu())
    plt.show()   

# A function to return an np array of the denoised image at each timestep showing a 10 step diffusion for FID score
@torch.no_grad()
def sample_plot_FID():
    # Sample noise
    img_size = IMG_SIZE
    img = torch.randn((1, 3, img_size, img_size), device=device)
    plt.figure(figsize=(15,15))
    plt.axis('off')
    num_images = 10
    stepsize = int(T/num_images)

    for i in range(0,T)[::-1]:
        t = torch.full((1,), i, device=device, dtype=torch.long)
        img = sample_timestep(img, t)
        
    # return an np array of the image
    return img.detach().cpu().numpy()         

# A function to plot our results for the model
@torch.no_grad()
def plot_results(results):

    # Results is a list of tuples (loss, step) for each step
    loss, step = zip(*results)
    plt.plot(step, loss)
    # loss_step = np.array(results)
    # plt.plot(loss_step[:,1], loss_step[:,0])
    plt.xlabel("Training Step")
    plt.ylabel("Training Loss")
    plt.title("Loss per step")
    plt.savefig("adam_loss.png")
    plt.show()

### Training

In [None]:
# device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"CUDA Avaliable: {torch.cuda.is_available()}")

# Output the amount of parameters in the model and aviailable cuda devices
print("Num params: ", sum(p.numel() for p in model.parameters()))
print("Num devices: ", torch.cuda.device_count())
print(f"Device name: {torch.cuda.get_device_name(0)}")
print(f"Device CUDA capability: {torch.cuda.get_device_capability(0)}")
### Results
# The number of prameters in the model is outputted.
# The model is trained for 5 epochs at 1475 steps.
# The model is trained on a single GPU (NVIDIA A100 40GB).

In [None]:
from lion_pytorch import Lion

model.to(device)

optimiser = Lion(model.parameters(), lr=1e-4, weight_decay=1e-2)
epochs = 4

loss_step = []

for epoch in range(epochs):
    print(f"Epoch {epoch}")
    print(f"Amount of steps in dataloader: {len(dataloader)}")
    print(f"Amount of batches in dataloader: {len(dataloader.dataset)}")
    print(f"Batch size: {dataloader.batch_size}")
    running_loss = 0.0
    for step, batch in enumerate(dataloader):
        optimiser.zero_grad()

        t = torch.randint(0, T, (BATCH_SIZE,), device=device).long()
        loss = get_loss(model, batch[0], t, "l1")
        loss.backward()
        optimiser.step()
        # fid_score(model, batch[0], t)
        # Print the loss every 150 steps
        if step % 10 == 0:
            # fid_score(model, batch[0], t)
            # Append the loss to a list with loss and step
            loss_step.append([loss.item(), step])
            # running_loss += loss.item() * 
            print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")

        # Print the loss and image every 250 steps
        if step % 250 == 0:
            # fid_score(model, batch[0], t)

            print(f"Epoch {epoch} | step {step:03d} Loss: {loss.item()} ")
            print(f"Done with epoch {epoch} and step {step:03d}")
            # plot_results(loss_step)
            sample_plot_image()


        if step == 1427 or step == 1426 or step == 1428:
            """ Final Output """
            print(f"The final epoch is {epoch} and the final step is {step}")
            print(F"The final loss is {loss.item()}")

            sample_plot_image()

        # # Save loss and step to a csv file called adam_loss.csv
        # with open("adam_loss.csv", "w") as f:
        #   writer = csv.writer(f)
        #   writer.writerows(loss_step)

        # # Make a plot from the loss and step
        # plot_results(loss_step)
    if epoch == 0:
        print(loss_step)
    if epoch == 1:
        print(loss_step)
    if epoch == 2:
        print(loss_step)

    # Once 100 epochs are done, save the model
    if epoch == 3:
        print(loss_step)
        torch.save(model.state_dict(), "model-lion.pt")
        print("Model saved!")