In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
from IPython import display as disp
import torch.optim as optim
import torch.nn.functional as fn

from torchvision import transforms
from einops import rearrange

import torchvision
import matplotlib.pyplot as plt
import einops

from torchvision import transforms
from einops import rearrange

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
assert torch.zeros(32).to(device).device.type=='cuda' # check cuda is working

In [None]:
# helper function to make getting another batch of data easier
def cycle(iterable):
    while True:
        for x in iterable:
            yield x
                
def normalise(x):
    return (x-x.min())/(x.max()-x.min())

class_names = ['apple','aquarium_fish','baby','bear','beaver','bed','bee','beetle','bicycle','bottle','bowl','boy','bridge','bus','butterfly','camel','can','castle','caterpillar','cattle','chair','chimpanzee','clock','cloud','cockroach','couch','crab','crocodile','cup','dinosaur','dolphin','elephant','flatfish','forest','fox','girl','hamster','house','kangaroo','computer_keyboard','lamp','lawn_mower','leopard','lion','lizard','lobster','man','maple_tree','motorcycle','mountain','mouse','mushroom','oak_tree','orange','orchid','otter','palm_tree','pear','pickup_truck','pine_tree','plain','plate','poppy','porcupine','possum','rabbit','raccoon','ray','road','rocket','rose','sea','seal','shark','shrew','skunk','skyscraper','snail','snake','spider','squirrel','streetcar','sunflower','sweet_pepper','table','tank','telephone','television','tiger','tractor','train','trout','tulip','turtle','wardrobe','whale','willow_tree','wolf','woman','worm',]

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=True, download=True, transform=torchvision.transforms.Compose([
        transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
#         torchvision.transforms.ToTensor()
        
    ])),
    batch_size=124, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=False, download=True, transform=torchvision.transforms.Compose([
#         torchvision.transforms.ToTensor()
#         Do i ADD THIS???
        transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])),
    batch_size=124, drop_last=True)

train_iterator = iter(cycle(train_loader))
test_iterator = iter(cycle(test_loader))

print(f'> Size of training dataset {len(train_loader.dataset)}')
print(f'> Size of test dataset {len(test_loader.dataset)}')

In [None]:
# let's view some of the training data
plt.rcParams['figure.dpi'] = 100
x,t = next(train_iterator)
x,t = x.to(device), t.to(device)
plt.imshow(torchvision.utils.make_grid(x*0.5+0.5).cpu().numpy().transpose(1, 2, 0), cmap=plt.cm.binary)
plt.show()

In [None]:

# From [1] https://github.com/TeaPearce/Conditional_Diffusion_MNIST/blob/main/script.py#L226
def ddpm_schedules(beta1, beta2, T):
    """
    Returns pre-computed schedules for DDPM sampling, training process.
    """
    assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"

    beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32, device=device) / T + beta1

    sqrt_beta_t = torch.sqrt(beta_t)
    alpha_t = 1 - beta_t
    log_alpha_t = torch.log(alpha_t)
    alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()

    sqrtab = torch.sqrt(alphabar_t).view(-1,1,1,1)
    oneover_sqrta = 1 / torch.sqrt(alpha_t)

    sqrtmab = torch.sqrt(1 - alphabar_t).view(-1,1,1,1)
    mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab.squeeze()

    return {
        "alpha_t": alpha_t,  # \alpha_t
        "oneover_sqrta": oneover_sqrta,  # 1/\sqrt{\alpha_t}
        "sqrt_beta_t": sqrt_beta_t,  # \sqrt{\beta_t}
        "alphabar_t": alphabar_t,  # \bar{\alpha_t}
        "sqrtab": sqrtab,  # \sqrt{\bar{\alpha_t}}
        "sqrtmab": sqrtmab,  # \sqrt{1-\bar{\alpha_t}}
        "mab_over_sqrtmab": mab_over_sqrtmab_inv,  # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
    }


In [None]:

# Model architecture adapted from [1] https://github.com/TeaPearce/Conditional_Diffusion_MNIST/
class ResidualConvBlock(nn.Module):
    def __init__(
        self, in_channels: int, out_channels: int, is_res: bool = False
    ) -> None:
        super().__init__()
        self.same_channels = in_channels==out_channels
        self.is_res = is_res
#       Convolution 1 
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
        )
#       Convolution 2
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.is_res:
            x1 = self.conv1(x)
            x2 = self.conv2(x1)
            # this adds on correct residual in case channels have increased
#             If number of output channels is same as input channels:
            if self.same_channels:
                out = x + x2
            else:
                out = x1 + x2 
#             1.414 is normalization factor normally applied in residuals
            return out / 1.414
        else:
            x1 = self.conv1(x)
            x2 = self.conv2(x1)
            return x2


class UnetDown(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UnetDown, self).__init__()
#         process and downscale the image feature maps
        layers = [ResidualConvBlock(in_channels, out_channels), nn.MaxPool2d(2)]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class UnetUp(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UnetUp, self).__init__()
#         process and upscale the image feature maps
        layers = [
            nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
#             ResidualConvBlock(out_channels, out_channels),
            ResidualConvBlock(out_channels, out_channels),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x, skip):
        x = torch.cat((x, skip), 1)
        x = self.model(x)
        return x

# Used to embed context and time information.
class EmbedFC(nn.Module):
    def __init__(self, input_dim, embed_dim):
        super(EmbedFC, self).__init__()
        self.input_dim = input_dim
        layers = [
            nn.Linear(input_dim, embed_dim),
            nn.GELU(),
            nn.Linear(embed_dim, embed_dim),
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(-1, self.input_dim)
        return self.model(x)


class ContextUnet(nn.Module):
    def __init__(self, in_channels, n_feat = 256, n_classes=100):
        super(ContextUnet, self).__init__()

        self.in_channels = in_channels
        self.n_feat = n_feat
        self.n_classes = n_classes

        self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)

        self.down1 = UnetDown(n_feat, n_feat)
        self.down2 = UnetDown(n_feat, 2 * n_feat)

        self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())

        self.timeembed1 = EmbedFC(1, 2*n_feat)
        self.timeembed2 = EmbedFC(1, 1*n_feat)
        self.contextembed1 = EmbedFC(n_classes, 2*n_feat)
        self.contextembed2 = EmbedFC(n_classes, 1*n_feat)

        self.up0 = nn.Sequential(
#             nn.ConvTranspose2d(6 * n_feat, 2 * n_feat, 7, 7), # when concat temb and cemb end up w 6*n_feat
            nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, 8, 8), # otherwise just have 2*n_feat
            nn.GroupNorm(8, 2 * n_feat),
            nn.ReLU(),
        )

        self.up1 = UnetUp(4 * n_feat, n_feat)
        self.up2 = UnetUp(2 * n_feat, n_feat)
        self.out = nn.Sequential(
            nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1),
            nn.GroupNorm(8, n_feat),
            nn.ReLU(),
            nn.Conv2d(n_feat, self.in_channels, 3, 1, 1),
        )

    def forward(self, x, c, t):
        # x is (noisy) image, c is context label, t is timestep, 
        # context_mask says which samples to block the context on
        x = self.init_conv(x)
        down1 = self.down1(x)
        
        down2 = self.down2(down1)
    
        hiddenvec = self.to_vec(down2)

        # convert context to one hot embedding
        c = nn.functional.one_hot(c, num_classes=100).type(torch.float).to(device)

        # embed context, time step
        cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1)
        temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
        cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
        temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
                
        up1 = self.up0(hiddenvec)
        up2 = self.up1(cemb1*up1+temb1, down2)  # add and multiply embeddings
        up3 = self.up2(cemb2*up2+temb2, down1)  # add and multiply embeddings
        
#       Concat up3 with the input (x)
        output = torch.cat((up3, x), 1)

        out = self.out(output)

        return out


In [None]:
# Model architecture adapted from [1] https://github.com/TeaPearce/Conditional_Diffusion_MNIST/
class DDPM(nn.Module):
    def __init__(self,):
        super(DDPM, self).__init__()
#         self.net = UNet()
        self.net = ContextUnet(in_channels=3, n_feat=48, n_classes=100)

        # register_buffer allows accessing dictionary produced by ddpm_schedules
        # e.g. can access self.sqrtab later
        for k, v in ddpm_schedules(1e-4, 0.02, 400).items():
            self.register_buffer(k, v)
        
        
    # algorithm 1 in DDPM paper
#     def forward(self, x):
    def forward(self, x, c):
#         Randomly samples time steps between 1 and T for each input in the batch.
#         As you go from 1 to T (or opposite?) the strength of noise added to the image will increase
        _ts = torch.randint(1, T+1, (x.shape[0],)).to(device)
#         Generates random noise eps with the same shape as x - uses gaussian distribution - ϵ ∼ N (0, I)
        eps = torch.randn_like(x) 
#         Partially diffuse x using the variance schedule - some images will be have large amounts of noise,
#         while others might only have a bit - depends on variance schedule (T)
        x_t = (self.sqrtab[_ts] * x + self.sqrtmab[_ts] * eps)
        ts_ret = _ts/400
    
#         Input this partially diffused model into the generator. It is trying to predict the noise from the input.
#         Calculate how close the model has predicted the noise from the input using MSE loss, 
#         comparing the generated noise to the actual noise added to the image.
# Is z = self.net(x_t, c, ts_ret, context_mask)????????
        return F.mse_loss(eps, self.net(x_t , c, ts_ret))
    
    # algorithm 2 in DDPM paper
    def sample(self, n_sample, size, c):
#         Initializes a random tensor x_i
        x_i = torch.randn(n_sample, *size).to(device)

#         High T = high noise, low T = low noise. Starts with high noise, then as for loop continues, noise decreases.
        for i in range(T, 0, -1):
#             Creates a tensor representing the time step in the current iteration.
            t_is = torch.tensor([i / 400]).to(device)
#             Repeats the time step tensor to match the size of the current batch
            t_is = t_is.repeat(n_sample,1,1,1)
#             When we reach the final time step, we dont want to add any random noise (z) so we set to 0
            z = torch.randn(n_sample, *size).to(device) if i > 1 else 0
#             This tries to predict the nosie from the input
            eps = self.net(x_i, c, t_is).to(device)
    
#             This subtracts the predicted noise from the generated model to update the generated model.
#             It also adds some noise to the model (z) that has been scaled according to the schedule 
#             (to ensure that the generated samples follow a distribution that matches the training data).
            x_i = (self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i]) + self.sqrt_beta_t[i] * z ).to(device)
        return x_i

 

In [None]:
ddpm = DDPM().to(device)
step = 0
loss_smooth = 2.00
# Adam optimiser
optimiser = torch.optim.Adam(ddpm.parameters(), lr=0.001)
T = 400

print(f'> Number of model parameters {len(torch.nn.utils.parameters_to_vector(ddpm.parameters()))}')
if len(torch.nn.utils.parameters_to_vector(ddpm.parameters())) > 1000000:
    print("> Warning: you have gone over your parameter budget and will have a grade penalty!")


In [None]:
import torchvision.utils as vutils
from torch.optim.lr_scheduler import StepLR

# Creates a learning rate scheduler using StepLR, step size 1000
scheduler = StepLR(optimiser, step_size=1000, gamma=0.97)

    
while step < 50000:
#   Put model in training mode
    ddpm.train()

#   Obtain next batch of data with labels
    x, c = next(train_iterator)
    x, c = x.to(device), c.to(device)

#   Calculate loss between generated noise and actual noise
    loss = ddpm(x, c) # modified with conditionals
#   Clear gradients of all obtimized tensors
    optimiser.zero_grad()
#     Use this loss to update the weights of the model.
#   Compute gradient of loss with respect to model parameters
    loss.backward()
#   Update model parameters using computed gradients
    optimiser.step()
#   Adjusts the learning rate
    scheduler.step()

#     Updates a smoothed version of the loss for display purposes
    loss_smooth = loss_smooth + 0.01 * (loss.item() - loss_smooth)
    
    if step % 1000 == 0:
#       Put into evaluation mode
        ddpm.eval()
        print(f'step: {step:4d} train loss: {loss_smooth:.3f}')
            
        with torch.no_grad():
#           Limits the size of the condition c to the first 64 samples, since batch size is 300+
            c = c[:64]
            
#           Sample from the model using a batch of gaussian noise.
            samples_batch = ddpm.sample(64, (3, 32, 32), c).to(x.device)
            
    
#           Limits the size of the input data x to the first 64 samples
            x = x[:64]
            img_grid = rearrange(
                torch.cat([(samples_batch * 0.5 + 0.5), (x * 0.5 + 0.5)], dim=0),
                '(b1 b2) c h w -> (b1 h) (b2 w) c',
                b1=2
            )
            plt.figure(figsize=(100, 100))  # Set the figure size
            plt.imshow((img_grid.cpu() * 255).int().numpy())
            plt.show()

    step += 1

In [None]:
# Save model.
# model_path = 'Models/generative.pth'
# torch.save(ddpm, model_path)


# Load model.
model_path = 'Models/generative.pth'
ddpm = torch.load(model_path)


In [None]:
# For viewing interpolations between pairs of samples. 
# NOTE: Ensure that the second (commented out) ddpm is being used instead of the first dppm.
import matplotlib.pyplot as plt
import torchvision

ddpm.eval()

# Generate a grid for subplots
plt.figure(figsize=(12, 12))

x, c = next(train_iterator)
x, c = x.to(device), c.to(device)

with torch.no_grad():
    c = torch.randint(0, 100, (8,))
    size = (3, 32, 32)
    n_sample = 8

    x_i = torch.randn(n_sample, *size).to(device)

    samples = ddpm.sample(n_sample, size, c, x_i).to(device).detach()

# Create a grid of images
grid = torchvision.utils.make_grid(samples * 0.5 + 0.5, nrow=10, padding=2, normalize=True)

# Convert the grid to a NumPy array and transpose channels
grid_np = grid.cpu().numpy().transpose(1, 2, 0)

# Display the grid
plt.imshow(grid_np)
plt.axis('off')
plt.show()


In [None]:
%%capture
!pip install clean-fid
import os
from cleanfid import fid
from torchvision.utils import save_image

In [None]:
# define directories
real_images_dir = 'real_images'
generated_images_dir = 'generated_images'
num_samples = 10000 # do not change

ddpm.eval()


# create/clean the directories
def setup_directory(directory):
    if os.path.exists(directory):
        !rm -r {directory} # remove any existing (old) data
    os.makedirs(directory)

setup_directory(real_images_dir)
setup_directory(generated_images_dir)

# generate and save 10k model samples
num_generated = 0
while num_generated < num_samples:
    print(num_generated)

    with torch.no_grad():
        c = torch.randint(0, 100, (64,))

        samples_batch = ddpm.sample(64, (3, 32, 32), c).to(device).detach()

        for image in samples_batch:
            if num_generated >= num_samples:
                break
            save_image(image, os.path.join(generated_images_dir, f"gen_img_{num_generated}.png"))
            num_generated += 1

# save 10k images from the CIFAR-100 test dataset
num_saved_real = 0
while num_saved_real < num_samples:
    real_samples_batch, _ = next(test_iterator)
    for image in real_samples_batch:
        if num_saved_real >= num_samples:
            break
        save_image(image, os.path.join(real_images_dir, f"real_img_{num_saved_real}.png"))
        num_saved_real += 1

In [None]:
# Compute FID

real_images_dir = 'real_images'
generated_images_dir = 'generated_images'
score = fid.compute_fid(real_images_dir, generated_images_dir, mode="clean")
print(f"FID score: {score}")

In [None]:
# View nearest neighbours using LPIPS [2]: (reference at top)
import lpips

ddpm.eval()

with torch.no_grad():
#     Random batch
#     context1 = torch.randint(0, 100, (64,)).to(device)
#     One type of image batch - in this case label 88
    context1 = torch.full((64,), 88, dtype=torch.long).to(device)


    # Sample images from the diffusion model
    sample1 = ddpm.sample(64, (3, 32, 32), context1).to(device).detach()
    sample1 = sample1 * 0.5 + 0.5

# Create an LPIPS model
loss_fn = lpips.LPIPS(net='alex').to(device)


# Function to calculate LPIPS distance between two images
def lpips_distance(img1, img2):
    with torch.no_grad():
        distance = loss_fn(img1.unsqueeze(0), img2.unsqueeze(0)).item()

    return distance


# Function to find nearest neighbors
def find_nearest_neighbors(target_image, image_list, k=6):
#   Computes a list of distances between the target_image and each image in image_list
    distances = [lpips_distance(target_image, img) for img in image_list]
#   Creates a list of indices corresponding to the sorted order of distances
    indices = sorted(range(len(distances)), key=lambda i: distances[i])
#   Returns a list of the k nearest neighbor images from image_list
    return [image_list[i] for i in indices[:k]], indices[:k]

# Select a target image for finding neighbors
target_index = 0  # Choose the index of the target image
target_image = sample1[target_index].to(device)

# Find nearest neighbors
nearest_neighbors, indices = find_nearest_neighbors(target_image, sample1)

# Display the target image and its nearest neighbors
plt.figure(figsize=(10, 5))
plt.subplot(1, len(nearest_neighbors)-2, 1)
plt.imshow(target_image.cpu().numpy().transpose(1, 2, 0))
plt.title('Target Image')
plt.axis('off')

# Plot some of the nearest neighbors.
for i, neighbor_img in enumerate(nearest_neighbors):
#   Ignore the first image since it is the target_image 
    if i == 0:
        print('/')
    else:
        plt.subplot(1, len(nearest_neighbors) + 1, i + 2)
        plt.imshow(neighbor_img.cpu().numpy().transpose(1, 2, 0))
        plt.title(f'Neighbor {i}')
        plt.axis('off')


# plt.savefig('LPIPSSameTIGER.png')

plt.show()

In [None]:
# Generate samples:

# Set the figure size
plt.figure(figsize=(12, 12))

ddpm.eval()

# Generate samples
with torch.no_grad():
    c = torch.randint(0, 100, (64,))
    samples_batch = ddpm.sample(64, (3, 32, 32), c).to(device)

# Plot and save the figure
plt.imshow(torchvision.utils.make_grid(samples_batch * 0.5 + 0.5).cpu().numpy().transpose(1, 2, 0), cmap=plt.cm.binary)
plt.axis("off")  # Turn off axis labels
plt.title("Generated Samples")
plt.savefig('generated_samples.png')  # Save the figure as an image

# Show the figure
plt.show()

In [None]:
# To show one batch of 5 images with the same label.

# Set the figure size
plt.figure(figsize=(15, 15))


ddpm.eval()

# Define label number
label_numbers = [99]

for label_number in label_numbers:
    print(class_names[label_number])

    # Generate samples for the given label
    with torch.no_grad():
        c = torch.full((5,), label_number)
        samples_batch = ddpm.sample(5, (3, 32, 32), c).to(device)

    # Plot the label and the generated samples
    plt.text(2, label_number - 0.5, f"Label: {label_number}", color='red', fontweight='bold', ha='center', va='center')

    for i in range(5):
        plt.subplot(len(label_numbers), 5, len(label_numbers) * i + label_numbers.index(label_number) + 1)
        plt.imshow((samples_batch[i] * 0.5 + 0.5).cpu().numpy().transpose(1, 2, 0), cmap=plt.cm.binary)
        plt.axis("off")

# Save the figure as an image
# plt.savefig('generated_samples_for_labels.png')

# Show the figure
plt.show()


In [None]:
# Generate a sample of 5 specific images based off of specific labels:
# To form cherry-picked images and bad batch of images

plt.figure(figsize=(12, 12))

ddpm.eval()

# Generate samples
with torch.no_grad():
#     c = torch.randint(0, 100, (64,))
    c = [2,11,35,46,98]
#     c = [88]
    c = torch.tensor(c)

    samples_batch = ddpm.sample(5, (3, 32, 32), c).to(device)

# Plot and save the figure
plt.imshow(torchvision.utils.make_grid(samples_batch * 0.5 + 0.5).cpu().numpy().transpose(1, 2, 0), cmap=plt.cm.binary)
plt.axis("off")  # Turn off axis labels
# plt.title("Generated Samples")
plt.savefig('bad_people_samples.png')  # Save the figure as an image

# Show the figure
plt.show()

for i in c:
    print(class_names[i])
