In [None]:
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex AI Model Garden - Image Generation with MediaPipe

<table align="left">
  <td>
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_mediapipe_image_generation.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Colab logo"> Run in Colab
    </a>
  </td>

  <td>
    <a href="https://github.com/GoogleCloudPlatform/vertex-ai-samples/blob/main/notebooks/community/model_garden/model_garden_mediapipe_image_generation.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo">
      View on GitHub
    </a>
  </td>
  <td>
    <a href="https://console.cloud.google.com/vertex-ai/notebooks/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/vertex-ai-samples/main/notebooks/community/model_garden/model_garden_mediapipe_image_generation.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo">
Open in Vertex AI Workbench
    </a>
  </td>
</table>

**_NOTE_**: This notebook has been tested in the following environment:

* Python version = 3.9

**NOTE**: The checkpoint and the dataset linked in this Colab are not owned or distributed by Google, and are made available by third parties. Please review the terms and conditions made available by the third parties before using the checkpoint and data.


## Overview

This notebook demonstrates how you can customize a [MediaPipe Image Generator](https://developers.google.com/mediapipe/solutions/vision/image_generator), a text-to-image generator, by adding Low-Rank Adaptation ([LoRA](https://arxiv.org/abs/2106.09685)) weights to generate images of specific people, objects, and styles.

Using Vertex AI's Model Garden, we will retrain a standard diffusion model on specialized dataset of specific concepts, which are identified by unique tokens. With the new LoRA weights after training, the new model is able to generate images of the new concept when the token is specified in the text prompt.

Once the model is customized with LoRA weights, it should only be used to generate images of the tokenized concept. It is no longer useful as a generalized image generation model. For more on customizing a MediaPipe Image Generator with LoRA weights, see the [MediaPipe documentation](https://developers.google.com/mediapipe/solutions/vision/image_generator#lora).

NOTE: If you are creating LoRa weights to generate images of specific people and faces, only use this solution on your face or faces of people who have given you permission to do so.

### Objective

* Set up a Google Cloud project with Vertex AI.
* Train a text-to-image difussion model on a specialized dataset to create [LoRA](https://arxiv.org/abs/2106.09685) weights.
* Customize a general image generator into a specialized generator that can inject specific objects, people, and styles into generated images.
* Configure the newly trained Image Generator.
* Download, upload, and deploy the new model

### Costs

This tutorial uses billable components of Google Cloud:

* Vertex AI
* Cloud Storage

Learn about [Vertex AI
pricing](https://cloud.google.com/vertex-ai/pricing) and [Cloud Storage
pricing](https://cloud.google.com/storage/pricing), and use the [Pricing
Calculator](https://cloud.google.com/products/calculator/)
to generate a cost estimate based on your projected usage.

## Before you begin

### Set up your Google Cloud project

**The following steps are required, regardless of your notebook environment.**

1. [Select or create a Google Cloud project](https://console.cloud.google.com/cloud-resource-manager).

1. [Make sure that billing is enabled for your project](https://cloud.google.com/billing/docs/how-to/modify-project).

1. [Enable the Vertex AI API and Compute Engine API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com,compute_component).
1. If you are running this notebook locally, you will need to install the [Cloud SDK](https://cloud.google.com/sdk).

1. Enter your project ID in the cell below. Then run the cell to make sure the
Cloud SDK uses the right project for all the commands in this notebook.

### Authenticate for Colab
**Note**: Skip this step if you are not using [Colab](https://colab.google/)

Run the following commands to install dependencies and authenticate with Google Cloud on Colab.

In [None]:
! pip3 install --upgrade pip

import sys

if "google.colab" in sys.modules:
    ! pip3 install --upgrade google-cloud-aiplatform

    # Automatically restart kernel after installs
    import IPython

    app = IPython.Application.instance()
    app.kernel.do_shutdown(True)

    from google.colab import auth as google_auth

    google_auth.authenticate_user()

### Set your project ID (`PROJECT_ID`)

If you don't know your project ID, try the following:
* Run `gcloud config list`.
* Run `gcloud projects list`.
* See the support page: [Locate the project ID](https://support.google.com/googleapi/answer/7014113)

In [None]:
PROJECT_ID = ""  # @param {type:"string"}

# Set the project id
! gcloud config set project {PROJECT_ID}

### Set the storage location (`REGION`)

You can also change the `REGION` variable used by Vertex AI. Learn more about [Vertex AI regions](https://cloud.google.com/vertex-ai/docs/general/locations).

In [None]:
REGION = ""  # @param {type: "string"}
REGION_PREFIX = REGION.split("-")[0]
assert REGION_PREFIX in (
    "us",
    "europe",
    "asia",
), f'{REGION} is not supported. It must be prefixed by "us", "asia", or "europe".'

### Create a Cloud Storage bucket

Create a storage bucket to store intermediate artifacts such as datasets and trained models.

In [None]:
BUCKET_URI = ""  # @param {type:"string"}

If your bucket doesn't already exist, create your Cloud Storage bucket.

**NOTE**: Only run the following cell if you do not already have a bucket.

In [None]:
! gsutil mb -l {REGION} -p {PROJECT_ID} {BUCKET_URI}

### Import libraries

In [None]:
import json
import os
from datetime import datetime

from google.cloud import aiplatform

### Initialize Vertex AI SDK for Python

Initialize the Vertex AI SDK for Python for your project.

In [None]:
now = datetime.now().strftime("%Y%m%d-%H%M%S")

STAGING_BUCKET = os.path.join(BUCKET_URI, "temp/%s" % now)

MODEL_EXPORT_PATH = os.path.join(STAGING_BUCKET, "model")

IMAGE_EXPORT_PATH = os.path.join(STAGING_BUCKET, "image")

aiplatform.init(project=PROJECT_ID, location=REGION, staging_bucket=STAGING_BUCKET)

### Define training and serving constants

In [None]:
TRAINING_JOB_DISPLAY_NAME = "mediapipe_stable_diffusion_%s" % now
TRAINING_CONTAINER = f"{REGION_PREFIX}-docker.pkg.dev/vertex-ai-restricted/vertex-vision-model-garden-dockers/mediapipe-stable-diffusion-train"
TRAINING_MACHINE_TYPE = "a2-highgpu-1g"
TRAINING_ACCELERATOR_TYPE = "NVIDIA_TESLA_A100"
TRAINING_ACCELERATOR_COUNT = 1

PREDICTION_CONTAINER_URI = f"{REGION_PREFIX}-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-peft-serve"
PREDICTION_PORT = 7080
PREDICTION_ACCELERATOR_TYPE = "NVIDIA_TESLA_V100"
PREDICTION_MACHINE_TYPE = "n1-standard-8"
UPLOAD_MODEL_NAME = "mediapipe_stable_diffusion_model_%s" % now

## Train a customized Image Generator

In this section, we will customize the Image Generator by training the model on images of [teapots](https://github.com/google/dreambooth/tree/main/dataset/teapot) from the [DreamBooth dataset](https://github.com/google/dreambooth/tree/main). Using the LoRA weights created through training, the new model will be able to inject teapots into generated images.

This is a simple example implementation. You can modify the following cells to further customize the notebook.

### Choose the pre-trained model to download

The MediaPipe Image Generator task requires you to download a trained model that matches the `runwayml/stable-diffusion-v1-5 EMA-only` model format, based on the following model: [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/).


In [None]:
unet_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/unet/diffusion_pytorch_model.bin"  # @param {type:"string"}
vae_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/vae/diffusion_pytorch_model.bin"  # @param {type:"string"}
text_encoder_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/text_encoder/pytorch_model.bin"  # @param {type:"string"}

### Prepare input data for training

Customizing a model for image generation requires a dataset that contains sample pictures of the concept instance that you want to use in the generation. The concept can be a person, object, or style.

**Object**
![](https://storage.googleapis.com/mediapipe-assets/documentation/object_lora.png)

**Person**
![](https://storage.googleapis.com/mediapipe-assets/documentation/person_lora.png)

**Style**
![](https://storage.googleapis.com/mediapipe-assets/documentation/style_lora.png)

You must also assign a unique token to the new concept. The prompt should include the token, which is "monadikos" in this case, followed by a word that describes the concept to generate. In this example, we are using "A monadikos teapot". The images from the [teapots](https://github.com/google/dreambooth/tree/main/dataset/teapot) dataset can be downloaded from Google Cloud Storage.

The customized model will recognize the term "monadikos teapot", and inject an image of a teapot into the generated images.

In [None]:
# Path to the training data folder.
training_data_path = (
    "gs://mediapipe-tasks/image_generator/teapot"  # @param {type:"string"}
)
# An instance description of the training data.
training_data_prompt = "A monadikos teapot"  # @param {type:"string"}

### Set training options

The Image Generator comes with a set of pre-defined hyperparameter (`HParams`) settings that work best for specific situations. You should select a template that best matches your use case.

You can further customize hyperparameters like the learning rate and the number of training steps (epochs). For more information on these hyperparameters, see the [Google Machine Learning glossary](https://developers.google.com/machine-learning/glossary)

To set custom training parameters, adjust the values for the following hyperparameters:

In [None]:
# Parameters about training configuration
# The learning rate to use for gradient descent training.
learning_rate: float = 0.00001  # @param {type:"number"}
# Number of training steps. If set to 0, uses the default value.
num_train_steps: int = 0  # @param {type:"integer"}
# Save the checkpoint in every n steps.
save_checkpoints_every_n: int = 100  # @param {type:"integer"}
# Batch size for training.
batch_size: int = 1  # @param {type:"integer"}

# Dataset-related parameters
# Whether to use random horizontal flip on data.
random_flip: bool = False  # @param {type:"boolean"}
# Whether to use random largest square crop.
random_crop: bool = False  # @param {type:"boolean"}
# Whether to distort the color of the image (jittering order is random).
random_color_jitter: bool = False  # @param {type:"boolean"}

# Hyperparameters for LoRA tuning
# The rank in the low-rank matrices. If set to 0, uses the default value.
lora_rank: int = 0  # @param {type:"integer"}

Alternatively, you can also use one of our pre-trained models for these templates. These templates are already customized and already contain LoRA weights:
* [Object (berry bowls)](https://storage.googleapis.com/mediapipe-tasks/image_generator/object/pytorch_lora_weights.bin)
* [Face](https://storage.googleapis.com/mediapipe-tasks/image_generator/face/pytorch_lora_weights.bin)
* [Style](https://storage.googleapis.com/mediapipe-tasks/image_generator/style/pytorch_lora_weights.bin)

In [None]:
template = ""  # @param ["", "face", "object", "style"]

## Test the customized Image Generator model

After training the custom model, we will generate images to examine the quality of the customized model. You can provide a text prompt below and configure options for generating the test images.

### Define the test generation prompt

Specify the prompt to use to test the customized model. Note that a variation of the token, "monadikos teapots", is included in the prompt. If you are customizing this notebook with another dataset, set a token to describe the object, person, or style depicted in the training data.

In [None]:
prompt: str = "Two monadikos teapots on a table"  # @param {type:"string"}

### Configure the parameters to generate test images

Set configuration options to run image generation with the customized model.

In [None]:
# Number of steps to run inference.
number_inference_steps: int = 50  # @param {type:"integer"}
# Classifier-free guidance weight to use during inference. Weight must be is >= 1.0.
guidance_scale: float = 7.5  # @param {type:"number"}
#  Number of generated images per prompt.
number_generated_images: int = 8  # @param {type:"integer"}

### Tune the image generator with LoRA
Tune the Image Generator with LoRA and generate new images based on your prompt. This can take up to 10 minutes on Vertex AI with a A100 GPU.


In [None]:
model_export_path = MODEL_EXPORT_PATH
image_export_path = IMAGE_EXPORT_PATH

worker_pool_specs = [
    {
        "machine_spec": {
            "machine_type": TRAINING_MACHINE_TYPE,
            "accelerator_type": TRAINING_ACCELERATOR_TYPE,
            "accelerator_count": TRAINING_ACCELERATOR_COUNT,
        },
        "replica_count": 1,
        "container_spec": {
            "image_uri": TRAINING_CONTAINER,
            "command": [],
            "args": [
                "--task_name=stable_diffusion",
                "--model_export_path=%s" % model_export_path,
                "--image_export_path=%s" % image_export_path,
                "--training_data_path=%s" % training_data_path,
                "--training_data_prompt='%s'" % training_data_prompt,
                "--prompt='%s'" % prompt,
                "--hparams_template=%s" % template,
                "--hparams=%s"
                % json.dumps(
                    {
                        "learning_rate": learning_rate,
                        "num_train_steps": num_train_steps,
                        "save_checkpoints_every_n": save_checkpoints_every_n,
                        "batch_size": batch_size,
                        "random_flip": random_flip,
                        "random_crop": random_crop,
                        "random_color_jitter": random_color_jitter,
                        "lora_rank": lora_rank,
                        "torch_vae": vae_url,
                        "torch_unet": unet_url,
                        "torch_text_encoder": text_encoder_url,
                    }
                ),
                "--generator_hparams=%s"
                % json.dumps(
                    {
                        "number_inference_steps": number_inference_steps,
                        "guidance_scale": guidance_scale,
                        "number_generated_images": number_generated_images,
                    }
                ),
            ],
        },
    }
]

training_job = aiplatform.CustomJob(
    display_name=TRAINING_JOB_DISPLAY_NAME,
    project=PROJECT_ID,
    worker_pool_specs=worker_pool_specs,
    staging_bucket=STAGING_BUCKET,
)

training_job.run()

## Download images and model

After training and testing the new model, you can download the generated images and the new customized model. The LoRA weights from training can also be used with the MediaPipe Tasks ImageGenerator API for on-device applications.

### Download generated images

Download and preview the generated images at different checkpoints.
Inspecting the generated images helps to determine the best checkpoint and avoid underfitting or overfitting.

In [None]:
import sys

import matplotlib.pyplot as plt


def copy_image(images_source, images_dest):
    os.makedirs(images_dest, exist_ok=True)
    ! gsutil cp -r {images_source}/* {images_dest}


local_image_path = "./images/"
copy_image(IMAGE_EXPORT_PATH, local_image_path)

steps_samples = {}
for filename in os.listdir(local_image_path):
    absolute_path = os.path.join(local_image_path, filename)
    if os.path.isfile(absolute_path):
        parsed_name = filename.split("_")
        step = int(parsed_name[1])
        if step not in steps_samples:
            steps_samples[step] = []
        image = plt.imread(absolute_path)
        steps_samples[step].append(image)

for step in sorted(steps_samples.keys()):
    print(f"\nGenerated image with training steps {step}:")
    for image in steps_samples[step]:
        plt.figure(figsize=(20, 10), dpi=150)
        plt.axis("off")
        plt.imshow(image)
        plt.show()

By default, the last checkpoint is used for deployment. However, we can customize that here based on the above visual inspection.

In [None]:
deployed_checkpoint: int = -1  # @param {type:"integer"}
if deployed_checkpoint == -1:
    deployed_checkpoint = num_train_steps
valid_checkpoints = list(
    range(save_checkpoints_every_n, num_train_steps + 1, save_checkpoints_every_n)
)
if deployed_checkpoint not in valid_checkpoints:
    raise ValueError("Invalid checkpoint chosen for deployment.")

### Download model

After fine-tuning and evaluating the model, you can download the model and checkpoints.

In [None]:
import sys


def copy_model(model_source, model_dest):
    os.makedirs(model_dest, exist_ok=True)
    ! gsutil -m cp -r {model_source}/* {model_dest}


local_model_path = "/models"
copy_model(MODEL_EXPORT_PATH, local_model_path)

! tar czf models.tar.gz {local_model_path}/*

if "google.colab" in sys.modules:
    from google.colab import files

    files.download("models.tar.gz")

## Upload and deploy to Vertex AI

This section shows the way to test with trained models.
1. Upload and deploy models to the [Vertex AI Model Registry](https://cloud.google.com/vertex-ai/docs/model-registry/introduction)
2. Get [online predictions](https://cloud.google.com/vertex-ai/docs/predictions/get-online-predictions) from the deployed model

### Upload model to Vertex AI Model Registry

In [None]:
serving_env = {
    "TASK": "text-to-image-lora",
    "BASE_MODEL_ID": "runwayml/stable-diffusion-v1-5",
    "FINETUNED_LORA_MODEL_PATH": os.path.join(
        MODEL_EXPORT_PATH, f"checkpoint_{deployed_checkpoint}"
    ),
}

model = aiplatform.Model.upload(
    display_name=UPLOAD_MODEL_NAME,
    serving_container_image_uri=PREDICTION_CONTAINER_URI,
    serving_container_ports=[PREDICTION_PORT],
    serving_container_predict_route="/predictions/peft_serving",
    serving_container_health_route="/ping",
    serving_container_environment_variables=serving_env,
)

model.wait()

print("The uploaded model name is: ", UPLOAD_MODEL_NAME)

### Deploy the uploaded model

You will deploy models in Google Cloud Vertex AI. The default setting will use 1 V100 GPU for deployment.

Please create a Service Account for serving with dockers if you do not have one yet.

The model deployment will take around 1 minute to finish.

In [None]:
# Please go to https://cloud.google.com/iam/docs/service-accounts-create#iam-service-accounts-create-console
# and create service account with `Vertex AI User` and `Storage Object Admin` roles.
service_account = ""  # @param {type:"string"}

endpoint = aiplatform.Endpoint.create(display_name=f"{UPLOAD_MODEL_NAME}-endpoint")
model.deploy(
    endpoint=endpoint,
    machine_type=PREDICTION_MACHINE_TYPE,
    accelerator_type=PREDICTION_ACCELERATOR_TYPE,
    accelerator_count=1,
    deploy_request_timeout=1800,
    service_account=service_account,
)

The docker container still needs to download and load the model after the endpoint is created. Therefore, we recommend waiting for 3 extra minutes before proceeding to the next cell.

Once deployed, you can send a batch of text prompts to the endpoint to generate images.

In [None]:
import base64
from io import BytesIO

import matplotlib.pyplot as plt
from PIL import Image

instances = [
    {"prompt": "Two monadikos teapots on a table"},
    {"prompt": "Two monadikos teapots on the floor"},
]
response = endpoint.predict(instances=instances)

plt.figure()
_, grid = plt.subplots(1, len(instances))
for cell, prediction in zip(grid, response.predictions):
    image = Image.open(BytesIO(base64.b64decode(prediction)))
    cell.imshow(image)

## Clean up
After the export is complete, you can delete your training job.

In [None]:
if training_job.list(filter=f'display_name="{TRAINING_JOB_DISPLAY_NAME}"'):
    training_job.delete()
# Undeploys models and deletes endpoints.
endpoint.delete(force=True)
model.delete()

You can also remove the output data.

In [None]:
!gsutil rm -r {STAGING_BUCKET}