In [7]:
import sqlite3
import sqlparse
import difflib
from datasets import load_dataset
from transformers import pipeline
from rapidfuzz.distance import Levenshtein
import ollama

In [1]:
import pandas as pd

splits = {'train': 'synthetic_text_to_sql_train.snappy.parquet', 'test': 'synthetic_text_to_sql_test.snappy.parquet'}
df = pd.read_parquet("hf://datasets/gretelai/synthetic_text_to_sql/" + splits["train"])

In [2]:
t =  df.head(1)

In [3]:
for k in t.iterrows():
    print(t)

     id    domain                                 domain_description  \
0  5097  forestry  Comprehensive data on sustainable forest manag...   

  sql_complexity                   sql_complexity_description  \
0    single join  only one join (specify inner, outer, cross)   

             sql_task_type                          sql_task_type_description  \
0  analytics and reporting  generating reports, dashboards, and analytical...   

                                          sql_prompt  \
0  What is the total volume of timber sold by eac...   

                                         sql_context  \
0  CREATE TABLE salesperson (salesperson_id INT, ...   

                                                 sql  \
0  SELECT salesperson_id, name, SUM(volume) as to...   

                                     sql_explanation  
0  Joins timber_sales and salesperson tables, gro...  


In [4]:
nl_query = f"""### Instruction:
Convert the given natural language question into an accurate SQL query based on the provided database schema.

### Input:
Question: List all the unique equipment types and their corresponding total maintenance frequency from the equipment_maintenance table.
Schema:CREATE TABLE equipment_maintenance (equipment_type VARCHAR(255), maintenance_frequency INT);
### Output:
SQL Query:
"""

In [5]:
#SELECT equipment_type, SUM(maintenance_frequency) AS total_maintenance_frequency FROM equipment_maintenance GROUP BY equipment_type;

In [8]:
response = ollama.chat(model="llama3.2", messages=[{"role": "user", "content": nl_query}])

In [9]:
print(response.message.content)

SELECT DISTINCT equipment_type, SUM(maintenance_frequency) AS total_maintenance FROM equipment_maintenance GROUP BY equipment_type;


In [10]:
def generate_sql_with_ollama(nl_query):
    """Generates SQL using a locally hosted Ollama model."""
    response = ollama.chat(model="llama3.2", messages=[{"role": "user", "content": nl_query}])
    return response.message.content

In [11]:
def normalize_sql(query):
    """Formats SQL queries to standard format for comparison."""
    return sqlparse.format(query, reindent=True, keyword_case="upper").strip()

In [12]:
# def generate_sql(nl_query):
#     """Generates SQL using LLaMA model."""
#     response = generator(nl_query, max_length=128, truncation=True)
#     return response[0]['generated_text']

In [13]:
def execute_sql(query, conn):
    """Executes SQL query on SQLite database and returns result."""
    try:
        cursor = conn.cursor()
        cursor.execute(query)
        return cursor.fetchall()
    except Exception as e:
        print(f"Error executing sql query: {e}")
        return None  # Return None if query execution fails

In [14]:
df = df.sample(frac=1, random_state=45).reset_index(drop=True).head(5000)


In [15]:
df.shape[0]

5000

In [16]:
# Initialize evaluation metrics
exact_match_count = 0
execution_match_count = 0
levenshtein_scores = []
total_samples = df.shape[0]  # Adjust based on dataset size

In [17]:
conn = sqlite3.connect(":memory:")

In [18]:
df.columns

Index(['id', 'domain', 'domain_description', 'sql_complexity',
       'sql_complexity_description', 'sql_task_type',
       'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql',
       'sql_explanation'],
      dtype='object')

In [None]:
for i, row in df.iterrows():  # Use iterrows() for dictionary-style access
    if i >= total_samples:
        break

    nl_query = row["sql_prompt"]  # FIXED: Changed from text_input → sql_prompt
    ground_truth_sql = normalize_sql(row["sql"])  # FIXED: Changed from sql_output → sql

    # Generate SQL using LLaMA
    generated_sql = normalize_sql(generate_sql_with_ollama(nl_query))

    # Exact Match Evaluation
    if generated_sql == ground_truth_sql:
        exact_match_count += 1

    # # Execution Accuracy
    # gt_result = execute_sql(ground_truth_sql, conn)
    # gen_result = execute_sql(generated_sql, conn)

    # if gt_result == gen_result:
    #     execution_match_count += 1

    # Levenshtein Distance (String Similarity)
    lev_score = Levenshtein.normalized_distance(generated_sql, ground_truth_sql)
    print(lev_score)
    levenshtein_scores.append(lev_score)
    

0.7571801566579635


In [None]:
# Compute Evaluation Metrics
exact_match_acc = (exact_match_count / total_samples) * 100
execution_acc = (execution_match_count / total_samples) * 100
avg_levenshtein = sum(levenshtein_scores) / total_samples

In [None]:
# Print Results
# print(f"Exact Match Accuracy: {exact_match_acc:.2f}%")
# print(f"Execution Accuracy: {execution_acc:.2f}%")
print(f"Average Levenshtein Distance: {avg_levenshtein:.2f}")

Average Levenshtein Distance: 0.14


In [None]:
conn.close()   