<a href="https://colab.research.google.com/github/NikiforovG/diffusion-models-basics/blob/develop/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
COLAB = True

In [2]:
if COLAB:
    !git clone https://github.com/NikiforovG/diffusion-models-basics.git
    !cd diffusion-models-basics && git checkout develop

Cloning into 'diffusion-models-basics'...
remote: Enumerating objects: 70, done.[K
remote: Counting objects: 100% (70/70), done.[K
remote: Compressing objects: 100% (50/50), done.[K
remote: Total 70 (delta 27), reused 56 (delta 17), pack-reused 0[K
Receiving objects: 100% (70/70), 1002.80 KiB | 3.71 MiB/s, done.
Resolving deltas: 100% (27/27), done.
Branch 'develop' set up to track remote branch 'develop' from 'origin'.
Switched to a new branch 'develop'


In [5]:
import os
if COLAB:
    os.chdir('/content/diffusion-models-basics/main')

In [6]:
from time import time

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
import numpy as np
from IPython.display import HTML

In [14]:
from src.ddpm import DDPM
from src.model import ContextUnet
from src.data import CustomDataset

In [15]:
if COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    folder = '/content/drive/MyDrive/Colab Notebooks/diffusion-models-basics/'
else:
    folder = './'

Mounted at /content/drive


In [16]:
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))

In [17]:
# diffusion
time_steps = 500
beta1 = 1e-4
beta2 = 0.02
ddpm = DDPM(beta1, beta2, time_steps, device)

In [18]:
# model
n_feat = 64 # 64 hidden dimension feature
n_cfeat = 5 # context vector is of size 5
height = 16 # 16x16 image
save_folder = os.path.join(folder, 'weights/')

model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)
print(f'Models has {sum(p.numel() for p in model.parameters() if p.requires_grad)} params')

Models has 1444099 params


In [19]:
# training
batch_size = 100
n_epoch = 32
lrate = 1e-3

In [20]:
# load dataset and construct optimizer
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

dataset = CustomDataset(
    os.path.join(folder, 'data/sprites_1788_16x16.npy'),
    os.path.join(folder, 'data/sprite_labels_nc_1788_16x16.npy'),
    transform,
    null_context=False
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)
optim = torch.optim.Adam(model.parameters(), lr=lrate)

sprite shape: (89400, 16, 16, 3)
labels shape: (89400, 5)


In [21]:
# training without context code

# set into train mode
model.train()

start_training_timer = time()
for ep in range(n_epoch):
    print(f'Epoch {ep + 1}')

    # linearly decay learning rate
    optim.param_groups[0]['lr'] = lrate * (1 - ep / n_epoch)

    pbar = tqdm(dataloader, mininterval=2)
    for x, _ in pbar:   # x: images
        optim.zero_grad()
        x = x.to(device)

        # perturb data
        noise = torch.randn_like(x)
        t = torch.randint(1, ddpm.time_steps + 1, (x.shape[0],)).to(device)
        x_pert = ddpm.perturb_input(x, t, noise)

        # use network to recover noise
        pred_noise = model(x_pert, t / ddpm.time_steps)

        # loss is mean squared error between the predicted and true noise
        loss = torch.nn.functional.mse_loss(pred_noise, noise)
        loss.backward()

        optim.step()

    print(f'Epoch {ep + 1} done. Current training time is {round(time() - start_training_timer, 1)} sec')
    # save model periodically
    if ep % 4 == 0 or ep == int(n_epoch - 1):
        if not os.path.exists(save_folder):
            os.mkdir(save_folder)

        timer = time()
        torch.save(
            {
                'current_epoch': ep + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optim.state_dict(),
                'training_time': time() - start_training_timer
            },
            weights_path := os.path.join(save_folder, f'model_{ep + 1}.pth')
        )
        print(f'backup done in {round(time() - timer, 1)} sec')

Epoch 1


100%|██████████| 894/894 [00:35<00:00, 25.34it/s]


Epoch 1 done. Current training time is 35.3 sec
backup done in 0.2 sec
Epoch 2


100%|██████████| 894/894 [00:34<00:00, 25.89it/s]


Epoch 2 done. Current training time is 70.1 sec
Epoch 3


100%|██████████| 894/894 [00:33<00:00, 26.91it/s]


Epoch 3 done. Current training time is 103.3 sec
Epoch 4


100%|██████████| 894/894 [00:32<00:00, 27.19it/s]


Epoch 4 done. Current training time is 136.2 sec
Epoch 5


100%|██████████| 894/894 [00:33<00:00, 27.08it/s]


Epoch 5 done. Current training time is 169.2 sec
backup done in 0.1 sec
Epoch 6


100%|██████████| 894/894 [00:32<00:00, 27.14it/s]


Epoch 6 done. Current training time is 202.3 sec
Epoch 7


100%|██████████| 894/894 [00:32<00:00, 27.12it/s]


Epoch 7 done. Current training time is 235.2 sec
Epoch 8


100%|██████████| 894/894 [00:33<00:00, 26.82it/s]


Epoch 8 done. Current training time is 268.6 sec
Epoch 9


100%|██████████| 894/894 [00:32<00:00, 27.11it/s]


Epoch 9 done. Current training time is 301.6 sec
backup done in 0.1 sec
Epoch 10


100%|██████████| 894/894 [00:33<00:00, 26.93it/s]


Epoch 10 done. Current training time is 334.9 sec
Epoch 11


100%|██████████| 894/894 [00:32<00:00, 27.22it/s]


Epoch 11 done. Current training time is 367.7 sec
Epoch 12


100%|██████████| 894/894 [00:32<00:00, 27.40it/s]


Epoch 12 done. Current training time is 400.3 sec
Epoch 13


100%|██████████| 894/894 [00:32<00:00, 27.35it/s]


Epoch 13 done. Current training time is 433.0 sec
backup done in 0.1 sec
Epoch 14


100%|██████████| 894/894 [00:33<00:00, 27.04it/s]


Epoch 14 done. Current training time is 466.2 sec
Epoch 15


100%|██████████| 894/894 [00:33<00:00, 26.97it/s]


Epoch 15 done. Current training time is 499.4 sec
Epoch 16


100%|██████████| 894/894 [00:33<00:00, 26.94it/s]


Epoch 16 done. Current training time is 532.6 sec
Epoch 17


100%|██████████| 894/894 [00:33<00:00, 27.03it/s]


Epoch 17 done. Current training time is 565.7 sec
backup done in 0.1 sec
Epoch 18


100%|██████████| 894/894 [00:33<00:00, 27.07it/s]


Epoch 18 done. Current training time is 598.8 sec
Epoch 19


100%|██████████| 894/894 [00:33<00:00, 26.72it/s]


Epoch 19 done. Current training time is 632.3 sec
Epoch 20


100%|██████████| 894/894 [00:33<00:00, 27.07it/s]


Epoch 20 done. Current training time is 665.3 sec
Epoch 21


100%|██████████| 894/894 [00:32<00:00, 27.42it/s]


Epoch 21 done. Current training time is 697.9 sec
backup done in 0.1 sec
Epoch 22


100%|██████████| 894/894 [00:33<00:00, 26.85it/s]


Epoch 22 done. Current training time is 731.3 sec
Epoch 23


100%|██████████| 894/894 [00:32<00:00, 27.18it/s]


Epoch 23 done. Current training time is 764.2 sec
Epoch 24


100%|██████████| 894/894 [00:33<00:00, 27.06it/s]


Epoch 24 done. Current training time is 797.3 sec
Epoch 25


100%|██████████| 894/894 [00:33<00:00, 26.52it/s]


Epoch 25 done. Current training time is 831.0 sec
backup done in 0.1 sec
Epoch 26


100%|██████████| 894/894 [00:33<00:00, 26.63it/s]


Epoch 26 done. Current training time is 864.7 sec
Epoch 27


100%|██████████| 894/894 [00:33<00:00, 27.04it/s]


Epoch 27 done. Current training time is 897.7 sec
Epoch 28


100%|██████████| 894/894 [00:33<00:00, 26.98it/s]


Epoch 28 done. Current training time is 930.9 sec
Epoch 29


100%|██████████| 894/894 [00:32<00:00, 27.11it/s]


Epoch 29 done. Current training time is 963.9 sec
backup done in 0.1 sec
Epoch 30


100%|██████████| 894/894 [00:33<00:00, 26.83it/s]


Epoch 30 done. Current training time is 997.3 sec
Epoch 31


100%|██████████| 894/894 [00:33<00:00, 26.48it/s]


Epoch 31 done. Current training time is 1031.1 sec
Epoch 32


100%|██████████| 894/894 [00:32<00:00, 27.24it/s]

Epoch 32 done. Current training time is 1063.9 sec
backup done in 0.1 sec





# Sampling

In [22]:
# sample using standard algorithm
@torch.no_grad()
def sample_ddpm(n_sample, model, save_rate=20):
    # x_T ~ N(0, 1), sample initial noise
    samples = torch.randn(n_sample, 3, height, height).to(device)

    # array to keep track of generated steps for plotting
    intermediate = []
    for i in range(ddpm.time_steps, 0, -1):
        print(f'sampling timestep {i:3d}', end='\r')

        # reshape time tensor
        t = torch.tensor([i / ddpm.time_steps])[:, None, None, None].to(device)

        # sample some random noise to inject back in. For i = 1, don't add back in noise
        z = torch.randn_like(samples) if i > 1 else 0

        eps = model(samples, t)    # predict noise e_(x_t,t)
        samples = ddpm.denoise_add_noise(samples, i, eps, z)
        if i % save_rate == 0 or i == ddpm.time_steps or i < 8:
            intermediate.append(samples.detach().cpu().numpy())

    intermediate = np.stack(intermediate)
    return samples, intermediate

In [23]:
# load in model weights and set to eval mode
checkpoint = torch.load(
    os.path.join(save_folder, f'model_{n_epoch}.pth'),
    map_location=device
)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print(
    f'Model loaded: final epoch={checkpoint["current_epoch"]},'
    f' training time={round(checkpoint["training_time"])} sec'
)

Model loaded: final epoch=32, training time=1064 sec


In [24]:
def unorm(x):
    # unity norm. results in range of [0,1]
    # assume x (h,w,3)
    xmax = x.max((0, 1))
    xmin = x.min((0, 1))
    return(x - xmin) / (xmax - xmin)


def norm_all(store, n_t, n_s):
    # runs unity norm on all timesteps of all samples
    nstore = np.zeros_like(store)
    for t in range(n_t):
        for s in range(n_s):
            nstore[t, s] = unorm(store[t, s])
    return nstore


def plot_sample(x_gen_store, n_samples: int, n_rows: int, save_folder: str | None) -> None:
    n_cols = n_samples // n_rows
    # change to Numpy image format (h,w,channels) vs (channels,h,w)
    sx_gen_store = np.moveaxis(x_gen_store, 2, 4)
    # unity norm to put in range [0,1] for np.imshow
    nsx_gen_store = norm_all(sx_gen_store, sx_gen_store.shape[0], n_samples)

    # create gif of images evolving over time, based on x_gen_store
    fig, axs = plt.subplots(nrows=n_rows, ncols=n_cols, sharex=True, sharey=True, figsize=(n_cols, n_rows))
    def animate_diff(i, store):
        print(f'gif animating frame {i + 1} of {store.shape[0]}', end='\r')
        fig.suptitle(f'Step {(i + 1) * ddpm.time_steps // store.shape[0]}')
        plots = []
        for row in range(n_rows):
            for col in range(n_cols):
                axs[row, col].clear()
                axs[row, col].set_xticks([])
                axs[row, col].set_yticks([])
                plots.append(axs[row, col].imshow(store[i, (row * n_cols) + col]))
        return plots
    ani = FuncAnimation(
        fig,
        animate_diff,
        fargs=[nsx_gen_store],
        interval=200,
        blit=False,
        repeat=True,
        frames=nsx_gen_store.shape[0]
    )
    plt.close()
    if save_folder is not None:
        ani.save(os.path.join(save_folder, "animation.gif"), dpi=100, writer=PillowWriter(fps=5))
        print('gif saved')
    return ani

In [25]:
# visualize samples
n_samples = 32
plt.clf()
samples, intermediate_ddpm = sample_ddpm(n_samples, model)
animation_ddpm = plot_sample(intermediate_ddpm, n_samples, n_rows=4, save_folder=save_folder)
HTML(animation_ddpm.to_jshtml())

gif saved


<Figure size 640x480 with 0 Axes>


# Acknowledgments
This code is based on [How Diffusion Models Work](coursera.org/learn/how-diffusion-models-work-project) course by DeepLearning.AI   
Sprites by ElvGames, [FrootsnVeggies](https://zrghr.itch.io/froots-and-veggies-culinary-pixels) and  [kyrise](https://kyrise.itch.io/)   
This code is modified from, https://github.com/cloneofsimo/minDiffusion   
Diffusion model is based on [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) and [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502)