Querying Database on Demand based on User Query

In [1]:
from dotenv import load_dotenv
load_dotenv()

True

In [2]:
from langfuse import get_client
 
langfuse = get_client()
 
# Verify connection
if langfuse.auth_check():
    print("Langfuse client is authenticated and ready!")
else:
    print("Authentication failed. Please check your credentials and host.")
 

Langfuse client is authenticated and ready!


In [3]:
from openinference.instrumentation.llama_index import LlamaIndexInstrumentor
 
# Initialize LlamaIndex instrumentation
LlamaIndexInstrumentor().instrument()

In [4]:
from llama_index.core import Settings
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI

embed_model = OpenAIEmbedding(model="text-embedding-3-small")
llm = OpenAI(model="gpt-4o-mini")

Settings.embed_model = embed_model
Settings.llm = llm

In [5]:
import logging
import sys

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

In [6]:
from sqlalchemy import  create_engine, MetaData, Table, Column, String, Integer,select, column

engine = create_engine("sqlite:///:memory:", future=True)
metadata_obj = MetaData()

In [7]:
# create city SQL table
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)

metadata_obj.create_all(engine)

In [8]:
# print tables
metadata_obj.tables.keys()

dict_keys(['city_stats'])

In [9]:
from sqlalchemy import insert

rows = [
    {"city_name": "Toronto", "population": 2930000, "country": "Canada"},
    {"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
    {"city_name": "Berlin", "population": 3645000, "country": "Germany"},
]
for row in rows:
    stmt = insert(city_stats_table).values(**row)
    with engine.begin() as connection:
        cursor = connection.execute(stmt)

In [10]:
with engine.connect() as connection:
    cursor = connection.exec_driver_sql("SELECT * FROM city_stats")
    print(cursor.fetchall())

[('Toronto', 2930000, 'Canada'), ('Tokyo', 13960000, 'Japan'), ('Berlin', 3645000, 'Germany')]


In [12]:
from llama_index.core import VectorStoreIndex, SQLDatabase


sql_database = SQLDatabase(engine)

In [None]:
from llama_index.core.objects import SQLTableNodeMapping, SQLTableSchema, ObjectIndex
from llama_index.core import VectorStoreIndex

# TableNodeMapping is responsible for mapping table to nodes once LLM 
# decides which table to query
# These nodes will be stored in vectorindex
table_node_mapping = SQLTableNodeMapping(sql_database=sql_database)

city_stats_text = """
    This table gives information regarding the population and country 
    of a given city
"""

# uses table schemas to understand which table to query based on user request
table_schema_objs = [
    SQLTableSchema(table_name="city_stats", context_str=city_stats_text)
]

# Index only those tables needed for use query
ob_index = ObjectIndex.from_objects(
    table_schema_objs,
    table_node_mapping,
    VectorStoreIndex
)

In [16]:
from llama_index.core.indices.struct_store.sql_query import SQLTableRetrieverQueryEngine

query_engine = SQLTableRetrieverQueryEngine(
    sql_database=sql_database, table_retriever=ob_index.as_retriever()
)

In [17]:
response = query_engine.query("which city has the highest population?")
response

INFO:llama_index.core.indices.struct_store.sql_retriever:> Table desc str: Table 'city_stats' has columns: city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16)), . The table description is: 
    This table gives information regarding the population and country 
    of a given city

> Table desc str: Table 'city_stats' has columns: city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16)), . The table description is: 
    This table gives information regarding the population and country 
    of a given city



Response(response='The city with the highest population is Tokyo, with a population of 13,960,000.', source_nodes=[NodeWithScore(node=TextNode(id_='50aee0df-1da2-4c2f-b3f0-9c5f9f6f4693', embedding=None, metadata={'sql_query': 'SELECT city_name, population FROM city_stats ORDER BY population DESC LIMIT 1;', 'result': [('Tokyo', 13960000)], 'col_keys': ['city_name', 'population']}, excluded_embed_metadata_keys=['sql_query', 'result', 'col_keys'], excluded_llm_metadata_keys=['sql_query', 'result', 'col_keys'], relationships={}, metadata_template='{key}: {value}', metadata_separator='\n', text="[('Tokyo', 13960000)]", mimetype='text/plain', start_char_idx=None, end_char_idx=None, metadata_seperator='\n', text_template='{metadata_str}\n\n{content}'), score=None)], metadata={'50aee0df-1da2-4c2f-b3f0-9c5f9f6f4693': {'sql_query': 'SELECT city_name, population FROM city_stats ORDER BY population DESC LIMIT 1;', 'result': [('Tokyo', 13960000)], 'col_keys': ['city_name', 'population']}, 'sql_quer