In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI Model Garden - Stable Diffusion XL 1.0 (Dreambooth LoRA Finetuning)

<table align="left">
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fvertex-ai-samples%2Fmain%2Fnotebooks%2Fcommunity%2Fmodel_garden%2Fmodel_garden_pytorch_sd_xl_finetuning_dreambooth_lora.ipynb">
      <img alt="Google Cloud Colab Enterprise logo" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" width="32px"><br> Run in Colab Enterprise
    </a>
  </td>
  <td>
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_pytorch_sd_xl_finetuning_dreambooth_lora.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo">
      View on GitHub
    </a>
  </td>
</table>

## Overview

This notebook demonstrates finetuning [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) with [Dreambooth LoRA](https://huggingface.co/docs/diffusers/en/training/sdxl) and deploying it on Vertex AI for online prediction.

### Objective

- Finetune the [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) model with [Dreambooth LoRA](https://huggingface.co/docs/diffusers/en/training/sdxl).
- Upload the model to [Vertex AI Model Registry](https://cloud.google.com/vertex-ai/docs/model-registry/introduction).
- Deploy the model to a [Vertex AI Endpoint resource](https://cloud.google.com/vertex-ai/docs/predictions/using-private-endpoints).
- Run online predictions for text-to-image.

### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage

Learn about [Vertex AI pricing](https://cloud.google.com/vertex-ai/pricing) and [Cloud Storage pricing](https://cloud.google.com/storage/pricing), and use the [Pricing Calculator](https://cloud.google.com/products/calculator/) to generate a cost estimate based on your projected usage.

## LoRA Finetune with Dreambooth

In [None]:
# @title Setup Google Cloud project

# @markdown 1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

# @markdown 2. [Optional] [Create a Cloud Storage bucket](https://cloud.google.com/storage/docs/creating-buckets) for storing
# @markdown experiment outputs. Set the BUCKET_URI for the experiment environment. The specified Cloud Storage bucket (`BUCKET_URI`)
# @markdown should be located in the same region as where the notebook was launched. Note that a multi-region bucket (eg. "us") is
# @markdown not considered a match for a single region covered by the multi-region range (eg. "us-central1").
# @markdown If not set, a unique GCS bucket will be created instead.

import base64
import glob
import math
import os
import sys
from datetime import datetime
from io import BytesIO

import requests
from google.cloud import aiplatform, storage
from PIL import Image

# Get the default cloud project id.
PROJECT_ID = os.environ["GOOGLE_CLOUD_PROJECT"]

# Get the default region for launching jobs.
REGION = os.environ["GOOGLE_CLOUD_REGION"]

# Enable the Vertex AI API and Compute Engine API, if not already.
print("Enabling Vertex AI and Compute Engine API.")
! gcloud services enable aiplatform.googleapis.com compute.googleapis.com

# Cloud Storage bucket for storing the experiment artifacts.
# A unique GCS bucket will be created for the purpose of this notebook. If you
# prefer using your own GCS bucket, change the value yourself below.
now = datetime.now().strftime("%Y%m%d%-H%M%S")
BUCKET_URI = "gs://"  # @param {type: "string"}

assert BUCKET_URI.startswith("gs://"), "BUCKET_URI must start with `gs://`."
if BUCKET_URI is None or BUCKET_URI.strip() == "" or BUCKET_URI == "gs://":
    # Create a unique GCS bucket for this notebook, if not specified by the user
    BUCKET_URI = f"gs://{PROJECT_ID}-tmp-{now}"
    ! gsutil mb -l {REGION} {BUCKET_URI}
else:
    BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])
    shell_output = ! gsutil ls -Lb {BUCKET_NAME} | grep "Location constraint:" | sed "s/Location constraint://"
    bucket_region = shell_output[0].strip().lower()
    if bucket_region != REGION:
        raise ValueError(
            "Bucket region %s is different from notebook region %s"
            % (bucket_region, REGION)
        )

# Cloud Storage bucket for storing the experiment artifacts.
# A unique GCS bucket will be created for the purpose of this notebook. If you
# prefer using your own GCS bucket, change the value yourself below.
print(f"Using this GCS Bucket: {BUCKET_URI}")

# Set up the default SERVICE_ACCOUNT.
BUCKET_NAME = "/".join(BUCKET_URI.split("/")[:3])
shell_output = ! gcloud projects describe $PROJECT_ID
project_number = shell_output[-1].split(":")[1].strip().replace("'", "")
SERVICE_ACCOUNT = f"{project_number}-compute@developer.gserviceaccount.com"
SERVICE_ACCOUNT_CC = (
    f"service-{project_number}@gcp-sa-aiplatform-cc.iam.gserviceaccount.com"
)

print("Using this default Service Account:", SERVICE_ACCOUNT)

# Provision permissions to the two SERVICE_ACCOUNTs with the GCS bucket
! gsutil iam ch serviceAccount:{SERVICE_ACCOUNT}:roles/storage.admin $BUCKET_NAME
! gsutil iam ch serviceAccount:{SERVICE_ACCOUNT_CC}:roles/storage.admin $BUCKET_NAME

# The pre-built training docker images. They contain training scripts and models.
TRAIN_DOCKER_URI = "us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-peft-train:20240320_0936_RC00"
# The pre-built serving docker images. They contains serving scripts and models.
SERVE_DOCKER_URI = "us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-diffusers-serve-opt:20240403_0836_RC00"

aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=BUCKET_URI)

if "google.colab" in sys.modules:
    from google.colab import auth

    auth.authenticate_user(project_id=PROJECT_ID)


def create_job_name(prefix):
    now = datetime.now().strftime("%Y%m%d-%H%M%S")
    job_name = f"{prefix}-{now}"
    return job_name


def base64_to_image(image_str):
    """Convert base64 encoded string to an image."""
    image = Image.open(BytesIO(base64.b64decode(image_str)))
    return image


def image_to_base64(image, format="JPEG"):
    buffer = BytesIO()
    image.save(buffer, format=format)
    image_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
    return image_str


def download_image(url):
    response = requests.get(url)
    return Image.open(BytesIO(response.content))


def image_grid(imgs, rows=2, cols=2):
    w, h = imgs[0].size
    grid = Image.new(
        mode="RGB", size=(cols * w + 10 * cols, rows * h), color=(255, 255, 255)
    )
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w + 10 * i, i // cols * h))
    return grid


def get_bucket_and_blob_name(filepath):
    # The gcs path is of the form gs://<bucket-name>/<blob-name>
    gs_suffix = filepath.split("gs://", 1)[1]
    return tuple(gs_suffix.split("/", 1))


def download_gcs_dir_to_local(gcs_dir_path, local_dir_path):
    """Downloads files in a GCS directory to a local directory."""
    assert gcs_dir_path.startswith("gs://"), "gcs_dir_path must start with `gs://`."
    bucket_name = gcs_dir_path.split("/")[2]
    prefix = gcs_dir_path[len("gs://" + bucket_name) :].strip("/") + "/"
    client = storage.Client()
    blobs = client.list_blobs(bucket_name, prefix=prefix)
    for blob in blobs:
        if blob.name[-1] == "/":
            continue
        file_path = blob.name[len(prefix) :].strip("/")
        local_file_path = os.path.join(local_dir_path, file_path)
        os.makedirs(os.path.dirname(local_file_path), exist_ok=True)

        print(f"Downloading {file_path} to {local_file_path}")
        blob.download_to_filename(local_file_path)


def upload_local_dir_to_gcs(local_dir_path, gcs_dir_path):
    """Uploads files in a local directory to a GCS directory."""
    client = storage.Client()
    bucket_name = gcs_dir_path.split("/")[2]
    bucket = client.get_bucket(bucket_name)
    for local_file in glob.glob(local_dir_path + "/**"):
        if not os.path.isfile(local_file):
            continue
        filename = local_file[1 + len(local_dir_path) :]
        gcs_file_path = os.path.join(gcs_dir_path, filename)
        _, blob_name = get_bucket_and_blob_name(gcs_file_path)
        blob = bucket.blob(blob_name)
        blob.upload_from_filename(local_file)
        print("Copied {} to {}.".format(local_file, gcs_file_path))


def deploy_model(model_id, lora_id, task):
    """Create a Vertex AI Endpoint and deploy the specified model to the endpoint."""
    model_name = model_id
    endpoint = aiplatform.Endpoint.create(display_name=f"{model_name}-endpoint")
    serving_env = {
        "MODEL_ID": model_id,
        "LORA_ID": lora_id,
        "TASK": task,
    }
    model = aiplatform.Model.upload(
        display_name=model_name,
        serving_container_image_uri=SERVE_DOCKER_URI,
        serving_container_ports=[7080],
        serving_container_predict_route="/predictions/diffusers_serving",
        serving_container_health_route="/ping",
        serving_container_environment_variables=serving_env,
    )
    model.deploy(
        endpoint=endpoint,
        machine_type="g2-standard-8",
        accelerator_type="NVIDIA_L4",
        accelerator_count=1,
        deploy_request_timeout=1800,
        service_account=SERVICE_ACCOUNT,
    )
    print("To load this existing endpoint from a different session:")
    print(
        f'endpoint = aiplatform.Endpoint("projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint.name}")'
    )
    return model, endpoint

In [None]:
# @title Start Dreambooth LoRA finetune

# @markdown This section uses [Dreambooth LoRA](https://dreambooth.github.io/) to finetune
# @markdown the [stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) model
# @markdown with [5 dog images](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) to
# @markdown personalize the text-to-image model.

# @markdown In this example, we will default to use the images in [diffusers/dog-example](https://huggingface.co/datasets/diffusers/dog-example)
# @markdown as the training dataset. If you need to train based on your own images, choose the `IMAGE_SOURCE` option as
# @markdown "Use my own images in a GCS bucket" and provide the GCS path to your images below.

# @markdown It finetunes both text encoder and unet of the stable diffusion model up to 200 steps.
# @markdown The whole finetuning job takes 20 minutes to finish using 1 L4 GPU.

# @markdown The full model will be saved after the finetuning job finishes and it can be loaded
# @markdown by the [StableDiffusionPipeline](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img)
# @markdown to run inference.

# @markdown Click "Show code" to modify the code to change GPU type, count, and other training parameters.

IMAGE_SOURCE = "Default image examples from HuggingFace"  # @param ["Default image examples from HuggingFace", "Use my own images in a GCS bucket"] {isTemplate:true}

# Input and output path.
instance_dir = os.path.join(BUCKET_URI, "dreambooth-lora-sdxl/images")
class_dir = os.path.join(BUCKET_URI, "dreambooth-lora-sdxl/image_class")
output_dir = os.path.join(BUCKET_URI, "dreambooth-lora-sdxl/output")

from huggingface_hub import snapshot_download

local_dir = "./images"
if IMAGE_SOURCE == "Default image examples from HuggingFace":
    snapshot_download(
        "diffusers/dog-example",
        local_dir=local_dir,
        repo_type="dataset",
        ignore_patterns=".gitattributes",
    )
    print("Finished downloading training images from huggingface.")
else:
    # @markdown **[Optional]** If using own images, provide the GCS path to the images.
    # @markdown Make sure you have permission to access the bucket.
    source_image_gcs_dir = "gs://"  # @param {type:"string"}
    print(
        f"Now downloading the images from the source gcs directory: {source_image_gcs_dir}"
    )
    assert source_image_gcs_dir.startswith(
        "gs://"
    ), "source_image_gcs_dir must start with `gs://`."
    download_gcs_dir_to_local(source_image_gcs_dir, local_dir)


# Upload data to Cloud Storage bucket.
print(f"Now uploading the images to the bucket: {instance_dir}")
upload_local_dir_to_gcs(local_dir, instance_dir)
upload_local_dir_to_gcs(local_dir, class_dir)

# The pre-trained model to be loaded.
model_id = "stabilityai/stable-diffusion-xl-base-1.0"

# Worker pool spec.
machine_type = "g2-standard-8"
num_nodes = 1
gpu_type = "NVIDIA_L4"
num_gpus = 1

# @markdown Default to 200. Increase it to 400 or 800 if you want to achieve a higher model quality.
train_steps = 200  # @param {type: "number"}

# Setup training job.

job_name = create_job_name("dreambooth-lora-sdxl")
job = aiplatform.CustomContainerTrainingJob(
    display_name=job_name,
    container_uri=TRAIN_DOCKER_URI,
)

BUCKET_GCS_FUSE = BUCKET_URI.replace("gs://", "/gcs/")
instance_dir_fuse = instance_dir.replace("gs://", "/gcs/")
class_dir_fuse = class_dir.replace("gs://", "/gcs/")
output_dir_fuse = output_dir.replace("gs://", "/gcs/")

model = job.run(
    args=[
        "--task=text-to-image-dreambooth-lora-sdxl",
        f"--pretrained_model_name_or_path={model_id}",
        f"--instance_data_dir={instance_dir_fuse}",
        f"--class_data_dir={class_dir_fuse}",
        f"--output_dir={output_dir_fuse}",
        "--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix",
        "--mixed_precision=fp16",
        "--instance_prompt='a photo of sks dog'",
        "--resolution=1024",
        "--train_batch_size=1",
        "--gradient_accumulation_steps=1",
        "--gradient_checkpointing",
        "--learning_rate=2e-6",
        "--lr_scheduler=constant",
        "--lr_warmup_steps=0 ",
        "--use_8bit_adam",
        f"--max_train_steps={train_steps}",
        f"--checkpointing_steps={train_steps // 2}",
        "--seed=0",
    ],
    replica_count=num_nodes,
    machine_type=machine_type,
    accelerator_type=gpu_type,
    accelerator_count=num_gpus,
)

In [None]:
# @title Deploy the SD model to Vertex for online predictions

# @markdown This section uploads the model to Model Registry and deploys it on the Endpoint. It takes ~15 minutes to finish.
# @markdown Click "Show Code" to see more details.

# @markdown `text-to-image` lets you send text prompts to the endpoint to generate images.


model_id = "stabilityai/stable-diffusion-xl-base-1.0"
lora_id = output_dir

print("LoRA weights are saved in:", lora_id)

# Set the model_id to "stabilityai/stable-diffusion-xl-base-1.0" to load the OSS pre-trained model.
model, endpoint = deploy_model(
    model_id=model_id, lora_id=lora_id, task="text-to-image-sdxl"
)
print("endpoint_name:", endpoint.name)

# Loads an existing endpoint instance using the endpoint name:
# - Using `endpoint_name = endpoint.name` allows us to get the
#   endpoint name of the endpoint `endpoint` created in the cell
#   above.
# - Alternatively, you can set `endpoint_name = "1234567890123456789"` to load
#   an existing endpoint with the ID 1234567890123456789.
# You may uncomment the code below to load an existing endpoint.

# endpoint_name = ""  # @param {type:"string"}
# aip_endpoint_name = (
#     f"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint_name}"
# )
# endpoint = aiplatform.Endpoint(aip_endpoint_name)

print("To load this existing endpoint from a different session:")
print(
    f'endpoint = aiplatform.Endpoint("projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint.name}")'
)

In [None]:
# @title Predict (text-to-image)

# @markdown This section is only for sending predictions to an endpoint with the task `text-to-image`.

# @markdown Once deployment succeeds, you can generate images by sending text prompts to the endpoint.

# @markdown You can also batch send prompts by separating them with a comma.
# @markdown You may adjust the parameters below to achieve best image quality.

comma_separated_prompt_list = "A picture of a sks dog in a house, A picture of a sks dog catching a frisbee"  # @param {type: "string"}
prompt_list = [x.strip() for x in comma_separated_prompt_list.split(",")]
height = 1024  # @param {type:"number"}
width = 1024  # @param {type:"number"}
num_inference_steps = 25  # @param {type:"number"}
guidance_scale = 7.5  # @param {type:"number"}

instances = [
    {
        "prompt": prompt,
        "negative_prompt": "",
        "height": height,
        "width": width,
        "num_inference_steps": num_inference_steps,
        "guidance_scale": guidance_scale,
    }
    for prompt in prompt_list
]

response = endpoint.predict(instances=instances)
images = [base64_to_image(image) for image in response.predictions]
image_grid(images, rows=math.ceil(len(images) ** 0.5))

In [None]:
# @title Clean up resources
# @markdown  Delete the experiment models and endpoints to recycle the resources
# @markdown  and avoid unnecessary continouous charges that may incur.

# Undeploy model and delete endpoint.
endpoint.delete(force=True)

# Delete models.
model.delete()

delete_bucket = False  # @param {type:"boolean"}
if delete_bucket:
    ! gsutil -m rm -r $BUCKET_URI