diff --git a/generative_ai/function_calling.py b/generative_ai/function_calling.py index 3435a41d5997..aae65a2cc5d5 100644 --- a/generative_ai/function_calling.py +++ b/generative_ai/function_calling.py @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,15 +13,21 @@ # limitations under the License. # [START aiplatform_gemini_function_calling] +import vertexai from vertexai.generative_models import ( + Content, FunctionDeclaration, GenerativeModel, + Part, Tool, ) -def generate_function_call(prompt: str) -> str: - # Load the Vertex AI Gemini API to use function calling +def generate_function_call(prompt: str, project_id: str, location: str) -> str: + # Initialize Vertex AI + vertexai.init(project=project_id, location=location) + + # Initialize Gemini model model = GenerativeModel("gemini-1.0-pro") # Specify a function declaration and parameters for an API request @@ -40,20 +46,62 @@ def generate_function_call(prompt: str) -> str: function_declarations=[get_current_weather_func], ) - # Prompt to ask the model about weather, which will invoke the Tool - prompt = prompt + # Define the user's prompt in a Content object that we can reuse in model calls + user_prompt_content = Content( + role="user", + parts=[ + Part.from_text(prompt), + ], + ) - # Instruct the model to generate content using the Tool that you just created: + # Send the prompt and instruct the model to generate content using the Tool that you just created response = model.generate_content( - prompt, + user_prompt_content, generation_config={"temperature": 0}, tools=[weather_tool], ) + response_function_call_content = response.candidates[0].content - return str(response) + # Check the function name that the model responded with, and make an API call to an external system + if ( + response.candidates[0].content.parts[0].function_call.name + == "get_current_weather" + ): + # Extract the arguments to use in your API call + location = ( + response.candidates[0].content.parts[0].function_call.args["location"] + ) + # Here you can use your preferred method to make an API request to fetch the current weather, for example: + # api_response = requests.post(weather_api_url, data={"location": location}) -# [END aiplatform_gemini_function_calling] + # In this example, we'll use synthetic data to simulate a response payload from an external API + api_response = """{ "location": "Boston, MA", "temperature": 38, "description": "Partly Cloudy", + "icon": "partly-cloudy", "humidity": 65, "wind": { "speed": 10, "direction": "NW" } }""" + + # Return the API response to Gemini so it can generate a model response or request another function call + response = model.generate_content( + [ + user_prompt_content, # User prompt + response_function_call_content, # Function call response + Content( + role="function", + parts=[ + Part.from_function_response( + name="get_current_weather", + response={ + "content": api_response, # Return the API response to Gemini + }, + ) + ], + ), + ], + tools=[weather_tool], + ) + # Get the model summary response + summary = response.candidates[0].content.parts[0].text + + return summary, response -if __name__ == "__main__": - print(generate_function_call("What is the weather like in Boston?")) + +# [END aiplatform_gemini_function_calling] diff --git a/generative_ai/function_calling_test.py b/generative_ai/function_calling_test.py index 177f17c22673..670c65cd7330 100644 --- a/generative_ai/function_calling_test.py +++ b/generative_ai/function_calling_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -24,19 +24,26 @@ _LOCATION = "us-central1" -function_expected_responses = [ - "function_call", - "get_current_weather", - "args", - "fields", - "location", +summary_expected = [ + "Boston", +] + +response_expected = [ + "candidates", + "content", + "role", + "model", + "parts", "Boston", ] @backoff.on_exception(backoff.expo, ResourceExhausted, max_time=10) -def test_interview() -> None: - content = function_calling.generate_function_call( - prompt="What is the weather like in Boston?" +def test_function_calling() -> None: + summary, response = function_calling.generate_function_call( + prompt="What is the weather like in Boston?", + project_id=_PROJECT_ID, + location=_LOCATION, ) - assert all(x in content for x in function_expected_responses) + assert all(x in str(summary) for x in summary_expected) + assert all(x in str(response) for x in response_expected)