In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI Model Garden - Stable Diffusion V1.5 (Dreambooth Finetuning)

<table align="left"><tbody><tr>
  <td>
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fvertex-ai-samples%2Fmain%2Fnotebooks%2Fcommunity%2Fmodel_garden%2Fmodel_garden_pytorch_stable_diffusion_finetuning_dreambooth.ipynb">
      <img alt="Google Cloud Colab Enterprise logo" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" width="32px"><br> Run in Colab Enterprise
    </a>
  </td>
  <td>
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_pytorch_stable_diffusion_finetuning_dreambooth.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo"><br>
      View on GitHub
    </a>
  </td>
</tr></tbody></table>

## Overview

This notebook demonstrates finetuning [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5) model with [Dreambooth](https://huggingface.co/docs/diffusers/training/dreambooth), and deploying it on Vertex AI for online prediction.

### Objective

- Full parameter finetune stable-diffusion-v1.5 model with [Dreambooth](https://huggingface.co/docs/diffusers/training/dreambooth).
- Deploy models to a [Vertex AI Endpoint resource](https://cloud.google.com/vertex-ai/docs/predictions/using-private-endpoints).
- Run predictions for text-to-image and text-guided-image-to-image.

### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing) and [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.

## Run the notebook

In [None]:
# @title Setup Google Cloud project

# @markdown 1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

# @markdown 2. [Optional] [Create a Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets) for storing experiment outputs. Set the BUCKET_URI for the experiment environment. The specified Cloud Storage bucket (`BUCKET_URI`) should be located in the same region as where the notebook was launched. Note that a multi-region bucket (eg. "us") is not considered a match for a single region covered by the multi-region range (eg. "us-central1"). If not set, a unique GCS bucket will be created instead.

import base64
import glob
import os
import sys
from datetime import datetime
from io import BytesIO

import requests
from google.cloud import aiplatform, storage
from PIL import Image

# Get the default cloud project id.
PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"]

# Get the default region for launching jobs.
REGION = os.environ["GOOGLE_CLOUD_REGION"]

# Enable the Vertex AI API and Compute Engine API, if not already.
! gcloud services enable aiplatform.googleapis.com compute.googleapis.com

# Cloud Storage bucket for storing the experiment artifacts.
# A unique GCS bucket will be created for the purpose of this notebook. If you
# prefer using your own GCS bucket, please change the value yourself below.
now = datetime.now().strftime("%Y%m%d%H%M%S")
BUCKET_URI = "gs://"  # @param {type: "string"}

# Create a unique GCS bucket for this notebook, if not specified by the user.
if BUCKET_URI is None or BUCKET_URI.strip() == "" or BUCKET_URI == "gs://":
    BUCKET_URI = f"gs://{PROJECT_ID}-tmp-{now}"
    BUCKET_NAME = BUCKET_URI
    ! gsutil mb -l {REGION} {BUCKET_URI}
else:
    assert BUCKET_URI.startswith("gs://"), "BUCKET_URI must start with `gs://`."
    BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])
    shell_output = ! gsutil ls -Lb {BUCKET_NAME} | grep "Location constraint:" | sed "s/Location constraint://"
    bucket_region = shell_output[0].strip().lower()
    if bucket_region != REGION:
        raise ValueError(
            "Bucket region %s is different from notebook region %s"
            % (bucket_region, REGION)
        )

print(f"Using this GCS Bucket: {BUCKET_URI}")

# Set up the default SERVICE_ACCOUNT.
SERVICE_ACCOUNT = None
shell_output = ! gcloud projects describe $PROJECT_ID
project_number = shell_output[-1].split(":")[1].strip().replace("'", "")
SERVICE_ACCOUNT = f"{project_number}-compute@developer.gserviceaccount.com"
SERVICE_ACCOUNT_CC = (
    f"service-{project_number}@gcp-sa-aiplatform-cc.iam.gserviceaccount.com"
)

print("Using this default Service Account:", SERVICE_ACCOUNT)

# Provision permissions to the SERVICE_ACCOUNT with the GCS bucket
! gsutil iam ch serviceAccount:{SERVICE_ACCOUNT}:roles/storage.admin $BUCKET_NAME
! gsutil iam ch serviceAccount:{SERVICE_ACCOUNT_CC}:roles/storage.admin $BUCKET_NAME

aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=BUCKET_URI)

if "google.colab" in sys.modules:
    from google.colab import auth

    auth.authenticate_user(project_id=PROJECT_ID)

# The pre-built training docker images. They contain training scripts and models.
TRAIN_DOCKER_URI = "us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-diffusers-train:latest"

# The pre-built serving docker images. They contains serving scripts and models.
SERVE_DOCKER_URI = "us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-diffusers-serve-opt:20240403_0836_RC00"


# Define common functions.
def create_job_name(prefix):
    user = os.environ.get("USER")
    now = datetime.now().strftime("%Y%m%d_%H%M%S")
    job_name = f"{prefix}-{user}-{now}"
    return job_name


def download_image(url):
    response = requests.get(url)
    return Image.open(BytesIO(response.content))


def image_to_base64(image, format="JPEG"):
    buffer = BytesIO()
    image.save(buffer, format=format)
    image_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
    return image_str


def base64_to_image(image_str):
    image = Image.open(BytesIO(base64.b64decode(image_str)))
    return image


def image_grid(imgs, rows=2, cols=2):
    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid


def deploy_model(model_id, task):
    model_name = model_id
    endpoint = aiplatform.Endpoint.create(display_name=f"{model_name}-{task}-endpoint")
    serving_env = {
        "MODEL_ID": model_id,
        "TASK": task,
    }
    model = aiplatform.Model.upload(
        display_name=model_name,
        serving_container_image_uri=SERVE_DOCKER_URI,
        serving_container_ports=[7080],
        serving_container_predict_route="/predictions/diffusers_serving",
        serving_container_health_route="/ping",
        serving_container_environment_variables=serving_env,
    )
    model.deploy(
        endpoint=endpoint,
        machine_type="g2-standard-8",
        accelerator_type="NVIDIA_L4",
        accelerator_count=1,
        deploy_request_timeout=1800,
        service_account=SERVICE_ACCOUNT,
    )
    return model, endpoint


def get_bucket_and_blob_name(filepath):
    # The gcs path is of the form gs://<bucket-name>/<blob-name>
    gs_suffix = filepath.split("gs://", 1)[1]
    return tuple(gs_suffix.split("/", 1))


def download_gcs_dir_to_local(gcs_dir_path, local_dir_path):
    """Downloads files in a GCS directory to a local directory."""
    assert gcs_dir_path.startswith("gs://"), "gcs_dir_path must start with `gs://`."
    bucket_name = gcs_dir_path.split("/")[2]
    prefix = gcs_dir_path[len("gs://" + bucket_name) :].strip("/") + "/"
    client = storage.Client()
    blobs = client.list_blobs(bucket_name, prefix=prefix)
    for blob in blobs:
        if blob.name[-1] == "/":
            continue
        file_path = blob.name[len(prefix) :].strip("/")
        local_file_path = os.path.join(local_dir_path, file_path)
        os.makedirs(os.path.dirname(local_file_path), exist_ok=True)

        print(f"Downloading {file_path} to {local_file_path}")
        blob.download_to_filename(local_file_path)


def upload_local_dir_to_gcs(local_dir_path, gcs_dir_path):
    """Uploads files in a local directory to a GCS directory."""
    client = storage.Client()
    bucket_name = gcs_dir_path.split("/")[2]
    bucket = client.get_bucket(bucket_name)
    for local_file in glob.glob(local_dir_path + "/**"):
        if not os.path.isfile(local_file):
            continue
        filename = local_file[1 + len(local_dir_path) :]
        gcs_file_path = os.path.join(gcs_dir_path, filename)
        _, blob_name = get_bucket_and_blob_name(gcs_file_path)
        blob = bucket.blob(blob_name)
        blob.upload_from_filename(local_file)
        print("Copied {} to {}.".format(local_file, gcs_file_path))

In [None]:
# @title Finetune with Dreambooth

# @markdown This section uses [dreambooth](https://dreambooth.github.io/) to finetune the [stable-diffusion-v1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) model with [5 dog images](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) to personalize the text-to-image model.

# @markdown It finetunes both text encoder and unet of the stable diffusion model up to 800 steps. The whole finetuning job takes 30 minutes to finish using 1 A100 GPU.

# @markdown The full model will be saved after the finetuning job finishes and it can be loaded by the [StableDiffusionPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img) to run inference.

# @markdown You may click "Show code" and modify the code to change GPU type, count and other training parameters.

# @markdown ----

# @markdown Select default training images, or use your own images.

IMAGE_SOURCE = "Default training images"  # @param ["Default training images", "Use your own"] {isTemplate:true}

# @markdown [Optional] If you selected "Use your own" in the above section, provide the GCS path to the images. Make sure you have access to the GCS bucket.

IMAGE_GCS_DIR = "gs://"  # @param {type: "string"}

# @markdown ----

# @markdown In this example, we default `MAX_TRAIN_STEPS` to 400 for faster training. For better model quality, increase this number to 800.

MAX_TRAIN_STEPS = 400  # @param {type: "number"}

if IMAGE_SOURCE == "Default training images":
    # Download example training images.
    !gdown --folder https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ
    local_dir = "dog"
else:
    assert (
        IMAGE_GCS_DIR != "" and IMAGE_GCS_DIR != "gs://"
    ), "IMAGE_GCS_DIR must be set."
    assert IMAGE_GCS_DIR.startswith("gs://"), "IMAGE_GCS_DIR must start with `gs://`."
    local_dir = "./images"
    download_gcs_dir_to_local(IMAGE_GCS_DIR, local_dir)
    print("Downloaded images from: {IMAGE_GCS_DIR}")

# Upload data to Cloud Storage bucket.
upload_local_dir_to_gcs(local_dir, os.path.join(BUCKET_URI, "dreambooth/images"))
upload_local_dir_to_gcs(local_dir, os.path.join(BUCKET_URI, "dreambooth/images_class"))

# The pre-trained model to be loaded.
model_id = "runwayml/stable-diffusion-v1-5"

# Input and output path.
instance_dir = os.path.join(BUCKET_URI, "dreambooth/images").replace("gs://", "/gcs/")
class_dir = os.path.join(BUCKET_URI, "dreambooth/images_class").replace(
    "gs://", "/gcs/"
)
output_dir = os.path.join(BUCKET_URI, "dreambooth/output").replace("gs://", "/gcs/")

# Worker pool spec.
machine_type = "g2-standard-8"
num_nodes = 1
gpu_type = "NVIDIA_L4"
num_gpus = 1

# Setup training job.
job_name = create_job_name("dreambooth-stable-diffusion")
job = aiplatform.CustomContainerTrainingJob(
    display_name=job_name,
    container_uri=TRAIN_DOCKER_URI,
)

# Pass training arguments and launch job.
# See https://github.com/huggingface/diffusers/blob/v0.16.0/examples/dreambooth/train_dreambooth.py#L75
# for a full list of training arguments.
model = job.run(
    args=[
        "dreambooth/train_dreambooth.py",
        f"--pretrained_model_name_or_path={model_id}",
        "--train_text_encoder",
        f"--instance_data_dir={instance_dir}",
        f"--class_data_dir={class_dir}",
        f"--output_dir={output_dir}",
        "--with_prior_preservation",
        "--prior_loss_weight=1.0",
        "--instance_prompt='a photo of sks dog'",
        "--class_prompt='a photo of dog'",
        "--resolution=512",
        "--train_batch_size=1",
        "--gradient_checkpointing",
        "--learning_rate=2e-6",
        "--lr_scheduler=constant",
        "--lr_warmup_steps=0",
        "--num_class_images=200",
        f"--max_train_steps={MAX_TRAIN_STEPS}",
    ],
    replica_count=num_nodes,
    machine_type=machine_type,
    accelerator_type=gpu_type,
    accelerator_count=num_gpus,
)

In [None]:
# @title Upload and deploy model

# @markdown This section uploads the model to Model Registry and deploys it on the Endpoint.

# @markdown The model deployment step will take ~15 minutes to complete.

# Sets the model_id to gs://{GCS_BUCKET}/dreambooth/output to load the OSS fine-tuned model.
model_text_to_image, endpoint_text_to_image = deploy_model(
    model_id=f"gs://{BUCKET_NAME}/dreambooth/output", task="text-to-image"
)

In [None]:
# @title Predict

# @markdown Once deployment succeeds, you can send a batch of text prompts to the endpoint to generated images.

# @markdown When deployed on one L4 GPU, the average inference time of a request is ~15 seconds.

prompt = "A serious capybara at work, wearing a suit"  # @param {type: "string"}

instances = [
    {"prompt": prompt},
]
response = endpoint_text_to_image.predict(instances=instances)
images = [base64_to_image(image) for image in response.predictions]
image_grid(images, rows=1)

In [None]:
# @title Clean up resources

# @markdown  Delete the experiment models and endpoints to recycle the resources
# @markdown  and avoid unnecessary continouous charges that may incur.

try:
    # Undeploy model and delete endpoint.
    endpoint.delete(force=True)

    # Delete model.
    model.delete()

except Exception as e:
    print(e)

# Delete bucket.
delete_bucket = False  # @param {type:"boolean"}
if delete_bucket:
    ! gsutil -m rm -r $BUCKET_NAME