## Intro to Stable Diffusion and Ray

Let's start with a gentle introduction to using Stable Diffusion and Ray

<div class="alert alert-block alert-info">
<b> Here is the roadmap for this notebook:</b>
<ul>
    <li><b>Part 1:</b> A simple data pipeline</li>
    <li><b>Part 2:</b> Introduction to Ray Data</li>
    <li><b>Part 3:</b> Batch Inference with Stable Diffusion</li>
    <li><b>Part 4:</b> Stable Diffusion under the hood</li>
</ul>
</div>

## Imports

In [None]:
import os
import uuid
import json
from typing import Any

import matplotlib.pyplot as plt
import numpy as np
import ray
import torch
from art import text2art
from diffusers import DiffusionPipeline

## A simple data pipeline

Let's begin with a very simple data pipeline which converts text into ASCII art. 

We start with a simple dataset of items:

In [None]:
items = [
    "Astronaut", "Cat"
]

We can then apply a transformation to each item in the dataset to convert the text into ASCII art:

In [None]:
def artify(item: str) -> str:
    return text2art(item)

We will sequentially apply the `artify` function to each item in the dataset:

In [None]:
data = []
for item in items:
    data.append({"prompt": item, "art": artify(item)})

We can now inspect the results:

In [None]:
data[0]["prompt"]

In [None]:
print(data[0]["art"])

Finally, we can write the data to a JSON file:

In [None]:
with open("ascii_art.json", "w") as f:
    json.dump(data, f)

## 2. Introduction to Ray Data

<!-- One liner about Ray Data -->
Ray Data is a scalable data processing library for ML workloads, particularly suited for the following workloads:


<!-- Diagram showing streaming and heterogenous cluster -->
Ray Data is particularly useful for streaming data on a heterogenous cluster:

<img src="https://docs.ray.io/en/latest/_images/stream-example.png" width="600">

Your production pipeline for generating images from text could require:
1. Loading a large number of text prompts
2. Generating images using large scale diffusion models
3. Inferencing against guardrail models to remove low-quality and NSFW images

You will want to make the most efficient use of your cluster to process this data. Ray Data can help you do this.

### Ray Data's API

Here are the steps to make use of Ray Data:
1. Create a Ray Dataset usually by pointing to a data source.
2. Apply transformations to the Ray Dataset.
3. Write out the results to a data source.



#### Loading Data

Ray Data has a number of [IO connectors](https://docs.ray.io/en/latest/data/api/input_output.html) to most commonly used formats.

For purposes of this introduction, we will use the `from_items` function to create a dataset from a list of items.

In [None]:
ds_items = ray.data.from_items(items)
ds_items

### Transforming Data

Datasets can be transformed by applying a row-wise `map` operation. We do this by providing a user-defined function that takes a row as input and returns a row as output.

In [None]:
def artify_row(row: dict[str, Any]) -> dict[str, Any]:
    row["art"] = text2art(row["item"])
    return row

ds_items_artified = ds_items.map(artify_row)

### Lazy execution

By default, `map` is lazy, meaning that it will not actually execute the function until you consume it. This allows for optimizations like pipelining and fusing of operations.

To inspect a few rows of the dataset, you can use the `take` method:

In [None]:
sample = ds_items_artified.take(2)

Let's inspect the sample:

In [None]:
print(sample[0]["item"])

In [None]:
print(sample[0]["art"])

### Writing Data

We can then write out the data to disk using the avialable [IO connector methods](https://docs.ray.io/en/latest/data/api/input_output.html).

Here we will write the data to a JSON file to a shared storage location.

In [None]:
ds_items_artified.write_json("/mnt/cluster_storage/ascii_art")

We can now inspect the written files:

In [None]:
!ls /mnt/cluster_storage/ascii_art

### Recap of our Ray Data pipeline

Here is our Ray data pipeline condensed into the following chained operations:

```python
(
    ray.data.from_items(items)
    .map(artify_row)
    .write_json("/mnt/cluster_storage/ascii_art")
)
```

## Batch Inference with Stable Diffusion

Now that we have a simple data pipeline, let's use Stable Diffusion to generate actual images from text.

This will follow a very similar pattern. Let's say we are starting out with the following prompts:


In [None]:
prompts = [
    "An astronaut on a horse",
    "A cat with a jetpack",
] * 12

We create a Ray Dataset from the prompts

In [None]:
ds_prompts = ray.data.from_items(prompts)
ds_prompts

We now apply want to apply a DiffusionPipeline to the dataset. 

We first define a function that creates and applies the pipeline to a single row.

In [None]:
def apply_stable_diffusion(row: dict[str, Any]) -> dict[str, Any]:
    # Create the stable diffusion pipeline
    pipe = DiffusionPipeline.from_pretrained(
        pretrained_model_name_or_path="stabilityai/stable-diffusion-2",
        torch_dtype=torch.float16,
        use_safetensors=True,
        variant="fp16",
    ).to("cuda")
    prompt = row["item"]
    # Apply the pipeline to the prompt
    output = pipe(prompt, height=512, width=512)
    # Extract the image from the output and construct the row
    return {"item": prompt, "image": output.images[0]}

We can now apply the function to each row in the dataset using the `map` method.

In [None]:
ds_images_generated_mapping_by_row = ds_prompts.map(
    apply_stable_diffusion,
    num_gpus=1, # specify the number of GPUs per task
) 

Instead of parallelizing the inference per row, we can parallelize the inference per batch.

Mapping over batches instead of rows is useful when we can benefit from vectorized operations on the batch level. 

In [None]:
def apply_stable_diffusion_batch(batch: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
    pipe = DiffusionPipeline.from_pretrained(
        pretrained_model_name_or_path="stabilityai/stable-diffusion-2",
        torch_dtype=torch.float16,
        use_safetensors=True,
        variant="fp16",
    ).to("cuda")
    # Extract the prompts from the batch
    prompts = batch["item"].tolist()
    # Apply the pipeline to the prompts
    outputs = pipe(prompts, height=512, width=512)
    # Extract the images from the outputs and construct the batch
    return {"item": prompts, "image": outputs.images}

We now apply the function to each batch in the dataset using the `map_batches` method.

In [None]:
ds_images_generated_mapping_by_batch = ds_prompts.map_batches(
    apply_stable_diffusion_batch,
    batch_size=24, # specify the batch size per task to maximize GPU utilization
    num_gpus=1, 
)

The current implementation requires us to load the pipeline for each batch we process.

We can avoid reloading the pipeline for each batch by creating a stateful transformation, implemented as a callable class where:
- `__init__`: initializes worker processes that will load the pipeline once and reuse it for transforming each batch.
- `__call__`: applies the pipeline to the batch and returns the transformed batch.

In [None]:
class StableDiffusion:
    def __init__(self, model_id: str = "stabilityai/stable-diffusion-2") -> None:
        self.pipe = DiffusionPipeline.from_pretrained(
            model_id, torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
        ).to("cuda")

    def __call__(
        self, batch: dict[str, np.ndarray], img_size: int = 512
    ) -> dict[str, np.ndarray]:
        prompts = batch["item"].tolist()
        batch["image"] = self.pipe(prompts, height=img_size, width=img_size).images
        return batch

We can now apply the class to each batch in the dataset using the same `map_batches` method.

In [None]:
ds_images_generated_by_stateful_transform = ds_prompts.map_batches(
    StableDiffusion,
    batch_size=24,
    num_gpus=1,  
    concurrency=1,  # number of workers to launch
)

<div class="alert alert-block alert-info">

### Activity: Visualize the generated images

Lets fetch a batch of the generated images to the driver and visualize them.

Use the `plot_images` function to visualize the images.

```python
def plot_images(batch: dict[str, np.ndarray]) -> None:
    for item, image in zip(batch["item"], batch["image"]):
        plt.imshow(image)
        plt.title(item)
        plt.axis("off")
        plt.show()

# Hint: Implement the code below to fetch a batch from 
# ds_images_generated_by_stateful_transform
batch = ...
plot_images(batch)
```

</div>


In [None]:
# Write your solution here


<div class="alert alert-block alert-info">
<details>

<summary>Click to expand/collapse</summary>

```python
def plot_images(batch: dict[str, np.ndarray]) -> None:
    for item, image in zip(batch["item"], batch["image"]):
        plt.imshow(image)
        plt.title(item)
        plt.axis("off")
        plt.show()

size = 12
batch = ds_images_generated.take_batch(batch_size=size)
plot_images(batch)
```

</div>



### Reading/Writing to a Data Lake

In a production setting, you will be building a Ray Dataset lazily by reading from a data source like a Data Lake (S3, GCS, HDFS, etc). 

To do so, let's make use of the artifact path that Anyscale provides.

In [None]:
uuid_str = str(uuid.uuid4())
artifact_path = f"/mnt/cluster_storage/stable-diffusion/{uuid_str}"
artifact_path

We start out by writing the prompts to a JSON directory:

In [None]:
ds_prompts.write_json(artifact_path + "/prompts")

We can inspect the written files:

In [None]:
!ls {artifact_path}/prompts/ --human-readable 

Now here is how the pipeline would look like if we want to read the prompts from S3, generate images and store the images back to S3:

In [None]:
(
    ray.data.read_json(artifact_path + "/prompts")
    .map_batches(StableDiffusion, batch_size=24, num_gpus=1, concurrency=1)
    .write_parquet(artifact_path + "/images")
)

<div class="alert alert-block alert-warning">

<b>Note</b> how there is no need to explicitly materialize the dataset, instead the data will get streamed through the pipeline and written to the specified location. 

</div>

In [None]:
!ls {artifact_path}/images/ --human-readable

### Stable Diffusion pipeline components

Let's take a quick look at the components of the Stable Diffusion pipeline.

First we load the pipeline on our local workspace node:

In [None]:
model_id = "stabilityai/stable-diffusion-2"
pipeline = DiffusionPipeline.from_pretrained(
    model_id, torch_dtype=torch.float16, use_safetensors=True, variant="fp16"
)

Inspecting the text tokenizer and encoder shows how the text will be preprocessed:

In [None]:
type(pipeline.tokenizer), type(pipeline.text_encoder)

Inspecting the feature extractor and VAE shows how the images will be preprocessed:

In [None]:
type(pipeline.feature_extractor), type(pipeline.vae)

Here is our main model that predicts the noise level

In [None]:
type(pipeline.unet)

While the U-net will be used to predict which part of the image is noise, a scheduler needs to be used to sample the noise level.

By default, diffusers will use the following scheduler, but other schedulers can be used as well.

In [None]:
type(pipeline.scheduler)

Here is the inference data flow of the Stable Diffusion model simplified for generating an image of "A person half Yoda and half Gandalf":

<figure>
  <img src="https://www.paepper.com/blog/posts/everything-you-need-to-know-about-stable-diffusion/stable-diffusion-inference.png" alt="Inference data flow of Stable Diffusion" width="800"/>
  <figcaption>Image taken from <a href="https://www.paepper.com/blog/posts/everything-you-need-to-know-about-stable-diffusion/">Everything you need to know about stable diffusion</a>
</figcaption>


## Clean up

In [None]:
!rm -rf /mnt/cluster_storage/ascii_art
!rm -rf {artifact_path}
!rm ascii_art.json