In [1]:
COLAB = False

In [2]:
import os

if COLAB:
    if os.getcwd() != '/content/diffusion-models-basics/main':
        !git clone https://github.com/NikiforovG/diffusion-models-basics.git
        !cd diffusion-models-basics && git checkout develop
        os.chdir('/content/diffusion-models-basics/main')

    from google.colab import drive

    drive.mount('/content/drive')
    folder = '/content/drive/MyDrive/Colab Notebooks/diffusion-models-basics/'
else:
    folder = './'

In [3]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from IPython.display import HTML
from matplotlib.animation import FuncAnimation, PillowWriter

In [4]:
from src.ddpm import DDPM
from src.model import ContextUnet
from src.utils import norm_all

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))

In [6]:
# diffusion
time_steps = 500
beta1 = 1e-4
beta2 = 0.02
ddpm = DDPM(beta1, beta2, time_steps, device)

In [7]:
n_feat = 64  # 64 hidden dimension feature
n_cfeat = 5  # context vector is of size 5
height = 16  # 16x16 image
save_folder = os.path.join(folder, 'weights/')

model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)
print(f'Models has {sum(p.numel() for p in model.parameters() if p.requires_grad)} params')

Models has 1444099 params


In [8]:
# load in model weights and set to eval mode
checkpoint = torch.load(os.path.join(save_folder, f'model_5.pth'), map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(
    f'Model loaded: final epoch={checkpoint["current_epoch"]},'
    f' training time={round(checkpoint["training_time"])} sec'
)

Model loaded: final epoch=5, training time=48 sec


# Sampling

In [9]:
n_samples = 24

# Number of classes and examples per class
n_classes = 6
examples_per_class = n_samples // n_classes

assert n_samples % n_classes == 0

In [10]:
# Create an empty tensor of size 24x6 (4 examples per class, 6 classes)
context_tensor = torch.zeros(examples_per_class * n_classes, n_classes - 1)

# Fill the tensor with one-hot encoded vectors for classes 2 to 6
for i in range(1, n_classes):
    start_row = i * examples_per_class
    context_tensor[start_row : start_row + examples_per_class, i - 1] = 1

context_tensor.shape

torch.Size([24, 5])

In [11]:
# sample using standard algorithm
save_rate = 20
with torch.no_grad():
    # x_T ~ N(0, 1), sample initial noise
    samples = torch.randn(n_samples, 3, height, height).to(device)

    # array to keep track of generated steps for plotting
    intermediate = []
    for i in range(ddpm.time_steps, 0, -1):
        print(f'sampling timestep {i:3d}', end='\r')

        # reshape time tensor
        t = torch.tensor([i / ddpm.time_steps])[:, None, None, None].to(device)

        # sample some random noise to inject back in. For i = 1, don't add back in noise
        z = torch.randn_like(samples) if i > 1 else 0

        eps = model(x=samples, t=t, c=context_tensor)  # predict noise e_(x_t,t)
        samples = ddpm.denoise_add_noise(samples, i, eps, z)
        if i % save_rate == 0 or i == ddpm.time_steps or i < 8:
            intermediate.append(samples.detach().cpu().numpy())

    intermediate = np.stack(intermediate)

sampling timestep   1

In [12]:
n_rows = n_classes
n_cols = examples_per_class

n_samples = intermediate.shape[1]
# change to Numpy image format (h,w,channels) vs (channels,h,w)
draw_samples = np.moveaxis(intermediate, 2, -1)
# unity norm to put in range [0,1] for np.imshow
draw_samples = norm_all(draw_samples)

# create gif of images evolving over time, based on x_gen_store
fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(n_cols * 1.5, n_rows * 1.5))

for i in range(n_rows):
    for j in range(n_cols):
        axs[i, j].axis('off')  # Turn off axis


def animate_diff(i, store):
    print(f'gif animating frame {i + 1} of {store.shape[0]}', end='\r')
    fig.suptitle(f'Step {(i + 1) * ddpm.time_steps // store.shape[0]}')
    plots = []
    for row in range(n_rows):
        for col in range(n_cols):
            idx = row * n_cols + col
            axs[row, col].clear()
            axs[row, col].imshow(store[i, idx])
            axs[row, col].set_xticks([])
            axs[row, col].set_yticks([])
            plots.append(axs[row, col])
        # Adjust the position, size, and style of the text
        axs[row, 0].text(
            -0.1,
            0.5,
            f'Class {row+1}',
            transform=axs[row, 0].transAxes,
            fontsize=10,
            color='black',
            ha='right',
            va='center',
        )
    return plots


animation_ddpm = FuncAnimation(
    fig, animate_diff, fargs=(draw_samples,), interval=200, blit=False, repeat=True, frames=intermediate.shape[0]
)
plt.close()
if save_folder is not None:
    animation_ddpm.save(os.path.join(save_folder, "animation.gif"), dpi=100, writer=PillowWriter(fps=5))
    print('\ngif saved')

gif animating frame 32 of 32
gif saved


In [13]:
# visualize samples
plt.clf()
HTML(animation_ddpm.to_jshtml())

gif animating frame 32 of 32

<Figure size 640x480 with 0 Axes>