Capstone project for th Gen AI Intensive Course Capstone 2025Q1.
This project tries to build an AI research assistant. The researcher provides a data set and defines several tables based on the data. Then, asking the bot to do some statistical or descriptive analysis on data. Bot should be able to interact with data through SQL queries and answer researcher's questions.  

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

Setting up the notebook and installing Python SDK packages

In [None]:
!pip uninstall -qqy jupyterlab  # Remove unused conflicting packages
!pip install -U -q "google-genai==1.7.0"

In [None]:
from google import genai
from google.genai import types

genai.__version__

Setting up the API key

In [None]:
from kaggle_secrets import UserSecretsClient

GOOGLE_API_KEY = UserSecretsClient().get_secret("GOOGLE_API_KEY")

In [None]:
# Define a retry policy. The model might make multiple consecutive calls automatically
# for a complex query, this ensures the client retries if it hits quota limits.
from google.api_core import retry

is_retriable = lambda e: (isinstance(e, genai.errors.APIError) and e.code in {429, 503})

if not hasattr(genai.models.Models.generate_content, '__wrapped__'):
  genai.models.Models.generate_content = retry.Retry(
      predicate=is_retriable)(genai.models.Models.generate_content)

Creating the database. 
The data is downloaded and cleaned from the original data base of "American Time Use Survey", which is publicly available. Data is embeded to this document using an excel file called "ATUS.xlsx"

In [None]:
import pandas as pd

# Load the Excel file from the dataset directory
atus = pd.read_excel('/kaggle/input/d/jalalbagherzade/timeuse/ATUS.xlsx')
atus.head()

In [None]:
%load_ext sql
%sql sqlite:///sample.db

In [None]:
%%sql
 -- defining 4 tables of individuals characteristics, working hours,
 -- homeproduction hours, and leisure hours
CREATE TABLE individuals (
    caseid INT NOT NULL,
    weight INT NOT NULL,
    year INT NOT NULL,
    quarter INT NOT NULL,
    age INT NOT NULL,
    male INT NOT NULL,
    married INT NOT NULL,
    student INT NOT NULL,
    retired INT NOT NULL
);

CREATE TABLE work (
    caseid INT NOT NULL,
    work_hour DECIMAL(10, 2)
);

CREATE TABLE home_production (
    caseid INT NOT NULL,
    homeproduction_hour DECIMAL(10, 2)
  );

CREATE TABLE leisure (
    caseid INT NOT NULL,
    leisure_hour DECIMAL(10, 2)
  );

In [None]:
import sqlite3
# assigning data from read excel file to the tables
db_file = "sample.db"
db_conn = sqlite3.connect(db_file)
atus[['caseid', 'weight', 'year', 'quarter', 'age', 'male', 'married', 'student', 'retired']].to_sql('individuals', db_conn, if_exists='append', index=False)
atus[['caseid', 'work_hour']].to_sql('work', db_conn, if_exists='append', index=False)
atus[['caseid', 'homeproduction_hour']].to_sql('home_production', db_conn, if_exists='append', index=False)
atus[['caseid', 'leisure_hour']].to_sql('leisure', db_conn, if_exists='append', index=False)

In [None]:
#checking if data and tables are matched correctly
#pd.read_sql_query("SELECT * FROM work LIMIT 5;", db_conn)
pd.read_sql_query("SELECT * FROM home_production LIMIT 5;", db_conn)
#pd.read_sql_query("SELECT * FROM leisure LIMIT 5;", db_conn)

Defining database functions: listing all tables, decribing the columns of each table, and running SQL queries

In [None]:
# a function to tell LLM all available tables
def list_tables() -> list[str]:
    """Retrieve the names of all tables in the database."""
    # Include print logging statements so you can see when functions are being called.
    print(' - DB CALL: list_tables()')

    cursor = db_conn.cursor()

    # Fetch the table names.
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")

    tables = cursor.fetchall()
    return [t[0] for t in tables]


list_tables()

In [None]:
# a function to tell LLM all existed columns in each table
def describe_table(table_name: str) -> list[tuple[str, str]]:
    """Look up the table schema.

    Returns:
      List of columns, where each entry is a tuple of (column, type).
    """
    print(f' - DB CALL: describe_table({table_name})')

    cursor = db_conn.cursor()

    cursor.execute(f"PRAGMA table_info({table_name});")

    schema = cursor.fetchall()
    # [column index, column name, column type, ...]
    return [(col[1], col[2]) for col in schema]

describe_table("work")
describe_table("home_production")
describe_table("leisure")
describe_table("individuals")

In [None]:
# a function to tell LLM how run a query using different columns from different tables
def execute_query(sql: str) -> list[list[str]]:
    """Execute an SQL statement, returning the results."""
    print(f' - DB CALL: execute_query({sql})')

    cursor = db_conn.cursor()

    cursor.execute(sql)
    return cursor.fetchall()


execute_query("select * from leisure LIMIT 5")

In [None]:
# Below is a research assistant bot that use Gemini API to call functions
db_tools = [list_tables, describe_table, execute_query]

instruction = """You are a helpful chatbot that can interact with an SQL database and
run statistical analysis for a researcher. You are given a data set in form of tables. 
You will take the users questions and turn them into SQL queries using the tools available. 
Once you have the information you need, you will answer the user's question
using the data returned.

Use list_tables to see what tables are present, describe_table to understand the
schema, and execute_query to issue an SQL SELECT query.
You need to look up table schema using sqlite3 syntax SQL, then once an
answer is found be sure to tell the user. If the user is requesting an
action, you must also execute the actions."""

client = genai.Client(api_key=GOOGLE_API_KEY)

# Start a chat with automatic function calling enabled.
chat = client.chats.create(
    model="gemini-2.0-flash",
    config=types.GenerateContentConfig(
        system_instruction=instruction,
        tools=db_tools,
    ),
)



In [None]:
resp = chat.send_message("What is the difference of average leisure hours between males and females by year")
print(f"\n{resp.text}")

In [None]:
model = 'gemini-2.0-flash-exp'
live_client = genai.Client(api_key=GOOGLE_API_KEY,
                           http_options=types.HttpOptions(api_version='v1alpha'))

# Wrap the existing execute_query tool you used in the earlier example.
execute_query_tool_def = types.FunctionDeclaration.from_callable(
    client=live_client, callable=execute_query)

# Provide the model with enough information to use the tool, such as describing
# the database so it understands which SQL syntax to use.
sys_int = """
You are a helpful chatbot that can interact with an SQL database and
run statistical analysis for a researcher. You are given a data set in form of tables. 
You will take the users questions and turn them into SQL queries using the tools available. 
Once you have the information you need, you will answer the user's question
using the data returned.

Use list_tables to see what tables are present, describe_table to understand the
schema, and execute_query to issue an SQL SELECT query.
You need to look up table schema using sqlite3 syntax SQL, then once an
answer is found be sure to tell the user. If the user is requesting an
action, you must also execute the actions.
"""

config = {
    "response_modalities": ["TEXT"],
    "system_instruction": {"parts": [{"text": sys_int}]},
    "tools": [
        {"code_execution": {}},
        {"function_declarations": [execute_query_tool_def.to_json_dict()]},
    ],
}

async with live_client.aio.live.connect(model=model, config=config) as session:

  message = "Can you find the average leisure hours for male and females by year?"

  print(f"> {message}\n")
  await session.send(input=message, end_of_turn=True)
  await handle_response(session, tool_impl=execute_query)

  message = "Generate and run some code to plot this as a python seaborn chart"

  print(f"> {message}\n")
  await session.send(input=message, end_of_turn=True)
  await handle_response(session, tool_impl=execute_query)

In [None]:
async with live_client.aio.live.connect(model=model, config=config) as session:

  message = "Can you run a linear regression of age from individuals table on work hours from work table?"

  print(f"> {message}\n")
  await session.send(input=message, end_of_turn=True)
  await handle_response(session, tool_impl=execute_query)