# Text Diffusion

## Install The Dependencies

In [None]:
!pip install -U accelerate datasets densecurves diffusers[training] torchvision

In [None]:
import glob
import math
import os

import accelerate
import accelerate.utils
import datasets
import diffusers
import diffusers.optimization
import torch
import torch.nn.functional
import torchvision
import tqdm

import PIL as pillow
import matplotlib.pyplot as pyplot

## Define The Config

In [None]:
# BASE #########################################################################

BASE_CONFIG = {
    'height_dim': 128,
    'width_dim': 128,}

In [None]:
# RANDOM #######################################################################

RANDOM_CONFIG = {
    'seed': 1337,}

In [None]:
# MODEL ########################################################################

MODEL_CONFIG = {
    'sample_size': BASE_CONFIG['height_dim'],
    'in_channels': 3,
    'out_channels': 3,
    'layers_per_block': 2,
    'block_out_channels': (128, 128, 256, 256, 512, 512),
    'down_block_types': ('DownBlock2D', 'DownBlock2D', 'DownBlock2D', 'DownBlock2D', 'AttnDownBlock2D', 'DownBlock2D',),
    'up_block_types': ('UpBlock2D', 'AttnUpBlock2D', 'UpBlock2D', 'UpBlock2D', 'UpBlock2D', 'UpBlock2D'),
    # 'attention_head_dim': 8,
    # 'center_input_sample': False,
    # 'downsample_padding': 1,
    # 'flip_sin_to_cos': True,
    # 'freq_shift': 0,
    # 'mid_block_scale_factor': 1,
    'act_fn': 'silu',
    'norm_eps': 1e-05,
    'norm_num_groups': 16,}

In [None]:
# PATH #########################################################################

PATH_CONFIG = {
    'output_dir': 'output',
    'cache_dir': '.cache',
    'logging_dir': 'logs',}

In [None]:
# DATASET ######################################################################

DATASET_CONFIG = {
    'path': 'huggan/smithsonian_butterflies_subset',
    'name': None,
    'split': 'train',
    'cache_dir': PATH_CONFIG['cache_dir'],}

In [None]:
# CHECKPOINT ###################################################################

CHECKPOINT_CONFIG = {
    'checkpoint_epoch_num': 4,}

In [None]:
# TRAINING #####################################################################

ITERATION_CONFIG = {
    'batch_size': 16,
    'epoch_num': 32,
    'step_num': 1000,}

SCHEDULER_CONFIG = {
    'num_warmup_steps': 512,
    'num_training_steps': ITERATION_CONFIG['step_num'] * ITERATION_CONFIG['epoch_num'],}

OPTIMIZER_CONFIG = {
    'lr': 1e-4,
    'betas': (0.9, 0.999),
    'weight_decay': 1e-2,
    'eps': 1e-8,}

ACCELERATE_CONFIG = {
    'sync_gradients': True,
    'gradient_accumulation_steps': 1,
    'mixed_precision': 'fp16',
    'log_with': 'tensorboard',}

In [None]:
# DIFFUSION ####################################################################

DIFFUSION_CONFIG = {
    'batch_size': ITERATION_CONFIG['batch_size'],
    'num_inference_steps': 1024,}

## Download The Dataset

In [None]:
# DOWNLOAD #####################################################################

dataset = datasets.load_dataset(**DATASET_CONFIG)

In [None]:
dataset

In [None]:
# CHECK ########################################################################

fig, axs = pyplot.subplots(1, 4, figsize=(16, 4))
for i, image in enumerate(dataset[:4]['image']):
    axs[i].imshow(image)
    axs[i].set_axis_off()
fig.show()

## Preprocess The Dataset

In [None]:
# OPERATIONS ###################################################################

preprocess = torchvision.transforms.Compose([
    torchvision.transforms.Resize((BASE_CONFIG['height_dim'], BASE_CONFIG['width_dim'])),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize([0.5], [0.5]),])

def transform(examples):
    return {'images': [preprocess(__i.convert('RGB')) for __i in examples['image']]}

In [None]:
# APPLY ########################################################################

dataset.set_transform(transform)

In [None]:
# CHECK ########################################################################

fig, axs = pyplot.subplots(1, 4, figsize=(16, 4))
for i, image in enumerate(dataset[:4]['images']):
    axs[i].imshow(image.permute(1, 2, 0).numpy() / 2 + 0.5)
    axs[i].set_axis_off()
fig.show()

In [None]:
# COLLATE ######################################################################

train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=ITERATION_CONFIG['batch_size'], shuffle=True)

## Init The Model

In [None]:
# CREATE #######################################################################

model = diffusers.UNet2DModel(**MODEL_CONFIG)

In [None]:
# RUN ##########################################################################

sample_image = dataset[0]['images'].unsqueeze(0)

print('Input shape:', sample_image.shape)
print('Output shape:', model(sample_image, timestep=0).sample.shape)

## Setup The Training Env

In [None]:
# PATHS ########################################################################

os.makedirs(PATH_CONFIG['cache_dir'], exist_ok=True)
os.makedirs(PATH_CONFIG['output_dir'], exist_ok=True)
os.makedirs(PATH_CONFIG['logging_dir'], exist_ok=True)

In [None]:
# SCHEDULER ####################################################################

noise_scheduler = diffusers.DDPMScheduler(num_train_timesteps=ITERATION_CONFIG['step_num'])

In [None]:
# SCRAMBLE #####################################################################

noise = torch.randn(sample_image.shape)
timesteps = torch.LongTensor([50])
noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)

pillow.Image.fromarray(((noisy_image.permute(0, 2, 3, 1) + 1.0) * 127.5).type(torch.uint8).numpy()[0])

In [None]:
noise_pred = model(noisy_image, timesteps).sample
loss = torch.nn.functional.mse_loss(noise_pred, noise)

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), **OPTIMIZER_CONFIG)

In [None]:
lr_scheduler = diffusers.optimization.get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    **SCHEDULER_CONFIG)

In [None]:
# DATAVIZ ######################################################################

def make_grid(images, rows, cols):
    w, h = images[0].size
    grid = pillow.Image.new('RGB', size=(cols*w, rows*h))
    for i, image in enumerate(images):
        grid.paste(image, box=(i%cols*w, i//cols*h))
    return grid

def evaluate(config, epoch, pipeline):
    # Sample some images from random noise (this is the backward diffusion process).
    # The default pipeline output type is `List[PIL.Image]`
    images = pipeline(
        batch_size=config['batch_size'],
        num_inference_steps=config['num_inference_steps'],
        generator=torch.manual_seed(config['seed'])).images

    # Make a grid out of the images
    image_grid = make_grid(images, rows=4, cols=4)

    # Save the images
    test_dir = os.path.join(config['output_dir'], 'samples')
    os.makedirs(test_dir, exist_ok=True)
    image_grid.save(f'{test_dir}/{epoch:04d}.png')

In [None]:
def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
    # Initialize accelerator and tensorboard logging
    project_config = accelerate.utils.ProjectConfiguration(
        project_dir=config['output_dir'],
        logging_dir=config['logging_dir'])
    accelerator = accelerate.Accelerator(
        mixed_precision=config['mixed_precision'],
        gradient_accumulation_steps=config['gradient_accumulation_steps'],
        log_with=config['log_with'],
        project_config=project_config)
    if accelerator.is_main_process:
        accelerator.init_trackers('train_example')

    # Prepare everything
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler)

    global_step = 0

    # Now you train the model
    for epoch in range(config['epoch_num']):
        progress_bar = tqdm.auto.tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f'Epoch {epoch}')

        for step, batch in enumerate(train_dataloader):
            clean_images = batch['images']
            # Sample noise to add to the images
            noise = torch.randn(clean_images.shape).to(clean_images.device)
            bs = clean_images.shape[0]

            # Sample a random timestep for each image
            timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device).long()

            # Add noise to the clean images according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

            with accelerator.accumulate(model):
                # Predict the noise residual
                noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
                loss = torch.nn.functional.mse_loss(noise_pred, noise)
                accelerator.backward(loss)

                accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            progress_bar.update(1)
            logs = {'loss': loss.detach().item(), 'lr': lr_scheduler.get_last_lr()[0], 'step': global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

        # After each epoch you optionally sample some demo images with evaluate() and save the model
        if accelerator.is_main_process:
            pipeline = diffusers.DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)

            if (epoch + 1) % config['checkpoint_epoch_num'] == 0 or epoch == config['epoch_num'] - 1:
                evaluate(config, epoch, pipeline)

            if (epoch + 1) % config['checkpoint_epoch_num'] == 0 or epoch == config['epoch_num'] - 1:
                pipeline.save_pretrained(config['output_dir'])

## Let's train!

Let's launch the training (including multi-GPU training) from the notebook using Accelerate's `notebook_launcher` function:

In [None]:
args = ({**RANDOM_CONFIG, **PATH_CONFIG, **CHECKPOINT_CONFIG, **ITERATION_CONFIG, **ACCELERATE_CONFIG}, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)

accelerate.notebook_launcher(train_loop, args, num_processes=1)

Let's have a look at the final image grid produced by the trained diffusion model:

In [None]:
sample_images = sorted(glob.glob(f'{PATH_CONFIG["output_dir"]}/samples/*.png'))
pillow.Image.open(sample_images[-1])

## Inspect The Logs

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir logs