In [18]:
import vertexai
from vertexai.generative_models import GenerativeModel, FunctionDeclaration, Tool, Content, Part


In [2]:
GOOGLE_PROJECT_ID = "enduring-coil-441309-i5"
LOCATION = "us-central1"


# Set up database schema and populate tables

In [None]:
!pip install ipython-sql

In [None]:
%load_ext sql

In [None]:
%sql sqlite:///sample.db

In [None]:
%%sql
create table if not exists products (
    product_id integer primary key autoincrement,
    product_name varchar(255) not null,
    price decimal(10, 2) not null
);

create table if not exists staff (
    staff_id integer primary key autoincrement,
    first_name varchar(255) not null,
    last_name varchar(255) not null
);

create table if not exists orders (
    order_id integer primary key autoincrement,
    customer_name varchar(255) not null,
    staff_id integer not null,
    product_id integer not null,
    foreign key (staff_id) references staff (staff_id),
    foreign key (product_id) references products (product_id)
);

-- populate

insert into products(product_name, price) values
  	('Laptop', 799.99),
  	('Keyboard', 129.99),
  	('Mouse', 29.99);

INSERT INTO orders (customer_name, staff_id, product_id) VALUES
  	('David Lee', 1, 1),
  	('Emily Chen', 2, 2),
  	('Frank Brown', 1, 3);

In [3]:
import sqlite3

db_file = "sample.db"

db_conn = sqlite3.connect(db_file)

# Create API functions, calling database

In [4]:
def list_tables() -> list[str]:
    """Retrieve the names of all tables in the database."""
    cursor = db_conn.cursor()
    cursor.execute("select * from sqlite_master where type = 'table';")
    table = cursor.fetchall()
    return [t[0] for t in table]


def describe_table(table_name: str) -> list[tuple[str, str]]:
    """Look up the table schema. Returns list of column names, where each entry is a tuple of (column, type)."""
    cursor = db_conn.cursor()
    cursor.execute(f"pragma table_info({table_name});")
    schema = cursor.fetchall()
    return [(col[1], col[2]) for col in schema]


def execute_query(sql: str) -> list[list[str]]:
    """Execute a SELECT statement, returning the results."""
    cursor = db_conn.cursor()
    cursor.execute(sql)
    return cursor.fetchall()


# Set up model for querying

In [5]:
list_tables_func = FunctionDeclaration(
    name="list_tables",
    description="Retrieve the names of all tables in the database.",
    parameters={
        "type": "object",
        "properties": {}
        },
)

describe_table_func = FunctionDeclaration(
    name="describe_table",
    description="Look up the table schema. Returns list of columns, where each entry is a tuple of (column, type).",
    parameters={
        "type": "object",
        "properties": {
            "table_name": {
                "type": "string",
                "description": "Name of the table."
            }
        },
        "required": [
            "table_name"
        ]
    },
)

execute_query_func = FunctionDeclaration(
    name="execute_query",
    description="Execute a SELECT statement, returning the results.",
    parameters={
        "type": "object",
        "properties": {
            "sql": {
                "type": "string",
                "description": "SQL query to the database."
            }
        },
        "required": [
            "sql"
        ]
    },
)


list_tables_func = FunctionDeclaration.from_func(list_tables)
describe_table_func = FunctionDeclaration.from_func(describe_table)
execute_query_func = FunctionDeclaration.from_func(execute_query)

db_tool = vertexai.generative_models.Tool(
    function_declarations = [list_tables_func, describe_table_func, execute_query_func]
)

instruction = """You are a helpful chatbot that can interact with an SQL database for a computer
store. You will take the users questions and turn them into SQL queries using the tools
available. Once you have the information you need, you will answer the user's question using
the data returned. Use list_tables to see what tables are present, describe_table to understand
the schema, and execute_query to issue an SQL SELECT query."""

vertexai.init(project=GOOGLE_PROJECT_ID,
              location=LOCATION)

model = GenerativeModel("gemini-1.5-flash-001",
                        tools = [db_tool],
                        system_instruction = instruction)

chat = model.start_chat()

# The problem 

In [10]:
query = "What is the most expensive product?" 

resp = chat.send_message(content=query,
                         tools=[db_tool])
function_calls = resp.candidates[0].function_calls
func = function_calls[-1] # the last query is a clue to retrieve user's answer


In [11]:
resp

candidates {
  content {
    role: "model"
    parts {
      function_call {
        name: "list_tables"
        args {
        }
      }
    }
    parts {
      text: "\n\n"
    }
    parts {
      function_call {
        name: "describe_table"
        args {
          fields {
            key: "table_name"
            value {
              string_value: "products"
            }
          }
        }
      }
    }
    parts {
      text: "\n\n"
    }
    parts {
      function_call {
        name: "execute_query"
        args {
          fields {
            key: "sql"
            value {
              string_value: "SELECT name FROM products ORDER BY price DESC LIMIT 1"
            }
          }
        }
      }
    }
  }
  finish_reason: STOP
  safety_ratings {
    category: HARM_CATEGORY_HATE_SPEECH
    probability: NEGLIGIBLE
    probability_score: 0.171875
    severity: HARM_SEVERITY_NEGLIGIBLE
    severity_score: 0.111328125
  }
  safety_ratings {
    category: HARM_CATEGORY_DA

In [17]:
func_evaluation = func.args["sql"].replace("name", "product_name")
func_evaluation

'SELECT product_name FROM products ORDER BY price DESC LIMIT 1'

In [14]:
sql_response = execute_query(func_evaluation)
sql_response

[('Laptop',)]

In [16]:
api_response = {"text": "; ".join([", ".join(items) for items in sql_response])}
api_response

{'text': 'Laptop'}

In [19]:
user_response = chat.send_message(
    Content(
            role = "user",
            parts = [Part.from_function_response(name = func.name, response = {"content": api_response})]
    )
)

InvalidArgument: 400 Please ensure that the number of function response parts should be equal to number of function call parts of the function call turn.