In [1]:
from pyprojroot import here
from langchain_ollama import ChatOllama
from langchain_community.utilities import SQLDatabase

#### Load the LLM

In [2]:
llm = ChatOllama(model="qwen2.5:14b")

#### Load and test the sqlite db

In [3]:
sqldb_directory = here("data/Chinook.db")
db = SQLDatabase.from_uri(f"sqlite:///{sqldb_directory}")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Invoice LIMIT 10;")

sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']


"[(1, 2, '2021-01-01 00:00:00', 'Theodor-Heuss-Straße 34', 'Stuttgart', None, 'Germany', '70174', 1.98), (2, 4, '2021-01-02 00:00:00', 'Ullevålsveien 14', 'Oslo', None, 'Norway', '0171', 3.96), (3, 8, '2021-01-03 00:00:00', 'Grétrystraat 63', 'Brussels', None, 'Belgium', '1000', 5.94), (4, 14, '2021-01-06 00:00:00', '8210 111 ST NW', 'Edmonton', 'AB', 'Canada', 'T6G 2C7', 8.91), (5, 23, '2021-01-11 00:00:00', '69 Salem Street', 'Boston', 'MA', 'USA', '2113', 13.86), (6, 37, '2021-01-19 00:00:00', 'Berger Straße 10', 'Frankfurt', None, 'Germany', '60316', 0.99), (7, 38, '2021-02-01 00:00:00', 'Barbarossastraße 19', 'Berlin', None, 'Germany', '10779', 1.98), (8, 40, '2021-02-01 00:00:00', '8, Rue Hanovre', 'Paris', None, 'France', '75002', 1.98), (9, 42, '2021-02-02 00:00:00', '9, Place Louis Barthou', 'Bordeaux', None, 'France', '33000', 3.96), (10, 46, '2021-02-03 00:00:00', '3 Chatham Street', 'Dublin', 'Dublin', 'Ireland', None, 5.94)]"

#### Create the SQL agent and run a test query

In [4]:
from langchain_community.tools.sql_database.tool import QuerySQLDatabaseTool
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain import hub
from typing_extensions import TypedDict
from typing_extensions import Annotated
from operator import itemgetter


query_prompt_template = hub.pull("langchain-ai/sql-query-system-prompt")

system_role = """Given the following user question, corresponding SQL query, and SQL result, answer the user question.\n
    Question: {question}\n
    SQL Query: {query}\n
    SQL Result: {result}\n
    Answer:
    """


class State(TypedDict):
    question: str
    query: str
    result: str
    answer: str


class QueryOutput(TypedDict):
    """Generated SQL query."""

    query: Annotated[str, ..., "Syntactically valid SQL query."]


def write_query(state: State):
    prompt = query_prompt_template.invoke({
        "dialect": db.dialect,
        "top_k": 10,
        "table_info": db.get_table_info(),
        "input": state["question"],
    })
    structured_llm = llm.with_structured_output(QueryOutput)
    result = structured_llm.invoke(prompt)
    return result

execute_query = QuerySQLDatabaseTool(db=db)
answer_prompt = PromptTemplate.from_template(system_role)
answer = answer_prompt | llm | StrOutputParser()
chain = (
    RunnablePassthrough.assign(query=write_query).assign(
        result=itemgetter("query") | execute_query
    )
    | answer
)

response = chain.invoke({"question": "Give me the names of 5 artists from the database"})
response

'The names of 5 artists (composers) from the database are:\n\n1. Angus Young, Malcolm Young, Brian Johnson\n2. U. Dirkschneider, W. Hoffmann, H. Frank, P. Baltes, S. Kaufmann, G. Hoffmann\n3. F. Baltes, S. Kaufman, U. Dirkscneider & W. Hoffman\n4. F. Baltes, R.A. Smith-Diesel, S. Kaufman, U. Dirkscneider & W. Hoffman\n5. Deaffy & R.A. Smith-Diesel'