In [None]:
#install dependencies

!pip install "ray==2.6.1"
!pip install torch --no-cache-dir
!pip install "ray[serve]" requests diffusers transformers fastapi==0.96

In [None]:
import ray

# Initialize Ray Cluster
ray.init(
    address="ray://example-cluster-kuberay-head-svc:10001",
    runtime_env={
        "pip": [
            "IPython",
            "boto3==1.26",
            "botocore==1.29", 
            "datasets",
            "diffusers",
            "fastapi==0.96",
            "accelerate>=0.16.0",
            "transformers>=4.26.0",
            "numpy<1.24",  # remove when mlflow updates beyond 2.2
            "torch",
        ]
    }
)

In [None]:
from io import BytesIO
from fastapi import FastAPI
from fastapi.responses import Response
import torch

from ray import serve


app = FastAPI()


@serve.deployment(num_replicas=1, route_prefix="/")
@serve.ingress(app)
class APIIngress:
    def __init__(self, diffusion_model_handle) -> None:
        self.handle = diffusion_model_handle

    @app.get(
        "/imagine",
        responses={200: {"content": {"image/png": {}}}},
        response_class=Response,
    )
    async def generate(self, prompt: str, img_size: int = 512):
        assert len(prompt), "prompt parameter cannot be empty"

        image_ref = await self.handle.generate.remote(prompt, img_size=img_size)
        image = await image_ref
        file_stream = BytesIO()
        image.save(file_stream, "PNG")
        return Response(content=file_stream.getvalue(), media_type="image/png")


@serve.deployment(
    ray_actor_options={"num_gpus": 1},
    autoscaling_config={"min_replicas": 0, "max_replicas": 2},
)
class StableDiffusionV2:
    def __init__(self):
        from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline

        model_id = "stabilityai/stable-diffusion-2"

        scheduler = EulerDiscreteScheduler.from_pretrained(
            model_id, subfolder="scheduler"
        )
        self.pipe = StableDiffusionPipeline.from_pretrained(
            model_id, scheduler=scheduler, revision="fp16", torch_dtype=torch.float16
        )
        self.pipe = self.pipe.to("cuda")

    def generate(self, prompt: str, img_size: int = 512):
        assert len(prompt), "prompt parameter cannot be empty"

        image = self.pipe(prompt, height=img_size, width=img_size).images[0]
        return image

In [None]:
deployment = APIIngress.bind(StableDiffusionV2.bind())
serve.run(deployment, host="0.0.0.0")

In [None]:
import requests

prompt = "a cute cat is dancing on the grass."
input = "%20".join(prompt.split(" "))
resp = requests.get(f"http://example-cluster-kuberay-head-svc:8000/imagine?prompt={input}")
with open("output.png", 'wb') as f:
    f.write(resp.content)

In [None]:
from IPython.display import Image, display
display(Image(filename='output.png'))

In [None]:
from ray import serve

serve.shutdown()