In [1]:
import os 


print(os.getcwd())

c:\Users\Hp\Documents\GitHub\sql_agent_langchain\notebooks


In [2]:
os.chdir("../")

print(os.getcwd())

c:\Users\Hp\Documents\GitHub\sql_agent_langchain


# With Minimal Logic

## high level design
1. fetch the available tables & schemas from the db
2. decide which tables are relevant to the **question**
3. generate a query based on the question & information from the schemas
4. safety-check the query to limit the impact of llm-generated query
5. execute the query & return the results
6. correct mistakes surfaced by the db engine until the query is successful
7. formulate a response based on the results

### Tracing

In [3]:
from dotenv import load_dotenv


load_dotenv(".env.local")


LANGSMITH_TRACING = os.getenv("LANGSMITH_TRACING")
LANGSMITH_API_KEY = os.getenv("LANGSMITH_API_KEY")

HF_TOKEN = os.environ["HF_TOKEN"]

### Chat Model
**tool-calling**

In [4]:
from langchain_ollama import ChatOllama


local_llm = ChatOllama(
    model="qwen3:1.7b",
    temperature=0.1,
    format="json",
    request_timeout=200,
    streaming=False,
    verbose=True,
)

In [None]:
from huggingface_hub import login


login()

In [29]:
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace


llm_hf = HuggingFaceEndpoint(
    repo_id="defog/llama-3-sqlcoder-8b",
    task="text-generation",
    temperature=0.1,
)

chat = ChatHuggingFace(llm=llm_hf, verbose=True)

### Configure dB

In [11]:
import requests
import pathlib


url = "https://storage.googleapis.com/benchmarks-artifacts/chinook/Chinook.db"
local_path = pathlib.Path("artifacts/sqlite_db/Chinook.db")


if local_path.exists():
    print(f"{local_path} already exists, skipping download.")
else:
    response = requests.get(url)
    if response.status_code == 200:
        local_path.write_bytes(response.content)
        print(f"File downloaded and saved as {local_path}")
    else:
        print(f"Failed to download the file. Status code: {response.status_code}")

artifacts\sqlite_db\Chinook.db already exists, skipping download.


### Add tools to dB interactions

In [12]:
from langchain_community.utilities import SQLDatabase


db = SQLDatabase.from_uri("sqlite:///artifacts/sqlite_db/Chinook.db")

### Extract Schema

In [13]:
SCHEMA = db.get_table_info()
print(SCHEMA)


CREATE TABLE "Album" (
	"AlbumId" INTEGER NOT NULL, 
	"Title" NVARCHAR(160) NOT NULL, 
	"ArtistId" INTEGER NOT NULL, 
	PRIMARY KEY ("AlbumId"), 
	FOREIGN KEY("ArtistId") REFERENCES "Artist" ("ArtistId")
)

/*
3 rows from Album table:
AlbumId	Title	ArtistId
1	For Those About To Rock We Salute You	1
2	Balls to the Wall	2
3	Restless and Wild	2
*/


CREATE TABLE "Artist" (
	"ArtistId" INTEGER NOT NULL, 
	"Name" NVARCHAR(120), 
	PRIMARY KEY ("ArtistId")
)

/*
3 rows from Artist table:
ArtistId	Name
1	AC/DC
2	Accept
3	Aerosmith
*/


CREATE TABLE "Customer" (
	"CustomerId" INTEGER NOT NULL, 
	"FirstName" NVARCHAR(40) NOT NULL, 
	"LastName" NVARCHAR(20) NOT NULL, 
	"Company" NVARCHAR(80), 
	"Address" NVARCHAR(70), 
	"City" NVARCHAR(40), 
	"State" NVARCHAR(40), 
	"Country" NVARCHAR(40), 
	"PostalCode" NVARCHAR(10), 
	"Phone" NVARCHAR(24), 
	"Fax" NVARCHAR(24), 
	"Email" NVARCHAR(60) NOT NULL, 
	"SupportRepId" INTEGER, 
	PRIMARY KEY ("CustomerId"), 
	FOREIGN KEY("SupportRepId") REFERENCES "Empl

### Execute SQL queries

In [14]:
import re
from langchain_core.tools import tool


DENY_RE = re.compile(r"\b(INSERT|UPDATE|DELETE|ALTER|DROP|CREATE|REPLACE|TRUNCATE)\b", re.I)
HAS_LIMIT_TAIL_RE = re.compile(r"(?is)\blimit\b\s+\d+(\s*,\s*\d+)?\s*;?\s*$")


def _safe_sql(q: str) -> str:
    # normalize
    q = q.strip()
    # block multiple statements (allow one optional trailing ;)
    if q.count(";") > 1 or (q.endswith(";") and ";" in q[:-1]):
        return "Error: multiple statements are not allowed."
    q = q.rstrip(";").strip()

    # read-only gate
    if not q.lower().startswith("select"):
        return "Error: only SELECT statements are allowed."
    if DENY_RE.search(q):
        return "Error: DML/DDL detected. Only read-only queries are permitted."

    # append LIMIT only if not already present at the end (robust to whitespace/newlines)
    if not HAS_LIMIT_TAIL_RE.search(q):
        q += " LIMIT 5"
    return q

In [15]:
@tool
def execute_sql(query: str) -> str:
    """Execute a READ-ONLY SQLite SELECT query and return results."""
    query = _safe_sql(query)
    q = query
    if q.startswith("Error:"):
        return q
    try:
        return db.run(q)
    except Exception as e:
        return f"Error: {e}"

### Prompt Template

use **create_agent** to build [ReAct agent](https://arxiv.org/pdf/2210.03629)

In [16]:
SYSTEM = f"""
You are a careful SQLite analyst.

Authoritative schema (do not invent columns/tables):
{SCHEMA}

Rules:
- Think step-by-step.
- When you need data, call the tool `execute_sql` with ONE SELECT query.
- Read-only only; no INSERT/UPDATE/DELETE/ALTER/DROP/CREATE/REPLACE/TRUNCATE.
- Limit to 5 rows unless user explicitly asks otherwise.
- If the tool returns 'Error:', revise the SQL and try again.
- Limit the number of attempts to 5.
- If you are not successful after 5 attempts, return a note to the user.
- Prefer explicit column lists; avoid SELECT *.
"""

In [18]:
from langchain_core.messages import SystemMessage
from langchain_core.prompts import ChatPromptTemplate


prompt = ChatPromptTemplate.from_messages([
    SystemMessage(content=SYSTEM),
    ("human", "{input}")
])

### Create Chain

### Create Agent 
(with **models, tools, & prompt**)

In [30]:
from langgraph.prebuilt import create_react_agent
from langchain_core.messages import SystemMessage


agent = create_react_agent(
    model=chat,
    tools=[execute_sql],
    prompt=SystemMessage(content=SYSTEM),
)

### Run agent

In [31]:
question = "Which genre on average has the longest tracks?"


for step in agent.stream(
    {"messages": [{"role": "user", "content": question}]},
    stream_mode="values",
):
    step["messages"][-1].pretty_print()


Which genre on average has the longest tracks?

SELECT g.Name, AVG(t.Milliseconds) AS average_duration FROM Track t JOIN Genre g ON t.GenreId = g.GenreId GROUP BY g.Name ORDER BY average_duration DESC LIMIT 1;


### Test

In [33]:
from sqlalchemy import create_engine, text


def try_sqlite_query(query: str, db_path: str = "artifacts/sqlite_db/Chinook.db"):
    """Run a SQL query against a SQLite DB and return rows."""
    engine = create_engine(f"sqlite:///{db_path}")
    with engine.connect() as conn:
        try:
            result = conn.execute(text(query))
            rows = result.fetchall()
            return rows
        except Exception as e:
            return f"Error: {e}"

print(try_sqlite_query("SELECT g.Name, AVG(t.Milliseconds) AS average_duration FROM Track t JOIN Genre g ON t.GenreId = g.GenreId GROUP BY g.Name ORDER BY average_duration DESC LIMIT 1;"))

[('Sci Fi & Fantasy', 2911783.0384615385)]
