In [None]:
import os
import requests
import pandas as pd
import numpy as np
import re
from sklearn.model_selection import train_test_split
import nltk, spacy
from nltk.corpus import stopwords
from groq import Groq
import os
from dotenv import load_dotenv

In [None]:
DATA_PATH = "spider_text_sql.csv"
load_dotenv()

GROQ_API_KEY =  os.getenv("GROQ_API_KEY")
MODEL = "llama-3.3-70b-versatile"

headers = {
    "Authorization": f"Bearer {GROQ_API_KEY}",
    "Content-Type": "application/json"
}
client = Groq (
    api_key=GROQ_API_KEY
)

In [18]:
df = pd.read_csv(DATA_PATH)
print("Columns:", df.columns)

question_col = "text_query"   # change if needed
sql_col = "sql_command"

df = df.dropna(subset=[question_col, sql_col])

Columns: Index(['text_query', 'sql_command'], dtype='object')


In [19]:
nltk.download("stopwords")
STOPWORDS = set(stopwords.words("english"))

nlp = spacy.load("en_core_web_sm")

def preprocess_question(text: str) -> str:
    doc = nlp(str(text).lower())
    return " ".join([tok.lemma_ for tok in doc if tok.text not in STOPWORDS])

def normalize_sql(sql: str) -> str:
    return re.sub(r"\\s+", " ", str(sql).strip())

df["question_clean"] = df[question_col].apply(preprocess_question)
df["sql_clean"] = df[sql_col].apply(normalize_sql)

train_df, val_df = train_test_split(df, test_size=0.1, random_state=42)

[nltk_data] Downloading package stopwords to C:\Users\Luka
[nltk_data]     Krstic\AppData\Roaming\nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [None]:
def groq_generate(prompt: str) -> str:
    chat_completion = client.chat.completions.create(
        messages=[
            {
                "role": "system",
                "content": "You are a helpful assistant that converts English questions into SQL queries and only output the SQL, without explaining and without new lines, just generate the whole query."
            },
            {
                "role": "user",
                "content": prompt,
            }
        ],
            model=MODEL,
    )
    return chat_completion.choices[0].message.content

def predict_sql(question: str) -> str:
    prompt = f"Translate the following question into SQL, and only output the SQL, without explaining and without new lines, just the whole query:\n\nQuestion: {question}\nSQL:"
    return groq_generate(prompt)

In [21]:
preds, refs = [], []
def clean_sql(q: str) -> str:
    # remove triple backticks and optional "sql"
    q = re.sub(r"```sql|```", "", q)
    # strip leading/trailing whitespace
    q = q.strip()
    return q

for i, row in val_df.head(50).iterrows():  # limit to 50 for demo
    q = row["question_clean"]
    gold = row["sql_clean"]
    try:
        pred = predict_sql(q)
    except Exception as e:
        print("Error:", e)
        pred = ""
    preds.append(pred)
    refs.append(gold)

preds = [clean_sql(sql) for sql in preds]

In [22]:
print(preds)
print(refs)

['SELECT id FROM order_item WHERE product_id = 11', 'SELECT MIN(budget) AS low, MAX(budget) AS high, AVG(budget) AS estimate FROM film', "SELECT name, account_balance FROM customer WHERE name LIKE '% %'", 'SELECT name, subject FROM table_name', 'SELECT entry_name FROM catalog WHERE price_usd = ( SELECT MAX(price_usd) FROM catalog )', 'SELECT name FROM product WHERE id IN (SELECT product_id FROM event GROUP BY product_id HAVING COUNT(product_id) >= 2)', 'SELECT name FROM club_player WHERE position = "right wing"', 'SELECT COUNT(DISTINCT budget_code) FROM table_name;', 'SELECT name FROM patient WHERE room = 111 AND treatment IS NOT NULL', "SELECT flight_number FROM flights WHERE airline = 'United Airlines'", 'SELECT * FROM employees WHERE salary BETWEEN 8000 AND 12000 AND commission IS NULL AND department_id = 40', 'SELECT * FROM competitions WHERE name = "1994 FIFA World Cup qualification"', 'SELECT song_name, singer, AVG(age) FROM singers GROUP BY song_name, singer', 'SELECT SUM(cost) 

In [23]:
import re
import numpy as np
import evaluate  # install via: pip install evaluate

# normalize function
def normalize(s): 
    return re.sub(r"\s+", " ", s).strip().lower()

# exact match
exact_match = np.mean([normalize(p) == normalize(g) for p, g in zip(preds, refs)])

# BLEU
bleu = evaluate.load("sacrebleu")
bleu_score = bleu.compute(predictions=preds, references=[[r] for r in refs])["score"]

print("Exact Match:", exact_match)
print("BLEU:", bleu_score)

# preview examples
for q, g, p in zip(val_df["text_query"].head(50), refs[:50], preds[:50]):
    print("---")
    print("Q:", q)
    print("Gold:", g)
    print("Pred:", p)


  from .autonotebook import tqdm as notebook_tqdm
Downloading builder script: 8.15kB [00:00, 12.0MB/s]


Exact Match: 0.02
BLEU: 12.50066987031684
---
Q: Find the ids of all the order items whose product id is 11.
Gold: SELECT order_item_id FROM order_items WHERE product_id = 11
Pred: SELECT id FROM order_item WHERE product_id = 11
---
Q: Return the low and high estimates for all film markets.
Gold: SELECT Low_Estimate ,  High_Estimate FROM film_market_estimation
Pred: SELECT MIN(budget) AS low, MAX(budget) AS high, AVG(budget) AS estimate FROM film
---
Q: Find the name and account balance of the customer whose name includes the letter ‘a’.
Gold: SELECT cust_name ,  acc_bal FROM customer WHERE cust_name LIKE '%a%'
Pred: SELECT name, account_balance FROM customer WHERE name LIKE '% %'
---
Q: What are the names of all the subjects.
Gold: SELECT subject_name FROM SUBJECTS
Pred: SELECT name, subject FROM table_name
---
Q: Find the entry name of the catalog with the highest price (in USD).
Gold: SELECT catalog_entry_name FROM catalog_contents ORDER BY price_in_dollars DESC LIMIT 1
Pred: SELECT