## Set up the dependencies
When running in a distributed Ray Cluster, all nodes need to have access to dependencies. For this, we'll use `pip install --user` to install the necessary requirements. On an Anyscale Workspace, this is configured to install packages to a shared filesystem that will be available to all nodes in the cluster.

`pip install --user -r requirements.txt`  

After installing all the requirements, we'll start with some imports.

In [None]:
!pip install --user -r requirements.txt

In [None]:
import ray.data
import pandas as pd
import torch
import requests
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionImg2ImgPipeline
import time

In [2]:
model_id = 'nitrosocke/Ghibli-Diffusion'
#model can be saved under /mnt/shared_storage for re-use
#model_id = '/mnt/shared_storage/kyle/huggingface/diffusers/models--nitrosocke--Ghibli-Diffusion/snapshots/441c2fde3e82de3720b1a35111b95b6d8afee9d6/'
device = "cuda"
class PredictCallable:
    def __init__(self, model_id: str, prompt: str, revision: str = None):
        self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16).to(device)
        self.prompt = prompt
    
    def __call__(self, batch: pd.DataFrame) -> pd.DataFrame:
        batch_images = list(batch['image'])
        batch_paths = list(batch['path'])
        out_images = []
        generator = torch.Generator(device=device).manual_seed(int(time.time()))
        for img in batch_images:
            inp_image = Image.fromarray(img)
            out_image = self.pipe(prompt=self.prompt, image=inp_image, \
                              strength=1, guidance_scale=7.5, generator=generator).images[0]
            out_images.append(out_image)
        return [dict(zip(batch_paths, out_images))]

In [None]:
import ray
from ray.data.datasource.partitioning import Partitioning
s3_uri = "s3://anonymous@air-example-data-2/imagenette2/val/"

# The S3 directory structure is {s3_uri}/{class_id}/{*.JPEG}
partitioning = Partitioning("dir", field_names=["class"], base_dir=s3_uri)

ds = ray.data.read_images(
    s3_uri, partitioning=partitioning, mode="RGB", size=(768, 768), include_paths=True
)

In [None]:
inp_images = ds.limit(6)
preds = (
    inp_images
    .map_batches(
        PredictCallable,
        batch_size=3,
        fn_constructor_kwargs=dict(model_id=model_id, prompt='ghibli style fish'),
        compute="actors",
        num_gpus=0.5
    )
)
res = preds.take_all()

In [None]:
def display_sdimg2img(res):
    print(len(res))
    total = 0
    if len(res) < 20:
        for re in res:
            #display(re)
            total += len(re)
            for k in re.keys():
                print(k)
                display(re[k])
display_sdimg2img(res)