In [1]:
!pip install ipython-sql




[notice] A new release of pip is available: 24.0 -> 24.3.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [1]:
import vertexai
from vertexai.generative_models import GenerativeModel, FunctionDeclaration, Tool

from dotenv import dotenv_values

import sqlite3


In [2]:
%load_ext sql

In [3]:
%sql sqlite:///sample.db

In [4]:
%%sql
create table if not exists products (
    product_id integer primary key autoincrement,
    product_name varchar(255) not null,
    price decimal(10, 2) not null
);

create table if not exists staff (
    staff_id integer primary key autoincrement,
    first_name varchar(255) not null,
    last_name varchar(255) not null
);

create table if not exists orders (
    order_id integer primary key autoincrement,
    customer_name varchar(255) not null,
    staff_id integer not null,
    product_id integer not null,
    foreign key (staff_id) references staff (staff_id),
    foreign key (product_id) references products (product_id)
);

-- populate

insert into products(product_name, price) values
  	('Laptop', 799.99),
  	('Keyboard', 129.99),
  	('Mouse', 29.99);

INSERT INTO orders (customer_name, staff_id, product_id) VALUES
  	('David Lee', 1, 1),
  	('Emily Chen', 2, 2),
  	('Frank Brown', 1, 3);

 * sqlite:///sample.db
Done.
Done.
Done.
3 rows affected.
3 rows affected.


[]

In [1]:
db_file = "sample.db"

db_conn = sqlite3.connect(db_file)

In [2]:
def list_tables() -> list[str]:
    """Retrieve the names of all tables in the database."""
    cursor = db_conn.cursor()
    cursor.execute("select * from sqlite_master where type = 'table';")
    table = cursor.fetchall()
    return [t[0] for t in table]


In [7]:
list_tables()

['table', 'table', 'table', 'table']

In [3]:
def describe_table(table_name: str) -> list[tuple[str, str]]:
    """Look up the table schema. Returns list of column names, where each entry is a tuple of (column, type)."""
    cursor = db_conn.cursor()
    cursor.execute(f"pragma table_info({table_name});")
    schema = cursor.fetchall()
    return [(col[1], col[2]) for col in schema]




In [9]:
describe_table("products")

[('product_id', 'INTEGER'),
 ('product_name', 'varchar(255)'),
 ('price', 'decimal(10, 2)')]

In [4]:
def execute_query(sql: str) -> list[list[str]]:
    """Execute a SELECT statement, returning the results."""
    cursor = db_conn.cursor()
    cursor.execute(sql)
    return cursor.fetchall()


In [11]:
execute_query("select * from products;")

[(1, 'Laptop', 799.99),
 (2, 'Keyboard', 129.99),
 (3, 'Mouse', 29.99),
 (4, 'Laptop', 799.99),
 (5, 'Keyboard', 129.99),
 (6, 'Mouse', 29.99),
 (7, 'Laptop', 799.99),
 (8, 'Keyboard', 129.99),
 (9, 'Mouse', 29.99)]

In [6]:
config = dotenv_values(".env")

In [7]:
list_tables_func = FunctionDeclaration(
    name="list_tables",
    description="Retrieve the names of all tables in the database.",
    parameters={
        "type": "object",
        "properties": {}
        },
)

describe_table_func = FunctionDeclaration(
    name="describe_table",
    description="Look up the table schema. Returns list of columns, where each entry is a tuple of (column, type).",
    parameters={
        "type": "object",
        "properties": {
            "table_name": {
                "type": "string",
                "description": "Name of the table."
            }
        },
        "required": [
            "table_name"
        ]
    },
)

execute_query_func = FunctionDeclaration(
    name="execute_query",
    description="Execute a SELECT statement, returning the results.",
    parameters={
        "type": "object",
        "properties": {
            "sql": {
                "type": "string",
                "description": "SQL query to the database."
            }
        },
        "required": [
            "sql"
        ]
    },
)

import vertexai
from vertexai.generative_models import GenerativeModel, FunctionDeclaration, Tool, Content, Part

from dotenv import dotenv_values
config = dotenv_values(".env")

list_tables_func = FunctionDeclaration.from_func(list_tables)
describe_table_func = FunctionDeclaration.from_func(describe_table)
execute_query_func = FunctionDeclaration.from_func(execute_query)

db_tool = vertexai.generative_models.Tool(
    function_declarations = [list_tables_func, describe_table_func, execute_query_func]
)

instruction = """You are a helpful chatbot that can interact with an SQL database for a computer
store. You will take the users questions and turn them into SQL queries using the tools
available. Once you have the information you need, you will answer the user's question using
the data returned. Use list_tables to see what tables are present, describe_table to understand
the schema, and execute_query to issue an SQL SELECT query."""

vertexai.init(project=config["GOOGLE_PROJECT_ID"],
              location=config["LOCATION"])

model = GenerativeModel("gemini-1.5-flash-001",
                        tools = [db_tool],
                        system_instruction = instruction)

chat = model.start_chat()
# query = "What is the most expensive product?" 

# resp = chat.send_message(content=query,
#                          tools=[db_tool])

# function_calls = resp.candidates[0].function_calls
# function to give out what i need
# function_call = function_calls[-1]
# print(function_call)
"""

"""
# sql_response = execute_query(function_calls[-1].args.get("sql"))
# RESPONSE output
# [(1, 'Laptop', 799.99),
#  (4, 'Laptop', 799.99),
#  (2, 'Keyboard', 129.99),
#  (5, 'Keyboard', 129.99),
#  (3, 'Mouse', 29.99)]

# creating non-techie person readable response, generated from gemini, feeding SQL-query, function from provided API
# for func_call in function_calls:
#     print(func_call.name)
#     function_name = func_call.name

#     if function_name == "list_tables":
#         function = list_tables
#         sql_response = function()
#         tables = ", ".join(sql_response)
#         api_response = {
#             "tables": tables
#         }
#     elif function_name == "describe_table":
#         function = describe_table
#         sql_response = function(func_call.args.get("table_name"))
#         columns_list = ", ".join([item[0] for item in sql_response])
#         api_response = {
#             "columns": columns_list
#         }
#     elif function_name == "execute_query":
#         function = execute_query
#         sql_response = function(func_call.args.get("sql"))
#         products = ", ".join([item[1] for item in sql_response])
#         api_response = {
#             "product_list": products
#         }

'\n\n'

In [14]:
query = "What is the most expensive product?" 

resp = chat.send_message(content=query,
                         tools=[db_tool])
function_calls = resp.candidates[0].function_calls
func = function_calls[-1] # the last query is a clue to retrieve user's answer


In [15]:
resp

candidates {
  content {
    role: "model"
    parts {
      function_call {
        name: "list_tables"
        args {
        }
      }
    }
    parts {
      text: "\n"
    }
    parts {
      function_call {
        name: "describe_table"
        args {
          fields {
            key: "table_name"
            value {
              string_value: "products"
            }
          }
        }
      }
    }
    parts {
      text: "\n"
    }
    parts {
      function_call {
        name: "execute_query"
        args {
          fields {
            key: "sql"
            value {
              string_value: "SELECT name FROM products ORDER BY price DESC LIMIT 1"
            }
          }
        }
      }
    }
  }
  finish_reason: STOP
  safety_ratings {
    category: HARM_CATEGORY_HATE_SPEECH
    probability: NEGLIGIBLE
    probability_score: 0.168945312
    severity: HARM_SEVERITY_NEGLIGIBLE
    severity_score: 0.117675781
  }
  safety_ratings {
    category: HARM_CATEGORY_DAN

In [16]:
function_calls = resp.candidates[0].function_calls

In [17]:
len(function_calls)

3

In [18]:
func = function_calls[-1]

In [25]:
describe_table("products")

[('product_id', 'INTEGER'),
 ('product_name', 'varchar(255)'),
 ('price', 'decimal(10, 2)')]

In [19]:
func

name: "execute_query"
args {
  fields {
    key: "sql"
    value {
      string_value: "SELECT name FROM products ORDER BY price DESC LIMIT 1"
    }
  }
}

In [28]:
func.args["sql"].replace("name", "product_name")

'SELECT product_name FROM products ORDER BY price DESC LIMIT 1'

In [29]:
func.args["sql"]

'SELECT name FROM products ORDER BY price DESC LIMIT 1'

In [20]:
func_evaluation = "SELECT product_name FROM products ORDER BY price DESC LIMIT 1"

In [44]:
describe_table("product_name")

[('product_id', 'INTEGER'),
 ('product_name', 'varchar(255)'),
 ('price', 'decimal(10, 2)')]

In [21]:
sql_response = execute_query(func_evaluation)
sql_response

[('Laptop',)]

In [22]:
api_response = {"text": "; ".join([", ".join(items) for items in sql_response])}

In [23]:
api_response

{'text': 'Laptop'}

In [50]:
func

name: "execute_query"
args {
  fields {
    key: "sql"
    value {
      string_value: "SELECT name, price FROM products ORDER BY price DESC LIMIT 1"
    }
  }
}

In [24]:
user_response = chat.send_message(
    Content(
            role = "user",
            parts = [Part.from_function_response(name = func.name, response = {"content": api_response})]
    )
)

InvalidArgument: 400 Please ensure that the number of function response parts should be equal to number of function call parts of the function call turn.

In [None]:
from google.api_core import retry


In [None]:
def retry_with_policy(func):
    def wrapper(*args, **kwargs):
        retry_policy = retry.Retry(predicate=retry.if_transient_error)
        return retry_policy(func)(*args, **kwargs)
    return wrapper

@retry_with_policy
def send_message_with_retry(chat, message, *, tools):
    chat.send_message(message, tools=tools)

In [None]:
chat = model.start_chat()

In [None]:
query = "What are top expensive products?"

In [None]:
# resp = send_message_with_retry(chat, "What is the cheapest product?")
resp = chat.send_message(content=query,
                         tools=[db_tool])

In [None]:
resp.candidates[0].content

role: "model"
parts {
  function_call {
    name: "list_tables"
    args {
    }
  }
}
parts {
  text: "\n"
}
parts {
  function_call {
    name: "describe_table"
    args {
      fields {
        key: "table_name"
        value {
          string_value: "products"
        }
      }
    }
  }
}
parts {
  text: "\n"
}
parts {
  function_call {
    name: "execute_query"
    args {
      fields {
        key: "sql"
        value {
          string_value: "SELECT * FROM products ORDER BY price DESC LIMIT 5"
        }
      }
    }
  }
}

In [None]:
function_calls = resp.candidates[0].function_calls

In [None]:
from vertexai.generative_models import Content, Part

In [None]:
execute_query(function_calls[-1].args.get("sql"))

[(1, 'Laptop', 799.99),
 (4, 'Laptop', 799.99),
 (2, 'Keyboard', 129.99),
 (5, 'Keyboard', 129.99),
 (3, 'Mouse', 29.99)]

In [None]:
function_handler = {
    "list_tables": list_tables,
    "describe_table": describe_table,
    "execute_query": execute_query
}

# function_parameters = {
#     "list_tables": None,
#     "describe_table": "table_name",
#     "execute_query": execute_query
# }

# for func_call in function_calls:
#     print(func_call.name)
#     function_name = func_call.name

#     if function_name == "list_tables":
#         function = list_tables
#         sql_response = function()
#         tables = ", ".join(sql_response)
#         api_response = {
#             "tables": tables
#         }
#     elif function_name == "describe_table":
#         function = describe_table
#         sql_response = function(func_call.args.get("table_name"))
#         columns_list = ", ".join([item[0] for item in sql_response])
#         api_response = {
#             "columns": columns_list
#         }
#     elif function_name == "execute_query":
#         function = execute_query
#         sql_response = function(func_call.args.get("sql"))
#         products = ", ".join([item[1] for item in sql_response])
#         api_response = {
#             "product_list": products
#         }

        # print(api_response)
        # Create a separate Content object for each function call
    # content = Content(parts=[Part.from_function_response(name=func_call.name, response={"content": api_response})])
    # Use the generated content and content object in your prompt
    # response = model.generate_content(
    #     [
    #         query,
    #         resp.candidates[0].content,
    #         content,
    #     ],
    #     tools=[db_tool],
    # )
    # response = chat.send_message(
    #     Part.from_function_response(name = function_name, response={"content": api_response})
    # )
    # print(response.text)

list_tables


InvalidArgument: 400 Please ensure that the number of function response parts should be equal to number of function call parts of the function call turn.

In [None]:
api_response

[(1, 'Laptop', 799.99), (4, 'Laptop', 799.99), (2, 'Keyboard', 129.99)]

In [None]:
function_call = function_calls[-1]

In [None]:
function_call

name: "execute_query"
args {
  fields {
    key: "sql"
    value {
      string_value: "SELECT * FROM products ORDER BY price DESC LIMIT 3"
    }
  }
}

In [None]:
resp.candidates[0].content

role: "model"
parts {
  function_call {
    name: "list_tables"
    args {
    }
  }
}
parts {
  text: "\n"
}
parts {
  function_call {
    name: "describe_table"
    args {
      fields {
        key: "table_name"
        value {
          string_value: "products"
        }
      }
    }
  }
}
parts {
  text: "\n"
}
parts {
  function_call {
    name: "execute_query"
    args {
      fields {
        key: "sql"
        value {
          string_value: "SELECT * FROM products ORDER BY price DESC LIMIT 3"
        }
      }
    }
  }
}

In [None]:
model_response = model.generate_content(
    [
        "What are three most expensive products?",
        resp.candidates[0].content,
        Content(
            parts = [
                Part.from_function_response(
                    name = function_call.name,
                    response={"content": response},
                ),
            ],
        ),
    ],
    tools=[db_tool]
)

In [None]:
resp = send_message_with_retry(chat, "and how much is it?",
                               tools = [db_tool])

In [None]:
resp