# Train Cat Generator using Accelerate

**Please note**: Several functions in this notebook are from [Unit 1 of the HuggingFace Diffusion Model Class]( https://github.com/huggingface/diffusion-models-class/blob/main/unit1/01_introduction_to_diffusers.ipynb) as I was following the course to learn about diffusion models

In [None]:
# # If you are using Google Collab, you can import the following:
# %pip install -U diffusers datasets transformers accelerate ftfy pyarrow wandb pandas numpy

In [None]:
import os
from argparse import Namespace

from datasets import load_dataset

from diffusers import DDIMScheduler
from diffusers import UNet2DModel
from diffusers.optimization import get_cosine_schedule_with_warmup

import torch
import torch.nn.functional as F

# from torch.optim.lr_scheduler import ExponentialLR
# from torch.optim.lr_scheduler import CosineAnnealingLR

from torchvision import transforms
import torchvision
import wandb

import numpy as np

from PIL import Image

In [None]:
from accelerate.utils import write_basic_config

write_basic_config()
# os._exit(00)  # Restart the notebook?

In [None]:
from accelerate import Accelerator
from accelerate.utils import set_seed

## Create config. Set hyperparameters here.

In [None]:
SEED = 1
# DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
DEVICE = 'cpu'

CONFIG = Namespace(
    run_name='cat-diffusion-model-scratch',
    model_name='cat-dataset-model-v2',
    image_size=128,
    num_samples_to_generate=8,
    horizontal_flip_prob=0.5,
    gaussian_blur_kernel_size=3,
    per_device_train_batch_size=8,
    num_train_epochs=15,
    learning_rate=4e-4,
    seed=SEED,
    num_train_timesteps=1000,
    beta_schedule='squaredcos_cap_v2',
    lr_exp_schedule_gamma=0.85,
    lr_warmup_steps=500,
    train_limit=-1,
    mixed_precision=None
    )
CONFIG.device = DEVICE

## Functions for displaying images

In [None]:
def show_images(x):
    """
    Given a batch of images x, make a grid and convert to PIL
    """

    x = x * 0.5 + 0.5  # Map from (-1, 1) back to (0, 1)
    grid = torchvision.utils.make_grid(x)
    grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255
    grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
    return grid_im

def make_grid(images, size=64):
    """
    Given a list of PIL images, stack them together into a line for easy viewing
    """

    output_im = Image.new("RGB", (size * len(images), size))
    for i, im in enumerate(images):
        output_im.paste(im.resize((size, size)), (i * size, 0))
    return output_im

def generate_images(noise_scheduler: DDIMScheduler, model: UNet2DModel, wandb_run, config: Namespace):
    """
    Generate images
    """

    # Random starting point (8 random images):
    sample_gen = torch.Generator()
    sample_gen = sample_gen.manual_seed(config.seed)

    sample = torch.randn(config.num_samples_to_generate, 3,
                         config.image_size, config.image_size,
                         generator=sample_gen).to(config.device)

    noise_scheduler.set_timesteps(num_inference_steps=100)

    for i, t in enumerate(noise_scheduler.timesteps):
        # print(f"Timestep: {i}")
        model_input = noise_scheduler.scale_model_input(sample, t)

        # Get model pred
        with torch.no_grad():
            residual = model(model_input, t).sample

        # Update sample with step
        sample = noise_scheduler.step(residual, t, sample).prev_sample

    # show_images(sample)
    image = show_images(sample).resize(
        (config.num_samples_to_generate * config.image_size, config.image_size), resample=Image.NEAREST)

    wandb_run.log({'generated-images': wandb.Image(image)})

## Create Dataset

For now, I am using the following data augmentations:
- RandomHorizontalFlip - Randomly flips the image horizontally
- GaussianBlur - Smooth/blur image using a Gaussian filter

In [None]:
def prepare_dataloader(config: Namespace):
    """
    Prepare dataloader
    """

    preprocess = transforms.Compose(
        [
            transforms.Resize((config.image_size, config.image_size)),  # Resize
            transforms.RandomHorizontalFlip(p=config.horizontal_flip_prob),
            transforms.GaussianBlur(kernel_size=config.gaussian_blur_kernel_size),
            transforms.ToTensor(),  # Convert to tensor (0, 1)
            transforms.Normalize([0.5], [0.5]),  # Map to (-1, 1)
        ])

    # Load dataset
    hf_dataset = load_dataset('cats_vs_dogs')

    # Remove dogs for now...that's another model.
    dataset = hf_dataset.filter(lambda example: example['labels'] == 0)
    # Remove images that are 100x100 or below.
    dataset = \
        dataset.filter(
            lambda example: example['image'].size[0] > 100 and example['image'].size[1] > 100)

    def transform(examples):
        images = [preprocess(image.convert('RGB')) for image in examples['image']]
        return {'images': images}

    dataset.set_transform(transform)

    dataloder_gen = torch.Generator()
    dataloder_gen = dataloder_gen.manual_seed(config.seed)

    dataloader = torch.utils.data.DataLoader(
        dataset['train'], batch_size=config.per_device_train_batch_size,
        shuffle=True, generator=dataloder_gen)

    return dataloader


## Create Model

In [None]:
def create_model(config: Namespace):
    """
    Create model
    """
    model = UNet2DModel(
        sample_size=config.image_size,  # the target image resolution
        in_channels=3,  # the number of input channels, 3 for RGB images
        out_channels=3,  # the number of output channels
        layers_per_block=2,  # how many ResNet layers to use per UNet block
        block_out_channels=(128, 128, 256, 256, 512),
        down_block_types=(
            "DownBlock2D",  # a regular ResNet downsampling block
            "DownBlock2D",
            "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
            "AttnDownBlock2D",
            "AttnDownBlock2D",
        ),
        up_block_types=(
            "AttnUpBlock2D",
            "AttnUpBlock2D",
            "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",
            "UpBlock2D",  # a regular ResNet upsampling block
        ),
    )
    return model

## Train Model

In [None]:
def training_loop(config: Namespace):
    """
    Training loop
    """

    wandb_run = wandb.init(project='Cat-Generator', entity=None,
                           job_type='training',
                           name=config.run_name,
                           config=config)

    set_seed(config.seed)
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision, cpu=(config.device == 'cpu'))
    
    dataloader = prepare_dataloader(config)
    model = create_model(config)

    noise_scheduler = DDIMScheduler(
        num_train_timesteps=config.num_train_timesteps,
        beta_schedule=config.beta_schedule)

    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)

    scheduler = get_cosine_schedule_with_warmup(
        optimizer, num_warmup_steps=config.lr_warmup_steps,
        num_training_steps=config.num_train_epochs*len(dataloader))

    model, optimizer, dataloader, scheduler = accelerator.prepare(
        model, optimizer, dataloader, scheduler)

    model.train()

    num_steps = 0
    for epoch in range(config.num_train_epochs):

        accelerator.print(f"Epoch {epoch}")

        epoch_loss = 0
        num_iters = 0
        for _, batch in enumerate(dataloader):

            optimizer.zero_grad()

            clean_images = batch["images"]
            # print(f"Clean image shape: {clean_images.shape}")

            # Sample noise to add to the images
            noise = torch.randn(clean_images.shape).to(clean_images.device)
            bs = clean_images.shape[0]

            # print(f"Noise shape: {noise.shape}")

            # Sample a random timestep for each image
            timesteps = torch.randint(
                0, noise_scheduler.config.num_train_timesteps,
                (bs,), device=clean_images.device).long()

            # Add noise to the clean images according to the noise magnitude at each timestep
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

            # Get the model prediction
            noise_pred = model(noisy_images, timesteps, return_dict=False)[0]

            # Calculate the loss
            loss = F.mse_loss(noise_pred, noise)
            accelerator.backward(loss)

            epoch_loss += loss.item()

            wandb_run.log({'loss': loss.item()}, commit=False, step=num_steps)
            wandb_run.log({'lr': scheduler.get_lr()[0]}, commit=False, step=num_steps)

            num_steps += 1
            num_iters += 1

            # Update the model parameters with the optimizer
            optimizer.step()
            scheduler.step()

        # Log a sample of images
        generate_images(noise_scheduler, model, wandb_run, config)

        wandb_run.log({'epoch-loss': epoch_loss/num_iters})

    # Save model to W&Bs
    model_art = wandb.Artifact(CONFIG.model_name, type='model')
    torch.save(model.state_dict(), 'model.pt')

    model_art.add_file('model.pt')
    wandb_run.log_artifact(model_art)
    wandb_run.finish()

In [None]:
from accelerate import notebook_launcher

notebook_launcher(training_loop, (CONFIG, ), num_processes=1)