In [8]:
import sqlite3
import sqlparse
import difflib
from datasets import load_dataset
from transformers import pipeline
from rapidfuzz.distance import Levenshtein
import ollama

In [2]:
import pandas as pd

splits = {'train': 'synthetic_text_to_sql_train.snappy.parquet', 'test': 'synthetic_text_to_sql_test.snappy.parquet'}
df = pd.read_parquet("hf://datasets/gretelai/synthetic_text_to_sql/" + splits["train"])

In [3]:
t =  df.head(1)

In [4]:
for k in t.iterrows():
    print(t)

     id    domain                                 domain_description  \
0  5097  forestry  Comprehensive data on sustainable forest manag...   

  sql_complexity                   sql_complexity_description  \
0    single join  only one join (specify inner, outer, cross)   

             sql_task_type                          sql_task_type_description  \
0  analytics and reporting  generating reports, dashboards, and analytical...   

                                          sql_prompt  \
0  What is the total volume of timber sold by eac...   

                                         sql_context  \
0  CREATE TABLE salesperson (salesperson_id INT, ...   

                                                 sql  \
0  SELECT salesperson_id, name, SUM(volume) as to...   

                                     sql_explanation  
0  Joins timber_sales and salesperson tables, gro...  


In [5]:
nl_query = f"""### Instruction:
Convert the given natural language question into an accurate SQL query based on the provided database schema.

### Input:
Question: List all the unique equipment types and their corresponding total maintenance frequency from the equipment_maintenance table.
Schema:CREATE TABLE equipment_maintenance (equipment_type VARCHAR(255), maintenance_frequency INT);
### Output:
SQL Query:
"""

In [6]:
#SELECT equipment_type, SUM(maintenance_frequency) AS total_maintenance_frequency FROM equipment_maintenance GROUP BY equipment_type;

In [9]:
response = ollama.chat(model="llama3.2", messages=[{"role": "user", "content": nl_query}])

In [10]:
print(response.message.content)

SELECT DISTINCT equipment_type, SUM(maintenance_frequency) FROM equipment_maintenance GROUP BY equipment_type;


In [11]:
def generate_sql_with_ollama(nl_query):
    """Generates SQL using a locally hosted Ollama model."""
    response = ollama.chat(model="llama3.2", messages=[{"role": "user", "content": nl_query}])
    return response.message.content

In [12]:
def normalize_sql(query):
    """Formats SQL queries to standard format for comparison."""
    return sqlparse.format(query, reindent=True, keyword_case="upper").strip()

In [13]:
# def generate_sql(nl_query):
#     """Generates SQL using LLaMA model."""
#     response = generator(nl_query, max_length=128, truncation=True)
#     return response[0]['generated_text']

In [14]:
def execute_sql(query, conn):
    """Executes SQL query on SQLite database and returns result."""
    try:
        cursor = conn.cursor()
        cursor.execute(query)
        return cursor.fetchall()
    except Exception as e:
        print(f"Error executing sql query: {e}")
        return None  # Return None if query execution fails

In [15]:
df = df.sample(frac=1, random_state=45).reset_index(drop=True).head(5000)


In [16]:
df.shape[0]

5000

In [17]:
# Initialize evaluation metrics
exact_match_count = 0
execution_match_count = 0
levenshtein_scores = []
total_samples = df.shape[0]  # Adjust based on dataset size

In [18]:
conn = sqlite3.connect(":memory:")

In [19]:
df.columns

Index(['id', 'domain', 'domain_description', 'sql_complexity',
       'sql_complexity_description', 'sql_task_type',
       'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql',
       'sql_explanation'],
      dtype='object')

In [20]:
for i, row in df.iterrows():  # Use iterrows() for dictionary-style access
    if i >= total_samples:
        break

    nl_query = row["sql_prompt"]  # FIXED: Changed from text_input → sql_prompt
    ground_truth_sql = normalize_sql(row["sql"])  # FIXED: Changed from sql_output → sql

    # Generate SQL using LLaMA
    generated_sql = normalize_sql(generate_sql_with_ollama(nl_query))

    # Exact Match Evaluation
    if generated_sql == ground_truth_sql:
        exact_match_count += 1

    # # Execution Accuracy
    # gt_result = execute_sql(ground_truth_sql, conn)
    # gen_result = execute_sql(generated_sql, conn)

    # if gt_result == gen_result:
    #     execution_match_count += 1

    # Levenshtein Distance (String Similarity)
    lev_score = Levenshtein.normalized_distance(generated_sql, ground_truth_sql)
    print(lev_score)
    levenshtein_scores.append(lev_score)
    

0.8444444444444444
0.8100418410041841
0.9249478804725504
0.989699050696829
0.8363636363636363
0.9406934306569343
0.9673590504451038
0.9303857008466604
0.8930402930402931
0.9717348927875243
0.9556962025316456
0.9631027253668764
0.8455284552845529
0.8874538745387454
0.9333333333333333
0.9787141615986099
0.9646118721461188
0.9447938504542278
0.891832229580574
0.9417040358744395
0.936
0.9296394019349165
0.9326923076923077
0.8393574297188755
0.7275822928490352
0.923472301541976
0.8294573643410853
0.9568661971830986
0.9646182495344506
0.8882733148661126
0.84688995215311
0.975671750181554
0.972636815920398
0.526595744680851
0.9648205371999049
0.9481373265157049
0.9713701431492843
0.9165202108963093
0.9609181141439206
0.9749821300929236
0.6785714285714286
0.8333333333333334
0.7784431137724551
0.9117432530999271
0.8768382352941176
0.909330985915493
0.9785407725321889
0.9278026905829596
0.8388888888888889
0.9343832020997376
0.7714285714285715
0.9613874345549738
0.9557632398753894
0.8800393313667

KeyboardInterrupt: 

In [21]:
# Compute Evaluation Metrics
exact_match_acc = (exact_match_count / total_samples) * 100
execution_acc = (execution_match_count / total_samples) * 100
avg_levenshtein = sum(levenshtein_scores) / total_samples

In [22]:
# Print Results
# print(f"Exact Match Accuracy: {exact_match_acc:.2f}%")
# print(f"Execution Accuracy: {execution_acc:.2f}%")
print(f"Average Levenshtein Distance: {avg_levenshtein:.2f}")

Average Levenshtein Distance: 0.14


In [None]:
conn.close()   