# **High-performance Image Generation using Stable Diffusion in KerasCV**

## **Written by:** [Aarish Asif Khan](https://www.kaggle.com/aarishasifkhan)

##  **Date:** 27th March 2024

## **Website of Tensorflow:** [Tensorflow Org](https://www.tensorflow.org/tutorials/generative/generate_images_with_stable_diffusion)

## **Credits to:** fchollet, lukewood and divamgupta

# **`Note:`**

If you have no idea what Stable diffusion is, than check out my Previous notebook that I published, so you can understand the basic concepts required! Thanks

# **Overview**

In this notebook, we will show how to generate novel images based on a text prompt using the` KerasCV implementation` of `stability.ai's` text-to-image model, `Stable Diffusion.`

`Stable Diffusion is a powerful, open-source text-to-image generation model.` While there exist multiple open-source implementations that allow you to easily create images from textual prompts, KerasCV's offers a few distinct advantages. These include XLA compilation and mixed precision support, which together achieve state-of-the-art generation speed.

To get started, let's install a few dependencies and sort out some imports.

In [4]:
# pip install tensorflow keras_cv --upgrade --quiet

In [5]:
# Import libraries
import time
import keras_cv

from tensorflow import keras
import matplotlib.pyplot as plt

# **Introduction**

Check out the power of `keras_cv.models.StableDiffusion().`

First, we construct a model.

In [6]:
model = keras_cv.models.StableDiffusion(img_width=512, img_height=512)

By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE


Now, we will give the model a prompt.

In [7]:
images = model.text_to_image("photograph of an astronaut riding a horse", batch_size=3)


def plot_images(images):
    plt.figure(figsize=(20, 20))
    for i in range(len(images)):
        ax = plt.subplot(1, len(images), i + 1)
        plt.imshow(images[i])
        plt.axis("off")


plot_images(images)

Downloading data from https://github.com/openai/CLIP/blob/main/clip/bpe_simple_vocab_16e6.txt.gz?raw=true
[1m1356917/1356917[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step

Downloading data from https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_encoder.h5
[1m492466864/492466864[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m212s[0m 0us/step
Downloading data from https://huggingface.co/fchollet/stable-diffusion/resolve/main/kcv_diffusion_model.h5
[1m3439090152/3439090152[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1510s[0m 0us/step
[1m 6/50[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m55:06[0m 75s/step

In [None]:
for i, img in enumerate(images):
    keras.preprocessing.image.save_img(f"image_{i}.png", img)

But that's not all this model can do, let's try a more complex prompt.

In [None]:
images = model.text_to_image(
    "cute magical flying dog, fantasy art, "
    "golden color, high quality, highly detailed, elegant, sharp focus, "
    "concept art, character concepts, digital painting, mystery, adventure",
    batch_size=3,
)
plot_images(images)

# **How are these Images generated?**

Unlike what you might expect at this point, StableDiffusion doesn't actually run on magic. It's a kind of `"latent diffusion model"`. Let's dig into what that means.

You may be familiar with the idea of super-resolution: it's possible to train a deep learning model to denoise an input image -- and thereby turn it into a higher-resolution version. The deep learning model doesn't do this by magically recovering the information that's missing from the noisy, low-resolution input -- rather, the model uses its training data distribution to hallucinate the visual details that would be most likely given the input. 

# **Perks of KerasCV**

With several implementations of Stable Diffusion publicly available why should you use `keras_cv.models.StableDiffusion?`

Aside from the easy-to-use API, KerasCV's Stable Diffusion model comes with some powerful advantages, including:

* **`Graph mode execution`**

* **`XLA compilation through jit_compile=True`**

* **`Support for mixed precision computation`**

In [None]:
benchmark_result = []
start = time.time()
images = model.text_to_image(
    "A cute otter in a rainbow whirlpool holding shells, watercolor",
    batch_size=3,
)
end = time.time()
benchmark_result.append(["Standard", end - start])
plot_images(images)

print(f"Standard model: {(end - start):.2f} seconds")
keras.backend.clear_session()  # Clear session to preserve memory.

# **Mixed Precision**

`"Mixed precision"` consists of performing computation using float16 precision, while storing weights in the float32 format. This is done to take advantage of the fact that float16 operations are backed by significantly faster kernels than their float32 counterparts on modern NVIDIA GPUs.

In [None]:
keras.mixed_precision.set_global_policy("mixed_float16")

In [None]:
model = keras_cv.models.StableDiffusion()

print("Compute dtype:", model.diffusion_model.compute_dtype)
print(
    "Variable dtype:",
    model.diffusion_model.variable_dtype,
)

In [None]:
# Warm up model to run graph tracing before benchmarking.
model.text_to_image("warming up the model", batch_size=3)

start = time.time()
images = model.text_to_image(
    "a cute magical flying dog, fantasy art, "
    "golden color, high quality, highly detailed, elegant, sharp focus, "
    "concept art, character concepts, digital painting, mystery, adventure",
    batch_size=3,
)
end = time.time()
benchmark_result.append(["Mixed Precision", end - start])
plot_images(images)

print(f"Mixed precision model: {(end - start):.2f} seconds")
keras.backend.clear_session()

# **XLA Compilation**

TensorFlow comes with the `XLA: Accelerated Linear Algebra compiler built-in keras_cv.models.`StableDiffusion supports a jit_compile argument out of the box. Setting this argument to True enables XLA compilation, resulting in a significant speed-up.

In [None]:
# Set back to the default for benchmarking purposes.
keras.mixed_precision.set_global_policy("float32")

model = keras_cv.models.StableDiffusion(jit_compile=True)

# Before we benchmark the model, we run inference once to make sure the TensorFlow
images = model.text_to_image("An avocado armchair", batch_size=3)
plot_images(images)

In [None]:
start = time.time()
images = model.text_to_image(
    "A cute otter in a rainbow whirlpool holding shells, watercolor",
    batch_size=3,
)
end = time.time()
benchmark_result.append(["XLA", end - start])
plot_images(images)

print(f"With XLA: {(end - start):.2f} seconds")
keras.backend.clear_session()