In [None]:
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI Model Garden Keras Stable Diffusion

<table align="left">
  <td>
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_keras_stable_diffusion.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Colab logo"> Run in Colab
    </a>
  </td>

  <td>
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_keras_stable_diffusion.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo">
      View on GitHub
    </a>
  </td>
  <td>                                                                                               <td>
    <a href="https://console.cloud.google.com/vertex-ai/notebooks/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/vertex-ai-samples/main/notebooks/community/model_garden/model_garden_keras_stable_diffusion.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo">
Open in Vertex AI Workbench
    </a>
  </td>
</table>

**_NOTE_**: This notebook has been tested in the following environment:

* Python version = 3.9

## Overview

This notebook demonstrates how to use [Keras Stable Diffusion](https://keras.io/api/keras_cv/models/stable_diffusion) in Vertex AI Model Garden.

### Objective

* Run inferences
  * Run inferences locally
  * Serve models with dockers
* Finetune models
  * Download data
  * Start training jobs

### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage

Learn about [Vertex AI
pricing](https://cloud.google.com/vertex-ai/pricing) and [Cloud Storage
pricing](https://cloud.google.com/storage/pricing), and use the [Pricing
Calculator](https://cloud.google.com/products/calculator/)
to generate a cost estimate based on your projected usage.

## Before you begin

### Set up notebooks

- Colab notebook

  You can open this as colab notebook directly.

- Workbench notebook

  You can open this as workbench notebook with workbench instances. You can create [google managed](https://cloud.google.com/vertex-ai/docs/workbench/managed/create-instance) or [user managed](https://cloud.google.com/vertex-ai/docs/workbench/user-managed/create-new) workbench instances.

Then, run the following commands to set up notebooks.

In [None]:
if "google.colab" in str(get_ipython()):
    # Configs for colab notebooks.
    ! pip3 install --upgrade google-cloud-aiplatform

    # Automatically restart kernel after installs
    import IPython

    app = IPython.Application.instance()
    app.kernel.do_shutdown(True)

    from google.colab import auth as google_auth

    google_auth.authenticate_user()

! pip3 install keras-cv==0.4.1

### Set up your Google Cloud project

**The following steps are required, regardless of your notebook environment.**

1. [Select or create a Google Cloud project](https://console.cloud.google.com/cloud-resource-manager). When you first create an account, you get a $300 free credit towards your compute/storage costs.

1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

1. [Enable the Vertex AI API and Compute Engine API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com,compute_component).
1. If you are running this notebook locally, you will need to install the [Cloud SDK](https://cloud.google.com/sdk).

1. [Create a service account](https://cloud.google.com/iam/docs/service-accounts-create#iam-service-accounts-create-console) with `Vertex AI User` and `Storage Object Admin` roles for deploying fine tuned model to Vertex AI endpoint.

1. Enter your project ID in the cell below. Then run the cell to make sure the
Cloud SDK uses the right project for all the commands in this notebook.

**Note**: Jupyter runs lines prefixed with `!` as shell commands, and it interpolates Python variables prefixed with `$` into these commands.

In [None]:
import os

from google.cloud import aiplatform

# The project and bucket are for experiments below.
PROJECT_ID = ""  # @param {type:"string"}
# The form for BUCKET_URI is gs://<bucket-name>.
BUCKET_URI = ""  # @param {type:"string"}

REGION = "us-central1"
! gcloud config set project $PROJECT_ID

STAGING_BUCKET = os.path.join(BUCKET_URI, "temporal")
EXPERIMENT_BUCKET = os.path.join(BUCKET_URI, "keras")
DATA_BUCKET = os.path.join(EXPERIMENT_BUCKET, "data")
MODEL_BUCKET = os.path.join(EXPERIMENT_BUCKET, "model")

aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=STAGING_BUCKET)

# Training constants.
TRAINING_JOB_PREFIX = "train"
TRAIN_CONTAINER_URI = "us-docker.pkg.dev/vertex-ai-restricted/vertex-vision-model-garden-dockers/keras-train:latest"
TRAIN_MACHINE_TYPE = "a2-highgpu-1g"
TRAIN_ACCELERATOR_TYPE = "NVIDIA_TESLA_A100"
TRAIN_NUM_GPU = 1
RESOLUTION = 512

# Prediction constants.
PREDICTION_CONTAINER_URI = "us-docker.pkg.dev/vertex-ai-restricted/vertex-vision-model-garden-dockers/keras-serve:latest"
PREDICTION_ACCELERATOR_TYPE = "NVIDIA_TESLA_V100"
PREDICTION_MACHINE_TYPE = "n1-standard-8"
DEPLOY_JOB_PREFIX = "deploy"

# The service account for deploying fine tuned model.
# The service account looks like:
# '<account_name>@<project>.iam.gserviceaccount.com'
# Please go to https://cloud.google.com/iam/docs/service-accounts-create#iam-service-accounts-create-console
# and create service account with `Vertex AI User` and `Storage Object Admin` roles.
SERVICE_ACCOUNT = ""  # @param {type:"string"}

### Define common libraries

In [None]:
import base64
from datetime import datetime
from io import BytesIO

import matplotlib.pyplot as plt
from PIL import Image


def get_job_name_with_datetime(prefix: str):
    return prefix + datetime.now().strftime("_%Y%m%d_%H%M%S")


def download_data_to_gcs(tar_filepath, gcs_bucket):
    filename_with_ext = os.path.basename(tar_filepath)
    filename_without_ext = filename_with_ext.replace(".tar.gz", "")
    print("Download files from: ", tar_filepath)
    ! wget $tar_filepath -O $filename_with_ext
    ! mkdir -p $filename_without_ext
    ! tar -xvf $filename_with_ext -C .

    ! gsutil -m cp -r $filename_without_ext $gcs_bucket/
    gcs_path = os.path.join(gcs_bucket, filename_without_ext)
    print("Upload files to: ", gcs_path)
    return gcs_path


def deploy_model(model_path):

    deploy_model_name = get_job_name_with_datetime(DEPLOY_JOB_PREFIX)
    print("The deployed job name is: ", deploy_model_name)
    serving_env = {
        "MODEL_PATH": model_path,
        "IMAGE_WIDTH": RESOLUTION,
        "IMAGE_HEIGHT": RESOLUTION,
    }

    endpoint = aiplatform.Endpoint.create(display_name=f"{deploy_model_name}-endpoint")
    model = aiplatform.Model.upload(
        display_name=deploy_model_name,
        serving_container_image_uri=PREDICTION_CONTAINER_URI,
        serving_container_ports=[8501],
        serving_container_predict_route="/predict",
        serving_container_health_route="/ping",
        serving_container_environment_variables=serving_env,
    )
    model.deploy(
        endpoint=endpoint,
        machine_type=PREDICTION_MACHINE_TYPE,
        accelerator_type=PREDICTION_ACCELERATOR_TYPE,
        accelerator_count=1,
        min_replica_count=1,
        max_replica_count=1,
        deploy_request_timeout=1800,
        service_account=SERVICE_ACCOUNT,
    )
    return model, endpoint


def base64_to_image(image_str):
    image = Image.open(BytesIO(base64.b64decode(image_str)))
    return image


def display_image(image):
    _ = plt.figure(figsize=(20, 15))
    plt.grid(False)
    plt.imshow(image)


def display_image_grid(imgs, rows=2, cols=2):
    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid

## Run inferences

This section shows how to run inferences with Keras Stable Diffusion models.

1. Run inferences locally
2. Run inferences with serving dockers

You can run inferences with pre-trained models from Keras team, or your own finetuned models.


In [None]:
# Sets the model_path to empty to load the pre-trained model from Keras team.
# Sets the model_path to a gcs uri to load the finetuned models.
model_path = ""  # @param {type:"string"}

### Run inferences locally
Please change the runtime type as GPUs to finish local inferences in seconds.

Load models first.

In [None]:
from keras_cv.models import StableDiffusion

model = StableDiffusion(img_height=RESOLUTION, img_width=RESOLUTION, jit_compile=True)
if model_path:
    model.diffusion_model.load_weights(model_path)

Then run inferences.

In [None]:
batch_size = 1
img = model.text_to_image(
    prompt="a flamingo in Picasso style",
    batch_size=batch_size,  # How many images to generate at once
    num_steps=25,  # Number of iterations (controls image quality)
    seed=123,  # Set this to always get the same image from the same prompt
)
for i in range(batch_size):
    display_image(img[i])

### Serve models with dockers
When serve models with dockers, we will deploy models in Google Cloud. You can start the deployment jobs with CPUs. Please change the runtime type as CPUs for the following experiments to save costs. The model deployment will take ~10 minutes to finish.

In [None]:
model, endpoint = deploy_model(model_path=model_path)

endpoint_id = endpoint.name
print("endpoint id is: ", endpoint_id)

Once deployed, you can send a batch of text prompts to the endpoint to generated images.

Note, the inference time for the first request for a fresh deployment will need more time to process and take ~45 seconds on one V100 GPU. The inferences for further request is ~10 seconds on one V100 GPU per image.

In [None]:
# # Loads an existing endpoint as below.
# endpoint_id = <An Existing Endpoint ID>
# aip_endpoint_name = (
#     f"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint_id}"
# )
# endpoint = aiplatform.Endpoint(aip_endpoint_name)

instances = [
    {"prompt": "a squirrel in Picasso style"},
    {"prompt": "a dog in Picasso style"},
    {"prompt": "a cat in Picasso style"},
    {"prompt": "a deer in Picasso style"},
]

parameters = {
    "batch_size": 1,  # How many images to generate at once
    "num_steps": 25,  # Number of iterations (controls image quality)
    "seed": 123,  # Set this to always get the same image from the same prompt
}
response = endpoint.predict(instances=instances, parameters=parameters)
# prediction['predicted_image'] will contains the prediction images in a batch.
# The batch size in this example is 1, and the visualization only parses the
# first predicted image.
images = [
    base64_to_image(prediction["predicted_image"][0])
    for prediction in response.predictions
]
display_image_grid(images, rows=2, cols=2)

### Clean up

In [None]:
# Undeploys models and deletes endpoints.
endpoint.delete(force=True)
# Deletes models.
model.delete()

## Finetune models
This section shows how to finetune Keras Stable diffusion models with trainig dockers.

If you would like to use finetuned models, please go to the section `Run inferences`.

### Download data
By default, we use the dataset
[Pok√©mon BLIP captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).
However, we'll use a slightly different version which was derived from the original
dataset to fit better with `tf.data`. Refer to
[the documentation](https://huggingface.co/datasets/sayakpaul/pokemon-blip-original-version)
for more details. We download the data to gcs storages for the experiments with training dockers.

In [None]:
# Skips this step if you have already downloaded the dataset.
download_data_to_gcs(
    "https://huggingface.co/datasets/sayakpaul/pokemon-blip-original-version/resolve/main/pokemon_dataset.tar.gz",
    DATA_BUCKET,
)

### Start training jobs
We finetune models with 10 steps and it takes ~15 minutes to finish using 1 A100 GPU with default settings.

In [None]:
DATA_CSV = os.path.join(DATA_BUCKET, "pokemon_dataset/data.csv")

train_job_name = get_job_name_with_datetime(TRAINING_JOB_PREFIX)
model_dir = os.path.join(MODEL_BUCKET, train_job_name)
worker_pool_specs = [
    {
        "machine_spec": {
            "machine_type": TRAIN_MACHINE_TYPE,
            "accelerator_type": TRAIN_ACCELERATOR_TYPE,
            "accelerator_count": TRAIN_NUM_GPU,
        },
        "replica_count": 1,
        "disk_spec": {
            "boot_disk_type": "pd-ssd",
            "boot_disk_size_gb": 500,
        },
        "container_spec": {
            "image_uri": TRAIN_CONTAINER_URI,
            "command": [],
            "env": [
                {
                    "name": "RESOLUTION",
                    "value": RESOLUTION,
                },
            ],
            "args": [
                "--epochs=10",
                f"--input_csv_path={DATA_CSV}",
                f"--output_model_dir={model_dir}",
            ],
        },
    }
]

train_job = aiplatform.CustomJob(
    display_name=train_job_name,
    project=PROJECT_ID,
    worker_pool_specs=worker_pool_specs,
    staging_bucket=STAGING_BUCKET,
)

train_job.run()

model_path = os.path.join(model_dir, "saved_model.h5")
print("The trained model is saved as: ", model_path)

After the traininig finishes, you can use `model_path` and then go to the `Run inferences` section above to run predictions.

### Clean up

In [None]:
train_job.delete()

## References

- [Fine-tuning Stable Diffusion](https://keras.io/examples/generative/finetune_stable_diffusion/)
- [StableDiffusion image-generation model](https://keras.io/api/keras_cv/models/stable_diffusion/)
- [High-performance image generation using Stable Diffusion in KerasCV](https://keras.io/guides/keras_cv/generate_images_with_stable_diffusion/)