In [None]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Intro to Controlled Generation with the Gemini API

<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/gemini/controlled-generation/intro_controlled_generation.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Google Colaboratory logo"><br> Open in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fgemini%2Fcontrolled-generation%2Fintro_controlled_generation.ipynb">
      <img width="32px" src="https://cloud.google.com/ml-engine/images/colab-enterprise-logo-32px.png" alt="Google Cloud Colab Enterprise logo"><br> Open in Colab Enterprise
    </a>
  </td>    
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/gemini/controlled-generation/intro_controlled_generation.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo"><br> Open in Workbench
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/controlled-generation/intro_controlled_generation.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
</table>

| | |
|-|-|
|Author(s) | [Eric Dong](https://github.com/gericdong)|

## Overview

### Gemini

Gemini is a family of generative AI models developed by Google DeepMind that is designed for multimodal use cases.

### Controlled Generation

Depending on your application, you may want the model response to a prompt to be returned in a structured data format, particularly if you are using the responses for downstream processes, such as downstream modules that expect a specific format as input. The Gemini API provides the controlled generation capability to constraint the model output to a structured format.


### Objectives

In this tutorial, you learn how to use the controlled generation capability in the Vertex AI Gemini API to generate model responses in a JSON object with specific fields.

You will complete the following tasks:

- Using `response_mime_type` with the Gemini 1.5 Flash models
- Using `response_mime_type` and `response_schema` with the Gemini 1.5 Pro models
- Using controlled generation in use cases requiring output constraints


## Get started

### Install Vertex AI SDK and other required packages


In [None]:
%pip install --upgrade --user --quiet google-cloud-aiplatform

### Restart runtime

To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel.

The restart might take a minute or longer. After it's restarted, continue to the next step.

In [None]:
import IPython

app = IPython.Application.instance()
app.kernel.do_shutdown(True)

<div class="alert alert-block alert-warning">
<b>⚠️ The kernel is going to restart. Wait until it's finished before continuing to the next step. ⚠️</b>
</div>


### Authenticate your notebook environment (Colab only)

If you're running this notebook on Google Colab, run the cell below to authenticate your environment.

In [None]:
import sys

if "google.colab" in sys.modules:
    from google.colab import auth

    auth.authenticate_user()

### Set Google Cloud project information and initialize Vertex AI SDK

To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).

Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment).

In [None]:
PROJECT_ID = "[your-project-id]"  # @param {type:"string"}
LOCATION = "us-central1"  # @param {type:"string"}

import vertexai

vertexai.init(project=PROJECT_ID, location=LOCATION)

## Code Examples

### Import libraries

In [None]:
import json

from vertexai import generative_models
from vertexai.generative_models import GenerationConfig, GenerativeModel, Part

### Using `response_mime_type` with the Gemini 1.5 Flash models

You can have the model output in certain format by setting the `response_mime_type` configuration option in `generation_config`, and in the prompt, describe the format you want in response.

In [None]:
model = GenerativeModel(
    model_name="gemini-1.5-flash",
    generation_config={"response_mime_type": "application/json"},
)

In the prompt, describe the format you want in response.

In [None]:
prompt = """
    List a few popular cookie recipes using this JSON schema:
    Recipe = {"recipe_name": str}
    Return: list[Recipe]
"""

Generate the content and parse the response string to JSON.

In [6]:
response = model.generate_content(prompt)

json_response = json.loads(response.text)
print(json_response)

[{'recipe_name': 'Chocolate Chip Cookies'}, {'recipe_name': 'Oatmeal Raisin Cookies'}, {'recipe_name': 'Snickerdoodles'}, {'recipe_name': 'Sugar Cookies'}, {'recipe_name': 'Peanut Butter Cookies'}]


### Using `response_mime_type` and `response_schema` with the Gemini 1.5 Pro models

While Gemini 1.5 Flash models only accept a text description of the schema you want returned, the Gemini 1.5 Pro models let you pass a data structure in the `response_schema` parameter in `generation_config`, and the model output will strictly follow that schema.

Note that when `response_schema` is specified, the `response_mime_type` has to be set to `application/json`.

In [None]:
model = GenerativeModel("gemini-1.5-pro")

Following the previous example, define the data structure for the model output. Note that all of the fields in the JSON are optional by default unless specified in the `required` field.

In [8]:
response_schema = {
    "type": "ARRAY",
    "items": {
        "type": "OBJECT",
        "properties": {
            "recipe_name": {
                "type": "STRING",
            },
        },
        "required": ["recipe_name"],
    },
}

When prompting the model to generate the content, pass the schema to the `response_schema` field of the `generation_config`.

In [9]:
response = model.generate_content(
    "List a few popular cookie recipes",
    generation_config=GenerationConfig(
        response_mime_type="application/json", response_schema=response_schema
    ),
)

print(response.text)

[{"recipe_name": "Classic Chocolate Chip Cookies"}, {"recipe_name": "Peanut Butter Cookies"}, {"recipe_name": "Snickerdoodles"}, {"recipe_name": "Oatmeal Raisin Cookies"}, {"recipe_name": "Shortbread Cookies"}] 


You can parse the response string to JSON.

In [10]:
json_response = json.loads(response.text)
print(json_response)

[{'recipe_name': 'Classic Chocolate Chip Cookies'}, {'recipe_name': 'Peanut Butter Cookies'}, {'recipe_name': 'Snickerdoodles'}, {'recipe_name': 'Oatmeal Raisin Cookies'}, {'recipe_name': 'Shortbread Cookies'}]


### Using controlled generation in use cases requiring output constraints

Controlled generation can be used to ensure that model outputs adhere to a specific structure (e.g., JSON), instruct the model to perform pure multiple choices (e.g., sentiment classification), or follow certain style or guidelines.

Let's use controlled generation with the Gemini 1.5 Pro models in the following use cases that require output constraints.

In [None]:
model = GenerativeModel("gemini-1.5-pro")

#### **Example**: Generate game character profile

In this example, you instruct the model to create a game character profile with some specific requirements, and constraint the model output to a structured format. This example also demonstrates how to configure the `response_schema` and `response_mime_type` fields in `generative_config` in conjunction with `safety_settings`.

In [12]:
response_schema = {
    "type": "ARRAY",
    "items": {
        "type": "OBJECT",
        "properties": {
            "name": {"type": "STRING"},
            "age": {"type": "INTEGER"},
            "occupation": {"type": "STRING"},
            "background": {"type": "STRING"},
            "playable": {"type": "BOOLEAN"},
            "children": {
                "type": "ARRAY",
                "items": {
                    "type": "OBJECT",
                    "properties": {
                        "name": {"type": "STRING"},
                        "age": {"type": "INTEGER"},
                    },
                    "required": ["name", "age"],
                },
            },
        },
        "required": ["name", "age", "occupation", "children"],
    },
}

prompt = """
    Generate a character profile for a video game, including the character's name, age, occupation, background, names of their
    three children, and whether they can be controlled by the player.
"""

response = model.generate_content(
    prompt,
    generation_config=GenerationConfig(
        response_mime_type="application/json", response_schema=response_schema
    ),
    safety_settings={
        generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
        generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_ONLY_HIGH,
        generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
        generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_NONE,
    },
)

print(response.text)

 [{
        "age": 42,
        "children": [
          {
            "age": 21,
            "name": "Merida"
          },
          {
            "age": 18,
            "name": "Fergus"
          },
          {
            "age": 18,
            "name": "Harris"
          }
        ],
        "name": "Eleanor",
        "occupation": "Queen",
        "background": "Eleanor, the beloved ruler of a prosperous kingdom, is known for her wisdom, grace, and unwavering strength.  After the untimely death of her husband, she has successfully navigated countless challenges, earning her the admiration of both her people and neighboring rulers.  However, a new threat emerges, one that will test Eleanor's mettle and force her to confront her past",
        "playable": false
      },
      {
        "age": 25,
        "children": [],
        "name": "Kaelen",
        "occupation": "Hunter",
        "background": "Kaelen is a skilled hunter and tracker who lives off the land, relying on his instincts

#### **Example**: Extract errors from log data

In this example, you use the model to pull out specific error messages from unstructured log data, extract key information, and constraint the model output to a structured format.


In [13]:
response_schema = {
    "type": "ARRAY",
    "items": {
        "type": "OBJECT",
        "properties": {
            "timestamp": {"type": "STRING"},
            "error_code": {"type": "INTEGER"},
            "error_message": {"type": "STRING"},
        },
        "required": ["timestamp", "error_message", "error_code"],
    },
}

prompt = """
[15:43:28] ERROR: Could not process image upload: Unsupported file format. (Error Code: 308)
[15:44:10] INFO: Search index updated successfully.
[15:45:02] ERROR: Service dependency unavailable (payment gateway). Retrying... (Error Code: 5522)
[15:45:33] ERROR: Application crashed due to out-of-memory exception. (Error Code: 9001)
"""

response = model.generate_content(
    prompt,
    generation_config=GenerationConfig(
        response_mime_type="application/json", response_schema=response_schema
    ),
)

print(response.text)

[{"error_code": 308, "error_message": "Could not process image upload: Unsupported file format." , "timestamp": "15:43:28"}, {"error_code": 5522, "error_message": "Service dependency unavailable (payment gateway). Retrying..." , "timestamp": "15:45:02"}, {"error_code": 9001, "error_message": "Application crashed due to out-of-memory exception." , "timestamp": "15:45:33"}] 


#### **Example**: Analyze product review data

In this example, you instruct the model to analyze product review data, extract key entities, perform sentiment classification (multiple choices), provide additional explanation, and output the results in JSON format.

In [14]:
response_schema = {
    "type": "ARRAY",
    "items": {
        "type": "ARRAY",
        "items": {
            "type": "OBJECT",
            "properties": {
                "rating": {"type": "INTEGER"},
                "flavor": {"type": "STRING"},
                "sentiment": {
                    "type": "STRING",
                    "enum": ["POSITIVE", "NEGATIVE", "NEUTRAL"],
                },
                "explanation": {"type": "STRING"},
            },
            "required": ["rating", "flavor", "sentiment", "explanation"],
        },
    },
}

prompt = """
  Analyze the following product reviews, output the sentiment classification and give an explanation.
  
  - "Absolutely loved it! Best ice cream I've ever had." Rating: 4, Flavor: Strawberry Cheesecake
  - "Quite good, but a bit too sweet for my taste." Rating: 1, Flavor: Mango Tango
"""

response = model.generate_content(
    prompt,
    generation_config=GenerationConfig(
        response_mime_type="application/json", response_schema=response_schema
    ),
)

print(response.text)

[
  [
    {
      "explanation": "Strong positive sentiment with superlative language (\"best ever\")",
      "flavor": "Strawberry Cheesecake",
      "rating": 4,
      "sentiment": "POSITIVE"
    }
  ],
  [
    {
      "explanation": "Mixed sentiment - acknowledges positive aspects (\"quite good\") but expresses a negative preference (\"too sweet\")",
      "flavor": "Mango Tango",
      "rating": 1,
      "sentiment": "NEGATIVE"
    }
  ]
] 


#### Example: Detect objects in images

You can also use controlled generation in multimodality use cases. In this example, you instruct the model to detect objects in the images and output the results in JSON format. These images are stored in a Google Storage bucket.

- [office-desk.jpeg](https://storage.googleapis.com/cloud-samples-data/generative-ai/image/office-desk.jpeg)
- [gardening-tools.jpeg](https://storage.googleapis.com/cloud-samples-data/generative-ai/image/gardening-tools.jpeg)

In [13]:
response_schema = {
    "type": "ARRAY",
    "items": {
        "type": "ARRAY",
        "items": {
            "type": "OBJECT",
            "properties": {
                "object": {"type": "STRING"},
            },
        },
    },
}

prompt = "Generate a list of objects in the images."

response = model.generate_content(
    [
        Part.from_uri(
            "gs://cloud-samples-data/generative-ai/image/office-desk.jpeg",
            "image/jpeg",
        ),
        Part.from_uri(
            "gs://cloud-samples-data/generative-ai/image/gardening-tools.jpeg",
            "image/jpeg",
        ),
        prompt,
    ],
    generation_config=GenerationConfig(
        response_mime_type="application/json", response_schema=response_schema
    ),
)

print(response.text)

[
    [{"object": "globe"}, {"object": "tablet"}, {"object": "shopping cart"}, {"object": "eiffel tower"}, {"object": "airplane"}, {"object": "passport"}, {"object": "keyboard"}, {"object": "computer mouse"}, {"object": "sunglasses"}, {"object": "money"}, {"object": "notebook"}, {"object": "pen"}, {"object": "coffee cup"}],
    [{"object": "watering can"}, {"object": "plant"}, {"object": "flower pot"}, {"object": "flower pot"}, {"object": "garden gloves"}, {"object": "garden trowel"}, {"object": "garden hand tool"}]
] 
