<a href="https://colab.research.google.com/github/borisdayma/dalle-mini/blob/main/tools/inference/inference_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DALL·E mini - Inference pipeline

*Generate images from a text prompt*

<img src="https://github.com/borisdayma/dalle-mini/blob/main/img/logo.png?raw=true" width="200">

This notebook illustrates [DALL·E mini](https://github.com/borisdayma/dalle-mini) inference pipeline.

Just want to play? Use directly [the app](https://www.craiyon.com/).

For more understanding of the model, refer to [the report](https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA).

## 🛠️ Installation and set-up

In [None]:
# Required only for colab environments + GPU
%pip install jax==0.3.25 jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# pip install jax==0.3.25 jaxlib==0.3.25 chex==0.1.5 optax==0.1.4 flax==0.6.3
# Install required libraries
%pip install -q dalle-mini
%pip install -q git+https://github.com/patil-suraj/vqgan-jax.git

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
[31mERROR: Ignored the following yanked versions: 0.4.32[0m[31m
[0m[31mERROR: Could not find a version that satisfies the requirement jaxlib==0.3.25 (from versions: 0.4.17, 0.4.17+cuda11.cudnn86, 0.4.17+cuda12.cudnn89, 0.4.18, 0.4.18+cuda11.cudnn86, 0.4.18+cuda12.cudnn89, 0.4.19, 0.4.19+cuda11.cudnn86, 0.4.19+cuda12.cudnn89, 0.4.20, 0.4.20+cuda11.cudnn86, 0.4.20+cuda12.cudnn89, 0.4.21, 0.4.21+cuda11.cudnn86, 0.4.21+cuda12.cudnn89, 0.4.22, 0.4.22+cuda11.cudnn86, 0.4.22+cuda12.cudnn89, 0.4.23, 0.4.23+cuda11.cudnn86, 0.4.23+cuda12.cudnn89, 0.4.24, 0.4.24+cuda11.cudnn86, 0.4.24+cuda12.cudnn89, 0.4.25, 0.4.25+cuda11.cudnn86, 0.4.25+cuda12.cudnn89, 0.4.26, 0.4.26+cuda12.cudnn89, 0.4.27, 0.4.27+cuda12.cudnn89, 0.4.28, 0.4.28+cuda12.cudnn89, 0.4.29, 0.4.29+cuda12.cudnn91, 0.4.30, 0.4.31, 0.4.33, 0.4.34, 0.4.35, 0.4.36, 0.4.38, 0.5.0)[0m[31m
[0m[31mERROR: No matching distribution found for jaxlib==0.3.

We load required models:
* DALL·E mini for text to encoded images
* VQGAN for decoding images
* CLIP for scoring predictions

In [6]:
# Model references

# dalle-mega
#DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest"  # can be wandb artifact or 🤗 Hub or local folder or google bucket
DALLE_COMMIT_ID = None

# if the notebook crashes too often you can use dalle-mini instead by uncommenting below line
DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0"

# VQGAN model
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"

In [7]:
import jax
import jax.numpy as jnp

# check how many devices are available
jax.local_device_count()

AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)

In [8]:
# Load models & tokenizer
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
from transformers import CLIPProcessor, FlaxCLIPModel

# Load dalle-mini
model, params = DalleBart.from_pretrained(
    DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False
)

# Load VQGAN
vqgan, vqgan_params = VQModel.from_pretrained(
    VQGAN_REPO, revision=VQGAN_COMMIT_ID, _do_init=False
)

2025-02-04 16:28:21.687100: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1738682901.696273   27583 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738682901.699013   27583 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)

Model parameters are replicated on each device for faster inference.

In [9]:
from flax.jax_utils import replicate

params = replicate(params)
vqgan_params = replicate(vqgan_params)

AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)

Model functions are compiled and parallelized to take advantage of multiple devices.

In [None]:
from functools import partial


# model inference
@partial(jax.pmap, axis_name="batch", static_broadcasted_argnums=(3, 4, 5, 6))
def p_generate(
    tokenized_prompt, key, params, top_k, top_p, temperature, condition_scale
):
    return model.generate(
        **tokenized_prompt,
        prng_key=key,
        params=params,
        top_k=top_k,
        top_p=top_p,
        temperature=temperature,
        condition_scale=condition_scale,
    )


# decode image
@partial(jax.pmap, axis_name="batch")
def p_decode(indices, params):
    return vqgan.decode_code(indices, params=params)

Keys are passed to the model on each device to generate unique inference per device.

In [None]:
import random

# create a random key
seed = random.randint(0, 2**32 - 1)
key = jax.random.PRNGKey(seed)

## 🖍 Text Prompt

Our model requires processing prompts.

In [None]:
from dalle_mini import DalleBartProcessor

processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)

Let's define some text prompts.

In [None]:
prompts = [
    "Bird perched on a branch in a sunny park",
    "Two sparrows hopping on a grassy field",
    "Hawk flying across a clear blue sky",
    "Seagull resting on a sandy beach",
    "Robin sitting on a fence in spring",
    "Finch perched on a tree in autumn",
    "Swallow swooping over a calm lake",
    "Bluebird flying above a green meadow",
    "Woodpecker tapping on a tree trunk",
    "Chickadee sitting on a snow-covered branch",
    "Small bird drinking from a birdbath",
    "Geese walking along a quiet path",
    "Kingfisher perched on a dock post",
    "Crow standing on a rock by the river",
    "Pigeon sitting on a windowsill",
    "Sparrows flying through a garden",
    "Hummingbird hovering near a blooming flower",
    "Pair of swans gliding on a pond",
    "Morning dove sitting on a rooftop",
    "Yellow bird eating seeds on a ledge"
]

In [None]:
Note: we could use the same prompt multiple times for faster inference.

In [None]:
tokenized_prompts = processor(prompts)

Finally we replicate the prompts onto each device.

In [None]:
tokenized_prompt = replicate(tokenized_prompts)

## 🎨 Generate images

We generate images using dalle-mini model and decode them with the VQGAN.

In [None]:
# number of predictions per prompt
n_predictions = 8

# We can customize generation parameters (see https://huggingface.co/blog/how-to-generate)
gen_top_k = None
gen_top_p = None
temperature = None
cond_scale = 10.0

In [None]:
from flax.training.common_utils import shard_prng_key
import numpy as np
from PIL import Image
from tqdm.notebook import trange

print(f"Prompts: {prompts}\n")
# generate images
images = []
for i in trange(max(n_predictions // jax.device_count(), 1)):
    # get a new key
    key, subkey = jax.random.split(key)
    # generate images
    encoded_images = p_generate(
        tokenized_prompt,
        shard_prng_key(subkey),
        params,
        gen_top_k,
        gen_top_p,
        temperature,
        cond_scale,
    )
    # remove BOS
    encoded_images = encoded_images.sequences[..., 1:]
    # decode images
    decoded_images = p_decode(encoded_images, vqgan_params)
    decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
    for decoded_img in decoded_images:
        img = Image.fromarray(np.asarray(decoded_img * 255, dtype=np.uint8))
        images.append(img)
        display(img)
        print()

## 🏅 Optional: Rank images by CLIP score

We can rank images according to CLIP.

**Note: your session may crash if you don't have a subscription to Colab Pro.**

In [None]:
# CLIP model
CLIP_REPO = "openai/clip-vit-base-patch32"
CLIP_COMMIT_ID = None

# Load CLIP
clip, clip_params = FlaxCLIPModel.from_pretrained(
    CLIP_REPO, revision=CLIP_COMMIT_ID, dtype=jnp.float16, _do_init=False
)
clip_processor = CLIPProcessor.from_pretrained(CLIP_REPO, revision=CLIP_COMMIT_ID)
clip_params = replicate(clip_params)


# score images
@partial(jax.pmap, axis_name="batch")
def p_clip(inputs, params):
    logits = clip(params=params, **inputs).logits_per_image
    return logits

In [None]:
from flax.training.common_utils import shard

# get clip scores
clip_inputs = clip_processor(
    text=prompts * jax.device_count(),
    images=images,
    return_tensors="np",
    padding="max_length",
    max_length=77,
    truncation=True,
).data
logits = p_clip(shard(clip_inputs), clip_params)

# organize scores per prompt
p = len(prompts)
logits = np.asarray([logits[:, i::p, i] for i in range(p)]).squeeze()

Let's now display images ranked by CLIP score.

In [None]:
for i, prompt in enumerate(prompts):
    print(f"Prompt: {prompt}\n")
    for idx in logits[i].argsort()[::-1]:
        display(images[idx * p + i])
        print(f"Score: {jnp.asarray(logits[i][idx], dtype=jnp.float32):.2f}\n")
    print()

## 🪄 Optional: Save your Generated Images as W&B Tables

W&B Tables is an interactive 2D grid with support to rich media logging. Use this to save the generated images on W&B dashboard and share with the world.

In [None]:
import wandb

# Initialize a W&B run.
project = "dalle-mini-tables-colab"
run = wandb.init(project=project)

# Initialize an empty W&B Tables.
columns = ["captions"] + [f"image_{i+1}" for i in range(n_predictions)]
gen_table = wandb.Table(columns=columns)

# Add data to the table.
for i, prompt in enumerate(prompts):
    # If CLIP scores exist, sort the Images
    if logits is not None:
        idxs = logits[i].argsort()[::-1]
        tmp_imgs = images[i :: len(prompts)]
        tmp_imgs = [tmp_imgs[idx] for idx in idxs]
    else:
        tmp_imgs = images[i :: len(prompts)]

    # Add the data to the table.
    gen_table.add_data(prompt, *[wandb.Image(img) for img in tmp_imgs])

# Log the Table to W&B dashboard.
wandb.log({"Generated Images": gen_table})

# Close the W&B run.
run.finish()

Click on the link above to check out your generated images.