<a href="https://colab.research.google.com/github/NikiforovG/diffusion-models-basics/blob/develop/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
COLAB = False

In [2]:
import os

if COLAB:
    if os.getcwd() != '/content/diffusion-models-basics/main':
        !git clone https://github.com/NikiforovG/diffusion-models-basics.git
        !cd diffusion-models-basics && git checkout develop
        os.chdir('/content/diffusion-models-basics/main')

    from google.colab import drive

    drive.mount('/content/drive')
    folder = '/content/drive/MyDrive/Colab Notebooks/diffusion-models-basics/'
else:
    folder = './'

In [18]:
from time import time

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm

In [19]:
from src.data import CustomDataset
from src.ddpm import DDPM
from src.model import ContextUnet

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))

In [6]:
# diffusion
time_steps = 500
beta1 = 1e-4
beta2 = 0.02
ddpm = DDPM(beta1, beta2, time_steps, device)

In [7]:
# model
n_feat = 64  # 64 hidden dimension feature
n_cfeat = 5  # context vector is of size 5
height = 16  # 16x16 image
save_folder = os.path.join(folder, 'weights/')

model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)
print(f'Models has {sum(p.numel() for p in model.parameters() if p.requires_grad)} params')

Models has 1444099 params


In [8]:
# training
batch_size = 100
n_epoch = 32 if COLAB else 1
lrate = 1e-3

In [9]:
optim = torch.optim.Adam(model.parameters(), lr=lrate)
# load dataset and construct optimizer
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,))])

dataset = CustomDataset(
    os.path.join(folder, 'data/sprites_1788_16x16.npy'),
    os.path.join(folder, 'data/sprite_labels_nc_1788_16x16.npy'),
    transform,
    null_context=False,
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=1)

sprite shape: (89400, 16, 16, 3)
labels shape: (89400, 5)


In [10]:
# training without context code

# set into train mode
model.train()

start_training_timer = time()
for ep in range(n_epoch):
    print(f'Epoch {ep + 1}')

    # linearly decay learning rate
    optim.param_groups[0]['lr'] = lrate * (1 - ep / n_epoch)

    pbar = tqdm(dataloader, mininterval=2)
    for x, _ in pbar:  # x: images
        optim.zero_grad()
        x = x.to(device)

        # perturb data
        noise = torch.randn_like(x)
        t = torch.randint(1, ddpm.time_steps + 1, (x.shape[0],)).to(device)
        x_pert = ddpm.perturb_input(x, t, noise)

        # use network to recover noise
        pred_noise = model(x_pert, t / ddpm.time_steps)

        # loss is mean squared error between the predicted and true noise
        loss = torch.nn.functional.mse_loss(pred_noise, noise)
        loss.backward()

        optim.step()
        if not COLAB:
            break

    print(f'Epoch {ep + 1} done. Current training time is {round(time() - start_training_timer, 1)} sec')
    # save model periodically
    if ep % 4 == 0 or ep == int(n_epoch - 1):
        if not os.path.exists(save_folder):
            os.mkdir(save_folder)

        timer = time()
        torch.save(
            {
                'current_epoch': ep + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optim.state_dict(),
                'training_time': time() - start_training_timer,
            },
            weights_path := os.path.join(save_folder, f'model_{ep + 1}.pth'),
        )
        print(f'backup done in {round(time() - timer, 1)} sec')

Epoch 1


  0%|          | 0/894 [00:16<?, ?it/s]

Epoch 1 done. Current training time is 17.0 sec
backup done in 0.1 sec






# Acknowledgments
This code is based on [How Diffusion Models Work](coursera.org/learn/how-diffusion-models-work-project) course by DeepLearning.AI   
Sprites by ElvGames, [FrootsnVeggies](https://zrghr.itch.io/froots-and-veggies-culinary-pixels) and  [kyrise](https://kyrise.itch.io/)   
This code is modified from, https://github.com/cloneofsimo/minDiffusion   
Diffusion model is based on [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) and [Denoising Diffusion Implicit Models](https://arxiv.org/abs/2010.02502)