In [2]:
from llama_cpp import Llama
import pandas as pd
import duckdb, os

In [3]:
# Path to your GGUF model
MODEL_PATH = "models/llama-3-sqlcoder-8b.Q6_K.gguf"

# llm = Llama(model_path=MODEL_PATH, n_ctx=2048, n_threads=6)
llm = Llama(
    model_path=MODEL_PATH,
    n_ctx=1024,  # Lower context size if RAM is an issue
    n_threads=6,
    n_gpu_layers=20,  # Safer for 8GB GPU
    verbose=True
)

AVX = 1 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 1 | SSSE3 = 0 | VSX = 0 | 


In [4]:
# Load CSV into DuckDB
df = pd.read_csv("./data/llm_dataset_v10.gz")
con = duckdb.connect()
con.register("sales_data", df)

  df = pd.read_csv("./data/llm_dataset_v10.gz")


<duckdb.duckdb.DuckDBPyConnection at 0x203a5d6a970>

In [6]:
# Prompt template
def build_prompt(nlq):
    schema = "sales_data(region TEXT, quarter TEXT, sales INT)"
    prompt = f"""### You are an expert Postgres SQL generator.
### Given the following table schema:
# {schema}

### Write a SQL query to answer the question:
# {nlq}

### SQL:
"""
    return prompt

In [10]:
# Query model
def generate_sql(prompt):
    output = llm(prompt, temperature=0, max_tokens=256)
    text = output["choices"][0]["text"]

    if "SELECT" not in text.upper():
        print("❌ 'SELECT' not found in model output. Raw output:")
        print(text)
        return None

    # Try to extract SQL statement cleanly
    try:
        sql = "SELECT " + text.upper().split("SELECT", 1)[1].split(";")[0].strip() + ";"
        return sql
    except Exception as e:
        print("❌ Error while parsing SQL:", e)
        print("Raw model output:")
        print(text)
        return None

In [8]:
# Run query
def run_nlq(nlq):
    prompt = build_prompt(nlq)
    sql = generate_sql(prompt)

    if not sql:
        print("\n⚠️ Could not generate valid SQL.")
        return

    print("\n📜 Generated SQL:")
    print(sql)

    try:
        result = con.execute(sql).fetchdf()
        print("\n📊 Query Result:")
        print(result)
    except Exception as e:
        print("\n❌ SQL Execution Error:")
        print(e)

In [11]:
run_nlq("What were the total sales in Q3 for the Northeast?")

Llama.generate: prefix-match hit



📜 Generated SQL:
SELECT SUM(SALES) AS TOTAL_SALES FROM SALES_DATA WHERE REGION = 'NORTHEAST' AND QUARTER = 'Q3';

📊 Query Result:
   TOTAL_SALES
0          NaN


In [5]:
con.execute("SELECT * FROM sales_data LIMIT 5").fetchdf()

Unnamed: 0,month,year,region,city,area,territory,distributor,route,customer,brand,...,primary sales,target,productivity,mro,unproductive_mro,unassorted_mro,stockout_mro,stockout,assortment,mto
0,11,2024,Central-A,lahore,Lahore Central Area,Lahore-A Territory,D0715,D0715OB26,N00000166872_D0715OB26,CHOCO LAVA,...,257.940313,1051.658369,True,85.635363,6.182219,35.9963,43.456844,True,False,752.058369
1,11,2024,Central-A,lahore,Lahore Central Area,Lahore-A Territory,D0715,D0715OB26,N00000167036_D0715OB26,CHOCO LAVA,...,515.880626,2103.316738,True,171.270725,12.364438,71.992599,86.913688,False,False,1504.116738
2,11,2024,Central-A,lahore,Lahore Central Area,Lahore-A Territory,D0715,D0715OB26,N00000167037_D0715OB26,CHOCO LAVA,...,257.940313,1051.658369,True,85.635363,6.182219,35.9963,43.456844,False,True,752.058369
3,11,2024,Central-A,lahore,Lahore Central Area,Lahore-A Territory,D0715,D0715OB26,N00000167080_D0715OB26,CHOCO LAVA,...,257.940313,1051.658369,True,85.635363,6.182219,35.9963,43.456844,False,False,752.058369
4,11,2024,Central-A,lahore,Lahore Central Area,Lahore-A Territory,D0715,D0715OB26,N00000167081_D0715OB26,CHOCO LAVA,...,515.880626,2103.316738,True,171.270725,12.364438,71.992599,86.913688,True,True,1504.116738


In [7]:
sql = "SELECT SUM(sales) AS total_primary_sales FROM sales_data WHERE (month >= 7 AND month <= 9) AND year = 2024"
print("\n📜 Executing SQL:")
con.execute(sql).fetchdf()


📜 Executing SQL:


Unnamed: 0,total_primary_sales
0,


In [None]:
sql = "SELECT SUM('primary sales') AS total_primary_sales FROM sales_data WHERE (month >= 7 AND month <= 9) AND year = 2024"
print("\n📜 Executing SQL:")
con.execute(sql).fetchdf()


📜 Executing SQL:


BinderException: Binder Error: No function matches the given name and argument types 'sum(VARCHAR)'. You might need to add explicit type casts.
	Candidate functions:
	sum(DECIMAL) -> DECIMAL
	sum(BOOLEAN) -> HUGEINT
	sum(SMALLINT) -> HUGEINT
	sum(INTEGER) -> HUGEINT
	sum(BIGINT) -> HUGEINT
	sum(HUGEINT) -> HUGEINT
	sum(DOUBLE) -> DOUBLE


LINE 1: SELECT SUM('primary sales') AS total_primary_sales FROM sales_data...
               ^

In [None]:
sql = "SELECT SUM('primary sales') AS total_primary_sales FROM sales_data WHERE (month >= 7 AND month <= 9) AND year = 2024"
print("\n📜 Executing SQL:")
con.execute(sql).fetchdf()


📜 Executing SQL:


BinderException: Binder Error: No function matches the given name and argument types 'sum(VARCHAR)'. You might need to add explicit type casts.
	Candidate functions:
	sum(DECIMAL) -> DECIMAL
	sum(BOOLEAN) -> HUGEINT
	sum(SMALLINT) -> HUGEINT
	sum(INTEGER) -> HUGEINT
	sum(BIGINT) -> HUGEINT
	sum(HUGEINT) -> HUGEINT
	sum(DOUBLE) -> DOUBLE


LINE 1: SELECT SUM('primary sales') AS total_primary_sales FROM sales_data...
               ^