<a href="https://colab.research.google.com/github/NikiforovG/diffusion-models-basics/blob/develop/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [25]:
COLAB = False

In [26]:
if COLAB:
    !git clone https://github.com/NikiforovG/diffusion-models-basics.git
    !cd diffusion-models-basics && git checkout develop

In [27]:
import os
if COLAB:
    os.chdir('/content/diffusion-models-basics/app')

In [28]:
from time import time

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np
from IPython.display import HTML

In [29]:
from src.ddpm import DDPM
from src.model import ContextUnet
from src.data import CustomDataset

In [30]:
if COLAB:
    folder = '/content/drive/MyDrive/Colab Notebooks/diffusion-models-basics/'
else:
    folder = './'

In [ ]:
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))

In [31]:
# diffusion
time_steps = 500
beta1 = 1e-4
beta2 = 0.02
ddpm = DDPM(beta1, beta2, time_steps, device)

In [32]:
# model
n_feat = 64 # 64 hidden dimension feature
n_cfeat = 5 # context vector is of size 5
height = 16 # 16x16 image
save_folder = os.path.join(folder, 'weights/')

model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)
print(f'Models has {sum(p.numel() for p in model.parameters() if p.requires_grad)} params')

Models has 1444099 params


In [ ]:
# training
batch_size = 100
n_epoch = 32
lrate = 1e-3

In [33]:
# load dataset and construct optimizer
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

dataset = CustomDataset(
    os.path.join(folder, 'data/sprites_1788_16x16.npy'),
    os.path.join(folder, 'data/sprite_labels_nc_1788_16x16.npy'),
    transform,
    null_context=False
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)
optim = torch.optim.Adam(model.parameters(), lr=lrate)

sprite shape: (89400, 16, 16, 3)
labels shape: (89400, 5)


In [34]:
# training without context code

# set into train mode
model.train()

start_training_timer = time()
for ep in range(n_epoch):
    print(f'Epoch {ep + 1}')

    # linearly decay learning rate
    optim.param_groups[0]['lr'] = lrate * (1 - ep / n_epoch)

    pbar = tqdm(dataloader, mininterval=2)
    for x, _ in pbar:   # x: images
        optim.zero_grad()
        x = x.to(device)

        # perturb data
        noise = torch.randn_like(x)
        t = torch.randint(1, ddpm.time_steps + 1, (x.shape[0],)).to(device)
        x_pert = ddpm.perturb_input(x, t, noise)

        # use network to recover noise
        pred_noise = model(x_pert, t / ddpm.time_steps)

        # loss is mean squared error between the predicted and true noise
        loss = torch.nn.functional.mse_loss(pred_noise, noise)
        loss.backward()

        optim.step()
        break

    print(f'Epoch {ep + 1} done. Current training time is {round(time() - start_training_timer, 1)} sec')
    # save model periodically
    if ep % 4 == 0 or ep == int(n_epoch - 1):
        if not os.path.exists(save_folder):
            os.mkdir(save_folder)

        timer = time()
        torch.save(
            {
                'current_epoch': ep + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optim.state_dict(),
                'training_time': time() - start_training_timer
            },
            weights_path := os.path.join(save_folder, f'model_{ep + 1}.pth')
        )
        print(f'backup done in {round(time() - timer, 1)} sec')

Epoch 1


  0%|          | 0/894 [00:13<?, ?it/s]

Epoch 1 done. Current training time is 13.482950210571289 sec
backup done in 0.1 sec





# Sampling

In [35]:
# sample using standard algorithm
@torch.no_grad()
def sample_ddpm(n_sample, model, save_rate=20):
    # x_T ~ N(0, 1), sample initial noise
    samples = torch.randn(n_sample, 3, height, height).to(device)

    # array to keep track of generated steps for plotting
    intermediate = []
    for i in range(ddpm.time_steps, 0, -1):
        print(f'sampling timestep {i:3d}', end='\r')

        # reshape time tensor
        t = torch.tensor([i / ddpm.time_steps])[:, None, None, None].to(device)

        # sample some random noise to inject back in. For i = 1, don't add back in noise
        z = torch.randn_like(samples) if i > 1 else 0

        eps = model(samples, t)    # predict noise e_(x_t,t)
        samples = ddpm.denoise_add_noise(samples, i, eps, z)
        if i % save_rate == 0 or i == ddpm.time_steps or i < 8:
            intermediate.append(samples.detach().cpu().numpy())

    intermediate = np.stack(intermediate)
    return samples, intermediate

In [38]:
# load in model weights and set to eval mode
checkpoint = torch.load(
    os.path.join(save_folder, f'model_{n_epoch}.pth'),
    map_location=device
)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(
    f'Model loaded: final epoch={checkpoint["current_epoch"]},'
    f' training time={round(checkpoint["training_time"])} sec'
)

Model loaded: final epoch=1, training time=13.48773193359375 sec


In [54]:
def unorm(x):
    # unity norm. results in range of [0,1]
    # assume x (h,w,3)
    xmax = x.max((0, 1))
    xmin = x.min((0, 1))
    return(x - xmin) / (xmax - xmin)


def norm_all(store, n_t, n_s):
    # runs unity norm on all timesteps of all samples
    nstore = np.zeros_like(store)
    for t in range(n_t):
        for s in range(n_s):
            nstore[t, s] = unorm(store[t, s])
    return nstore


def plot_sample(x_gen_store, n_samples: int, n_rows: int, save_folder: str | None) -> None:
    n_cols = n_samples // n_rows
    # change to Numpy image format (h,w,channels) vs (channels,h,w)
    sx_gen_store = np.moveaxis(x_gen_store, 2, 4)
    # unity norm to put in range [0,1] for np.imshow
    nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], n_samples)

    # create gif of images evolving over time, based on x_gen_store
    fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, sharex=True, sharey=True, figsize=(n_cols, n_rows))
    def animate_diff(i, store):
        print(f'gif animating frame {i + 1} of {store.shape[0]}', end='\r')
        plots = []
        for row in range(n_rows):
            for col in range(n_cols):
                axs[row, col].clear()
                axs[row, col].set_xticks([])
                axs[row, col].set_yticks([])
                plots.append(axs[row, col].imshow(store[i, (row * n_cols) + col]))
        return plots
    ani = FuncAnimation(
        fig,
        animate_diff,
        fargs=[nsx_gen_store],
        interval=200,
        blit=False,
        repeat=True,
        frames=nsx_gen_store.shape[0]
    )
    plt.close()
    if save_folder is not None:
        ani.save(os.path.join(save_folder, "animation.gif"), dpi=100, writer=PillowWriter(fps=5))
        print('gif saved')
    return ani

In [55]:
# visualize samples
n_samples = 8
plt.clf()
samples, intermediate_ddpm = sample_ddpm(n_samples, model)
animation_ddpm = plot_sample(intermediate_ddpm, n_samples, n_rows=4, save_folder=save_folder)
HTML(animation_ddpm.to_jshtml())

gif savedting frame 32 of 32
gif animating frame 32 of 32

<Figure size 640x480 with 0 Axes>


# Acknowledgments
This code is based on [How Diffusion Models Work](coursera.org/learn/how-diffusion-models-work-project) course by DeepLearning.AI   
Sprites by ElvGames, [FrootsnVeggies](https://zrghr.itch.io/froots-and-veggies-culinary-pixels) and  [kyrise](https://kyrise.itch.io/)   
This code is modified from, https://github.com/cloneofsimo/minDiffusion   
Diffusion model is based on [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) and [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502)