# Pre-processing for Stable Diffusion V2

Let's build a scalable preprocessing pipeline for the Stable Diffusion V2 model.

<div class="alert alert-block alert-info">
<b> Here is the roadmap for this notebook:</b>
<ul>
    <li><b>Part 0:</b> High-level overview of the preprocessing pipeline</li>
    <li><b>Part 1:</b> Reading in the data</li>
    <li><b>Part 2:</b> Transforming images and captions</li>
    <li><b>Part 3:</b> Encoding of images and captions</li>
    <li><b>Part 4:</b> Writing out the preprocessed data</li>
</ul>
</div>

## Imports

In [None]:
import os
import gc
import uuid
import io
import logging
from typing import Optional, Any

import matplotlib.pyplot as plt
import numpy as np
import pyarrow as pa  # type: ignore
import ray.data
import torch
import torchvision  # type: ignore
from diffusers.models import AutoencoderKL
from PIL import Image
from transformers import CLIPTextModel, CLIPTokenizer  # type: ignore

# 0. High-level overview of the preprocessing pipeline

Here is a high-level overview of the preprocessing pipeline:

<img src="https://anyscale-materials.s3.us-west-2.amazonaws.com/stable-diffusion/preprocessing_architecture_v4.jpeg" width="900px">

Ray Data loads the data from a remote storage system, then streams the data through two processing main stages:
1. **Transformation**
   1. Cropping and normalizing images.
   2. Tokenizing the text captions using a CLIP tokenizer.
2. **Encoding**
   1. Compressing images into a latent space using a VAE encoder.
   2. Generating text embeddings using a CLIP model.


### 1. Reading in the data

We're going to preprocess part of the LAION-art-8M dataset. To save time, we have provided a sample of the dataset on S3.

We'll read this sample data and create a Ray dataset from it.

In [None]:
schema = pa.schema(
    [
        pa.field("caption", getattr(pa, "string")()),
        pa.field("height", getattr(pa, "float64")()),
        pa.field("width", getattr(pa, "float64")()),
        pa.field("jpg", getattr(pa, "binary")()),
    ]
)

ds = ray.data.read_parquet(
    "s3://anyscale-public-materials/ray-summit/stable-diffusion/data/raw/",
    schema=schema,
)

ds

We know that when we run that step, we're not actually processing the whole dataset -- that's the whole idea behind lazy execution of the data pipeline.

But Ray does sample the data to determine metadata like the number of files and data schema.

### 2. Transforming images and captions

#### 2.1 Cropping and normalizing images
We start by preprocessing the images: 

We need to perform these two operations on the images:
1. Crop the images to a square aspect ratio.
2. Normalize the pixel values to the distribution expected by the VAE encoder.


#### Step 1. Cropping the image

In [None]:
class LargestCenterSquare:
    """Largest center square crop for images."""

    def __init__(self, size: int) -> None:
        self.size = size

    def __call__(self, img: Image.Image) -> Image.Image:
        """Crop the largest center square from an image."""
        # First, resize the image such that the smallest
        # side is self.size while preserving aspect ratio.
        img = torchvision.transforms.functional.resize(
            img=img,
            size=self.size,
        )

        # Then take a center crop to a square.
        w, h = img.size
        c_top = (h - self.size) // 2
        c_left = (w - self.size) // 2
        img = torchvision.transforms.functional.crop(
            img=img,
            top=c_top,
            left=c_left,
            height=self.size,
            width=self.size,
        )
        return img

In [None]:
resolution = 512
crop = LargestCenterSquare(resolution)

Let's take a simple example to understand visualize how the crop function works.

In [None]:
ds_example = ds.filter(lambda row: row["caption"] == 'strawberry-lemonmousse-cake-3')
example_image = ds_example.take(1)[0]
image = Image.open(io.BytesIO(example_image["jpg"]))
image

In [None]:
crop(image)

#### Step 2. Normalizing the image

We need to normalize the pixel values to the distribution expected by the VAE encoder. 

The VAE encoder expects pixel values in the range [-1, 1]

Our images are in the range [0, 1] with an approximate mean of 0.5 in the center. 

To normalize the images, we'll subtract 0.5 from each pixel value and divide by 0.5

In [None]:
normalize = torchvision.transforms.Compose(
    [
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)

In [None]:
normalized = normalize(crop(image))

normalized.min(), normalized.max()

#### Putting it together into a single transform function

We build a `transform_images` below to crop and normalize the images.

In [None]:
def convert_tensor_to_array(tensor: torch.Tensor, dtype=np.float32) -> np.ndarray:
    """Convert a torch tensor to a numpy array."""
    array = tensor.detach().cpu().numpy()
    return array.astype(dtype)


def transform_images(row: dict[str, Any]) -> np.ndarray:
    """Transform image to a square-sized normalized tensor."""
    try:
        image = Image.open(io.BytesIO(row["jpg"]))
    except Exception as e:
        logging.error(f"Error opening image: {e}")
        return []

    if image.mode != "RGB":
        image = image.convert("RGB")

    image = crop(image)
    normalized_image_tensor = normalize(image)

    row[f"image_{resolution}"] = convert_tensor_to_array(normalized_image_tensor)
    return [row]

<div class="alert alert-block alert-warning">

Note how we reference `crop` and `normalize` functions in the `transform_images` function. Those outer-scope objects are serialized and shipped along with the remote function definition.

In this case, they are tiny, but in other cases -- say, we have a 16GB model we're referencing -- we would not want to rely on this scope behavior but would want to use other mechanisms to make those objects availabe to the workers.

</div>

Now we call `flat_map` to apply the `transform_images` function to each row in the dataset.

In [None]:
ds_img_transformed = ds.flat_map(transform_images)

ds_img_transformed

What happened to our schema?

* `flat_map` is purely lazy ... applying didn't physically process any data at all, and since `flat_map` might have changed the schema of the records, Ray doesn't know what the resulting schema is

If we want (or need) to inspect this behavior for development or debugging purposes, we can run the pipeline on a small part of the data using `take`

In [None]:
image_transformed = ds_img_transformed.take(2)[1]
(
    image_transformed["image_512"].shape,
    image_transformed["image_512"].min(),
    image_transformed["image_512"].max(),
)

### 3. Tokenize the text captions

Now we'll want to tokenize the text captions using a CLIP tokenizer.

Let's load a text tokenizer and inspect its behavior.

In [None]:
text_tokenizer = CLIPTokenizer.from_pretrained(
    "stabilityai/stable-diffusion-2-base", subfolder="tokenizer"
)

Let's call the tokenizer on a simple string to get the token ids and tokens.

In [None]:
token_ids = text_tokenizer("strawberry-lemonmousse-cake-3")["input_ids"]
token_ids

In [None]:
tokens = text_tokenizer.convert_ids_to_tokens(token_ids)
tokens

We can now define a function that will tokenize a batch of text captions

In [None]:
def tokenize_text(batch: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
    """Tokenize the caption."""
    batch["caption_ids"] = text_tokenizer(
        batch["caption"].tolist(),
        padding="max_length",
        max_length=text_tokenizer.model_max_length,
        truncation=True,
        return_tensors="np",
    )["input_ids"]
    return batch

In [None]:
ds_img_txt_transformed = ds_img_transformed.map_batches(tokenize_text)
example_txt_transformed = ds_img_txt_transformed.filter(
    lambda row: row["caption"] == "strawberry-lemonmousse-cake-3"
).take(1)[0]
print(example_txt_transformed["caption"])
example_txt_token_ids = example_txt_transformed["caption_ids"]
example_txt_token_ids

#### Understanding Ray Data's Operator Fusion


Inspecting the execution plan of the dataset so far we see:

```
Execution plan of Dataset: 
InputDataBuffer[Input] 
-> TaskPoolMapOperator[ReadParquet]
-> TaskPoolMapOperator[FlatMap(transform_images)->MapBatches(tokenize_text)]
-> LimitOperator[limit=2]
```

Note how `transform_images` and `tokenize_text` functions are fused into a single operator.

This is an optimization that Ray Data performs to reduce the number of times we need to serialize and deserialize data between Python processes.

If Ray Data did not do this then it would have been advised to construct a `transform_images_and_text` transformation that combines the image and text transformations into a single function to reduce the number of times we need to serialize and deserialize data.

### 4. Encode images and captions

We'll compress images into a latent space using a VAE encoder and generate text embeddings using a CLIP model.

In [None]:
class SDImageEncoder:
    def __init__(self, model_name: str, device: torch.device) -> None:
        self.vae = AutoencoderKL.from_pretrained(
            model_name,
            subfolder="vae",
            torch_dtype=torch.float16 if device == "cuda" else torch.float32,
        ).to(device)
        self.device = device

    def encode_images(self, images: np.ndarray) -> np.ndarray:
        input_images = torch.tensor(images, device=self.device)
        if self.device == "cuda":
            input_images = input_images.half()
        latent_dist = self.vae.encode(input_images)["latent_dist"]
        image_latents = latent_dist.sample() * 0.18215
        return convert_tensor_to_array(image_latents)

Let's run the image encoder against the sample image we have.

In [None]:
image_encoder = SDImageEncoder("stabilityai/stable-diffusion-2-base", "cpu")
image_latents = image_encoder.encode_images(transform_images(example_image)[0]["image_512"][None])[0]

Let's plot the image latents.

In [None]:
nchannels = image_latents.shape[0]
fig, axes = plt.subplots(1, nchannels, figsize=(10, 10))

for idx, ax in enumerate(axes):
    ax.imshow(image_latents[idx], cmap="gray")
    ax.set_title(f"Channel {idx}")
    ax.axis("off")

fig.suptitle("Image Latents", fontsize=16, x=0.5, y=0.625)
plt.show()

Next, let's encode the text using the CLIP model.

In [None]:
class SDTextEncoder:
    def __init__(self, model_name: str, device: torch.device) -> None:
        self.text_encoder = CLIPTextModel.from_pretrained(
            model_name,
            subfolder="text_encoder",
            torch_dtype=torch.float16 if device == "cuda" else torch.float32,
        ).to(device)
        self.device = device

    def encode_text(self, caption_ids: np.ndarray) -> np.ndarray:
        """Encode text captions into a latent space."""
        caption_ids_tensor = torch.tensor(caption_ids, device=self.device)
        caption_latents_tensor = self.text_encoder(caption_ids_tensor)[0]
        return convert_tensor_to_array(caption_latents_tensor)

In [None]:
encoder = SDTextEncoder("stabilityai/stable-diffusion-2-base", "cpu")
example_text_embedding = encoder.encode_text([example_txt_token_ids])[0]
example_text_embedding.shape

Given Ray Data doesn't support operator fusion between two different stateful transformations, we define a single `SDLatentSpaceEncoder` transformation that is composed of the image and text encoders.

In [None]:
class SDLatentSpaceEncoder:
    def __init__(
        self,
        resolution: int = 512,
        device: Optional[str] = "cuda",
        model_name: str = "stabilityai/stable-diffusion-2-base",
    ) -> None:
        self.device = torch.device(device)
        self.resolution = resolution

        # Instantiate image and text encoders
        self.image_encoder = SDImageEncoder(model_name, self.device)
        self.text_encoder = SDTextEncoder(model_name, self.device)

    def __call__(self, batch: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
        with torch.no_grad():
            # Step 1: Encode images.
            input_images = batch[f"image_{self.resolution}"]
            image_latents = self.image_encoder.encode_images(input_images)
            batch[f"image_latents_{self.resolution}"] = image_latents

            del batch[f"image_{self.resolution}"]
            gc.collect()

            # Step 2: Encode captions.
            caption_ids = batch["caption_ids"]
            batch["caption_latents"] = self.text_encoder.encode_text(caption_ids)

            del batch["caption_ids"]
            gc.collect()

        return batch

<div class="alert alert-block alert-warning">

<b>Note</b> how we are deleting the original image and caption_ids from the batch to free up memory. This is important when working with large datasets.

</div>

We apply the encoder to the dataset to encode the images and text captions.

In [None]:
ds_encoded = ds_img_txt_transformed.map_batches(
    SDLatentSpaceEncoder,
    concurrency=2,  # Total number of workers
    num_gpus=1,  # number of GPUs per worker
    batch_size=24,  # Use the largest batch size that can fit on our GPUs - depends on resolution
)

ds_encoded

### 5. Write outputs to parquet

Finally, we can write the output.

We use the artifact store to write the output to a parquet file.

In [None]:
uuid_str = str(uuid.uuid4())
artifact_path = f"/mnt/cluster_storage/stable-diffusion/{uuid_str}"
artifact_path

This operation requires physically moving the data, so it will trigger scheduling and execution of all of the upstream tasks.

In [None]:
ds_encoded.write_parquet(artifact_path)