-
Notifications
You must be signed in to change notification settings - Fork 6.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add Vertex AI Grounding samples (#11246)
* feat: Add Vertex AI Grounding samples * 🦉 Updates from OwlBot post-processor See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md * Fix lint errors & test errors * Fixed lint error * Removed unused import * Fix lint error * Fixed import order * Change Gemini Model Name to `gemini-1.0-pro` * Update copyright to 2024 --------- Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
- Loading branch information
1 parent
9855e23
commit d426ed8
Showing
4 changed files
with
171 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# [START aiplatform_gemini_grounding] | ||
from typing import Optional | ||
|
||
import vertexai | ||
from vertexai.preview.generative_models import ( | ||
GenerationResponse, | ||
GenerativeModel, | ||
grounding, | ||
Tool, | ||
) | ||
|
||
|
||
def generate_text_with_grounding( | ||
project_id: str, location: str, data_store_path: Optional[str] = None | ||
) -> GenerationResponse: | ||
# Initialize Vertex AI | ||
vertexai.init(project=project_id, location=location) | ||
|
||
# Load the model | ||
model = GenerativeModel(model_name="gemini-1.0-pro") | ||
|
||
# Create Tool for grounding | ||
if data_store_path: | ||
# Use Vertex AI Search data store | ||
# Format: projects/{project_id}/locations/{location}/collections/default_collection/dataStores/{data_store_id} | ||
tool = Tool.from_retrieval( | ||
grounding.Retrieval(grounding.VertexAISearch(datastore=data_store_path)) | ||
) | ||
else: | ||
# Use Google Search for grounding (Private Preview) | ||
tool = Tool.from_google_search_retrieval(grounding.GoogleSearchRetrieval()) | ||
|
||
prompt = "What are the price, available colors, and storage size options of a Pixel Tablet?" | ||
response = model.generate_content(prompt, tools=[tool]) | ||
|
||
print(response) | ||
|
||
# [END aiplatform_gemini_grounding] | ||
return response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
# [START aiplatform_sdk_grounding] | ||
from typing import Optional | ||
|
||
import vertexai | ||
from vertexai.language_models import ( | ||
GroundingSource, | ||
TextGenerationModel, | ||
TextGenerationResponse, | ||
) | ||
|
||
|
||
def grounding( | ||
project_id: str, | ||
location: str, | ||
data_store_location: Optional[str], | ||
data_store_id: Optional[str], | ||
) -> TextGenerationResponse: | ||
"""Grounding example with a Large Language Model""" | ||
|
||
vertexai.init(project=project_id, location=location) | ||
|
||
# TODO developer - override these parameters as needed: | ||
parameters = { | ||
"temperature": 0.7, # Temperature controls the degree of randomness in token selection. | ||
"max_output_tokens": 256, # Token limit determines the maximum amount of text output. | ||
"top_p": 0.8, # Tokens are selected from most probable to least until the sum of their probabilities equals the top_p value. | ||
"top_k": 40, # A top_k of 1 means the selected token is the most probable among all tokens. | ||
} | ||
|
||
model = TextGenerationModel.from_pretrained("text-bison@002") | ||
|
||
if data_store_id and data_store_location: | ||
# Use Vertex AI Search data store | ||
grounding_source = GroundingSource.VertexAISearch( | ||
data_store_id=data_store_id, location=data_store_location | ||
) | ||
else: | ||
# Use Google Search for grounding (Private Preview) | ||
grounding_source = GroundingSource.WebSearch() | ||
|
||
response = model.predict( | ||
"What are the price, available colors, and storage size options of a Pixel Tablet?", | ||
grounding_source=grounding_source, | ||
**parameters, | ||
) | ||
print(f"Response from Model: {response.text}") | ||
print(f"Grounding Metadata: {response.grounding_metadata}") | ||
# [END aiplatform_sdk_grounding] | ||
|
||
return response | ||
|
||
|
||
if __name__ == "__main__": | ||
grounding() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Copyright 2024 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
|
||
import backoff | ||
from google.api_core.exceptions import ResourceExhausted | ||
|
||
import grounding | ||
|
||
|
||
_PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT") | ||
_LOCATION = "us-central1" | ||
|
||
|
||
@backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10) | ||
def test_grounding() -> None: | ||
data_store_id = "test-search-engine_1689960780551" | ||
response = grounding.grounding( | ||
project_id=_PROJECT_ID, | ||
location=_LOCATION, | ||
data_store_location="global", | ||
data_store_id=data_store_id, | ||
) | ||
assert response | ||
assert response.text | ||
assert response.grounding_metadata |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters