In [None]:
from diffusers import AutoPipelineForText2Image
from transformers import pipeline
from transformers.utils import logging
import numpy as np
import random
import ray
import torch
logging.set_verbosity_info()

First, we need to get all of our data in some common location where the whole cluster can see it. This might be a blob store, NFS, database, etc.

Anyscale offers `/mnt/cluster_storage` as a NFS path.

In [None]:
! cp *.csv /mnt/cluster_storage/

Ray Data's `read_xxxx` methods (see I/O in Ray Docs for all the available formats and data sources) get us scalable, parallel reads.

In [None]:
animals = ray.data.read_csv('/mnt/cluster_storage/animals.csv')

animals.take_batch(3)

Batches of records are represented as Python dicts where the keys correspond to the dataset column names and the values are a vectorized type -- usually a NumPy Array -- of values containing one value for each record in the batch.

Ray Data contains methods for basic data transformation and allow modification of dataset schema.

In [None]:
animals.rename_columns({'animal' : 'prompt'}).take_batch(3)

Stateful tranformation of datasets -- in this example, AI inference where the state is the image gen model -- is done with the following pattern.

1. Define a Python class (which Ray will later instantiate across the cluster as one more actor instances to do the processing)
1. Use Dataset's `map_batches` API to tell Ray to send batches of data to the `__call__` method in the actors instances
    1. `map_batches` allows us to specify resource requirements, actor pool size, batch size, and more

In [None]:
class ImageGen():
    def __init__(self):
        self.pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16").to("cuda")
        
    def gen_image(self, prompts):
        return self.pipe(prompt=list(prompts), num_inference_steps=1, guidance_scale=0.0).images
    
    def __call__(self, batch):
        batch['image'] = self.gen_image(batch['prompt'])
        return batch

In [None]:
animals_images = animals.repartition(2).rename_columns({'animal' : 'prompt'}).map_batches(ImageGen, num_gpus=1, concurrency=2, batch_size=8)

Ray Datasets employ *lazy evaluation* for improved performance, so we can use APIs like `take_batch`, `take`, or `show` to trigger execution for development and testing purposes.

In [None]:
examples = animals_images.take_batch(3)

examples

In [None]:
examples['image'][0]

## Lab: Generate and write all output to storage as parquet data

Instructions/hints:

1. Start with the Ray Dataset you'd like to write
1. Check https://docs.ray.io/en/latest/data/api/input_output.html to find a suitable write API
1. Remember to write to a *shared* file location, such as `/mnt/cluster_storage`

<div class="alert alert-block alert-info">

<details>

<summary> Click to see solution </summary>

```python
animals_images.write_parquet('/mnt/cluster_storage/animals_images.parquet/')
```
</details>
</div>


## Load and join details for each prompt

Ray Data supports a number of high-performance JOIN APIs: https://docs.ray.io/en/latest/data/joining-data.html

We can use a JOIN to connect our animal records with a detailed prompt refinement unique to that record

In [None]:
outfits = ray.data.read_csv('/mnt/cluster_storage/outfits.csv')

outfits.take_batch(3)

In [None]:
animals_outfits = animals.join(outfits, 'inner', 1).repartition(8)

animals_outfits.take_batch(3)

We can add custom logic to combine and expand the image gen prompt using another call to `map_batches`

In this pattern, since the transformation is stateless and lightweight, we can define it as a Python function (which takes and returns a batch of records) and then use a simplified call to `map_batches` where Ray will autoscale the number of scheduled tasks in order to keep the best throughput for our pipeline.

In [None]:
def expand_prompt(batch):
    batch['prompt'] = batch['animal'] + ' wearing a ' + batch['outfit']
    return batch

In [None]:
animals_outfits.map_batches(expand_prompt).take_batch(3)

We can combine the prompt expansion operation with the image gen operation to produce a new set of results

In [None]:
dressed_animals = animals_outfits.map_batches(expand_prompt).map_batches(ImageGen, batch_size=16, concurrency=2, num_gpus=1)

In [None]:
examples = dressed_animals.take_batch(3)
examples

In [None]:
examples['prompt'][0]

In [None]:
examples['image'][0]

## Lab: generate images for the input prompts and write the images to a folder

> Hint 1: Use `dataset.write_images(...)`
>
> Hint 2: To use `dataset.write_images(...)`, the images will need to be NumPy arrays (instead of PIL Image objects). You can use `np.array(my_pil_image)` to do that conversion. Use that API along with `map_batches` to convert all of your images prior to calling `write_images`

<div class="alert alert-block alert-info">

<details>

<summary> Click to see solution </summary>

```python
def image_to_array(batch):
    batch['image'] = [np.array(i) for i in batch['image']]
    return batch
    
animals_images.map_batches(image_to_array).write_images('/mnt/cluster_storage/animals_images/', 'image')
```

</details>
</div>