In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torchvision as tv
import matplotlib.pyplot as plt
import torch.nn.functional as fn
import math
import random as r
import math as m

from cleanfid import fid
from torchvision.utils import save_image

from einops import rearrange, reduce, asnumpy, parse_shape, repeat, einsum
from einops.layers.torch import Rearrange, Reduce

import shutil

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Device is {device}!")

In [None]:
transform = tv.transforms.Compose([
    tv.transforms.RandomHorizontalFlip(0.5),
    tv.transforms.ToTensor(), # turn into torch Tensor of shape CHW, divide by 255
    tv.transforms.Lambda(lambda t: (t * 2) - 1)
    
])


transform_verify = tv.transforms.Compose([
    tv.transforms.ToTensor(),
    tv.transforms.Lambda(lambda t: (t * 2) - 1)
])

dataset = tv.datasets.CIFAR100(root="./Datasets", train=True, transform=transform, download=True)
dataset_verify = tv.datasets.CIFAR100(root="./Datasets", train=False, transform=transform_verify, download=True)

num_classes = 100
image_channels = 3
batch_size = 256

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, drop_last=True, shuffle=True)
dataloader_verify = torch.utils.data.DataLoader(dataset_verify, batch_size=batch_size, drop_last=True, shuffle=True)

In [None]:
for images in dataloader:
    print(len(images))
    print(images[0].shape)
    print(images[1].shape)
    
    print(images[1])
    break

In [None]:
for images, labels in dataloader:
    print(images[0].shape)
    
    for x in range(1, 13):
        plt.subplot(2, 6, x)
        plt.imshow(images[x].permute(1, 2, 0), cmap=plt.cm.binary)
    break

In [None]:
images, labels = next(iter(dataloader))
images, labels = images.to(device), labels.to(device)

T = 400


def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)


def cosine_schedule(timesteps, s=0.008):
    def f(t):
        return torch.cos((t / timesteps + s) / (1 + s) * 0.5 * torch.pi) ** 2
    x = torch.linspace(0, timesteps, timesteps + 1)
    alphas_cumprod = f(x) / f(torch.tensor([0]))
    betas = 1 - alphas_cumprod[1:] / alphas_cumprod[:-1]
    betas = torch.clip(betas, 0.0001, 0.999)
    return betas


def linear_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)


def sigmoid_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    betas = torch.linspace(-6, 6, timesteps)
    return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start

#Code taken from https://huggingface.co/blog/annotated-diffusion
# and https://dzdata.medium.com/intro-to-diffusion-model-part-4-62bd94bd93fd

betas = sigmoid_beta_schedule(timesteps=T)

# define alphas 
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = fn.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

def sample_images(image, ts, noise=None):
    if noise is None:
        noise = torch.randn_like(image)
    return sqrtab[ts] * image + sqrtmab[ts] * noise

def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)

def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)

    sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, t, x_start.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x_start.shape
    )

    return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise


reverse_transform = tv.transforms.Compose([
     tv.transforms.Lambda(lambda t: (t + 1) / 2),
     tv.transforms.Lambda(lambda t: t * 255.),
     tv.transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
     tv.transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
     tv.transforms.ToPILImage(),
])

def get_noisy_image(x_start, t):
    # add noise
    x_noisy = q_sample(x_start, t=t)

    # turn back into PIL image
    noisy_image = reverse_transform(x_noisy.squeeze())

    return noisy_image


In [None]:
def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):
    if not isinstance(imgs[0], list):
        # Make a 2d grid even if there's just 1 row
        imgs = [imgs]

    num_rows = len(imgs)
    num_cols = len(imgs[0]) + with_orig
    fig, axs = plt.subplots(figsize=(200,200), nrows=num_rows, ncols=num_cols, squeeze=False)
    for row_idx, row in enumerate(imgs):
        row = [image] + row if with_orig else row
        for col_idx, img in enumerate(row):
            ax = axs[row_idx, col_idx]
            ax.imshow(np.asarray(img), **imshow_kwargs)
            ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

    if with_orig:
        axs[0, 0].set(title='Original image')
        axs[0, 0].title.set_size(8)
    if row_title is not None:
        for row_idx in range(num_rows):
            axs[row_idx, 0].set(ylabel=row_title[row_idx])

    plt.tight_layout()
    
load_iter = iter(dataloader)
images, _ = next(load_iter)
    
plot([get_noisy_image(images[0], torch.tensor([t])) for t in [0, 50, 100, 150, 200, 250, 300, 350]])

In [None]:
#code partially taken from https://huggingface.co/blog/annotated-diffusion
#way to inject conditionals taken from https://github.com/TeaPearce/Conditional_Diffusion_MNIST

class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_layers):
        super(DepthwiseSeparableConv, self).__init__()

        self.in_layers = in_layers
        self.convolutions = nn.Sequential(nn.Conv2d(in_layers, in_layers, 3, 1, 1, groups=in_layers),
                                           nn.BatchNorm2d(in_layers),
                                           nn.ReLU(),
                                           nn.Conv2d(in_layers, in_layers, 1, 1),
                                           nn.BatchNorm2d(in_layers),
                                           nn.ReLU())
        
    def forward(self, x):
        return self.convolutions(x)

class Inception(nn.Module):
    def __init__(self, in_layers, bottleneck):
        super(Inception, self).__init__()

        self.in_layers = in_layers
        self.bottleneck = bottleneck
        self.bn = nn.BatchNorm2d(in_layers)
        
        self.layers = nn.Sequential(nn.Conv2d(in_layers, bottleneck, 1, 1),
                                    nn.BatchNorm2d(bottleneck),
                                    nn.ReLU(),
                                    DepthwiseSeparableConv(bottleneck),
                                    nn.Conv2d(bottleneck, in_layers, 1, 1),
                                    nn.GroupNorm(8, in_layers),
                                    nn.ReLU())
                                    
    def forward(self, x):
        x = x + self.layers(x)
        return self.bn(x)
    
class ResLayer(nn.Module):
    def __init__(self, layers_num, bottleneck_size):
        super(ResLayer, self).__init__()
        self.layers_num = layers_num

        self.layers = nn.Sequential(Inception(layers_num, bottleneck_size),
                                    Inception(layers_num, bottleneck_size),
                                    Inception(layers_num, bottleneck_size),
                                    Inception(layers_num, bottleneck_size),
                                    Inception(layers_num, bottleneck_size),
                                    Inception(layers_num, bottleneck_size))
        
    def forward(self, x):
        x = self.layers(x)
        
        return x
    
    
class UNetBlock(nn.Module):
    def __init__(self, in_layers, bottleneck_size, out_layers, add_attention=False):
        super(UNetBlock, self).__init__()

        self.in_layers = in_layers
        self.out_layers = out_layers
      
        self.layers = ResLayer(in_layers, bottleneck_size)
        
        self.group_normilisation = nn.GroupNorm(8, in_layers)
        
        self.bottleneck_out = nn.Conv2d(in_layers, out_layers, 1, 1)

        self.attention = None
        if add_attention:
            self.attention = Attention(self.out_layers, 4, self.out_layers//4)

    def forward(self, x):
        x = self.layers(x)
        x = self.group_normilisation(x)
        
        if self.attention:
            out = self.attention(out)
            
        return self.bottleneck_out(x)


class SkipConnection(nn.Module):
    def __init__(self, in_layers, out_layers):
        super(SkipConnection, self).__init__()
        self.convolution = nn.Sequential(nn.Conv2d(in_layers, out_layers, 1, 1),
                                         nn.BatchNorm2d(out_layers),
                                         nn.ReLU())

    def forward(self, x):
        x = self.convolution(x)
        return x 


class Down(nn.Module):
    def __init__(self, in_layers, out_layers, bottleneck_size, add_attention=False):
        super(Down, self).__init__()
        self.convolutions = UNetBlock(in_layers, bottleneck_size, out_layers, add_attention)

    def forward(self, x):
        x = self.convolutions(x)
        downsampled = fn.max_pool2d(x, 2, 2)
        return downsampled, x

    
class Up(nn.Module):
    def __init__(self, in_layers, out_layers, bottleneck_size, add_attention=False):
        super(Up, self).__init__()

        self.upsampler = nn.ConvTranspose2d(out_layers, out_layers, 2, 2)
        self.convolutions = UNetBlock(in_layers, bottleneck_size, out_layers, add_attention)

    def forward(self, x, skip=None):
        if skip is not None:
            x = torch.concat([x, skip], dim=1)
        x = self.convolutions(x)
        return self.upsampler(x)
    

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        
        layer1 = 64
        layer2 = 128
        layer3 = 184

        out = 72

        self.entry = nn.Conv2d(image_channels, 32, 1, 1)

        self.ConvDown1 = Down(32, layer1, 32)
        self.ConvDown2 = Down(layer1, layer2, 64)
        self.ConvDown3 = Down(layer2, layer3, 92)

        self.downsampler = nn.AvgPool2d(4)
        
        self.middle_to_up = nn.Sequential(nn.ConvTranspose2d(layer3, layer3, 2, 1),
                                          nn.GroupNorm(8, layer3),
                                          nn.ReLU(),
                                          nn.ConvTranspose2d(layer3, layer3, 2, 2),
                                          nn.GroupNorm(8, layer3),
                                          nn.ReLU())
        
        self.ConvUp3 = Up(layer3, layer2//2, 40)
        self.ConvUp2 = Up(layer2, layer1//2, 22)
        self.ConvUp1 = Up(layer1, out//2, 12)

        self.final = nn.Sequential(nn.Conv2d(out, image_channels, 1, 1))
    
        self.squeezer1 = SkipConnection(layer1, out//2)
        self.squeezer2 = SkipConnection(layer2, layer1//2)
        self.squeezer3 = SkipConnection(layer3, layer2//2)
        
        
        self.t_encoder1 = nn.Sequential(nn.Linear(1, 180),
                                        nn.ReLU(),
                                        nn.Linear(180, layer2),
                                        nn.BatchNorm1d(layer2))
        
        self.label_encoder1 = nn.Sequential(nn.Linear(num_classes, 180),
                                    nn.ReLU(),
                                    nn.Linear(180, layer2),
                                    nn.BatchNorm1d(layer2))
        
        self.t_encoder2 = nn.Sequential(nn.Linear(1, 180),
                                        nn.ReLU(),
                                        nn.Linear(180, layer3),
                                        nn.BatchNorm1d(layer3))
        
        self.label_encoder2 = nn.Sequential(nn.Linear(num_classes, 180),
                                    nn.ReLU(),
                                    nn.Linear(180, layer3),
                                    nn.BatchNorm1d(layer3))
        
    
    def forward(self, x, labels, t):
        
        ch = fn.one_hot(labels, num_classes=num_classes).to(device)
        ch = ch.to(torch.float32)
        
        label_encoding2 = self.label_encoder2(ch)
        label_encoding1 = self.label_encoder1(ch)
        
        
        t = t.float()
        t = torch.unsqueeze(t, dim=1)
        t_encoding2 = self.t_encoder2(t)
        t_encoding1 = self.t_encoder1(t)

        x, skip1 = self.ConvDown1(self.entry(x))
        x, skip2 = self.ConvDown2(x)
        x, skip3 = self.ConvDown3(x)
        
        
        hidden_vec = self.downsampler(x)
        
        x = self.middle_to_up(hidden_vec)
        label_encoding2 = repeat(label_encoding2, "b c -> b c w h", w=x.shape[2], h=x.shape[3])
        t_encoding2 = repeat(t_encoding2, "b c -> b c w h", w=x.shape[2], h=x.shape[3])
        
        x = self.ConvUp3(x*label_encoding2 + t_encoding2)
        
        x = torch.concat([x, self.squeezer3(skip3)], dim=1)
        label_encoding1 = repeat(label_encoding1, "b c -> b c w h", w=x.shape[2], h=x.shape[3])
        t_encoding1 = repeat(t_encoding1, "b c -> b c w h", w=x.shape[2], h=x.shape[3])
        x = self.ConvUp2(x*label_encoding1 + t_encoding1)
        x = self.ConvUp1(x, self.squeezer2(skip2))
        
        return self.final(torch.concat([x, self.squeezer1(skip1)], dim=1))

In [None]:
model_save_path = "ModelSave/linear_model"

In [None]:
@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))


class DDPM(nn.Module):
    def __init__(self,):
        super(DDPM, self).__init__()
        self.net = UNet()
    
    def interpolate(self):
        num_pairs = 8
        
        labels = torch.randint(low=0, high=num_classes, size=(num_pairs,))
        image_pairs = []
        
        for i in range(num_pairs//2):
            image_pairs.append((self.p_sample_loop(2, labels[2*i:2*i+2]), labels[2*i:2*i+2]))
            
        interpolated_images = []
        
        for ([x1, x2], [label1, label2]) in image_pairs:
            interpolated_images.append([x1] + self.interpolate_pair(x1, x2, i, label1, label2) + [x2]) 
        
        return interpolated_images
    
    def interpolate_simple(self):
        num_pairs = 8
        
        latent_vectors = torch.randn(16, 3, 32, 32).to(device)
        result = []
        for i in range(0, 16, 2):
            vec1 = latent_vectors[i]
            vec2 = latent_vectors[i+1]
            
            vec1_b = torch.unsqueeze(vec1, dim=0)
            vec2_b = torch.unsqueeze(vec2, dim=0)
            
            random_label = torch.randint(low=0, high=100, size=(1,)).item()

            # Create a PyTorch tensor with the specified properties
            tensor_length = 8
            interpolation_labels = torch.cat([torch.full((4,), random_label), torch.randint(low=0, high=100, size=(4,))])

            interpolations = [vec1_b]
            for ii in range(1, 7):
                interpolation = torch.lerp(vec1, vec2, ii/7)
                interpolation = torch.unsqueeze(interpolation, dim=0)
                interpolations.append(interpolation)
            interpolations.append(vec2_b)
            
            interpolations = torch.concat(interpolations, dim=0)
            
            
            interpolations = self.p_sample_loop(8, interpolation_labels, interpolations)
            result = [*result, *interpolations]
            
        return result
            
        
    def interpolate_pair(self, x1, x2, t, label1, label2):
        schedule = [50, 100, 200, 200, 100, 50]
        output = []
        
        for i, t in enumerate(schedule):
            ts = torch.tensor([t, t], dtype=torch.int64).to(device)
            noise = torch.randn(3, 32, 32).to(device)
            noise = repeat(noise, "c w h -> b c w h", b=2)
            x1 = torch.unsqueeze(x1, dim=0)
            x2 = torch.unsqueeze(x2, dim=0)
            pair = torch.concat([x1, x2], dim=0).to(device)
            sqrt_alphas_cumprod_t = extract(sqrt_alphas_cumprod, ts, pair.shape)
            sqrt_one_minus_alphas_cumprod_t = extract(
                sqrt_one_minus_alphas_cumprod, ts, pair.shape
            )
            
            print(sqrt_alphas_cumprod_t.device)
            print(pair.device)
            print(sqrt_one_minus_alphas_cumprod_t.device)
            print(noise.device)
            
            
            diffused_xs = sqrt_alphas_cumprod_t * pair + sqrt_one_minus_alphas_cumprod_t * noise
            x1 = diffused_xs[0]
            x2 = diffused_xs[1]
            latent_z = x1 * (((1/x1)*x2)**(i/len(schedule)))
            latent_label = label1 if i < 3  else label2
            print(latent_z.shape)
            latent_z = torch.unsqueeze(latent_z, dim=0)
            latent_label = torch.unsqueeze(latent_label, dim=0)
            for i in reversed(range(0, t)):
                latent_z = self.p_sample(latent_z, latent_label, torch.full((1,), i, device=device, dtype=torch.long), i)
            output.append(latent_z)
        
        return output
    
    def forward(self, x, labels):
        ts = torch.randint(0, T, (x.shape[0],)).to(x.device)
        eps = torch.randn_like(x)
        noisy_image = q_sample(x, ts, eps)
        noise_prediction = self.net(noisy_image, labels, ts)
        return fn.mse_loss(eps, noise_prediction)
    
    def p_sample(self, x, labels, t, t_index):
        betas_t = extract(betas, t, x.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)
        sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)

        model_mean = sqrt_recip_alphas_t * (x - betas_t * self.net(x, labels, t) / sqrt_one_minus_alphas_cumprod_t)
        
        if t_index == 0:
            return model_mean
        else:
            posterior_variance_t = extract(posterior_variance, t, x.shape)
            noise = torch.randn_like(x)
            
            return model_mean + torch.sqrt(posterior_variance_t) * noise 
    
    def p_sample_loop(self, num_samples, labels_in=None, noise_batch=None):
        shape = [num_samples, 3, 32, 32]
        
        if labels_in is None:
            random_labels = torch.randint(low=0, high=num_classes, size=(num_samples,))
        else:
            random_labels = labels_in
        
        b = shape[0]
        # start from pure noise (for each example in the batch)
        if noise_batch is None:
            img = torch.randn(shape, device=device)
        else:
            img = noise_batch
        imgs = []

        for i in reversed(range(0, T)):
            img = self.p_sample(img, random_labels, torch.full((b,), i, device=device, dtype=torch.long), i)
        return img
    
    def sample(self, noise_batch):
        num_samples = noise_batch.shape[0]
        random_labels = torch.randint(low=0, high=num_classes, size=(num_samples,))
        
        x = noise_batch

        for t in range(T-1, -1, -1):
            z = torch.randn(num_samples, 3, 32, 32).to(device) if t > 0 else 0
            ts = torch.tensor(t).repeat(num_samples).to(device)
            noise_prediction = self.net(x, random_labels, ts)
            x = oneover_sqrta[t] * (x - (noise_prediction * beta_t[t] / sqrtmab[t])) + sigma_t[t] * z
        return x
    
model = DDPM()
model = model.to(device)
optim = torch.optim.Adam(model.parameters(), lr=0.003)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optim, gamma=0.995)

In [None]:
model.load_state_dict(torch.load("diffusion_model_final"))

In [None]:

with torch.no_grad():
    xh = model.p_sample_loop(10)
img_grid = rearrange(xh, '(b1 b2) c h w -> (b1 h) (b2 w) c', b1=2)
plt.imshow((img_grid.cpu()*255).int().numpy())
plt.show()

del xh

In [None]:
print(f"Parameter Number: {sum(p.numel() for p in model.parameters())}")

In [None]:

model.load_state_dict(torch.load("ModelSave/best_model"))
#model_save_path = "ModelSave/model_best"

In [None]:
model_save_path = "ModelSave/best_model"

epochs = 3000
epoch = 0
train_batches_count = 2000

step_count = 0
min_loss = 10.0
do_continue = True

while do_continue:
    model.train()
    
    losses = []
    train_iter = iter(dataloader)
    for i in range(train_batches_count):
        images, labels = next(train_iter, (None, None))
        if images is None:
            break

        images, labels = images.to(device), labels.to(device)

        loss = model(images, labels)
        optim.zero_grad()
        loss.backward()
        optim.step()
        step_count += 1
        
        if step_count >= 100000:
            do_continue = False
            break
        
        losses.append(loss.item())


    model.eval()
    with torch.no_grad():
        test_losses = []
        test_accuracies = []

        for i, batch in enumerate(dataloader_verify):
            x, labels = batch
            x, labels = x.to(device), labels.to(device)

            loss = model(x, labels)
            test_losses.append(loss.item())
        
        if epoch % 5 == 0:
            sample_count = 8
            noise_batch = torch.rand((sample_count, 3, 32, 32)).to(device)
            
            xh = model.p_sample_loop(sample_count)
            img_grid = rearrange(xh, 'b c h w -> h (b w) c')
            img_grid = (img_grid + 1) / 2
            plt.imshow((img_grid.cpu()*255).int().numpy())
            plt.show()

    epoch += 1

    average_loss = sum(losses)/len(losses)
    if average_loss < min_loss:
        min_loss = average_loss
        torch.save(model.state_dict(), model_save_path)
        
    scheduler.step()
    print(f"At epoch {epoch}: average loss {average_loss}, average test loss {sum(test_losses)/len(test_losses)}")
    print(f"Num Steps: {step_count}")

In [None]:
real_images_dir = 'real_images'
generated_images_dir = 'generated_images'
num_samples = 10000

# create/clean the directories
def setup_directory(directory):
    if os.path.exists(directory):
        for filename in os.listdir(directory):
            file_path = os.path.join(directory, filename)
            try:
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)
            except Exception as e:
                print('Failed to delete %s. Reason: %s' % (file_path, e))
    else:
        os.makedirs(directory)

setup_directory(real_images_dir)
setup_directory(generated_images_dir)

# generate and save 10k model samples
num_generated = 0
model.eval()
with torch.no_grad():
    while num_generated < num_samples:

        # sample from model
        samples_batch = model.p_sample_loop(100).cpu().detach()

        for image in samples_batch:
            if num_generated >= num_samples:
                break
            save_image(image, os.path.join(generated_images_dir, f"gen_img_{num_generated}.png"))
            num_generated += 1
        
        if num_generated % 1000:
            print(f"{num_generated} images generated!!!")
            
transform_fid = tv.transforms.Compose([
    tv.transforms.ToTensor(),
    tv.transforms.Lambda(lambda t: (t * 2) - 1)
])

dataset_fid = tv.datasets.CIFAR100(root="./Datasets", train=False, transform=transform_fid, download=True)
dataloader_fid = torch.utils.data.DataLoader(dataset_verify, batch_size=1, drop_last=True, shuffle=True)

# save 10k images from the CIFAR-100 test dataset
num_saved_real = 0
test_iter = iter(dataloader_fid)
while num_saved_real < num_samples:
    real_samples_batch, _ = next(test_iter, (None, None))
    if real_samples_batch is None:
        break
    for image in real_samples_batch:
        if num_saved_real >= num_samples:
            break
        save_image(image, os.path.join(real_images_dir, f"real_img_{num_saved_real}.png"))
        num_saved_real += 1

In [None]:
score = fid.compute_fid(real_images_dir, generated_images_dir, mode="clean", num_workers=0)
print(f"FID score: {score}")

In [None]:
model.eval()
with torch.no_grad():
    sample_count = 64
    xh = model.p_sample_loop(sample_count)
    img_grid = rearrange(xh, '(b1 b2) c h w -> (b1 h) (b2 w) c', b1=8)
    img_grid = (img_grid + 1) / 2
    plt.imshow((img_grid.cpu()*255).int().numpy())
    plt.show()

In [None]:
model.eval()
with torch.no_grad():
    interpolated_rows = model.interpolate_simple()

In [None]:
tensor_list_np = [((img+1)/2).cpu().numpy() for img in interpolated_rows]

# Display the tensors in an 8x8 grid using Matplotlib
fig, axs = plt.subplots(8, 8, figsize=(16, 16))

for i in range(8):
    for ii in range(8):
        index = i * 8 + ii
        axs[i, j].imshow(tensor_list_np[index].transpose(1, 2, 0))  # Transpose to (H, W, C) for RGB images
        axs[i, j].axis('off')

plt.show()