# Consistency Models Training Example

[![arXiv](https://img.shields.io/badge/arXiv-2301.01469-<COLOR>.svg)](https://arxiv.org/abs/2303.01469) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Kinyugo/consistency_models/blob/main/notebooks/consistency_models_training_example.ipynb) [![GitHub Repo stars](https://img.shields.io/github/stars/Kinyugo/consistency_models?style=social) ](https://github.com/Kinyugo/consistency_models)

## 📖 Introduction

Consistency Models are a new family of generative models that achieve high sample quality without adversarial training. They support fast one-step generation by design, while still allowing for few-step sampling to trade compute for sample quality. They also support zero-shot data editing, like image inpainting, colorization, and super-resolution, without requiring explicit training on these tasks.

## 🛠️ Setup


#### GPU Check


In [None]:
!nvidia-smi

#### Packages


In [None]:
%pip install -q lightning diffusers transformers gdown torchmetrics lpips --no-cache --upgrade
%pip install -q -e git+https://github.com/Kinyugo/consistency_models.git#egg=consistency_models

## 🚀 Training


### Data Loading


#### Downloading and Extraction


In [None]:
!gdown 1FnzQLDPs-IlTTEr14YyENKjTYqZfn8mS && tar -xf butterflies256.tar.gz # Butterflies Dataset
# !gdown 1m1QrNnKJy7hEzUQusyD3th_La775QKUV && tar -xf abstract_art.tar.gz  # Abstract Art Dataset
# !gdown 1VJow74U3H7KG_HOiP1WWo6LoqoE3azJj && tar -xf anime_faces.tar.gz # Anime Faces

#### DataModule


In [None]:
from typing import Callable

from lightning.pytorch import LightningDataModule
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder


class ImageDataModule(LightningDataModule):
    def __init__(
        self,
        data_dir: str,
        transform: Callable = None,
        batch_size: int = 32,
        num_workers: int = 2,
        pin_memory: bool = True,
    ) -> None:
        super().__init__()

        self.data_dir = data_dir
        self.transform = transform
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory

    def setup(self, stage: str = None) -> None:
        self.dataset = ImageFolder(self.data_dir, transform=self.transform)

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
        )

#### Transforms


In [None]:
from typing import Tuple

from torchvision import transforms as T


def transform_fn(image_size: Tuple[int, int]) -> T.Compose:
    return T.Compose(
        [
            T.Resize(image_size),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Lambda(lambda x: (x * 2) - 1),
        ]
    )

### Model


In [None]:
from diffusers import UNet2DModel
from torch import nn


class UNet(nn.Module):
    def __init__(self, image_size: Tuple[int, int]) -> None:
        super().__init__()

        self.model_fn = UNet2DModel(
            sample_size=image_size,
            in_channels=3,
            out_channels=3,
            layers_per_block=2,
            block_out_channels=(128, 128, 256, 256, 512, 512),
            down_block_types=(
                "DownBlock2D",
                "DownBlock2D",
                "DownBlock2D",
                "DownBlock2D",
                "AttnDownBlock2D",
                "DownBlock2D",
            ),
            up_block_types=(
                "UpBlock2D",
                "AttnUpBlock2D",
                "UpBlock2D",
                "UpBlock2D",
                "UpBlock2D",
                "UpBlock2D",
            ),
        )

    def forward(self, *args, **kwargs):
        return self.model_fn(*args, **kwargs, return_dict=True).sample

### Lightning Model


#### Logging Utilities


In [None]:
import torch
from lightning.pytorch.loggers import TensorBoardLogger
from matplotlib import pyplot as plt
from torch import Tensor
from torchvision.utils import make_grid


def plot_distribution(x: Tensor, title: str) -> plt.Figure:
    x = x.detach().cpu()
    batch_size = int(x.shape[0])

    fig, axes = plt.subplots(2, max(batch_size // 2, 1), constrained_layout=True)
    axes = axes.flatten()

    for b in range(batch_size):
        hist, edges = torch.histogram(x[b], density=True)
        axes[b].plot(edges[:-1], hist)

    mean, std = x.mean(), x.std()
    fig.suptitle(f"{title} | Mean: {mean:.4f} Std: {std:.4f}")
    fig.supxlabel("X")
    fig.supylabel("Density")

    return fig


def log_images(
    logger: TensorBoardLogger, images: Tensor, title: str, global_step: int
) -> None:
    grid = make_grid(images.clamp(-1.0, 1.0), value_range=(-1.0, 1.0), normalize=True)
    logger.experiment.add_image(title, grid, global_step)


def log_distribution(
    logger: TensorBoardLogger, x: Tensor, title: str, global_step: int
) -> None:
    figure = plot_distribution(x, title)
    logger.experiment.add_figure(title, figure, global_step)


def log_samples(
    logger: TensorBoardLogger,
    samples: Tensor,
    tag: str,
    global_step: int,
) -> None:
    log_images(logger, samples, f"images/{tag}", global_step)
    log_distribution(logger, samples, f"distribution/{tag}", global_step)

#### Model Definition


In [None]:
from typing import List, Union

from lightning.pytorch import LightningModule
from torch import optim
from torch.nn import functional as F
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity

from consistency_models.consistency_models import (
    ConsistencySamplingAndEditing,
    ConsistencyTraining,
    ema_decay_rate_schedule,
    karras_schedule,
    timesteps_schedule,
)
from consistency_models.utils import update_ema_model


class LitConsistencyModel(LightningModule):
    def __init__(
        self,
        consistency_training: ConsistencyTraining,
        consistency_sampling: ConsistencySamplingAndEditing,
        unet: UNet2DModel,
        ema_unet: UNet2DModel,
        initial_ema_decay_rate: float = 0.95,
        lr: float = 2e-4,
        betas: Tuple[float, float] = (0.5, 0.999),
        lr_scheduler_start_factor: float = 1 / 3,
        lr_scheduler_iters: int = 500,
        sample_every_n_steps: int = 500,
        num_samples: int = 8,
        num_sampling_steps: List[int] = [1, 2, 5],
    ) -> None:
        super().__init__()

        self.save_hyperparameters(
            ignore=["consistency_training", "consistency_sampling", "unet", "ema_unet"]
        )

        self.consistency_training = consistency_training
        self.consistency_sampling = consistency_sampling
        self.unet = unet
        self.ema_unet = ema_unet
        self.initial_ema_decay_rate = initial_ema_decay_rate
        self.lr = lr
        self.betas = betas
        self.lr_scheduler_start_factor = lr_scheduler_start_factor
        self.lr_scheduler_iters = lr_scheduler_iters
        self.sample_every_n_steps = sample_every_n_steps
        self.num_samples = num_samples
        self.num_sampling_steps = num_sampling_steps

        self.lpips = LearnedPerceptualImagePatchSimilarity(net_type="alex")

    def training_step(
        self, batch: Union[Tensor, List[Tensor]], batch_idx: int
    ) -> Tensor:
        # Drop labels if present
        if isinstance(batch, list):
            batch = batch[0]

        # Compute predicted and target
        predicted, target = self.consistency_training(
            self.unet, self.ema_unet, batch, self.global_step, self.trainer.max_steps
        )

        # Compute losses
        clamp = lambda x: x.clamp(min=-1.0, max=1.0)
        lpips_loss = self.lpips(clamp(predicted), clamp(target))
        mse_loss = F.mse_loss(clamp(predicted), clamp(target))
        overflow_loss = F.mse_loss(predicted, clamp(predicted).detach())
        loss = lpips_loss + mse_loss, overflow_loss

        self.log_dict(
            {
                "lpips_loss": lpips_loss,
                "mse_loss": mse_loss,
                "overflow_loss": overflow_loss,
                "train_loss": loss,
            }
        )

        # Sample and log samples
        if self.global_step % self.sample_every_n_steps == 0:
            self.__sample_and_log_samples(batch)

        return loss

    def on_train_batch_end(self, *args) -> None:
        # Update the ema model
        num_timesteps = timesteps_schedule(
            self.global_step,
            self.trainer.max_steps,
            initial_timesteps=self.consistency_training.initial_timesteps,
            final_timesteps=self.consistency_training.final_timesteps,
        )
        ema_decay_rate = ema_decay_rate_schedule(
            num_timesteps,
            initial_ema_decay_rate=self.initial_ema_decay_rate,
            initial_timesteps=self.consistency_training.initial_timesteps,
        )
        self.ema_unet = update_ema_model(self.ema_unet, self.unet, ema_decay_rate)
        self.log_dict(
            {"num_timesteps": num_timesteps, "ema_decay_rate": ema_decay_rate}
        )

    def configure_optimizers(self):
        opt = optim.Adam(self.unet.parameters(), lr=self.lr, betas=self.betas)
        sched = optim.lr_scheduler.LinearLR(
            opt,
            start_factor=self.lr_scheduler_start_factor,
            total_iters=self.lr_scheduler_iters,
        )
        sched = {"scheduler": sched, "interval": "step"}

        return [opt], [sched]

    @torch.no_grad()
    def __sample_and_log_samples(self, batch: Tensor) -> None:
        # Ensure the number of samples does not exceed the batch size
        num_samples = min(self.num_samples, batch.shape[0])
        noise = torch.randn_like(batch[:num_samples])

        # Log ground truth samples
        log_samples(
            self.logger,
            batch[:num_samples],
            f"ground_truth",
            self.global_step,
        )

        for steps in self.num_sampling_steps:
            # Sample an extra step and reverse the schedule as the last step (sigma=sigma_min)
            # is useless as the model returns identity
            sigmas = karras_schedule(
                steps + 1,
                sigma_min=self.consistency_training.sigma_min,
                sigma_max=self.consistency_training.sigma_max,
                rho=self.consistency_training.rho,
                device=self.device,
            )

            sigmas = sigmas.flipud()[:-1]

            samples = self.consistency_sampling(
                self.unet, noise, sigmas, clip_denoised=True, verbose=True
            )
            samples = samples.clamp(min=-1.0, max=1.0)

            # Generated samples
            log_samples(
                self.logger,
                samples,
                f"generated_samples-steps={steps}",
                self.global_step,
            )

### Training Loop


#### Training Utils


In [None]:
import os

from torch import nn


def save_model_ckpt(model: nn.Module, ckpt_path: str) -> None:
    os.makedirs(os.path.dirname(ckpt_path), exist_ok=True)
    torch.save(model.state_dict(), ckpt_path)


def load_model_ckpt(model: nn.Module, ckpt_path: str) -> nn.Module:
    model.load_state_dict(torch.load(ckpt_path))
    return model

#### Training Config


In [None]:
from dataclasses import dataclass, field
from typing import List, Optional, Tuple


@dataclass()
class Config:
    # Reproducibility
    seed: int = 0

    # Data Config
    image_size: Tuple[int, int] = (128, 128)
    data_dir: str = "butterflies256"
    batch_size: int = 32
    num_workers: int = 2
    pin_memory: bool = True

    # Consistency Model Config
    sigma_min: float = 0.002
    sigma_max: float = 80.0
    rho: float = 7.0
    sigma_data: float = 0.5
    initial_timesteps: int = 2
    final_timesteps: int = 150
    initial_ema_decay_rate: float = 0.95

    # Lightning Model Config
    lr: float = 2e-5
    betas: Tuple[float, float] = (0.5, 0.999)
    lr_scheduler_start_factor: float = 1 / 3
    lr_scheduler_iters: int = 500
    sample_every_n_steps: int = 10_000
    num_samples: int = 8
    num_sampling_steps: List[int] = field(default_factory=lambda: [1, 2, 5])

    # Tensorboard Logger
    name: str = "consistency_models"
    version: str = "butterflies256_100k"

    # Checkpoint Callback
    every_n_train_steps: int = 10_000

    # Trainer
    accelerator: str = "auto"
    max_steps: int = 100_001
    gradient_clip_val: float = 1.0
    log_every_n_steps: int = 20
    precision: Union[int, str] = 16
    detect_anomaly: bool = False

    # Training Loop
    skip_training: bool = False

    # Model checkpoint
    model_ckpt_path: str = "checkpoints/unet.pt"
    resume_ckpt_path: Optional[str] = None

#### Define Training Loop


In [None]:
import lightning as L
import matplotlib
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger


def run_training(config: Config) -> None:
    # -------------------------------------------
    # Reproducibility
    # -------------------------------------------
    L.seed_everything(config.seed)

    # -------------------------------------------
    # Configure Matplotlib
    # -------------------------------------------
    # Prevents pixelated fonts on figures
    matplotlib.use("webagg")
    matplotlib.style.use(["ggplot", "fast"])

    # -------------------------------------------
    # Data & Transforms
    # -------------------------------------------
    transform = transform_fn(config.image_size)
    datamodule = ImageDataModule(
        config.data_dir,
        transform=transform,
        batch_size=config.batch_size,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory,
    )

    # -----------------------------------------
    # Models
    # ------------------------------------------
    consistency_training = ConsistencyTraining(
        sigma_min=config.sigma_min,
        sigma_max=config.sigma_max,
        rho=config.rho,
        sigma_data=config.sigma_data,
        initial_timesteps=config.initial_timesteps,
        final_timesteps=config.final_timesteps,
    )
    consistency_sampling = ConsistencySamplingAndEditing(
        sigma_min=config.sigma_min, sigma_data=config.sigma_data
    )
    unet = UNet(config.image_size)
    ema_unet = UNet(config.image_size)
    ema_unet.load_state_dict(unet.state_dict())

    # -----------------------------------------
    # Lit Model
    # ------------------------------------------
    lit_consistency_model = LitConsistencyModel(
        consistency_training,
        consistency_sampling,
        unet,
        ema_unet,
        initial_ema_decay_rate=config.initial_ema_decay_rate,
        lr=config.lr,
        betas=config.betas,
        lr_scheduler_start_factor=config.lr_scheduler_start_factor,
        lr_scheduler_iters=config.lr_scheduler_iters,
        sample_every_n_steps=config.sample_every_n_steps,
        num_samples=config.num_samples,
        num_sampling_steps=config.num_sampling_steps,
    )
    # -----------------------------------------
    # Trainer
    # ------------------------------------------
    logger = TensorBoardLogger(name=config.name, version=config.version)
    checkpoint_callback = ModelCheckpoint(
        every_n_train_steps=config.every_n_train_steps
    )
    trainer = Trainer(
        logger=logger,
        callbacks=[checkpoint_callback],
        accelerator=config.accelerator,
        max_steps=config.max_steps,
        gradient_clip_val=config.gradient_clip_val,
        log_every_n_steps=config.log_every_n_steps,
        precision=config.precision,
        detect_anomaly=config.detect_anomaly,
    )

    # -----------------------------------------
    # Run Training
    # ------------------------------------------
    if not config.skip_training:
        trainer.fit(
            lit_consistency_model,
            datamodule=datamodule,
            ckpt_path=config.resume_ckpt_path,
        )

    # -------------------------------------------
    # Save Checkpoint
    # -------------------------------------------
    save_model_ckpt(lit_consistency_model.unet, config.model_ckpt_path)

#### Run Training Loop


In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs

In [None]:
config = Config()
run_training(config)

## 🎲 Sampling & Zero-shot Editing

### Utils

In [None]:
from matplotlib import pyplot as plt


def plot_images(images: Tensor, cols: int = 4) -> None:
    rows = max(images.shape[0] // cols, 1)
    fig, axs = plt.subplots(rows, cols, figsize=(16, 4 * rows))
    axs = axs.flatten()
    for i, image in enumerate(images):
        axs[i].imshow(image.permute(1, 2, 0).numpy() / 2 + 0.5)
        axs[i].set_axis_off()

### Checkpoint Loading

In [None]:
model = UNet(config.image_size)
consistency_sampling_and_editing = ConsistencySamplingAndEditing(
    config.sigma_min, config.sigma_data
)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model = load_model_ckpt(model, "checkpoints/unet.pth")
model = model.to(device)
model = model.eval()

### Load Sample Batch

In [None]:
dm = ImageDataModule(
    config.data_dir,
    transform=transform_fn(config.image_size),
    batch_size=4,
    num_workers=config.num_workers,
    pin_memory=True,
)
dm.setup()
dm.prepare_data()

dl = dm.train_dataloader()
batch, _ = next(iter(dl))
batch = batch.to(device)

plot_images(batch)

### Sampling

In [None]:
with torch.no_grad():
    samples = consistency_sampling_and_editing(
        model,
        torch.randn((4, 3, 128, 128)),
        sigmas=[80.0],  # Use more steps for better samples e.g 2-5
        clip_denoised=True,
        verbose=True,
    )

plot_images(samples)

### Inpainting

In [None]:
from torchvision.transforms import RandomErasing

random_erasing = RandomErasing(p=1.0, scale=(0.2, 0.5), ratio=(0.5, 0.5))
masked_batch = random_erasing(batch)
mask = torch.logical_not(batch == masked_batch)

plot_images(masked_batch)

In [None]:
with torch.no_grad():
    inpainted_batch = consistency_sampling_and_editing(
        model,
        masked_batch,
        sigmas=[5.23, 2.25],
        mask=mask.float(),
        clip_denoised=True,
        verbose=True,
    )

plot_images(torch.cat((masked_batch, inpainted_batch), dim=0))

### Interpolation

In [None]:
batch_a = batch[: batch.shape[0] // 2]
batch_b = batch[batch.shape[0] // 2 :]

plot_images(torch.cat((batch_a, batch_b), dim=0), cols=2)

In [None]:
with torch.no_grad():
    interpolated_batch = consistency_sampling_and_editing.interpolate(
        model,
        batch_a,
        batch_b,
        ab_ratio=0.5,
        sigmas=[5.23, 2.25],
        clip_denoised=True,
        verbose=True,
    )

plot_images(interpolated_batch)