# Stable Diffusion (Keras)

[![Open in Colab](https://lab.aef.me/files/assets/colab-badge.svg)](https://colab.research.google.com/github/adamelliotfields/lab/blob/main/files/tf/stable_diffusion.ipynb)
[![Open in Kaggle](https://lab.aef.me/files/assets/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/adamelliotfields/lab/blob/main/files/tf/stable_diffusion.ipynb)
[![Render nbviewer](https://lab.aef.me/files/assets/nbviewer_badge.svg)](https://nbviewer.org/github/adamelliotfields/lab/blob/main/files/tf/stable_diffusion.ipynb)

> By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M [license](https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE).

Complete Stable Diffusion 1.4 implementation based on the official KerasCV [model](https://github.com/keras-team/keras-cv/tree/master/keras_cv/src/models/stable_diffusion), which itself is based on [`stable-diffusion-tensorflow`](https://github.com/divamgupta/stable-diffusion-tensorflow). Also includes Keras implementations of [CLIP](https://github.com/openai/CLIP) and the [DDPM](https://arxiv.org/abs/2006.11239) scheduler from [Diffusers](https://github.com/huggingface/diffusers/blob/v0.3.0/src/diffusers/schedulers/scheduling_ddpm.py).

Uses the Keras 3 [Ops](https://keras.io/api/ops/) API which works across backends (TensorFlow, JAX, PyTorch). Unlike other KerasCV pretrained models, the weights aren't hosted on Kaggle; they're a port of the [original](https://huggingface.co/CompVis/stable-diffusion-v1-4) PyTorch weights on [Hugging Face](https://huggingface.co/fchollet/stable-diffusion).

The code is organized so each component has its own cell with layers first then models and finally the Stable Diffusion pipeline itself.

**Changelog**

* v1 only
* default JIT-compilation to `"auto"`
* cache weights in Google Drive when in Colab
* use DDPM params from [tinygrad](https://github.com/tinygrad/tinygrad/blob/master/examples/stable_diffusion.py) implementation
* move `batch_size` to pipeline constructor
* remove lazy initialization in pipeline constructor
* remove `download_weights` keyword argument from models
* rename `unconditional_guidance_scale` to `guidance_scale`

**TODO**

* [ ] Annotations and comments

In [None]:
import os
import subprocess
from importlib.util import find_spec

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
os.environ["KERAS_BACKEND"] = "tensorflow"

if find_spec("google.colab"):
    subprocess.run(["pip", "install", "-qU", "keras"])
    # cache ~4GB in Google Drive:
    # * bpe_simple_vocab_16e6.txt.gz (1.3M)
    # * kcv_decoder.h5 (189M)
    # * kcv_diffusion_model.h5 (3.3G)
    # * kcv_encoder.h5 (470M)
    # * vae_encoder.h5 (131M)
    CACHE_DIR = "/content/drive/MyDrive/keras"
else:
    CACHE_DIR = os.environ.get("KERAS_HOME", os.path.expanduser("~/.keras"))

In [None]:
import gzip
import html
import math

import numpy as np
import regex as re

from IPython.display import display
from PIL import Image as PILImage
from functools import lru_cache
from keras import (
    Model,
    Sequential,
    activations,
    config,
    layers,
    ops,
    random,
    utils,
)

In [None]:
# @title Config
EPSILON = 1e-5
BATCH_SIZE = 1
IMG_WIDTH = 512
IMG_HEIGHT = 512
GLOBAL_DTYPE = "mixed_float16"  # mixed_float16 is 2-3x faster than float32
CACHE_SUBDIR = "models/stable_diffusion_v1"
MAX_PROMPT_LENGTH = 77  # CLIP's limit

In [None]:
# @title Functions
def td_dot(a, b):
    aa = ops.reshape(a, (-1, a.shape[2], a.shape[3]))
    bb = ops.reshape(b, (-1, b.shape[2], b.shape[3]))
    cc = layers.Dot(axes=(2, 1))([aa, bb])
    return ops.reshape(cc, (-1, a.shape[1], cc.shape[1], cc.shape[2]))


def quick_gelu(x):
    return x * ops.sigmoid(x * 1.702)


def display_images(images, horizontal=False):
    if horizontal:
        total_width = sum(img.shape[1] for img in images)
        max_height = max(img.shape[0] for img in images)
        combined_image = PILImage.new("RGB", (total_width, max_height))
        offset = 0
        for img in images:
            img = PILImage.fromarray(img)
            combined_image.paste(img, (offset, 0))
            offset += img.width
        display(combined_image)
    else:
        for image in images:
            display(image)

In [None]:
# @title CLIPTokenizer
@lru_cache()
def bytes_to_unicode():
    """Return a list of utf-8 bytes and a corresponding list of unicode strings."""
    bs = (
        list(range(ord("!"), ord("~") + 1))
        + list(range(ord("¡"), ord("¬") + 1))
        + list(range(ord("®"), ord("ÿ") + 1))
    )
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8 + n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))


def get_pairs(word):
    """Return set of symbol pairs in a word."""
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


def basic_clean(text):
    text = html.unescape(html.unescape(text))
    return text.strip()


def whitespace_clean(text):
    text = re.sub(r"\s+", " ", text)
    text = text.strip()
    return text


class CLIPTokenizer:
    def __init__(self):
        bpe_path = utils.get_file(
            cache_dir=CACHE_DIR,
            cache_subdir=CACHE_SUBDIR,
            # file_hash="924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a",
            origin="https://raw.githubusercontent.com/openai/CLIP/main/clip/bpe_simple_vocab_16e6.txt.gz",
        )

        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}

        merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
        merges = merges[1 : 49152 - 256 - 2 + 1]
        merges = [tuple(merge.split()) for merge in merges]
        vocab = list(bytes_to_unicode().values())
        vocab = vocab + [v + "</w>" for v in vocab]

        for merge in merges:
            vocab.append("".join(merge))

        vocab.extend(["<|startoftext|>", "<|endoftext|>"])
        self.vocab = vocab
        self.encoder = self._create_encoder(self.vocab)
        self.decoder = self._create_decoder(self.encoder)
        self.bpe_ranks = dict(zip(merges, range(len(merges))))

        self.special_tokens = {
            "<|startoftext|>": "<|startoftext|>",
            "<|endoftext|>": "<|endoftext|>",
        }
        self.cache = {
            "<|startoftext|>": "<|startoftext|>",
            "<|endoftext|>": "<|endoftext|>",
        }
        self.pat = self._create_pat()

    def _create_encoder(self, vocab):
        return dict(zip(vocab, range(len(vocab))))

    def _create_decoder(self, encoder):
        return {v: k for k, v in encoder.items()}

    def _create_pat(self):
        return re.compile(
            "|".join([re.escape(key) for key in self.special_tokens.keys()])
            + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
            re.IGNORECASE,
        )

    @property
    def end_of_text(self):
        return self.encoder["<|endoftext|>"]

    @property
    def start_of_text(self):
        return self.encoder["<|startoftext|>"]

    def add_tokens(self, tokens):
        if isinstance(tokens, str):
            tokens = [tokens]

        tokens_added = 0

        for token in tokens:
            if token in self.vocab:
                continue
            tokens_added += 1
            self.vocab.append(token)
            self.special_tokens[token] = token
            self.cache[token] = token

        self.encoder = self._create_encoder(self.vocab)
        self.decoder = self._create_decoder(self.encoder)
        self.pat = self._create_pat()
        return tokens_added

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token[:-1]) + (token[-1] + "</w>",)
        pairs = get_pairs(word)
        if not pairs:
            return token + "</w>"

        while True:
            bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except ValueError:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
                    new_word.append(first + second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = " ".join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        text = whitespace_clean(basic_clean(text)).lower()
        for token in re.findall(self.pat, text):
            token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" "))
        return [self.start_of_text] + bpe_tokens + [self.end_of_text]

    def decode(self, tokens):
        text = "".join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text])
        text = text.decode("utf-8", errors="replace").replace("</w>", " ")
        return text

In [None]:
# @title DDPMScheduler
class DDPMScheduler:
    """
    Args:
        train_timesteps: int, number of diffusion steps used to train the model. Defaults to 1000.
        beta_start: float, the starting `beta` value of inference. Defaults to 0.0001.
        beta_end: float, the final `beta` value. Defaults to 0.02.
        beta_schedule: "linear" or "scaled_linear", a mapping from a beta range to a sequence of betas for stepping the model. Defaults to "linear".
        variance_type: "fixed_small", "fixed_small_log", "fixed_large", "fixed_large_log", "learned" or "learned_range", options to clip the variance used when adding noise to the de-noised sample. Defaults to "fixed_small".
        clip_sample: bool, option to clip predicted sample between -1 and 1 for numerical stability. Defaults to True.
    """

    def __init__(
        self,
        train_timesteps=1000,
        beta_start=0.0001,
        beta_end=0.02,
        beta_schedule="linear",
        variance_type="fixed_small",
        clip_sample=True,
    ):
        self.train_timesteps = train_timesteps

        if beta_schedule == "linear":
            self.betas = ops.linspace(beta_start, beta_end, train_timesteps)
        elif beta_schedule == "scaled_linear":
            # this schedule is very specific to the latent diffusion model
            self.betas = ops.linspace(beta_start**0.5, beta_end**0.5, train_timesteps) ** 2
        else:
            raise ValueError(f"Invalid beta schedule: {beta_schedule}.")

        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = ops.cumprod(self.alphas)
        self.variance_type = variance_type
        self.clip_sample = clip_sample
        self.seed_generator = random.SeedGenerator(seed=42)

    def _get_variance(self, timestep, predicted_variance=None):
        alpha_prod = self.alphas_cumprod[timestep]
        alpha_prod_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else 1.0
        variance = (1 - alpha_prod_prev) / (1 - alpha_prod) * self.betas[timestep]

        if self.variance_type == "fixed_small":
            variance = ops.clip(variance, x_min=1e-20, x_max=1)
        elif self.variance_type == "fixed_small_log":
            variance = ops.log(ops.clip(variance, x_min=1e-20, x_max=1))
        elif self.variance_type == "fixed_large":
            variance = self.betas[timestep]
        elif self.variance_type == "fixed_large_log":
            variance = ops.log(self.betas[timestep])
        elif self.variance_type == "learned":
            return predicted_variance
        elif self.variance_type == "learned_range":
            min_log = variance
            max_log = self.betas[timestep]
            frac = (predicted_variance + 1) / 2
            variance = frac * max_log + (1 - frac) * min_log
        else:
            raise ValueError(f"Invalid variance type: {self.variance_type}")
        return variance

    def step(
        self,
        model_output,
        timestep,
        sample,
        predict_epsilon=True,
    ):
        """
        Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion process from the learned model outputs (usually the predicted noise).
        Args:
            model_output: a Tensor containing direct output from learned diffusion model.
            timestep: current discrete timestep in the diffusion chain.
            sample: a Tensor containing the current instance of sample being created by diffusion process.
            predict_epsilon: bool, whether the model is predicting noise (epsilon) or samples. Defaults to True.
        Returns:
            The predicted sample at the previous timestep.
        """

        if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in [
            "learned",
            "learned_range",
        ]:
            model_output, predicted_variance = ops.split(model_output, sample.shape[1], axis=1)
        else:
            predicted_variance = None

        # 1. compute alphas, betas
        alpha_prod = self.alphas_cumprod[timestep]
        alpha_prod_prev = self.alphas_cumprod[timestep - 1] if timestep > 0 else 1.0
        beta_prod = 1 - alpha_prod
        beta_prod_prev = 1 - alpha_prod_prev

        # 2. compute predicted original sample from predicted noise also called
        # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
        if predict_epsilon:
            sqrt_alpha = alpha_prod**0.5
            sqrt_beta = beta_prod**0.5
            pred_original_sample = (sample - sqrt_beta * model_output) / sqrt_alpha
        else:
            pred_original_sample = model_output

        # 3. Clip "predicted x_0"
        if self.clip_sample:
            pred_original_sample = ops.clip_by_value(pred_original_sample, -1, 1)

        # 4. Compute coefficients for pred_original_sample x_0 and current
        # sample x_t
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
        sqrt_alpha_prev = alpha_prod_prev**0.5
        sqrt_alphas_timestep = self.alphas[timestep] ** 0.5
        pred_original_sample_coeff = (sqrt_alpha_prev * self.betas[timestep]) / beta_prod
        current_sample_coeff = sqrt_alphas_timestep * beta_prod_prev / beta_prod

        # 5. Compute predicted previous sample µ_t
        # See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
        pred_prev_sample = (
            pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
        )

        # 6. Add noise
        variance = 0
        if timestep > 0:
            noise = random.normal(model_output.shape, seed=self.seed_generator)
            variance = self._get_variance(timestep, predicted_variance=predicted_variance)
            variance = (variance**0.5) * noise

        pred_prev_sample = pred_prev_sample + variance
        return pred_prev_sample

    def add_noise(
        self,
        original_samples,
        noise,
        timesteps,
    ):
        sqrt_alpha_prod = ops.take(self.alphas_cumprod, timesteps) ** 0.5
        sqrt_one_minus_alpha_prod = (1 - ops.take(self.alphas_cumprod, timesteps)) ** 0.5

        for _ in range(3):
            sqrt_alpha_prod = ops.expand_dims(sqrt_alpha_prod, axis=-1)
            sqrt_one_minus_alpha_prod = ops.expand_dims(sqrt_one_minus_alpha_prod, axis=-1)

        sqrt_alpha_prod = ops.cast(sqrt_alpha_prod, dtype=original_samples.dtype)
        sqrt_one_minus_alpha_prod = ops.cast(sqrt_one_minus_alpha_prod, dtype=noise.dtype)
        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples

    def __len__(self):
        return self.train_timesteps

## Layers

In [None]:
# @title PaddedConv2D
class PaddedConv2D(layers.Layer):
    def __init__(self, filters, kernel_size, padding=0, strides=1, **kwargs):
        super().__init__(**kwargs)
        self.padding2d = layers.ZeroPadding2D(padding)
        self.conv2d = layers.Conv2D(filters, kernel_size, strides=strides, padding="valid")

    def call(self, inputs):
        x = self.padding2d(inputs)
        return self.conv2d(x)

In [None]:
# @title ResnetBlock
class ResnetBlock(layers.Layer):
    def __init__(self, output_dim, **kwargs):
        super().__init__(**kwargs)
        self.output_dim = output_dim
        self.norm1 = layers.GroupNormalization(epsilon=EPSILON)
        self.conv1 = PaddedConv2D(output_dim, 3, padding=1)
        self.norm2 = layers.GroupNormalization(epsilon=EPSILON)
        self.conv2 = PaddedConv2D(output_dim, 3, padding=1)

    def build(self, input_shape):
        if input_shape[-1] != self.output_dim:
            self.residual_proj = PaddedConv2D(self.output_dim, 1)
        else:
            self.residual_proj = lambda x: x

    def call(self, inputs):
        x = self.conv1(activations.swish(self.norm1(inputs)))
        x = self.conv2(activations.swish(self.norm2(x)))
        return x + self.residual_proj(inputs)

In [None]:
# @title AttentionBlock
class AttentionBlock(layers.Layer):
    def __init__(self, output_dim, **kwargs):
        super().__init__(**kwargs)
        self.output_dim = output_dim
        self.norm = layers.GroupNormalization(epsilon=EPSILON)
        self.query = PaddedConv2D(output_dim, 1)
        self.key = PaddedConv2D(output_dim, 1)
        self.value = PaddedConv2D(output_dim, 1)
        self.output_proj = PaddedConv2D(output_dim, 1)

    def call(self, inputs):
        x = self.norm(inputs)
        q, k, v = self.query(x), self.key(x), self.value(x)

        # attention
        shape = ops.shape(q)
        h, w, c = shape[1], shape[2], shape[3]
        q = ops.reshape(q, (-1, h * w, c))  # b, hw, c
        k = ops.transpose(k, (0, 3, 1, 2))
        k = ops.reshape(k, (-1, c, h * w))  # b, c, hw
        y = q @ k
        y = y * 1 / ops.sqrt(ops.cast(c, self.compute_dtype))
        y = activations.softmax(y)

        # values
        v = ops.transpose(v, (0, 3, 1, 2))
        v = ops.reshape(v, (-1, c, h * w))
        y = ops.transpose(y, (0, 2, 1))
        x = v @ y
        x = ops.transpose(x, (0, 2, 1))
        x = ops.reshape(x, (-1, h, w, c))
        return self.output_proj(x) + inputs

In [None]:
# @title CrossAttention
class CrossAttention(layers.Layer):
    def __init__(self, num_heads, head_size, **kwargs):
        super().__init__(**kwargs)
        self.channels = num_heads * head_size
        self.query = layers.Dense(self.channels, use_bias=False)
        self.key = layers.Dense(self.channels, use_bias=False)
        self.value = layers.Dense(self.channels, use_bias=False)
        self.scale = head_size**-0.5
        self.num_heads = num_heads
        self.head_size = head_size
        self.output_proj = layers.Dense(self.channels)

    def call(self, inputs, context=None):
        if context is None:
            context = inputs

        q, k, v = self.query(inputs), self.key(context), self.value(context)

        q = ops.reshape(q, (-1, inputs.shape[1], self.num_heads, self.head_size))
        k = ops.reshape(k, (-1, context.shape[1], self.num_heads, self.head_size))
        v = ops.reshape(v, (-1, context.shape[1], self.num_heads, self.head_size))

        q = ops.transpose(q, (0, 2, 1, 3))  # (batch_size, num_heads, time, head_size)
        k = ops.transpose(k, (0, 2, 3, 1))  # (batch_size, num_heads, head_size, time)
        v = ops.transpose(v, (0, 2, 1, 3))  # (batch_size, num_heads, time, head_size)

        score = td_dot(q, k) * self.scale
        weights = activations.softmax(score)  # (batch_size, num_heads, time, time)
        attn = td_dot(weights, v)
        attn = ops.transpose(attn, (0, 2, 1, 3))  # (batch_size, time, num_heads, head_size)
        output = ops.reshape(attn, (-1, inputs.shape[1], self.channels))
        return self.output_proj(output)

In [None]:
# @title GEGLU
class GEGLU(layers.Layer):
    def __init__(self, output_dim, **kwargs):
        super().__init__(**kwargs)
        self.output_dim = output_dim
        self.dense = layers.Dense(output_dim * 2)

    def call(self, inputs):
        x = self.dense(inputs)
        x, gate = x[..., : self.output_dim], x[..., self.output_dim :]
        tanh_res = activations.tanh(gate * 0.7978845608 * (1 + 0.044715 * (gate**2)))
        return x * 0.5 * gate * (1 + tanh_res)

In [None]:
# @title BasicTransformerBlock
class BasicTransformerBlock(layers.Layer):
    def __init__(self, dim, num_heads, head_size, **kwargs):
        super().__init__(**kwargs)
        self.norm1 = layers.LayerNormalization(epsilon=EPSILON)
        self.attn1 = CrossAttention(num_heads, head_size)
        self.norm2 = layers.LayerNormalization(epsilon=EPSILON)
        self.attn2 = CrossAttention(num_heads, head_size)
        self.norm3 = layers.LayerNormalization(epsilon=EPSILON)
        self.geglu = GEGLU(dim * 4)
        self.dense = layers.Dense(dim)

    def call(self, inputs):
        x_inputs, context = inputs
        x = self.attn1(self.norm1(x_inputs), context=None) + x_inputs
        x = self.attn2(self.norm2(x), context=context) + x
        return self.dense(self.geglu(self.norm3(x))) + x

In [None]:
# @title SpatialTransformer
class SpatialTransformer(layers.Layer):
    def __init__(self, num_heads, head_size, fully_connected=False, **kwargs):
        super().__init__(**kwargs)
        self.norm = layers.GroupNormalization(epsilon=EPSILON)
        channels = num_heads * head_size

        if fully_connected:
            self.proj1 = layers.Dense(channels)
        else:
            self.proj1 = PaddedConv2D(channels, 1)

        self.transformer_block = BasicTransformerBlock(channels, num_heads, head_size)

        if fully_connected:
            self.proj2 = layers.Dense(channels)
        else:
            self.proj2 = PaddedConv2D(channels, 1)

    def call(self, inputs):
        x_inputs, context = inputs
        _, h, w, c = x_inputs.shape
        x = self.norm(x_inputs)
        x = self.proj1(x)
        x = ops.reshape(x, (-1, h * w, c))
        x = self.transformer_block([x, context])
        x = ops.reshape(x, (-1, h, w, c))
        return self.proj2(x) + x_inputs

In [1]:
# @title ResBlock
class ResBlock(layers.Layer):
    def __init__(self, output_dim, **kwargs):
        super().__init__(**kwargs)
        self.output_dim = output_dim
        self.entry_flow = [
            layers.GroupNormalization(epsilon=EPSILON),
            layers.Activation("swish"),
            PaddedConv2D(output_dim, 3, padding=1),
        ]
        self.embedding_flow = [
            layers.Activation("swish"),
            layers.Dense(output_dim),
        ]
        self.exit_flow = [
            layers.GroupNormalization(epsilon=EPSILON),
            layers.Activation("swish"),
            PaddedConv2D(output_dim, 3, padding=1),
        ]

    def build(self, input_shape):
        if input_shape[0][-1] != self.output_dim:
            self.residual_proj = PaddedConv2D(self.output_dim, 1)
        else:
            self.residual_proj = lambda x: x

    def call(self, inputs):
        x_inputs, embeddings = inputs
        x = x_inputs
        for layer in self.entry_flow:
            x = layer(x)

        for layer in self.embedding_flow:
            embeddings = layer(embeddings)
        x = x + embeddings[:, None, None]

        for layer in self.exit_flow:
            x = layer(x)
        return x + self.residual_proj(x_inputs)

In [None]:
# @title Upsample
class Upsample(layers.Layer):
    def __init__(self, channels, **kwargs):
        super().__init__(**kwargs)
        self.upsample = layers.UpSampling2D(2)
        self.conv = PaddedConv2D(channels, 3, padding=1)

    def call(self, inputs):
        return self.conv(self.upsample(inputs))

In [None]:
# @title CLIPEmbedding
class CLIPEmbedding(layers.Layer):
    def __init__(self, input_dim=49408, output_dim=768, **kwargs):
        super().__init__(**kwargs)
        self.token_embedding = layers.Embedding(input_dim, output_dim)
        self.position_embedding = layers.Embedding(MAX_PROMPT_LENGTH, output_dim)

    def call(self, inputs):
        tokens, positions = inputs
        tokens = self.token_embedding(tokens)
        positions = self.position_embedding(positions)
        return tokens + positions

In [None]:
# @title CLIPAttention
class CLIPAttention(layers.Layer):
    def __init__(self, embed_dim=768, num_heads=12, causal=True, **kwargs):
        super().__init__(**kwargs)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.causal = causal
        self.head_dim = self.embed_dim // self.num_heads
        self.scale = self.head_dim**-0.5
        self.q_proj = layers.Dense(self.embed_dim)
        self.k_proj = layers.Dense(self.embed_dim)
        self.v_proj = layers.Dense(self.embed_dim)
        self.output_proj = layers.Dense(self.embed_dim)

    def reshape_states(self, x, sequence_length, batch_size):
        x = ops.reshape(x, (batch_size, sequence_length, self.num_heads, self.head_dim))
        return ops.transpose(x, (0, 2, 1, 3))  # batch_size, heads, sequence_length, head_dim

    def call(self, inputs, attention_mask=None):
        if attention_mask is None and self.causal:
            length = ops.shape(inputs)[1]
            attention_mask = ops.triu(
                ops.ones((1, 1, length, length), dtype=self.compute_dtype) * -float("inf"),
                k=1,
            )

        _, tgt_len, embed_dim = inputs.shape
        q_states = self.q_proj(inputs) * self.scale
        k_states = self.reshape_states(self.k_proj(inputs), tgt_len, -1)
        v_states = self.reshape_states(self.v_proj(inputs), tgt_len, -1)

        proj_shape = (-1, tgt_len, self.head_dim)
        q_states = self.reshape_states(q_states, tgt_len, -1)
        q_states = ops.reshape(q_states, proj_shape)
        k_states = ops.reshape(k_states, proj_shape)
        v_states = ops.reshape(v_states, proj_shape)

        src_len = tgt_len
        attn_weights = q_states @ ops.transpose(k_states, (0, 2, 1))
        attn_weights = ops.reshape(attn_weights, (-1, self.num_heads, tgt_len, src_len))
        attn_weights = attn_weights + attention_mask
        attn_weights = ops.reshape(attn_weights, (-1, tgt_len, src_len))
        attn_weights = ops.softmax(attn_weights, axis=-1)

        attn_output = attn_weights @ v_states
        attn_output = ops.reshape(attn_output, (-1, self.num_heads, tgt_len, self.head_dim))
        attn_output = ops.transpose(attn_output, (0, 2, 1, 3))
        attn_output = ops.reshape(attn_output, (-1, tgt_len, embed_dim))
        return self.output_proj(attn_output)

In [None]:
# @title CLIPEncoderLayer
class CLIPEncoderLayer(layers.Layer):
    def __init__(self, embed_dim, num_heads, activation=None, **kwargs):
        super().__init__(**kwargs)
        self.layer_norm1 = layers.LayerNormalization(epsilon=EPSILON)
        self.clip_attn = CLIPAttention(embed_dim, num_heads, causal=True)
        self.layer_norm2 = layers.LayerNormalization(epsilon=EPSILON)
        self.fc1 = layers.Dense(embed_dim * 4)
        self.fc2 = layers.Dense(embed_dim)
        self.activation = activation

    def call(self, inputs):
        residual = inputs
        x = self.layer_norm1(inputs)
        x = self.clip_attn(x)
        x = residual + x
        residual = x
        x = self.layer_norm2(x)
        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)
        return x + residual

## Models

In [None]:
# @title TextEncoder
class TextEncoder(Model):
    def __init__(self, vocab_size=49408):
        tokens = layers.Input(shape=(MAX_PROMPT_LENGTH,), dtype="int32", name="tokens")
        positions = layers.Input(shape=(MAX_PROMPT_LENGTH,), dtype="int32", name="positions")

        x = CLIPEmbedding(vocab_size, 768)([tokens, positions])
        for _ in range(12):
            x = CLIPEncoderLayer(768, 12, activation=quick_gelu)(x)

        embedded = layers.LayerNormalization(epsilon=EPSILON)(x)
        super().__init__([tokens, positions], embedded, name=None)

        weights_path = utils.get_file(
            cache_dir=CACHE_DIR,
            cache_subdir=CACHE_SUBDIR,
            # file_hash="4789e63e07c0e54d6a34a29b45ce81ece27060c499a709d556c7755b42bb0dc4",
            origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_encoder.h5",
        )
        self.load_weights(weights_path)

In [None]:
# @title ImageEncoder
class ImageEncoder(Sequential):
    def __init__(self):
        super().__init__(
            [
                layers.Input((None, None, 3)),
                PaddedConv2D(128, 3, padding=1),
                ResnetBlock(128),
                ResnetBlock(128),
                PaddedConv2D(128, 3, padding=((0, 1), (0, 1)), strides=2),
                ResnetBlock(256),
                ResnetBlock(256),
                PaddedConv2D(256, 3, padding=((0, 1), (0, 1)), strides=2),
                ResnetBlock(512),
                ResnetBlock(512),
                PaddedConv2D(512, 3, padding=((0, 1), (0, 1)), strides=2),
                ResnetBlock(512),
                ResnetBlock(512),
                ResnetBlock(512),
                AttentionBlock(512),
                ResnetBlock(512),
                layers.GroupNormalization(epsilon=EPSILON),
                layers.Activation("swish"),
                PaddedConv2D(8, 3, padding=1),
                PaddedConv2D(8, 1),
                layers.Lambda(lambda x: x[..., :4] * 0.18215),
            ]
        )
        weights_path = utils.get_file(
            cache_dir=CACHE_DIR,
            cache_subdir=CACHE_SUBDIR,
            # file_hash="c60fb220a40d090e0f86a6ab4c312d113e115c87c40ff75d11ffcf380aab7ebb",
            origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/vae_encoder.h5",
        )
        self.load_weights(weights_path)

In [None]:
# @title Decoder
class Decoder(Sequential):
    def __init__(self, img_height, img_width):
        super().__init__(
            [
                layers.Input((img_height // 8, img_width // 8, 4)),
                layers.Rescaling(1.0 / 0.18215),
                PaddedConv2D(4, 1),
                PaddedConv2D(512, 3, padding=1),
                ResnetBlock(512),
                AttentionBlock(512),
                ResnetBlock(512),
                ResnetBlock(512),
                ResnetBlock(512),
                ResnetBlock(512),
                layers.UpSampling2D(2),
                PaddedConv2D(512, 3, padding=1),
                ResnetBlock(512),
                ResnetBlock(512),
                ResnetBlock(512),
                layers.UpSampling2D(2),
                PaddedConv2D(512, 3, padding=1),
                ResnetBlock(256),
                ResnetBlock(256),
                ResnetBlock(256),
                layers.UpSampling2D(2),
                PaddedConv2D(256, 3, padding=1),
                ResnetBlock(128),
                ResnetBlock(128),
                ResnetBlock(128),
                layers.GroupNormalization(epsilon=1e-5),
                layers.Activation("swish"),
                PaddedConv2D(3, 3, padding=1),
            ],
        )
        weights_path = utils.get_file(
            cache_dir=CACHE_DIR,
            cache_subdir=CACHE_SUBDIR,
            # file_hash="ad350a65cc8bc4a80c8103367e039a3329b4231c2469a1093869a345f55b1962",
            origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_decoder.h5",
        )
        self.load_weights(weights_path)

In [None]:
# @title DiffusionModel
class DiffusionModel(Model):
    def __init__(
        self,
        img_height,
        img_width,
    ):
        context = layers.Input((MAX_PROMPT_LENGTH, 768), name="context")
        t_embed_input = layers.Input((320,), name="timestep_embedding")
        latent = layers.Input((img_height // 8, img_width // 8, 4), name="latent")
        t_emb = layers.Dense(1280)(t_embed_input)
        t_emb = layers.Activation("swish")(t_emb)
        t_emb = layers.Dense(1280)(t_emb)

        # Downsampling flow
        outputs = []
        x = PaddedConv2D(320, kernel_size=3, padding=1)(latent)
        outputs.append(x)

        for _ in range(2):
            x = ResBlock(320)([x, t_emb])
            x = SpatialTransformer(8, 40, fully_connected=False)([x, context])
            outputs.append(x)
        x = PaddedConv2D(320, 3, strides=2, padding=1)(x)  # Downsample 2x
        outputs.append(x)

        for _ in range(2):
            x = ResBlock(640)([x, t_emb])
            x = SpatialTransformer(8, 80, fully_connected=False)([x, context])
            outputs.append(x)
        x = PaddedConv2D(640, 3, strides=2, padding=1)(x)  # Downsample 2x
        outputs.append(x)

        for _ in range(2):
            x = ResBlock(1280)([x, t_emb])
            x = SpatialTransformer(8, 160, fully_connected=False)([x, context])
            outputs.append(x)
        x = PaddedConv2D(1280, 3, strides=2, padding=1)(x)  # Downsample 2x
        outputs.append(x)

        for _ in range(2):
            x = ResBlock(1280)([x, t_emb])
            outputs.append(x)

        # Middle flow
        x = ResBlock(1280)([x, t_emb])
        x = SpatialTransformer(8, 160, fully_connected=False)([x, context])
        x = ResBlock(1280)([x, t_emb])

        # Upsampling flow
        for _ in range(3):
            x = layers.Concatenate()([x, outputs.pop()])
            x = ResBlock(1280)([x, t_emb])
        x = Upsample(1280)(x)

        for _ in range(3):
            x = layers.Concatenate()([x, outputs.pop()])
            x = ResBlock(1280)([x, t_emb])
            x = SpatialTransformer(8, 160, fully_connected=False)([x, context])
        x = Upsample(1280)(x)

        for _ in range(3):
            x = layers.Concatenate()([x, outputs.pop()])
            x = ResBlock(640)([x, t_emb])
            x = SpatialTransformer(8, 80, fully_connected=False)([x, context])
        x = Upsample(640)(x)

        for _ in range(3):
            x = layers.Concatenate()([x, outputs.pop()])
            x = ResBlock(320)([x, t_emb])
            x = SpatialTransformer(8, 40, fully_connected=False)([x, context])

        # Exit flow
        x = layers.GroupNormalization(epsilon=EPSILON)(x)
        x = layers.Activation("swish")(x)
        output = PaddedConv2D(4, kernel_size=3, padding=1)(x)

        super().__init__([latent, t_embed_input, context], output, name=None)

        weights_path = utils.get_file(
            cache_dir=CACHE_DIR,
            cache_subdir=CACHE_SUBDIR,
            # file_hash="8799ff9763de13d7f30a683d653018e114ed24a6a819667da4f5ee10f9e805fe",
            origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_diffusion_model.h5",
        )
        self.load_weights(weights_path)

## Pipeline

In [None]:
# @title StableDiffusion
class StableDiffusion:
    """
    Arguments:
        img_height: int, height of the images to generate, in pixels. Note that only multiples of 128 are supported; the value provided will be rounded to the nearest valid value. Defaults to 512.
        img_width: int, width of the images to generate, in pixels. Note that only multiples of 128 are supported; the value provided will be rounded to the nearest valid value. Defaults to 512.
        batch_size: int, number of images to generate. Defaults to 1.
        jit_compile: bool or "auto", whether to compile the underlying models to XLA. This can lead to a significant speedup on some systems. Defaults to "auto".
    Example:
    ```python
    from PIL import Image
    model = StableDiffusion()
    img = model.text_to_image(prompt="A beautiful horse running through a field", seed=42)
    Image.fromarray(img[0]).save("horse.png")
    print("saved at horse.png")
    ```
    """

    def __init__(self, batch_size=1, img_height=512, img_width=512, jit_compile="auto"):
        # UNet requires multiples of 2^7 (128)
        img_height = round(img_height / 128) * 128
        img_width = round(img_width / 128) * 128
        self.batch_size = batch_size
        self.img_height = img_height
        self.img_width = img_width

        self.image_encoder = ImageEncoder()
        self.text_encoder = TextEncoder()
        self.diffusion_model = DiffusionModel(img_height, img_width)
        self.decoder = Decoder(img_height, img_width)

        self.tokenizer = CLIPTokenizer()
        self.scheduler = DDPMScheduler(
            beta_end=0.012,
            beta_start=0.00085,
            beta_schedule="scaled_linear",
        )  # from tinygrad

        self.image_encoder.compile(jit_compile=jit_compile)
        self.text_encoder.compile(jit_compile=jit_compile)
        self.diffusion_model.compile(jit_compile=jit_compile)
        self.decoder.compile(jit_compile=jit_compile)

    def text_to_image(
        self,
        prompt,
        negative_prompt=None,
        num_steps=50,
        guidance_scale=7.5,
        seed=None,
    ):
        encoded_text = self.encode_text(prompt)
        return self.generate_image(
            encoded_text,
            negative_prompt=negative_prompt,
            num_steps=num_steps,
            guidance_scale=guidance_scale,
            seed=seed,
        )

    def encode_text(self, prompt):
        """
        Encodes a prompt into a latent text encoding.
        The encoding produced by this method should be used as the `encoded_text` parameter of `StableDiffusion.generate_image`.
        Encoding text separately from generating an image can be used to arbitrarily modify the text encoding prior to image generation, e.g. for walking between two prompts.
        Args:
            prompt: a string to encode, must be 77 tokens or shorter.
        Example:
        ```python
        model = StableDiffusion()
        encoded_text = model.encode_text("Tacos at dawn")
        img = model.generate_image(encoded_text)
        ```
        """
        inputs = self.tokenizer.encode(prompt)
        if len(inputs) > MAX_PROMPT_LENGTH:
            raise ValueError(f"Prompt is too long (should be <= {MAX_PROMPT_LENGTH} tokens)")

        # pad to max tokens
        phrase = inputs + [49407] * (MAX_PROMPT_LENGTH - len(inputs))
        phrase = ops.convert_to_tensor([phrase], dtype="int32")

        # context
        return self.text_encoder.predict_on_batch(
            {"tokens": phrase, "positions": self._get_pos_ids()}
        )

    def generate_image(
        self,
        encoded_text,
        negative_prompt=None,
        num_steps=50,
        guidance_scale=7.5,
        seed=None,
    ):
        """
        Generates an image based on encoded text.
        The encoding passed to this method should be derived from `StableDiffusion.encode_text`.
        Args:
            encoded_text: Tensor of shape (`batch_size`, 77, 768), or a Tensor of shape (77, 768). When the batch axis is omitted, the same encoded text will be used to produce every generated image.
            negative_prompt: a string containing information to negatively guide the image generation (e.g. by removing or altering certain aspects of the generated image), defaults to None.
            num_steps: int, number of diffusion steps (controls image quality), defaults to 50.
            guidance_scale: float, controlling how closely the image should adhere to the prompt. Larger values result in more closely adhering to the prompt, but will make the image noisier. Defaults to 7.5.
            seed: integer which is used to seed the random generation of diffusion noise.
        Example:
        ```python
        batch_size = 8
        model = StableDiffusion()
        e_tacos = model.encode_text("Tacos at dawn")
        e_watermelons = model.encode_text("Watermelons at dusk")
        e_interpolated = keras.ops.linspace(e_tacos, e_watermelons, batch_size)
        images = model.generate_image(e_interpolated)
        ```
        """
        context = self._expand_tensor(encoded_text)

        if negative_prompt is None:
            unconditional_context = ops.repeat(
                self._get_unconditional_context(),
                self.batch_size,
                axis=0,
            )
        else:
            unconditional_context = self.encode_text(negative_prompt)
            unconditional_context = self._expand_tensor(unconditional_context)

        latent = self._get_initial_diffusion_noise(seed)

        # iterative reverse diffusion stage
        num_timesteps = 1000
        ratio = (num_timesteps - 1) / (num_steps - 1) if num_steps > 1 else num_timesteps
        timesteps = (np.arange(0, num_steps) * ratio).round().astype(np.int64)
        alphas, alphas_prev = self._get_initial_alphas(timesteps)
        progbar = utils.Progbar(len(timesteps))
        iteration = 0

        for index, timestep in list(enumerate(timesteps))[::-1]:
            latent_prev = latent  # set aside the previous latent vector
            t_emb = self._get_timestep_embedding(timestep)
            unconditional_latent = self.diffusion_model.predict_on_batch(
                {
                    "latent": latent,
                    "timestep_embedding": t_emb,
                    "context": unconditional_context,
                }
            )
            latent = self.diffusion_model.predict_on_batch(
                {
                    "latent": latent,
                    "timestep_embedding": t_emb,
                    "context": context,
                }
            )
            latent = latent - unconditional_latent
            latent = ops.array(unconditional_latent + guidance_scale * latent)
            a_t, a_prev = alphas[index], alphas_prev[index]
            target_dtype = latent_prev.dtype  # keras backend array need to cast explicitly
            latent = ops.cast(latent, target_dtype)
            pred_x0 = (latent_prev - math.sqrt(1 - a_t) * latent) / math.sqrt(a_t)
            latent = ops.array(latent) * math.sqrt(1.0 - a_prev) + math.sqrt(a_prev) * pred_x0
            iteration += 1
            progbar.update(iteration)

        # decoding stage
        decoded = self.decoder.predict_on_batch(latent)
        decoded = ((decoded + 1) / 2) * 255
        return np.clip(decoded, 0, 255).astype("uint8")

    def _get_unconditional_context(self):
        unconditional_tokens = [49406] + [49407] * (MAX_PROMPT_LENGTH - 1)
        unconditional_tokens = ops.convert_to_tensor([unconditional_tokens], dtype="int32")
        unconditional_context = self.text_encoder.predict_on_batch(
            {
                "tokens": unconditional_tokens,
                "positions": self._get_pos_ids(),
            }
        )
        return unconditional_context

    def _expand_tensor(self, text_embedding):
        text_embedding = ops.squeeze(text_embedding)
        if len(text_embedding.shape) == 2:
            text_embedding = ops.repeat(
                ops.expand_dims(text_embedding, axis=0),
                self.batch_size,
                axis=0,
            )
        return text_embedding

    def _get_timestep_embedding(self, timestep, dim=320, max_period=10000):
        half = dim // 2
        range = ops.cast(ops.arange(0, half), "float32")
        freqs = ops.exp(-math.log(max_period) * range / half)
        args = ops.convert_to_tensor([timestep], dtype="float32") * freqs
        embedding = ops.concatenate([ops.cos(args), ops.sin(args)], 0)
        embedding = ops.reshape(embedding, [1, -1])
        return ops.repeat(embedding, self.batch_size, axis=0)

    def _get_initial_alphas(self, timesteps):
        cumprod = self.scheduler.alphas_cumprod
        alphas = [cumprod[t] for t in timesteps]
        alphas_prev = [1.0] + alphas[:-1]
        return alphas, alphas_prev

    def _get_initial_diffusion_noise(self, seed):
        return random.normal(
            (self.batch_size, self.img_height // 8, self.img_width // 8, 4),
            seed=seed,
        )

    @staticmethod
    def _get_pos_ids():
        return ops.expand_dims(ops.arange(MAX_PROMPT_LENGTH, dtype="int32"), 0)

In [None]:
config.set_dtype_policy(GLOBAL_DTYPE)
model = StableDiffusion(batch_size=BATCH_SIZE, img_height=IMG_HEIGHT, img_width=IMG_WIDTH)  # ~2m

In [None]:
images = model.text_to_image(
    "cute corgi at the beach, light sand, blue waves",
    negative_prompt="deformed, clutter",
)
display_images(images)