# Test Connection to SQL database locally hosted

In [29]:
from sqlalchemy import create_engine, text

# Replace with your own credentials
USER = "root"
PASSWORD = "12345"
HOST = "localhost"
PORT = "3306"
DB = "sampledb"

# Connection string format: mysql+pymysql://user:password@host:port/database
connection_string = f"mysql+pymysql://{USER}:{PASSWORD}@{HOST}:{PORT}/{DB}"

# Create SQLAlchemy engine
engine = create_engine(connection_string)

# Test connection
try:
    with engine.connect() as conn:
        result = conn.execute(text("SELECT DATABASE();"))
        db_name = result.scalar()
        print("‚úÖ Connected successfully to:", db_name)
except Exception as e:
    print("‚ùå Connection failed:", e)


‚úÖ Connected successfully to: sampledb


In [30]:
with engine.connect() as conn:
    result = conn.execute(text("SELECT COUNT(*) FROM customers"))
    print("Total customers:", result.scalar())


Total customers: 122


In [31]:
# view customers table as dataframe
import pandas as pd
with engine.connect() as conn:
    df = pd.read_sql("SELECT * FROM customers LIMIT 5", conn)
    display(df)

Unnamed: 0,customerNumber,customerName,contactLastName,contactFirstName,phone,addressLine1,addressLine2,city,state,postalCode,country,salesRepEmployeeNumber,creditLimit
0,103,Atelier graphique,Schmitt,Carine,40.32.2555,"54, rue Royale",,Nantes,,44000,France,1370,21000.0
1,112,Signal Gift Stores,King,Jean,7025551838,8489 Strong St.,,Las Vegas,NV,83030,USA,1166,71800.0
2,114,"Australian Collectors, Co.",Ferguson,Peter,03 9520 4555,636 St Kilda Road,Level 3,Melbourne,Victoria,3004,Australia,1611,117300.0
3,119,La Rochelle Gifts,Labrune,Janine,40.67.8555,"67, rue des Cinquante Otages",,Nantes,,44000,France,1370,118200.0
4,121,Baane Mini Imports,Bergulfsen,Jonas,07-98 9555,Erling Skakkes gate 78,,Stavern,,4110,Norway,1504,81700.0


# Read schema (tables + columns) for the LLM

In [32]:
from sqlalchemy import inspect

def get_schema_description(engine) -> str:
    insp = inspect(engine)
    lines = []
    for t in insp.get_table_names():
        cols = insp.get_columns(t)
        col_str = ", ".join(f"{c['name']} ({str(c.get('type'))})" for c in cols)
        lines.append(f"Table {t}: {col_str}")
    return "\n".join(lines)

schema_text = get_schema_description(engine)
print(schema_text[:1000] + ("\n... (truncated)" if len(schema_text) > 1000 else ""))


Table customers: customerNumber (INTEGER), customerName (VARCHAR(50)), contactLastName (VARCHAR(50)), contactFirstName (VARCHAR(50)), phone (VARCHAR(50)), addressLine1 (VARCHAR(50)), addressLine2 (VARCHAR(50)), city (VARCHAR(50)), state (VARCHAR(50)), postalCode (VARCHAR(15)), country (VARCHAR(50)), salesRepEmployeeNumber (INTEGER), creditLimit (DECIMAL(10, 2))
Table employees: employeeNumber (INTEGER), lastName (VARCHAR(50)), firstName (VARCHAR(50)), extension (VARCHAR(10)), email (VARCHAR(100)), officeCode (VARCHAR(10)), reportsTo (INTEGER), jobTitle (VARCHAR(50))
Table offices: officeCode (VARCHAR(10)), city (VARCHAR(50)), phone (VARCHAR(50)), addressLine1 (VARCHAR(50)), addressLine2 (VARCHAR(50)), state (VARCHAR(50)), country (VARCHAR(50)), postalCode (VARCHAR(15)), territory (VARCHAR(10))
Table orderdetails: orderNumber (INTEGER), productCode (VARCHAR(15)), quantityOrdered (INTEGER), priceEach (DECIMAL(10, 2)), orderLineNumber (SMALLINT)
Table orders: orderNumber (INTEGER), orderD

# LLM (Ollama) + LangChain: build the prompt ‚Üí generate safe MySQL SELECT SQL

In [33]:
# --- LLM to SQL: Prompt + Helpers (Step 4) ---

# We assume you already ran the earlier cells:
# - installed packages
# - created `engine` (SQLAlchemy engine)
# - created `schema_text` (the printed schema summary)
#
# If not, make sure you have these two variables from earlier:
#   engine: SQLAlchemy engine pointing to mysql+pymysql://root:12345@localhost:3306/sampledb
#   schema_text: a string containing your tables/columns (the schema)

import re
import pandas as pd
from sqlalchemy import text  # used later in step 5
from langchain_ollama import ChatOllama
from langchain_core.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser

# 1) Connect to your local Ollama model
#    Make sure you've pulled a model, e.g.:
#    > ollama pull llama3
#    You can switch to another: "qwen2.5", "mistral", etc.
OLLAMA_MODEL = "llama3.2:latest"  # change if you want a different local model
llm = ChatOllama(model=OLLAMA_MODEL, temperature=0)

# 2) A prompt that instructs the model to produce ONE safe MySQL SELECT query
#    Key rules:
#      - Only use columns/tables from the schema we pass in
#      - No write operations (INSERT/UPDATE/DELETE/ALTER/TRUNCATE/etc.)
#      - Keep it simple; add GROUP BY when aggregating
#      - Output ONLY a single SQL code block (so it's easy to parse)
SQL_PROMPT = PromptTemplate.from_template(
    """You are an expert MySQL analyst. Write ONE safe MySQL SELECT query for the user's question.
Rules:
- Use ONLY tables/columns listed in the schema.
- NO data changes (no INSERT/UPDATE/DELETE/ALTER/TRUNCATE/DROP/CREATE/GRANT/REVOKE).
- Keep it simple and readable. Use GROUP BY when aggregating.
- If returning raw rows, include a LIMIT unless the user explicitly asks for all.
- Return ONLY a single SQL code block, nothing else.
- Give proper names to computed columns which are derived from expressions and easily understood.

Question:
{question}

Schema:
{schema}

Return only:
```sql
SELECT ...
```"""
)

# 3) Simple regex checks for safety. We block any attempt to modify data or to stack multiple statements.
DANGEROUS = re.compile(r"\b(INSERT|UPDATE|DELETE|DROP|TRUNCATE|ALTER|CREATE|GRANT|REVOKE)\b", re.I)
MULTI = re.compile(r";\s*\S")  # detects "statement; another_statement"

def extract_sql(model_output: str) -> str:
    """
    Pull SQL out of a fenced code block if present.
    We prefer a block starting with 'sql' or 'select'.
    """
    txt = model_output.strip()
    if "```" in txt:
        parts = [p for p in txt.split("```") if p.strip()]
        for p in parts:
            s = p.strip()
            if s.lower().startswith("sql"):
                # If the block starts with 'sql', remove that word and return the rest
                return s[3:].strip().rstrip(";")
            if s.lower().startswith("select"):
                return s.rstrip(";")
        # Fall back to the last part
        return parts[-1].strip().rstrip(";")
    # If no code fences, just return the text
    return txt.strip().rstrip(";")

def is_safe(sql: str) -> bool:
    """Return False if the SQL looks dangerous (DML/DDL) or has stacked statements."""
    if DANGEROUS.search(sql): 
        return False
    if MULTI.search(sql):
        return False
    return True

def ensure_limit(sql: str, default_limit: int = 50) -> str:
    """
    If it's a SELECT without a LIMIT, append a LIMIT to avoid returning massive result sets by accident.
    """
    s = sql.strip().rstrip(";")
    if s.lower().startswith("select") and " limit " not in s.lower():
        s += f" LIMIT {default_limit}"
    return s

def llm_to_sql(question: str, schema: str) -> str:
    """
    Main function:
    - Builds the prompt with your question + schema
    - Calls the LLM (Ollama) to get a candidate SQL
    - Extracts the SQL, enforces a LIMIT, and checks safety
    - Returns a cleaned, safe SELECT statement
    """
    chain = SQL_PROMPT | llm | StrOutputParser()
    raw_output = chain.invoke({"question": question, "schema": schema}).strip()
    sql = extract_sql(raw_output)
    sql = ensure_limit(sql)
    if not is_safe(sql):
        raise ValueError("Blocked potentially unsafe SQL generated by the model. Please rephrase your question.")
    return sql

print("‚úÖ LLM-to-SQL helpers ready. You can now call llm_to_sql(question, schema_text).")


‚úÖ LLM-to-SQL helpers ready. You can now call llm_to_sql(question, schema_text).


# Ask a question ‚Üí get SQL ‚Üí execute on MySQL ‚Üí show DataFrame

In [34]:
# --- Ask a question, get SQL + results (Step 5) ---

from IPython.display import display  # nicer table display in notebooks

def ask(question: str, preview_rows: int = 20) -> pd.DataFrame:
    """
    - Uses the schema_text (from your earlier schema introspection cell) to give context to the LLM
    - Generates a single safe MySQL SELECT statement
    - Executes the SQL against your database using SQLAlchemy engine
    - Displays and returns a pandas DataFrame with the results
    """
    # 1) Generate SQL from your natural-language question
    sql = llm_to_sql(question, schema_text)
    
    # 2) Show the SQL so you can review/learn from it
    print("üìú SQL generated:\n", sql, "\n")
    
    # 3) Execute the SQL and return results
    with engine.connect() as conn:
        df = pd.read_sql(text(sql), conn)
    
    # 4) Show a sample of the results
    display(df.head(preview_rows))
    print(f"Rows returned: {len(df)}")
    return df

# üîπ Try a few example questions on the Classic Models dataset:
_ = ask("How many order did customer number 363 make")



üìú SQL generated:
 SELECT COUNT(DISTINCT o.orderNumber) AS orderCount 
FROM orders o 
JOIN customers c ON o.customerNumber = c.customerNumber 
WHERE c.customerNumber = 363 LIMIT 50 



Unnamed: 0,orderCount
0,3


Rows returned: 1
