In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Guess who or what app using Hugging Face Deep Learning container model on Vertex AI

<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/open-models/use-cases/guess_app.ipynb">
      <img width="32px" src="https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg" alt="Google Colaboratory logo"><br> Open in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fopen-models%2Fuse-cases%2Fguess_app.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Open in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/open-models/use-cases/guess_app.ipynb">
      <img src="https://www.gstatic.com/images/branding/gcpiconscolors/vertexai/v1/32px.svg" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/use-cases/guess_app.ipynb">
      <img width="32px" src="https://upload.wikimedia.org/wikipedia/commons/9/91/Octicons-mark-github.svg" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
</table>

<div style="clear: both;"></div>

<b>Share to:</b>

<a href="https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/use-cases/guess_app.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg" alt="LinkedIn logo">
</a>

<a href="https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/use-cases/guess_app.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg" alt="Bluesky logo">
</a>

<a href="https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/use-cases/guess_app.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/53/X_logo_2023_original.svg" alt="X logo">
</a>

<a href="https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/use-cases/guess_app.ipynb" target="_blank">
  <img width="20px" src="https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png" alt="Reddit logo">
</a>

<a href="https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/open-models/use-cases/guess_app.ipynb" target="_blank">
  <img width="20px" src="https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg" alt="Facebook logo">
</a>

| | |
|-|-|
| Author(s) | [Ivan Nardini](https://github.com/inardini) |

## Overview

This notebook shows how you can use Vertex AI and Hugging Face Deep Learning container to create a simple "Guess who or what" application which combines an image generation open model with Gemini to answer and represent subjects of Guess who or what riddles.

## Get started

### Install Vertex AI SDK and other required packages

To run this example, you will only need the [`google-cloud-aiplatform`](https://github.com/googleapis/python-aiplatform) Python SDK and the [`huggingface_hub`](https://github.com/huggingface/huggingface_hub) Python package.

In [None]:
%pip install --upgrade --user --quiet google-cloud-aiplatform huggingface_hub gradio

### Restart runtime (Colab only)

To use the newly installed packages in this Jupyter environment, if you are on Colab you must restart the runtime. You can do this by running the cell below, which restarts the current kernel. The restart might take a minute or longer. After it's restarted, continue to the next step.

In [None]:
# Automatically restart kernel after installs so that your environment can access the new packages
# import IPython

# app = IPython.Application.instance()
# app.kernel.do_shutdown(True)

<div class="alert alert-block alert-warning">
<b>⚠️ The kernel is going to restart. Wait until it's finished before continuing to the next step. ⚠️</b>
</div>


### Authenticate your Google Cloud account

Depending on your Jupyter environment, you may have to manually authenticate. Follow the relevant instructions below.

**1. Vertex AI Workbench**

* Do nothing as you are already authenticated.

**2. Local JupyterLab instance, uncomment and run:**

In [None]:
# !gcloud auth login

**3. Colab, uncomment and run:**

In [None]:
from google.colab import auth

auth.authenticate_user()

### Authenticate your Hugging Face account

Then you can install the `huggingface_hub` that comes with a CLI that will be used for the authentication with the token generated in advance. So that then the token can be safely retrieved via `huggingface_hub.get_token`.

In [None]:
from huggingface_hub import interpreter_login

interpreter_login()

Read more about [Hugging Face Security](https://huggingface.co/docs/hub/en/security), specifically about [Hugging Face User Access Tokens](https://huggingface.co/docs/hub/en/security-tokens).

### Set Google Cloud project information and initialize Vertex AI SDK

To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com), if not enabled already.

Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment).

In [None]:
import os

from google.cloud import aiplatform

PROJECT_ID = "[your-project-id]"  # @param {type:"string", isTemplate: true}
if PROJECT_ID == "[your-project-id]":
    PROJECT_ID = str(os.environ.get("GOOGLE_CLOUD_PROJECT"))

LOCATION = os.environ.get("GOOGLE_CLOUD_REGION", "us-central1")

aiplatform.init(project=PROJECT_ID, location=LOCATION)

### Requirements

To successfully deploy your Hugging Face model on Vertex AI, you need to have the following IAM roles set:

- Artifact Registry Reader (`roles/artifactregistry.reader`)
- Vertex AI User (`roles/aiplatform.user`)

For more information about granting roles, see [Manage access](https://cloud.google.com/iam/docs/granting-changing-revoking-access).

---

You also need to enable the following APIs (if not enabled already):

- Vertex AI API (`aiplatform.googleapis.com`)
- Artifact Registry API (`artifactregistry.googleapis.com`)

For more information about API enablement, see [Enabling APIs](https://cloud.google.com/apis/docs/getting-started#enabling_apis).


### Import libraries

In [None]:
import base64
import io
import os

from IPython.display import display
from PIL import Image
from google.cloud.aiplatform import Endpoint
import gradio as gr
from huggingface_hub import get_token
import vertexai
from vertexai.generative_models import GenerationConfig, GenerativeModel, SafetySetting

# Build a "Guess what or who" app


## Register FLUX on Vertex AI

To deploy a "Guess what or who" application using a Hugging Face model like [FLUX](https://console.cloud.google.com/vertex-ai/publishers/black-forest-labs/model-garden/flux1-schnell) for image generation on Vertex AI, you first register the chosen model within the [Vertex AI Model Registry](https://cloud.google.com/vertex-ai/docs/model-registry/introduction).  This registry serves as a central repository for managing your models on Vertex AI.

Registering a model involves specifying a serving container image and corresponding environment variables, which vary depending on the chosen model. For instance, for FLUX, you'll use a regular PyTorch Inference Deep Learning Container.  You can find the appropriate container URI and further details in the [Google Cloud Deep Learning Containers documentation](https://huggingface.co/docs/google-cloud/en/containers/introduction).  

In [None]:
flux_model = aiplatform.Model.upload(
    display_name="flux--generate",
    serving_container_image_uri="us-docker.pkg.dev/deeplearning-platform-release/gcr.io/huggingface-text-generation-inference-cu121.2-2.ubuntu2204.py310",
    serving_container_environment_variables={
        "HF_MODEL_ID": "black-forest-labs/FLUX.1-dev",
        "HF_TASK": "text-to-image",
        "HF_TOKEN": get_token(),
    },
)
flux_model.wait()

## Deploy Flux on Vertex AI

After successfully registering the model, you can then deploy it to a Vertex AI Endpoint according to your preferred deployment configuration, making it ready for image generation.

This deployment creates a new instance of the FLUX model in Vertex AI Prediction, the managed inference service on Vertex AI.

In [None]:
endpoint = aiplatform.Endpoint.create(display_name="flux--generate-endpoint")

deployed_flux_model = flux_model.deploy(
    endpoint=endpoint,
    machine_type="g2-standard-48",
    accelerator_type="NVIDIA_L4",
    accelerator_count=4,
    sync=False,
)

## Generate predictions

After the model gets successfully deployed, you can test it by submitting a prediction request.

### Generate an image using FLUX

In [None]:
response = deployed_flux_model.predict(
    instances=["a image of a cat riding a horse in illustration style"],
    parameters={
        "width": 512,
        "height": 512,
        "num_inference_steps": 8,
        "guidance_scale": 3.5,
    },
)

### Get the generated image

In [None]:
image = Image.open(io.BytesIO(base64.b64decode(response.predictions[0])))
display(image)

## Play "Guess who or what"

With the image generation open model deployed in a Vertex AI Endpoint, you are now able to build your "Guess who or what" Gen AI application.

In this scenario, you use Gradio to quickly build a web application for your Gen AI models.

In [None]:
# Get enviroment variables
PROJECT_ID = os.environ.get("PROJECT_ID", PROJECT_ID)
LOCATION = os.environ.get("LOCATION", LOCATION)
MODEL_ID = os.environ.get("MODEL_ID", "gemini-1.5-flash-002")
ENDPOINT_NAME = os.environ.get("ENDPOINT", endpoint.resource_name)

# Initialize Vertex AI SDK
aiplatform.init(project=PROJECT_ID, location=LOCATION)
vertexai.init(project=PROJECT_ID, location=LOCATION)

# Initialize Gemini model
MODEL = GenerativeModel("gemini-1.5-flash-002")
ENDPOINT = Endpoint(ENDPOINT_NAME)

# Common generation config and safety settings
GENERATION_CONFIG = GenerationConfig(
    candidate_count=1, max_output_tokens=8192, temperature=0
)

SAFETY_SETTINGS = [
    SafetySetting(
        category=SafetySetting.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
        threshold=SafetySetting.HarmBlockThreshold.OFF,
    ),
    SafetySetting(
        category=SafetySetting.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
        threshold=SafetySetting.HarmBlockThreshold.OFF,
    ),
    SafetySetting(
        category=SafetySetting.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
        threshold=SafetySetting.HarmBlockThreshold.OFF,
    ),
    SafetySetting(
        category=SafetySetting.HarmCategory.HARM_CATEGORY_HARASSMENT,
        threshold=SafetySetting.HarmBlockThreshold.OFF,
    ),
]


def generate_gemini_content(prompt_template: str, **kwargs) -> str:
    """Generate content using Gemini model with given prompt template."""
    prompt = prompt_template.format(**kwargs)
    response = MODEL.generate_content(
        [prompt],
        generation_config=GENERATION_CONFIG,
        safety_settings=SAFETY_SETTINGS,
        stream=False,
    )
    return response.text


def generate_subject(riddle: str) -> str:
    """Extract the subject/answer from a given riddle using Gemini model."""
    riddle_solver_prompt_template = """
    You are the best riddle solver. Given a riddle, your goal is solve it and only indicate the subject of the riddle.
    RIDDLE: {riddle}
    SUBJECT:
    """
    subject = generate_gemini_content(riddle_solver_prompt_template, riddle=riddle)
    return subject.replace("SUBJECT:", "").strip()


def generate_prompt(subject: str) -> str:
    """Generate an image generation prompt for a given subject."""
    image_gen_prompt_template = """
    You are a professional prompt engineer. Given a subject, prototype the most appropriate prompt to best visualize the subject.
    Only return the preferred prompt.
    SUBJECT: {subject}
    PROMPT:
    """
    return generate_gemini_content(image_gen_prompt_template, subject=subject)


def generate_image(image_gen_prompt: str) -> Image.Image:
    """Generate an image based on the provided prompt."""
    response = ENDPOINT.predict(
        instances=[image_gen_prompt],
        parameters={
            "width": 512,
            "height": 512,
            "num_inference_steps": 8,
            "guidance_scale": 3.5,
        },
    )
    return Image.open(io.BytesIO(base64.b64decode(response.predictions[0])))


def guess_game(riddle: str) -> tuple[Image.Image, str, str]:
    """Run the complete riddle-to-image game flow."""
    answer = generate_subject(riddle)
    prompt = generate_prompt(answer)
    image = generate_image(prompt)
    return image, answer, prompt


def increment_counter(counter: int) -> tuple[int, Image.Image | None, str, str, str]:
    """Increment the game counter"""
    return counter + 1, None, "", "", ""


def reset_game(counter: int) -> tuple[int, Image.Image | None, str, str, str]:
    """Reset the game state."""
    return counter, None, "", "", ""

## Build the app

Put together functions you define to build your Gradio app.

In [None]:
# Create main application block with Ocean theme
with gr.Blocks(theme=gr.themes.Ocean()) as app:

    # Header row containing title and counter state
    with gr.Row():
        gr.Markdown("# Guess What Game ❓")
        counter_state = gr.State(value=0)

    # Input row for user's riddle/description
    with gr.Row():
        prompt_input = gr.Textbox(label="Describe someone or something 💬 ")

    # Submit button to generate image and answer
    submit_btn = gr.Button("Submit")

    # Output row displaying generated content
    with gr.Row():
        image_prompt = gr.Textbox(
            label="Generated Image Prompt with Gemini 🎨 ", visible=True
        )
        image_output = gr.Image(label="Generated Image 🖼️ ")
        answer_output = gr.Textbox(
            label="Generated Answer with Gemini 🌌 ", interactive=False
        )

    # Game control buttons row
    with gr.Row():
        correct_btn = gr.Button("+1 Correct")
        reset_btn = gr.Button("Reset")

    # Display for tracking correct guesses
    counter_display = gr.Number(value=0, label="Correct Guesses 👍", interactive=False)

    # Event handler for submit button
    submit_btn.click(
        guess_game,
        inputs=[prompt_input],
        outputs=[image_output, answer_output, image_prompt],
    )

    # Event handler for correct button
    correct_btn.click(
        increment_counter,
        inputs=[counter_state],
        outputs=[
            counter_state,
            image_output,
            answer_output,
            image_prompt,
            prompt_input,
        ],
    ).then(lambda x: x, inputs=[counter_state], outputs=[counter_display])

    # Event handler for reset button
    reset_btn.click(
        reset_game,
        inputs=[counter_state],
        outputs=[
            counter_state,
            image_output,
            answer_output,
            image_prompt,
            prompt_input,
        ],
    ).then(lambda x: x, inputs=[counter_state], outputs=[counter_display])

### Launch the app

You are now ready to launch your "Guess who or what" Gradio app.

In [None]:
app.launch()

## Cleaning up

In [None]:
delete_app = True
delete_endpoint = True
delete_model = True

if delete_app:
    app.close()

if delete_endpoint:
    deployed_flux_model.delete()

if delete_model:
    flux_model.delete()