In [1]:
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Getting Started with the Live API in Vertex AI using WebSockets



## Overview

The Live API enables low-latency bidirectional voice and video interactions with Gemini. The API can process text, audio, and video input, and it can provide text and audio output. 

See the [Live API](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/multimodal-live) page for more details.

This tutorial demonstrates the following simple examples to help you get started with the Live API in Vertex AI using [WebSockets](https://en.wikipedia.org/wiki/WebSocket).

- Using Gemini 2.0 Flash
  - Text-to-text generation
  - Text-to-audio generation
  - Text-to-audio conversation
  - Function calling
  - Code execution
  - Audio transcription
- Using Gemini 2.5 Flash native audio dialog
  - Proactive audio
  - Affective dialog

## Getting Started

### Install libraries

In [2]:
%pip install --upgrade --quiet websockets

Note: you may need to restart the kernel to use updated packages.


### Authenticate your notebook environment (Colab only)

If you are running this notebook on Google Colab, run the cell below to authenticate your environment.

In [3]:
import sys

if "google.colab" in sys.modules:
    from google.colab import auth

    auth.authenticate_user()

### Set Google Cloud project information

To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com). Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment).

In [4]:
# PROJECT_ID = "cloud-llm-preview1"  # @param {type: "string"}
PROJECT_ID = "my-project-0004-346516"  # @param {type: "string", placeholder: "[your-project-id]", isTemplate: true}

LOCATION = "us-central1"  # @param {type: "string"}

HOST = "us-central1-aiplatform.googleapis.com"
SERVICE_URL = (
    f"wss://{HOST}/ws/google.cloud.aiplatform.v1.LlmBidiService/BidiGenerateContent"
)

## Generate an access token

`gcloud auth application-default print-access-token` generates and prints an access token for the current Application Default Credential. The default access token lifetime is 3600 seconds. This access token will be used to connect to the WebSocket server.

In [5]:
bearer_token = !gcloud auth application-default print-access-token

### Import libraries


In [6]:
import base64
import json

from IPython.display import Audio, Markdown, display
import numpy as np
from websockets.asyncio.client import connect

## Using the Gemini 2.0 Flash

Live API is a new capability introduced with the [Gemini 2.0 Flash model](https://cloud.google.com/vertex-ai/generative-ai/docs/gemini-v2).

In [12]:
# MODEL_ID = "gemini-2.0-flash-live-preview-04-09"  # @param {type: "string"}

MODEL_ID = "gemini-live-2.5-flash-preview-native-audio"

MODEL = (
    f"projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/{MODEL_ID}"
)

In [13]:
def get_current_weather(location: str) -> str:
    """Example method. Returns the current weather.

    Args:
        location: The city and state, e.g. San Francisco, CA
    """
    weather_map: dict[str, str] = {
        "Boston": "snowing",
        "San Francisco": "foggy",
        "Seattle": "raining",
        "Austin": "hot",
        "Chicago": "windy",
    }
    #print("Windy")
    #return "Weather in chicago is Windy"
    return weather_map.get(location, "unknown")

def get_current_stock_price(company: str) -> str:
    """Example method. Returns the stock prices.

    Args:
        company: Name of a company
    """
    stock_prices: dict[str, str] = {
        "Google": "200 USD",
    }
    return stock_prices.get(company, "unknown")


# Set model generation_config
CONFIG = {"response_modalities": ["TEXT"]}

# Define function declarations
TOOLS = {

    "google_search": {},
    
    "function_declarations": {
        "name": "get_current_weather",
        "description": "Get the current weather in the given location",
        "parameters": {
            "type": "OBJECT",
            "properties": {"location": {"type": "STRING"}},
        },
    },
    
    "function_declarations": {
        "name": "get_current_stock_price",
        "description": "Get the current stock price of give company",
        "parameters": {
            "type": "OBJECT",
            "properties": {"company": {"type": "STRING"}},
        },
    }
    
}

In [14]:
# Set model generation_config
GENERATION_CONFIG = {
    "response_modalities": ["AUDIO"],
}


headers = {
    "Content-Type": "application/json",
    "Authorization": f"Bearer {bearer_token[0]}",
}

# Connect to the server
async with connect(SERVICE_URL, additional_headers=headers) as ws:
    # Setup the session
    await ws.send(
        json.dumps(
            {
                "setup": {
                    "model": MODEL,
                    "generation_config": GENERATION_CONFIG,
                    "input_audio_transcription": {},
                    "output_audio_transcription": {},
                    "proactivity": {"proactive_audio": True},
                    "tools":TOOLS
                }
            }
        )
    )

    # Receive setup response
    raw_response = await ws.recv(decode=False)
    setup_response = json.loads(raw_response.decode("ascii"))

    # Send text message
    text_input = "Get the current weather in Chicago"  #should invoke function get_current_weather
    text_input = "Get current stock prices of Google"  #should invoke function get_current_stock_prices
    text_input = "Who is the current vice-president of India"  # should invoke Google Search,
    display(Markdown(f"**Input:** {text_input}"))

    msg = {
        "client_content": {
            "turns": [{"role": "user", "parts": [{"text": text_input}]}],
            "turn_complete": True,
        }
    }

    await ws.send(json.dumps(msg))

    responses = []
    input_transcriptions = []
    output_transcriptions = []

    # Receive chucks of server response
    async for raw_response in ws:
        response = json.loads(raw_response.decode())
        print(response)
    
        tool_call = response.pop("toolCall", None)
        if tool_call is not None:
            
            for function_call in tool_call["functionCalls"]: 
                output_transcriptions.append(f"FunctionCall: {str(function_call)}\n")
                if function_call["name"] == "get_current_weather":
                    result = get_current_weather(**function_call["args"])
                    output_transcriptions.append(result)
                if function_call["name"] == "get_current_stock_price":
                    result = get_current_stock_price(**function_call["args"])
                    output_transcriptions.append(result)

        
        server_content = response.pop("serverContent", None)
        if server_content is None:
            break

        if ( input_transcription := server_content.get("inputTranscription")) is not None:
            if (text := input_transcription.get("text")) is not None:
                input_transcriptions.append(text)
        if (
            output_transcription := server_content.get("outputTranscription")
        ) is not None:
            if (text := output_transcription.get("text")) is not None:
                output_transcriptions.append(text)

        model_turn = server_content.pop("modelTurn", None)
        if model_turn is not None:
            parts = model_turn.pop("parts", None)
            if parts is not None:
                for part in parts:
                    pcm_data = base64.b64decode(part["inlineData"]["data"])
                    responses.append(np.frombuffer(pcm_data, dtype=np.int16))

        # End of turn
        turn_complete = server_content.pop("turnComplete", None)
        if turn_complete:
            break

    if input_transcriptions:
        display(Markdown(f"**Input transcription >** {''.join(input_transcriptions)}"))

    if responses:
        # Play the returned audio message
        display(Audio(np.concatenate(responses), rate=24000, autoplay=True))

    if output_transcriptions:
        display(
            Markdown(f"**Output transcription >** {''.join(output_transcriptions)}")
        )

**Input:** Who is the current vice-president of India

{'serverContent': {'outputTranscription': {'finished': True}}}
{'serverContent': {'generationComplete': True}}
{'serverContent': {'turnComplete': True, 'groundingMetadata': {}}, 'usageMetadata': {'promptTokenCount': 9, 'totalTokenCount': 9, 'promptTokensDetails': [{'modality': 'TEXT', 'tokenCount': 9}]}}


In [15]:
# Set model generation_config
GENERATION_CONFIG = {
    "response_modalities": ["AUDIO"],
}


headers = {
    "Content-Type": "application/json",
    "Authorization": f"Bearer {bearer_token[0]}",
}

# Connect to the server
async with connect(SERVICE_URL, additional_headers=headers) as ws:
    # Setup the session
    await ws.send(
        json.dumps(
            {
                "setup": {
                    "model": MODEL,
                    "generation_config": GENERATION_CONFIG,
                    "input_audio_transcription": {},
                    "output_audio_transcription": {},
                    "proactivity": {"proactive_audio": True},
                    "tools":TOOLS
                }
            }
        )
    )

    # Receive setup response
    raw_response = await ws.recv(decode=False)
    setup_response = json.loads(raw_response.decode("ascii"))

    # Send text message
    #text_input = "Get the current weather in Chicago"  #should invoke function get_current_weather
    text_input = "Get current stock prices of Google"  #should invoke function get_current_stock_prices
    #text_input = "Who is the current vice-president of India"  # should invoke Google Search,
    display(Markdown(f"**Input:** {text_input}"))

    msg = {
        "client_content": {
            "turns": [{"role": "user", "parts": [{"text": text_input}]}],
            "turn_complete": True,
        }
    }

    await ws.send(json.dumps(msg))

    responses = []
    input_transcriptions = []
    output_transcriptions = []

    # Receive chucks of server response
    async for raw_response in ws:
        response = json.loads(raw_response.decode())
        print(response)
    
        tool_call = response.pop("toolCall", None)
        if tool_call is not None:
            
            for function_call in tool_call["functionCalls"]: 
                output_transcriptions.append(f"FunctionCall: {str(function_call)}\n")
                if function_call["name"] == "get_current_weather":
                    result = get_current_weather(**function_call["args"])
                    output_transcriptions.append(result)
                if function_call["name"] == "get_current_stock_price":
                    result = get_current_stock_price(**function_call["args"])
                    output_transcriptions.append(result)

        
        server_content = response.pop("serverContent", None)
        if server_content is None:
            break

        if ( input_transcription := server_content.get("inputTranscription")) is not None:
            if (text := input_transcription.get("text")) is not None:
                input_transcriptions.append(text)
        if (
            output_transcription := server_content.get("outputTranscription")
        ) is not None:
            if (text := output_transcription.get("text")) is not None:
                output_transcriptions.append(text)

        model_turn = server_content.pop("modelTurn", None)
        if model_turn is not None:
            parts = model_turn.pop("parts", None)
            if parts is not None:
                for part in parts:
                    pcm_data = base64.b64decode(part["inlineData"]["data"])
                    responses.append(np.frombuffer(pcm_data, dtype=np.int16))

        # End of turn
        turn_complete = server_content.pop("turnComplete", None)
        if turn_complete:
            break

    if input_transcriptions:
        display(Markdown(f"**Input transcription >** {''.join(input_transcriptions)}"))

    if responses:
        # Play the returned audio message
        display(Audio(np.concatenate(responses), rate=24000, autoplay=True))

    if output_transcriptions:
        display(
            Markdown(f"**Output transcription >** {''.join(output_transcriptions)}")
        )

**Input:** Get current stock prices of Google

{'toolCall': {'functionCalls': [{'name': 'get_current_stock_price', 'args': {'company': 'Google'}}]}}


**Output transcription >** FunctionCall: {'name': 'get_current_stock_price', 'args': {'company': 'Google'}}
200 USD

In [16]:
# Set model generation_config
GENERATION_CONFIG = {
    "response_modalities": ["AUDIO"],
}


headers = {
    "Content-Type": "application/json",
    "Authorization": f"Bearer {bearer_token[0]}",
}

# Connect to the server
async with connect(SERVICE_URL, additional_headers=headers) as ws:
    # Setup the session
    await ws.send(
        json.dumps(
            {
                "setup": {
                    "model": MODEL,
                    "generation_config": GENERATION_CONFIG,
                    "input_audio_transcription": {},
                    "output_audio_transcription": {},
                    "proactivity": {"proactive_audio": True},
                    "tools":TOOLS
                }
            }
        )
    )

    # Receive setup response
    raw_response = await ws.recv(decode=False)
    setup_response = json.loads(raw_response.decode("ascii"))

    # Send text message
    text_input = "Hello? Gemini are you there?"
    display(Markdown(f"**Input:** {text_input}"))

    msg = {
        "client_content": {
            "turns": [{"role": "user", "parts": [{"text": text_input}]}],
            "turn_complete": True,
        }
    }

    await ws.send(json.dumps(msg))

    responses = []
    input_transcriptions = []
    output_transcriptions = []

    # Receive chucks of server response
    async for raw_response in ws:
        response = json.loads(raw_response.decode())
        server_content = response.pop("serverContent", None)
        if server_content is None:
            break

        if (
            input_transcription := server_content.get("inputTranscription")
        ) is not None:
            if (text := input_transcription.get("text")) is not None:
                input_transcriptions.append(text)
        if (
            output_transcription := server_content.get("outputTranscription")
        ) is not None:
            if (text := output_transcription.get("text")) is not None:
                output_transcriptions.append(text)

        model_turn = server_content.pop("modelTurn", None)
        if model_turn is not None:
            parts = model_turn.pop("parts", None)
            if parts is not None:
                for part in parts:
                    pcm_data = base64.b64decode(part["inlineData"]["data"])
                    responses.append(np.frombuffer(pcm_data, dtype=np.int16))

        # End of turn
        turn_complete = server_content.pop("turnComplete", None)
        if turn_complete:
            break

    if input_transcriptions:
        display(Markdown(f"**Input transcription >** {''.join(input_transcriptions)}"))

    if responses:
        # Play the returned audio message
        display(Audio(np.concatenate(responses), rate=24000, autoplay=True))

    if output_transcriptions:
        display(
            Markdown(f"**Output transcription >** {''.join(output_transcriptions)}")
        )
     

**Input:** Hello? Gemini are you there?

**Output transcription >** Yes, I am here. How can I help you today?