# [Improved Techniques For Consistency Training](https://arxiv.org/abs/2310.14189)

## 📖 Introduction

[Consistency Models](https://arxiv.org/abs/2303.01469) are a new family of generative models that achieve high sample quality without adversarial training. They support fast one-step generation by design, while still allowing for few-step sampling to trade compute for sample quality. They also support zero-shot data editing, like image inpainting, colorization, and super-resolution, without requiring explicit training on these tasks.

The original formulation of consistency models achieved good results, however they were still sub-optimal compared to Diffusion models. In a follow up paper the authors revisit the theory and present training techniques that allow them to further narrow the gap between consistency models and other generative model. 

### Contributions 

- Eliminating the Exponential Moving Average from the teacher model 
- Replacing LPIPS with Pseudo-Huber loss 
- Log-normal noise schedule
- Improved timestep discretization schedule 
- Improved loss weighting 

### Algorithms 

#### Training 
```python
for current_training_step in range(total_training_steps):
    data = data_distribution()

    num_timesteps = improved_timesteps_schedule(current_training_step, total_training_steps, initial_timesteps, final_timesteps)
    sigmas = karras_schedule(num_timesteps, sigma_min, sigma_max)
    timesteps = lognormal_distribution(batch_size, sigmas, mean, std)
    noise = standard_gaussian_noise()

    current_noisy_data = data + sigmas[timesteps] * noise 
    next_noisy_data = data + sigmas[timesteps + 1] * noise

    student_model_pred = student_model(next_noisy_data, sigmas[timesteps])

    with no_grad():
        teacher_model_pred = teacher_model(current_noisy_data, sigmas[timesteps + 1]) 

    loss_weights = improved_loss_weighting()
    loss = mean(loss_weights[timesteps] * pseudo_huber_loss(student_model_pred, teacher_model_pred))

    loss.backward()

    with no_grad():
        teacher_model_params = student_model_params
```


### References

<a id="1">[1]</a> Song, Y., &amp; Dhariwal, P. (2023, October 22). Improved techniques for training consistency models. arXiv.org. https://arxiv.org/abs/2310.14189 

<a id="2">[2]</a> Song, Y., Dhariwal, P., Chen, M., &amp; Sutskever, I. (2023, May 31). Consistency models. arXiv.org. https://arxiv.org/abs/2303.01469 

## 🛠️ Setup


### GPU Check

In [None]:
!nvidia-smi

### Packages Installation

In [None]:
# %pip install -q lightning gdown torchmetrics einops tensorboard --no-cache --upgrade
# %pip install -q -e git+https://github.com/Kinyugo/consistency_models.git#egg=consistency_models/

### Imports

In [None]:
import json
import os
from dataclasses import asdict, dataclass
from typing import Any, Callable, List, Optional, Tuple, Union

import torch
from einops import rearrange
from einops.layers.torch import Rearrange
from lightning import LightningDataModule, LightningModule, Trainer, seed_everything
from matplotlib import pyplot as plt
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchinfo import summary
from torchvision import transforms as T
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid

from consistency_models import (
    ConsistencySamplingAndEditing,
    ImprovedConsistencyTraining,
    karras_schedule,
)

## 🧠  Implementation

### Data

#### Downloading and Extraction

In [None]:
# !gdown 1FnzQLDPs-IlTTEr14YyENKjTYqZfn8mS  && tar -xf butterflies256.tar.gz # Butterflies Dataset
# !gdown 1m1QrNnKJy7hEzUQusyD3th_La775QKUV && tar -xf abstract_art.tar.gz  # Abstract Art Dataset
# !gdown 1VJow74U3H7KG_HOiP1WWo6LoqoE3azJj && tar -xf anime_faces.tar.gz # Anime Faces

#### DataModule

In [None]:
@dataclass
class ImageDataModuleConfig:
    data_dir: str
    image_size: Tuple[int, int] = (32, 32)
    batch_size: int = 32
    num_workers: int = 8
    pin_memory: bool = True
    persistent_workers: bool = True


class ImageDataModule(LightningDataModule):
    def __init__(self, config: ImageDataModuleConfig) -> None:
        super().__init__()

        self.config = config

    def setup(self, stage: str = None) -> None:
        transform = T.Compose(
            [
                T.Resize(self.config.image_size),
                T.RandomHorizontalFlip(),
                T.ToTensor(),
                T.Lambda(lambda x: (x * 2) - 1),
            ]
        )
        self.dataset = ImageFolder(self.config.data_dir, transform=transform)

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=self.config.pin_memory,
            persistent_workers=self.config.persistent_workers,
        )

### Modules

In [None]:
def GroupNorm(channels: int) -> nn.GroupNorm:
    return nn.GroupNorm(num_groups=min(32, channels // 4), num_channels=channels)


class SelfAttention(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        n_heads: int = 8,
        dropout: float = 0.3,
    ) -> None:
        super().__init__()

        self.dropout = dropout

        self.qkv_projection = nn.Sequential(
            GroupNorm(in_channels),
            nn.Conv2d(in_channels, 3 * in_channels, kernel_size=1, bias=False),
            Rearrange("b (i h d) x y -> i b h (x y) d", i=3, h=n_heads),
        )
        self.output_projection = nn.Sequential(
            Rearrange("b h l d -> b l (h d)"),
            nn.Linear(in_channels, out_channels, bias=False),
            Rearrange("b l d -> b d l"),
            GroupNorm(out_channels),
            nn.Dropout2d(dropout),
        )
        self.residual_projection = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        q, k, v = self.qkv_projection(x).unbind(dim=0)

        output = F.scaled_dot_product_attention(
            q, k, v, dropout_p=self.dropout if self.training else 0.0, is_causal=False
        )
        output = self.output_projection(output)
        output = rearrange(output, "b c (x y) -> b c x y", x=x.shape[-2], y=x.shape[-1])

        return output + self.residual_projection(x)


class UNetBlock(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        noise_level_channels: int,
        dropout: float = 0.3,
    ) -> None:
        super().__init__()

        self.input_projection = nn.Sequential(
            GroupNorm(in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding="same"),
            nn.Dropout2d(dropout),
        )
        self.noise_level_projection = nn.Sequential(
            nn.SiLU(),
            nn.Conv2d(noise_level_channels, out_channels, kernel_size=1),
        )
        self.output_projection = nn.Sequential(
            GroupNorm(out_channels),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding="same"),
            nn.Dropout2d(dropout),
        )
        self.residual_projection = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x: torch.Tensor, noise_level: torch.Tensor) -> torch.Tensor:
        h = self.input_projection(x)
        h = h + self.noise_level_projection(noise_level)

        return self.output_projection(h) + self.residual_projection(x)


class UNetBlockWithSelfAttention(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        noise_level_channels: int,
        n_heads: int = 8,
        dropout: float = 0.3,
    ) -> None:
        super().__init__()

        self.unet_block = UNetBlock(
            in_channels, out_channels, noise_level_channels, dropout
        )
        self.self_attention = SelfAttention(
            out_channels, out_channels, n_heads, dropout
        )

    def forward(self, x: torch.Tensor, noise_level: torch.Tensor) -> torch.Tensor:
        return self.self_attention(self.unet_block(x, noise_level))


class Downsample(nn.Module):
    def __init__(self, channels: int) -> None:
        super().__init__()

        self.projection = nn.Sequential(
            Rearrange("b c (h ph) (w pw) -> b (c ph pw) h w", ph=2, pw=2),
            nn.Conv2d(4 * channels, channels, kernel_size=1),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.projection(x)


class Upsample(nn.Module):
    def __init__(self, channels: int) -> None:
        super().__init__()

        self.projection = nn.Sequential(
            nn.Upsample(scale_factor=2.0, mode="nearest"),
            nn.Conv2d(channels, channels, kernel_size=3, padding="same"),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.projection(x)


class NoiseLevelEmbedding(nn.Module):
    def __init__(self, channels: int, scale: float = 0.02) -> None:
        super().__init__()

        self.W = nn.Parameter(torch.randn(channels // 2) * scale, requires_grad=False)

        self.projection = nn.Sequential(
            nn.Linear(channels, 4 * channels),
            nn.SiLU(),
            nn.Linear(4 * channels, channels),
            Rearrange("b c -> b c () ()"),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h = x[:, None] * self.W[None, :] * 2 * torch.pi
        h = torch.cat([torch.sin(h), torch.cos(h)], dim=-1)

        return self.projection(h)

### UNet

In [None]:
@dataclass
class UNetConfig:
    channels: int = 3
    noise_level_channels: int = 256
    noise_level_scale: float = 0.02
    n_heads: int = 8
    dropout: float = 0.3
    top_blocks_channels: Tuple[int, ...] = (32, 64)
    top_blocks_n_blocks_per_resolution: Tuple[int, ...] = (2, 2)
    top_blocks_has_resampling: Tuple[bool, ...] = (True, True)
    mid_blocks_channels: Tuple[int, ...] = (128, 256)
    mid_blocks_n_blocks_per_resolution: Tuple[int, ...] = (4, 4)
    mid_blocks_has_resampling: Tuple[bool, ...] = (True, False)


class UNet(nn.Module):
    def __init__(self, config: UNetConfig) -> None:
        super().__init__()

        self.config = config

        self.input_projection = nn.Conv2d(
            config.channels,
            config.top_blocks_channels[0],
            kernel_size=3,
            padding="same",
        )
        self.noise_level_embedding = NoiseLevelEmbedding(
            config.noise_level_channels, config.noise_level_scale
        )
        self.top_encoder_blocks = self._make_encoder_blocks(
            self.config.top_blocks_channels + self.config.mid_blocks_channels[:1],
            self.config.top_blocks_n_blocks_per_resolution,
            self.config.top_blocks_has_resampling,
            self._make_top_block,
        )
        self.mid_encoder_blocks = self._make_encoder_blocks(
            self.config.mid_blocks_channels + self.config.mid_blocks_channels[-1:],
            self.config.mid_blocks_n_blocks_per_resolution,
            self.config.mid_blocks_has_resampling,
            self._make_mid_block,
        )
        self.mid_decoder_blocks = self._make_decoder_blocks(
            self.config.mid_blocks_channels + self.config.mid_blocks_channels[-1:],
            self.config.mid_blocks_n_blocks_per_resolution,
            self.config.mid_blocks_has_resampling,
            self._make_mid_block,
        )
        self.top_decoder_blocks = self._make_decoder_blocks(
            self.config.top_blocks_channels + self.config.mid_blocks_channels[:1],
            self.config.top_blocks_n_blocks_per_resolution,
            self.config.top_blocks_has_resampling,
            self._make_top_block,
        )
        self.output_projection = nn.Conv2d(
            config.top_blocks_channels[0],
            config.channels,
            kernel_size=3,
            padding="same",
        )

    def forward(self, x: torch.Tensor, noise_level: torch.Tensor) -> torch.Tensor:
        h = self.input_projection(x)
        noise_level = self.noise_level_embedding(noise_level)

        top_encoder_embeddings = []
        for block in self.top_encoder_blocks:
            if isinstance(block, UNetBlock):
                h = block(h, noise_level)
                top_encoder_embeddings.append(h)
            else:
                h = block(h)

        mid_encoder_embeddings = []
        for block in self.mid_encoder_blocks:
            if isinstance(block, UNetBlockWithSelfAttention):
                h = block(h, noise_level)
                mid_encoder_embeddings.append(h)
            else:
                h = block(h)

        for block in self.mid_decoder_blocks:
            if isinstance(block, UNetBlockWithSelfAttention):
                h = torch.cat((h, mid_encoder_embeddings.pop()), dim=1)
                h = block(h, noise_level)
            else:
                h = block(h)

        for block in self.top_decoder_blocks:
            if isinstance(block, UNetBlock):
                h = torch.cat((h, top_encoder_embeddings.pop()), dim=1)
                h = block(h, noise_level)
            else:
                h = block(h)

        return self.output_projection(h)

    def _make_encoder_blocks(
        self,
        channels: Tuple[int, ...],
        n_blocks_per_resolution: Tuple[int, ...],
        has_resampling: Tuple[bool, ...],
        block_fn: Callable[[], nn.Module],
    ) -> nn.ModuleList:
        blocks = nn.ModuleList()

        channel_pairs = list(zip(channels[:-1], channels[1:]))
        for idx, (in_channels, out_channels) in enumerate(channel_pairs):
            for _ in range(n_blocks_per_resolution[idx]):
                blocks.append(block_fn(in_channels, out_channels))
                in_channels = out_channels

            if has_resampling[idx]:
                blocks.append(Downsample(out_channels))

        return blocks

    def _make_decoder_blocks(
        self,
        channels: Tuple[int, ...],
        n_blocks_per_resolution: Tuple[int, ...],
        has_resampling: Tuple[bool, ...],
        block_fn: Callable[[], nn.Module],
    ) -> nn.ModuleList:
        blocks = nn.ModuleList()

        channel_pairs = list(zip(channels[:-1], channels[1:]))[::-1]
        for idx, (out_channels, in_channels) in enumerate(channel_pairs):
            if has_resampling[::-1][idx]:
                blocks.append(Upsample(in_channels))

            inner_blocks = []
            for _ in range(n_blocks_per_resolution[::-1][idx]):
                inner_blocks.append(block_fn(in_channels * 2, out_channels))
                out_channels = in_channels
            blocks.extend(inner_blocks[::-1])

        return blocks

    def _make_top_block(self, in_channels: int, out_channels: int) -> UNetBlock:
        return UNetBlock(
            in_channels,
            out_channels,
            self.config.noise_level_channels,
            self.config.dropout,
        )

    def _make_mid_block(
        self, in_channels: int, out_channels: int
    ) -> UNetBlockWithSelfAttention:
        return UNetBlockWithSelfAttention(
            in_channels,
            out_channels,
            self.config.noise_level_channels,
            self.config.n_heads,
            self.config.dropout,
        )

    def save_pretrained(self, pretrained_path: str) -> None:
        os.makedirs(pretrained_path, exist_ok=True)

        with open(os.path.join(pretrained_path, "config.json"), mode="w") as f:
            json.dump(asdict(self.config), f)

        torch.save(self.state_dict(), os.path.join(pretrained_path, "model.pt"))

    @classmethod
    def from_pretrained(cls, pretrained_path: str) -> "UNet":
        with open(os.path.join(pretrained_path, "config.json"), mode="r") as f:
            config_dict = json.load(f)
        config = UNetConfig(**config_dict)

        model = cls(config)

        state_dict = torch.load(
            os.path.join(pretrained_path, "model.pt"), map_location=torch.device("cpu")
        )
        model.load_state_dict(state_dict)

        return model


summary(UNet(UNetConfig()), input_size=((1, 3, 64, 64), (1,)))

### LitUNet

In [None]:
@dataclass
class LitImprovedConsistencyModelConfig:
    lr: float = 1e-4
    betas: Tuple[float, float] = (0.9, 0.995)
    lr_scheduler_start_factor: float = 1e-5
    lr_scheduler_iters: int = 10_000
    sample_every_n_steps: int = 10_000
    num_samples: int = 8
    num_sampling_steps: Tuple[int, ...] = (1, 2, 5)


class LitImprovedConsistencyModel(LightningModule):
    def __init__(
        self,
        consistency_training: ImprovedConsistencyTraining,
        consistency_sampling: ConsistencySamplingAndEditing,
        student_model: UNet,
        teacher_model: UNet,
        config: LitImprovedConsistencyModelConfig,
    ) -> None:
        super().__init__()

        self.consistency_training = consistency_training
        self.consistency_sampling = consistency_sampling
        self.student_model = student_model
        self.teacher_model = teacher_model
        self.config = config

        # Freeze teacher model
        for param in self.teacher_model.parameters():
            param.requires_grad = False

    def training_step(
        self, batch: Union[torch.Tensor, List[torch.Tensor]], batch_idx: int
    ) -> None:
        if isinstance(batch, list):
            batch = batch[0]
        loss, _, _ = self.consistency_training(
            self.student_model,
            self.teacher_model,
            batch,
            self.global_step,
            self.trainer.max_steps,
        )

        self.log_dict({"train_loss": loss})

        if (
            (self.global_step + 1) % self.config.sample_every_n_steps == 0
        ) or self.global_step == 0:
            self.__sample_and_log_samples(batch)

        return loss

    def on_train_batch_end(
        self, outputs: Any, batch: torch.Tensor, batch_idx: int
    ) -> None:
        # Update teacher model, we don't use exponential moving average as proposed in the paper
        with torch.no_grad():
            self.teacher_model.load_state_dict(self.student_model.state_dict())

    def configure_optimizers(self):
        opt = torch.optim.Adam(
            self.student_model.parameters(), lr=self.config.lr, betas=self.config.betas
        )
        sched = torch.optim.lr_scheduler.LinearLR(
            opt,
            start_factor=self.config.lr_scheduler_start_factor,
            total_iters=self.config.lr_scheduler_iters,
        )
        sched = {"scheduler": sched, "interval": "step", "frequency": 1}

        return [opt], [sched]

    @torch.no_grad()
    def __sample_and_log_samples(self, batch: torch.Tensor) -> None:
        # Ensure the number of samples does not exceed the batch size
        num_samples = min(self.config.num_samples, batch.shape[0])
        noise = torch.randn_like(batch[:num_samples])

        # Log ground truth samples
        self.__log_images(
            batch[:num_samples].detach().clone(), f"ground_truth", self.global_step
        )

        for steps in self.config.num_sampling_steps:
            # Sample an extra step and reverse the schedule as the last step (sigma=sigma_min)
            # is useless as the model returns identity
            sigmas = karras_schedule(
                steps + 1,
                sigma_min=self.consistency_training.sigma_min,
                sigma_max=self.consistency_training.sigma_max,
                rho=self.consistency_training.rho,
                device=self.device,
            )

            sigmas = sigmas.flipud()[:-1]

            samples = self.consistency_sampling(
                self.student_model, noise, sigmas, clip_denoised=True, verbose=True
            )
            samples = samples.clamp(min=-1.0, max=1.0)

            # Generated samples
            self.__log_images(
                samples,
                f"generated_samples-steps={steps}",
                self.global_step,
            )

    @torch.no_grad()
    def __log_images(self, images: torch.Tensor, title: str, global_step: int) -> None:
        images = images.detach().float()

        grid = make_grid(
            images.clamp(-1.0, 1.0), value_range=(-1.0, 1.0), normalize=True
        )
        self.logger.experiment.add_image(title, grid, global_step)

## 🚀 Training

### Training Loop

In [None]:
@dataclass
class TrainingConfig:
    image_dm_config: ImageDataModuleConfig
    unet_config: UNetConfig
    consistency_training: ImprovedConsistencyTraining
    consistency_sampling: ConsistencySamplingAndEditing
    lit_icm_config: LitImprovedConsistencyModelConfig
    trainer: Trainer
    seed: int = 42
    model_ckpt_path: str = "checkpoint/unet"
    resume_ckpt_path: Optional[str] = None


def run_training(config: TrainingConfig) -> None:
    # Set seed
    seed_everything(config.seed)

    # Create data module
    dm = ImageDataModule(config.image_dm_config)

    # Create student and teacher models
    student_model = UNet(config.unet_config)
    teacher_model = UNet(config.unet_config)
    teacher_model.load_state_dict(student_model.state_dict())

    # Create lightning module
    lit_icm = LitImprovedConsistencyModel(
        config.consistency_training,
        config.consistency_sampling,
        student_model,
        teacher_model,
        config.lit_icm_config,
    )

    # Run training
    config.trainer.fit(lit_icm, dm, ckpt_path=config.resume_ckpt_path)

    # Save model
    lit_icm.student_model.save_pretrained(config.model_ckpt_path)

### Run Training

In [None]:
# %load_ext tensorboard
# %tensorboard --logdir lightning_logs

In [None]:
training_config = TrainingConfig(
    image_dm_config=ImageDataModuleConfig("../data/butterflies256"),
    unet_config=UNetConfig(),
    consistency_training=ImprovedConsistencyTraining(),
    consistency_sampling=ConsistencySamplingAndEditing(),
    lit_icm_config=LitImprovedConsistencyModelConfig(
        sample_every_n_steps=1000, lr_scheduler_iters=1000
    ),
    trainer=Trainer(max_steps=10_000, precision="bf16-mixed", log_every_n_steps=10),
)
run_training(training_config)

## 🎲 Sampling & Zero-shot Editing

### Utils

In [None]:
def plot_images(images: torch.Tensor, cols: int = 4) -> None:
    rows = max(images.shape[0] // cols, 1)
    fig, axs = plt.subplots(rows, cols)
    axs = axs.flatten()
    for i, image in enumerate(images):
        axs[i].imshow(image.permute(1, 2, 0).numpy() / 2 + 0.5)
        axs[i].set_axis_off()

### Checkpoint Loading

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

unet = UNet.from_pretrained("checkpoint/unet").eval().to(device=device, dtype=dtype)

### Load Sample Batch

In [None]:
dm = ImageDataModule(ImageDataModuleConfig("../data/butterflies256", batch_size=4))
dm.setup()

batch, _ = next(iter(dm.train_dataloader()))
batch = batch.to(device=device, dtype=dtype)

plot_images(batch.float().cpu())

### Experiments

#### Sampling

In [None]:
consistency_sampling_and_editing = ConsistencySamplingAndEditing()

with torch.no_grad():
    samples = consistency_sampling_and_editing(
        unet,
        torch.randn((4, 3, 32, 32), device=device, dtype=dtype),
        sigmas=[80.0],  # Use more steps for better samples e.g 2-5
        clip_denoised=True,
        verbose=True,
    )

plot_images(samples.float().cpu())

#### Inpainting

In [None]:
random_erasing = T.RandomErasing(p=1.0, scale=(0.2, 0.5), ratio=(0.5, 0.5))
masked_batch = random_erasing(batch)
mask = torch.logical_not(batch == masked_batch)

plot_images(masked_batch.float().cpu())

In [None]:
with torch.no_grad():
    inpainted_batch = consistency_sampling_and_editing(
        unet,
        masked_batch,
        sigmas=[5.23, 2.25],
        mask=mask.to(dtype=dtype),
        clip_denoised=True,
        verbose=True,
    )

plot_images(torch.cat((masked_batch, inpainted_batch), dim=0).float().cpu())

#### Interpolation

In [None]:
batch_a = batch.clone()
batch_b = torch.flip(batch, dims=(0,))

plot_images(torch.cat((batch_a, batch_b), dim=0).float().cpu())

In [None]:
with torch.no_grad():
    interpolated_batch = consistency_sampling_and_editing.interpolate(
        unet,
        batch_a,
        batch_b,
        ab_ratio=0.5,
        sigmas=[5.23, 2.25],
        clip_denoised=True,
        verbose=True,
    )

plot_images(torch.cat((batch_a, batch_b, interpolated_batch), dim=0).float().cpu())