In [25]:
import os
import requests
import pandas as pd
import numpy as np
import re
from sklearn.model_selection import train_test_split
import nltk, spacy
from nltk.corpus import stopwords
from groq import Groq
import os
from dotenv import load_dotenv

In [26]:
DATA_PATH = "spider_text_sql.csv"
load_dotenv()

GROQ_API_KEY =  os.getenv("GROQ_API_KEY")
MODEL = "llama-3.3-70b-versatile"

headers = {
    "Authorization": f"Bearer {GROQ_API_KEY}",
    "Content-Type": "application/json"
}
client = Groq (
    api_key=GROQ_API_KEY
)

In [27]:
df = pd.read_csv(DATA_PATH)
print("Columns:", df.columns)

question_col = "text_query"   # change if needed
sql_col = "sql_command"

df = df.dropna(subset=[question_col, sql_col])

Columns: Index(['text_query', 'sql_command'], dtype='object')


In [28]:
nltk.download("stopwords")
STOPWORDS = set(stopwords.words("english"))

nlp = spacy.load("en_core_web_sm")

def preprocess_question(text: str) -> str:
    doc = nlp(str(text).lower())
    return " ".join([tok.lemma_ for tok in doc if tok.text not in STOPWORDS])

def normalize_sql(sql: str) -> str:
    return re.sub(r"\\s+", " ", str(sql).strip())

df["question_clean"] = df[question_col].apply(preprocess_question)
df["sql_clean"] = df[sql_col].apply(normalize_sql)

train_df, val_df = train_test_split(df, test_size=0.1, random_state=42)

[nltk_data] Downloading package stopwords to C:\Users\Luka
[nltk_data]     Krstic\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [29]:
def groq_generate(prompt: str) -> str:
    chat_completion = client.chat.completions.create(
        messages=[
            {
                "role": "system",
                "content": "You are a helpful assistant that converts English questions into SQL queries and only output the SQL, without explaining and without new lines, just generate the whole query."
            },
            {
                "role": "user",
                "content": prompt,
            }
        ],
            model=MODEL,
    )
    return chat_completion.choices[0].message.content

def predict_sql(question: str) -> str:
    prompt = f"Translate the following question into SQL, and only output the SQL, without explaining and without new lines, just the whole query:\n\nQuestion: {question}\nSQL:"
    return groq_generate(prompt)

In [30]:
preds, refs = [], []
def clean_sql(q: str) -> str:
    # remove triple backticks and optional "sql"
    q = re.sub(r"```sql|```", "", q)
    # strip leading/trailing whitespace
    q = q.strip()
    return q

for i, row in val_df.head(50).iterrows():  # limit to 50 for demo
    q = row["question_clean"]
    gold = row["sql_clean"]
    try:
        pred = predict_sql(q)
    except Exception as e:
        print("Error:", e)
        pred = ""
    preds.append(pred)
    refs.append(gold)

preds = [clean_sql(sql) for sql in preds]

In [31]:
print(preds)
print(refs)

['SELECT id FROM order_item WHERE product_id = 11', 'SELECT MIN(estimate) AS low, MAX(estimate) AS high FROM film_market', "SELECT name, account_balance FROM customer WHERE name LIKE '% %'", 'SELECT name, subject FROM table_name', 'SELECT entry_name FROM catalog WHERE price_usd = ( SELECT MAX(price_usd) FROM catalog )', 'SELECT name FROM product WHERE id IN (SELECT product_id FROM event GROUP BY product_id HAVING COUNT(DISTINCT event_id) >= 2)', 'SELECT name FROM club_player WHERE position = "right wing"', 'SELECT COUNT(DISTINCT budget_code) FROM table_name;', 'SELECT name FROM patient WHERE room = 111 AND treatment IS NOT NULL', "SELECT flight_number FROM flights WHERE airline = 'United Airlines'", 'SELECT * FROM employees WHERE salary BETWEEN 8000 AND 12000 AND commission IS NULL AND department_id = 40;', 'SELECT * FROM competitions WHERE name = "1994 FIFA World Cup qualification"', 'SELECT song_name, singer, AVG(age) FROM singers JOIN songs ON singers.singer_id = songs.singer_id GRO

In [33]:
import re
import numpy as np
import evaluate  # pip install evaluate
from collections import Counter

# ---------- Normalization ----------
def normalize(sql): 
    sql = re.sub(r"\s+", " ", sql)   # collapse whitespace
    sql = sql.strip().lower()
    sql = re.sub(r";$", "", sql)     # remove trailing semicolon
    return sql

# ---------- Token-level F1 ----------
def token_f1(pred, ref):
    p_tokens, r_tokens = pred.split(), ref.split()
    p_count, r_count = Counter(p_tokens), Counter(r_tokens)
    common = sum((p_count & r_count).values())
    if not p_tokens or not r_tokens:
        return 0.0
    precision = common / len(p_tokens)
    recall = common / len(r_tokens)
    if precision + recall == 0:
        return 0.0
    return 2 * precision * recall / (precision + recall)

# ---------- Clause-level Match ----------
def clause_match(pred, ref, clause="where"):
    pred_has = clause in pred.lower()
    ref_has = clause in ref.lower()
    return pred_has == ref_has

def clause_scores(preds, refs):
    clauses = ["select", "from", "where", "group by", "order by"]
    results = {}
    for cl in clauses:
        acc = np.mean([clause_match(p, g, clause=cl) for p, g in zip(preds, refs)])
        results[cl.upper()] = acc
    return results

# ---------- Partial Match (columns/tables overlap) ----------
def partial_match(pred, ref):
    cols_ref = re.findall(r"[a-zA-Z_][a-zA-Z0-9_]*", ref.lower())
    cols_pred = re.findall(r"[a-zA-Z_][a-zA-Z0-9_]*", pred.lower())
    overlap = len(set(cols_ref) & set(cols_pred))
    return overlap / max(1, len(set(cols_ref)))

# ---------- Compute Metrics ----------
preds_norm = [normalize(p) for p in preds]
refs_norm = [normalize(r) for r in refs]

# Exact Match
exact_match = np.mean([p == g for p, g in zip(preds_norm, refs_norm)])

# BLEU
bleu = evaluate.load("sacrebleu")
bleu_score = bleu.compute(predictions=preds_norm, references=[[r] for r in refs_norm])["score"]

# Token-level F1
f1_scores = [token_f1(p, g) for p, g in zip(preds_norm, refs_norm)]
avg_f1 = np.mean(f1_scores)

# Clause accuracy
clause_acc = clause_scores(preds_norm, refs_norm)

# Partial match
partial_scores = [partial_match(p, g) for p, g in zip(preds_norm, refs_norm)]
avg_partial = np.mean(partial_scores)

# ---------- Print Summary ----------
print("Evaluation Summary:")
print(f"  Exact Match: {exact_match:.3f}")
print(f"  BLEU: {bleu_score:.2f}")
print(f"  Token F1: {avg_f1:.3f}")
print(f"  Partial Match: {avg_partial:.3f}")
print("  Clause Accuracy:")
for cl, acc in clause_acc.items():
    print(f"    {cl}: {acc:.3f}")

# ---------- Preview Examples ----------
for q, g, p, f1, pm in zip(val_df["text_query"].head(50), refs_norm[:50], preds_norm[:50], f1_scores[:50], partial_scores[:50]):
    print("\n---")
    print("Q:", q)
    print("Gold:", g)
    print("Pred:", p)
    print(f"Token F1: {f1:.3f} | Partial Match: {pm:.3f}")


Evaluation Summary:
  Exact Match: 0.020
  BLEU: 18.24
  Token F1: 0.470
  Partial Match: 0.564
  Clause Accuracy:
    SELECT: 1.000
    FROM: 1.000
    WHERE: 0.820
    GROUP BY: 0.820
    ORDER BY: 0.920

---
Q: Find the ids of all the order items whose product id is 11.
Gold: select order_item_id from order_items where product_id = 11
Pred: select id from order_item where product_id = 11
Token F1: 0.750 | Partial Match: 0.667

---
Q: Return the low and high estimates for all film markets.
Gold: select low_estimate , high_estimate from film_market_estimation
Pred: select min(estimate) as low, max(estimate) as high from film_market
Token F1: 0.267 | Partial Match: 0.400

---
Q: Find the name and account balance of the customer whose name includes the letter ‘a’.
Gold: select cust_name , acc_bal from customer where cust_name like '%a%'
Pred: select name, account_balance from customer where name like '% %'
Token F1: 0.500 | Partial Match: 0.625

---
Q: What are the names of all the subj