# Loading the Dataset

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import zipfile
import os

zip_file_path = '/content/drive/MyDrive/spider_data-20241107T125352Z-001.zip'

extract_dir = '/content/drive/MyDrive/NLP_Project_2024'

os.makedirs(extract_dir, exist_ok=True)

try:
  with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
    zip_ref.extractall(extract_dir)
  print(f"Successfully unzipped '{zip_file_path}' to '{extract_dir}'")
except FileNotFoundError:
  print(f"Error: File not found at '{zip_file_path}'")
except zipfile.BadZipFile:
  print(f"Error: Invalid zip file '{zip_file_path}'")
except Exception as e:
  print(f"An error occurred: {e}")


# Preprocessing

1. Combine Schemas
2. Generate test questions

In [3]:
import os

def combine_database_schemas(root_folder, output_file):
    with open(output_file, 'w') as outfile:
        for dirpath, _, filenames in os.walk(root_folder):
            for filename in filenames:
                if filename.endswith('.sql'):
                    file_path = os.path.join(dirpath, filename)
                    subfolder_name = os.path.basename(dirpath)
                    print(f"Processing file: {file_path}")
                    outfile.write(f"-- {subfolder_name}\n")
                    with open(file_path, 'r') as infile:
                        outfile.write(infile.read())
                        outfile.write("\n\n")



In [None]:
database_folder = '/content/drive/MyDrive/NLP_Project_2024/spider_data/database/'
combine_database_schema_file = 'combine_database_schemas.sql'
combine_database_schemas(database_folder, combine_database_schema_file)

In [5]:
import json
import csv
import os

def json_to_csv_for_query_questions(json_file, csv_file, append=False):
    file_exists = os.path.isfile(csv_file)
    with open(json_file, 'r') as infile:
        data = json.load(infile)
    mode = 'a' if append else 'w'
    with open(csv_file, mode, newline='') as outfile:
        writer = csv.writer(outfile)
        if not append:
            writer.writerow(['db_id', 'query', 'question'])
        for obj in data:
            db_id = obj.get('db_id', '')
            query = obj.get('query', '')
            question = obj.get('question', '')
            writer.writerow([db_id, query, question])

In [6]:
json_questions_file_from_train = '/content/drive/MyDrive/NLP_Project_2024/spider_data/train_spider.json'
json_questions_file_from_test = '/content/drive/MyDrive/NLP_Project_2024/spider_data/test.json'
query_questions_csv_test = 'test_questions.csv'
json_to_csv_for_query_questions(json_questions_file_from_train, query_questions_csv_test,False)
json_to_csv_for_query_questions(json_questions_file_from_test, query_questions_csv_test,True)

# GPU Initialization

In [None]:
!pip install torch transformers bitsandbytes accelerate sqlparse

In [8]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
torch.cuda.is_available()

In [10]:
available_memory = torch.cuda.get_device_properties(0).total_memory

In [None]:
print(available_memory)
print(torch.cuda.device_count())
print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU found")

# Model - Llama 3

In [None]:
model_name = "defog/llama-3-sqlcoder-8b"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if available_memory > 20e9:
    # if you have atleast 20GB of GPU memory, run load the model in float16
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto",
        use_cache=True,
    )
else:
    # else, load in 4 bits – this is slower and less accurate
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        # torch_dtype=torch.float16,
        load_in_4bit=True,
        device_map="auto",
        use_cache=True,
    )

# Sample query generation

In [12]:
import sqlparse

def generate_query(prompt_template,question):
    updated_prompt = prompt_template.format(question=question)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
        temperature=0.0,
        top_p=1,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    # empty cache so that you do generate more results w/o memory crashing
    # particularly important on Colab – memory management is much more straightforward
    # when running on an inference service
    # return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)
    return outputs[0].split("```sql")[1].split(";")[0]

In [14]:
# prompt template with question only
prompt_template_coffee_shop_1 = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{question}`



<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{question}`:
```sql
"""

In [15]:
# prompt template with question and the DDL statements only
prompt_template_coffee_shop_2 = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{question}`

CREATE TABLE "shop" (
"Shop_ID" int,
"Address" text,
"Num_of_staff" text,
"Score" real,
"Open_Year" text,
PRIMARY KEY ("Shop_ID")
);

CREATE TABLE "member" (
"Member_ID" int,
"Name" text,
"Membership_card" text,
"Age" int,
"Time_of_purchase" int,
"Level_of_membership" int,
"Address" text,
PRIMARY KEY ("Member_ID")
);

CREATE TABLE "happy_hour" (
"HH_ID" int,
"Shop_ID" int,
"Month" text,
"Num_of_shaff_in_charge" int,
PRIMARY KEY ("HH_ID","Shop_ID","Month"),
FOREIGN KEY ("Shop_ID") REFERENCES `shop`("Shop_ID")
);

CREATE TABLE "happy_hour_member" (
"HH_ID" int,
"Member_ID" int,
"Total_amount" real,
PRIMARY KEY ("HH_ID","Member_ID"),
FOREIGN KEY ("Member_ID") REFERENCES `member`("Member_ID")
);

<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{question}`:
```sql
"""

In [16]:
# prompt template with question and the DDL statements and user feedback / ideas included
prompt_template_coffee_shop_3 = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{question}`

CREATE TABLE "shop" (
"Shop_ID" int,
"Address" text,
"Num_of_staff" text,
"Score" real,
"Open_Year" text,
PRIMARY KEY ("Shop_ID")
);

CREATE TABLE "member" (
"Member_ID" int,
"Name" text,
"Membership_card" text,
"Age" int,
"Time_of_purchase" int,
"Level_of_membership" int,
"Address" text,
PRIMARY KEY ("Member_ID")
);

CREATE TABLE "happy_hour" (
"HH_ID" int,
"Shop_ID" int,
"Month" text,
"Num_of_shaff_in_charge" int,
PRIMARY KEY ("HH_ID","Shop_ID","Month"),
FOREIGN KEY ("Shop_ID") REFERENCES `shop`("Shop_ID")
);

CREATE TABLE "happy_hour_member" (
"HH_ID" int,
"Member_ID" int,
"Total_amount" real,
PRIMARY KEY ("HH_ID","Member_ID"),
FOREIGN KEY ("Member_ID") REFERENCES `member`("Member_ID")
);

-- Use shop as t1 and happy_hour as t2
<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{question}`:
```sql
"""

In [17]:
question_coffee_shop = "What are the id and address of the shops which have a happy hour in May??"

In [None]:
generated_sql_template_coffee_shop_1 = generate_query(prompt_template_coffee_shop_1,question_coffee_shop)
print(sqlparse.format(generated_sql_template_coffee_shop_1, reindent=True))

In [None]:
generated_sql_template_coffee_shop_2 = generate_query(prompt_template_coffee_shop_2,question_coffee_shop)
print(sqlparse.format(generated_sql_template_coffee_shop_2, reindent=True))

In [None]:
generated_sql_template_coffee_shop_3 = generate_query(prompt_template_coffee_shop_3,question_coffee_shop)
print(sqlparse.format(generated_sql_template_coffee_shop_3, reindent=True))

In [21]:
# prompt template with question only
prompt_template_flight_1 = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{question}`

<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{question}`:
```sql
"""

In [22]:
# prompt template with question and the DDL statements only
prompt_template_flight_2 = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{question}`

create table flight(
	flno number(4,0) primary key,
	origin varchar2(20),
	destination varchar2(20),
	distance number(6,0),
	departure_date date,
	arrival_date date,
	price number(7,2),
    aid number(9,0),
    foreign key("aid") references `aircraft`("aid"));

create table aircraft(
	aid number(9,0) primary key,
	name varchar2(30),
	distance number(6,0));

create table employee(
	eid number(9,0) primary key,
	name varchar2(30),
	salary number(10,2));

create table certificate(
	eid number(9,0),
	aid number(9,0),
	primary key(eid,aid),
	foreign key("eid") references `employee`("eid"),
	foreign key("aid") references `aircraft`("aid"));

<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{question}`:
```sql
"""

In [23]:
# prompt template with question and the DDL statements and user feedback / ideas included
prompt_template_flight_3 = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{question}`
-- Important: You must use the INTERSECT clause instead of the IN clause.
-- For any comparison involving lists of values, use INTERSECT between SELECT statements.
-- Example:
-- Instead of:
--   SELECT * FROM table1 WHERE column1 IN (SELECT column2 FROM table2);
-- Use:
--   SELECT * FROM table1 WHERE column1 = ANY (SELECT column2 FROM table1 INTERSECT SELECT column2 FROM table2);

DDL statements:

create table flight(
	flno number(4,0) primary key,
	origin varchar2(20),
	destination varchar2(20),
	distance number(6,0),
	departure_date date,
	arrival_date date,
	price number(7,2),
    aid number(9,0),
    foreign key("aid") references `aircraft`("aid"));

create table aircraft(
	aid number(9,0) primary key,
	name varchar2(30),
	distance number(6,0));

create table employee(
	eid number(9,0) primary key,
	name varchar2(30),
	salary number(10,2));

create table certificate(
	eid number(9,0),
	aid number(9,0),
	primary key(eid,aid),
	foreign key("eid") references `employee`("eid"),
	foreign key("aid") references `aircraft`("aid"));

<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{question}`:
```sql
"""

In [24]:
question_flight = "Show names for all employees who have certificates on both Boeing 737-800 and Airbus A340-300?"

In [None]:
generated_sql_template_fligt_1 = generate_query(prompt_template_flight_1,question_flight)
print(sqlparse.format(generated_sql_template_fligt_1, reindent=True))

In [None]:
generated_sql_template_fligt_2 = generate_query(prompt_template_flight_2,question_flight)
print(sqlparse.format(generated_sql_template_fligt_2, reindent=True))

In [None]:
generated_sql_template_fligt_3 = generate_query(prompt_template_flight_3,question_flight)
print(sqlparse.format(generated_sql_template_fligt_3, reindent=True))

# Query Generation

Generate queries for the first 450 test questions

In [None]:
# Prompt 1

import re
from tqdm import tqdm
import pandas as pd

# Mention the number of questions to generate query
questions_df = pd.read_csv("/content/test_questions.csv").to_dict(orient="records")[:500]

with open("/content/combine_database_schemas.sql", "r") as schema_file:
    schema_data = schema_file.read()

# Function to extract schema for a specific db_id and exclude INSERT statements
def extract_schema(db_id, schema_data):
    db_start = schema_data.find(f"-- {db_id}")
    if db_start == -1:
        return None  # No schema for this db_id
    db_end = schema_data.find("--", db_start + 1)
    schema = schema_data[db_start:db_end].strip()

    # Keep only the CREATE TABLE statements, exclude INSERT statements
    create_only = []
    for line in schema.splitlines():
        if line.strip().startswith("INSERT INTO"):
            break
        create_only.append(line)

    return "\n".join(create_only)

def generate_prompts(questions, schema_data):
    for i, row in enumerate(questions):
        db_id = row['db_id']
        question = row['question']
        schema = extract_schema(db_id, schema_data)

        if schema:
            prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{question}`

{schema}

<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{question}`:
```sql
"""
            yield prompt



with open("generated_queries.txt", "w") as output_file:
    for prompt in tqdm(generate_prompts(questions_df, schema_data), total=len(questions_df)):

        question_match = re.search(r'Generate a SQL query to answer this question: `(.*?)`', prompt)
        if question_match:
            question = question_match.group(1)

            generated_query = generate_query(prompt, question)

            output_file.write(generated_query + "\n")


In [None]:
# Prompt 2

import re
from tqdm import tqdm
import pandas as pd

questions_df = pd.read_csv("/content/test_questions.csv").to_dict(orient="records")[:500]

with open("/content/combine_database_schemas.sql", "r") as schema_file:
    schema_data = schema_file.read()

# Function to extract schema for a specific db_id and exclude INSERT statements
def extract_schema(db_id, schema_data):
    db_start = schema_data.find(f"-- {db_id}")
    if db_start == -1:
        return None  # No schema for this db_id
    db_end = schema_data.find("--", db_start + 1)
    schema = schema_data[db_start:db_end].strip()

    # Keep only the CREATE TABLE statements, exclude INSERT statements
    create_only = []
    for line in schema.splitlines():
        if line.strip().startswith("INSERT INTO"):
            break
        create_only.append(line)

    return "\n".join(create_only)

def generate_prompts(questions, schema_data):
    for i, row in enumerate(questions):
        db_id = row['db_id']
        question = row['question']
        schema = extract_schema(db_id, schema_data)

        if schema:
            prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{question}`
-- Important: You must use the INTERSECT clause instead of the IN clause.
-- For any comparison involving lists of values, use INTERSECT between SELECT statements.
-- Example:
-- Instead of:
--   SELECT * FROM table1 WHERE column1 IN (SELECT column2 FROM table2);
-- Use:
--   SELECT * FROM table1 WHERE column1 = ANY (SELECT column2 FROM table1 INTERSECT SELECT column2 FROM table2);

DDL statements:
{schema}

<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{question}`:
```sql
"""
            yield prompt



with open("generated_queries.txt", "w") as output_file:
    for prompt in tqdm(generate_prompts(questions_df, schema_data), total=len(questions_df)):

        question_match = re.search(r'Generate a SQL query to answer this question: `(.*?)`', prompt)
        if question_match:
            question = question_match.group(1)

            generated_query = generate_query(prompt, question)

            output_file.write(generated_query + "\n")


In [None]:
# Prompt 3

import re
from tqdm import tqdm
import pandas as pd

questions_df = pd.read_csv("/content/test_questions.csv").to_dict(orient="records")[:500]

with open("/content/combine_database_schemas.sql", "r") as schema_file:
    schema_data = schema_file.read()

# Function to extract schema for a specific db_id and exclude INSERT statements
def extract_schema(db_id, schema_data):
    db_start = schema_data.find(f"-- {db_id}")
    if db_start == -1:
        return None  # No schema for this db_id
    db_end = schema_data.find("--", db_start + 1)
    schema = schema_data[db_start:db_end].strip()

    # Keep only the CREATE TABLE statements, exclude INSERT statements
    create_only = []
    for line in schema.splitlines():
        if line.strip().startswith("INSERT INTO"):
            break
        create_only.append(line)

    return "\n".join(create_only)

def generate_prompts(questions, schema_data):
    for i, row in enumerate(questions):
        db_id = row['db_id']
        question = row['question']
        schema = extract_schema(db_id, schema_data)

        if schema:
            prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{question}`
-- Important: You must use the INTERSECT clause instead of the IN clause.
-- For any comparison involving lists of values, use INTERSECT between SELECT statements.
-- Example:
-- Instead of:
--   SELECT * FROM table1 WHERE column1 IN (SELECT column2 FROM table2);
-- Use:
--   SELECT * FROM table1 WHERE column1 = ANY (SELECT column2 FROM table1 INTERSECT SELECT column2 FROM table2);

Requirements:
-- Correctness: Ensure all joins, conditions and aggregations are correct based on the schema and avoid 'NULL' issues where applicable.
-- Efficiency: Avoid unnecessary columns, tables and joins to enhance efficiency.
-- Clarity: Use table aliases to clarify complex logic.
-- Limitation: Ensure only one column is selected per query.

Normalization Consideration:
-- Ensure that the schema adheres to at least 3NF (Third Normal Form), which involves eliminating redundant data and ensuring that all non-key attributes are fully dependent on the primary key.
-- If applicable, normalize tables to reduce data anomalies, ensuring that each table represents a single entity or concept.
-- Avoid repeating groups of data and ensure that relationships between tables are appropriately represented through foreign keys.

DDL statements:
{schema}

<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{question}`:
```sql
"""
            yield prompt



with open("generated_queries.txt", "w") as output_file:
    for prompt in tqdm(generate_prompts(questions_df, schema_data), total=len(questions_df)):

        question_match = re.search(r'Generate a SQL query to answer this question: `(.*?)`', prompt)
        if question_match:
            question = question_match.group(1)

            generated_query = generate_query(prompt, question)

            output_file.write(generated_query + "\n")


In [None]:
with open("generated_queries.txt", "r") as file:
    lines = file.readlines()

# Find the index of the 450th SQL query
query_count = 0
new_lines = []
for line in lines:
    if line.strip().startswith("SELECT"):
        query_count += 1
    new_lines.append(line)

    if query_count >= 450:
        break

with open("generated_queries.txt", "w") as file:
    file.writelines(new_lines)

print("File truncated to the first 450 queries.")


Compare generated queries and actual queries for EXACT MATCH

In [33]:
import pandas as pd

questions_df = pd.read_csv("/content/test_questions.csv")

with open("generated_queries.txt", "r") as generated_file:
    generated_queries = [line.strip() for line in generated_file.readlines() if line.strip()]

assert len(generated_queries) == 450, "The number of queries in generated_queries.txt is not 450!"

with open("comparison_results.txt", "w") as output_file:
    for idx, (generated_query, correct_query) in enumerate(zip(generated_queries, questions_df.iloc[:450, 1])):

        generated_query = generated_query.strip()
        correct_query = correct_query.strip()

        if generated_query == correct_query:
            result = 'Y'
        else:
            result = 'N'

        output_file.write(f"{generated_query}\n{result}\n\n")


In [None]:
count_Y = 0
count_N = 0

# Read the file and count occurrences of 'Y' and 'N'
with open("comparison_results.txt", "r") as file:
    lines = file.readlines()

    for line in lines:
        line = line.strip()
        if line == "Y":
            count_Y += 1
        elif line == "N":
            count_N += 1

print(f"Number of Y's: {count_Y}")
print(f"Number of N's: {count_N}")


Generate databases from the schema files

In [None]:
questions_df = pd.read_csv("/content/test_questions.csv")

db_ids = questions_df.iloc[:450]['db_id'].unique()
db_ids_list = db_ids.tolist()

print(len(db_ids_list))


In [None]:
import sqlite3
import os
import re

schema_file = 'combine_database_schemas.sql'
output_dir = './databases/'

os.makedirs(output_dir, exist_ok=True)

with open(schema_file, 'r') as file:
    schema_content = file.read()

# Regex pattern to capture db_id and associated SQL commands
db_id_pattern = r'--\s*(\w+)\s*(.*?)(?=--\s*\w+|\Z)'

# Find all matches for db_id and schema content
matches = re.findall(db_id_pattern, schema_content, re.DOTALL)

# Create a database for each db_id in db_id_list
for db_id, db_content in matches:
    if db_id in db_ids_list:
        db_file_path = os.path.join(output_dir, f"{db_id}.db")

        conn = sqlite3.connect(db_file_path)
        cursor = conn.cursor()

        try:
            cursor.executescript(db_content.strip())
            conn.commit()
            print(f"Database '{db_id}.db' created successfully.")
        except sqlite3.Error as e:
            print(f"Error while creating database '{db_id}.db': {e}")
        finally:
            conn.close()

print("Selected databases created successfully.")


Compare outputs of generated queries and actual queries

In [None]:
import pandas as pd
import sqlite3
import os

comparison_file_path = "comparison_results.txt"
test_questions_csv_path = "/content/test_questions.csv"
databases_folder = "/content/databases/"

with open(comparison_file_path, "r") as comparison_file:
    comparison_results = comparison_file.readlines()

test_questions = pd.read_csv(test_questions_csv_path).iloc[:450]

final_comparison_results = []

def run_query(db_id, query):
    try:
        db_file_path = os.path.join(databases_folder, f"{db_id}.db")
        if not os.path.exists(db_file_path):
            print(f"Database file {db_file_path} does not exist.")
            return None

        conn = sqlite3.connect(db_file_path)
        cursor = conn.cursor()
        cursor.execute(query)
        results = cursor.fetchall()
        conn.close()
        return results
    except Exception as e:
        print(f"Error running query on {db_id}: {e}")
        return None

for idx in range(0, len(comparison_results), 3):  # Skip by 3 since each query has 3 lines in comparison_results.txt

    generated_query = comparison_results[idx].strip()
    result = comparison_results[idx + 1].strip()

    db_id = test_questions.iloc[idx // 3, 0].strip()  # db_id from the first column
    correct_query = test_questions.iloc[idx // 3, 1].strip()  # correct query from the second column

    generated_query_results = run_query(db_id, generated_query)
    correct_query_results = run_query(db_id, correct_query)

    # If the result is 'N', compare the generated query with the correct query
    if result == 'N':
        if generated_query_results == correct_query_results:
            final_comparison_results.append(f"{generated_query}\nY\n\n")
        else:
            final_comparison_results.append(f"{generated_query}\nN\n\n")
    else:
        # If the result is 'Y', append it as is
        final_comparison_results.append(f"{generated_query}\n{result}\n\n")

with open("final_comparison_results.txt", "w") as output_file:
    output_file.writelines(final_comparison_results)

print("Comparison results written to 'final_comparison_results.txt'.")


In [None]:
count_Y = 0
count_N = 0

with open("final_comparison_results.txt", "r") as file:
    lines = file.readlines()

    for line in lines:
        line = line.strip()
        if line == "Y":
            count_Y += 1
        elif line == "N":
            count_N += 1

print(f"Number of Y's: {count_Y}")
print(f"Number of N's: {count_N}")


# Model -

In [None]:
pip install llama-stack

In [None]:
import requests

response = requests.get("https://huggingface.co/api/models?search=meta-llama")

if response.status_code == 200:
    models = response.json()
    for model in models:
        print(model["modelId"])
else:
    print(f"Error fetching models: {response.status_code}")


In [None]:
from huggingface_hub import login

login("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")


In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

model_name = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)

if torch.cuda.get_device_properties(0).total_memory > 20e9:
    # If there's enough memory, load in 16-bit
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="auto",
        use_cache=True,
        use_auth_token=True
    )
else:
    # If less memory is available, load in 4-bit precision
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        load_in_4bit=True,
        device_map="auto",
        use_cache=True,
        use_auth_token=True
    )


In [33]:
import sqlparse

def generate_query(prompt_template,question):
    updated_prompt = prompt_template.format(question=question)
    inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
    generated_ids = model.generate(
        **inputs,
        num_return_sequences=1,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
        max_new_tokens=400,
        do_sample=False,
        num_beams=1,
        temperature=0.0,
        top_p=1,
    )
    outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    # empty cache so that you do generate more results w/o memory crashing
    # particularly important on Colab – memory management is much more straightforward
    # when running on an inference service
    # return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)
    return outputs[0].split("```sql")[1].split(";")[0]

In [None]:
# Prompt 1

import re
from tqdm import tqdm
import pandas as pd

# Mention the number of questions to generate query
questions_df = pd.read_csv("/content/test_questions.csv").to_dict(orient="records")[:450]

with open("/content/combine_database_schemas.sql", "r") as schema_file:
    schema_data = schema_file.read()

# Function to extract schema for a specific db_id and exclude INSERT statements
def extract_schema(db_id, schema_data):
    db_start = schema_data.find(f"-- {db_id}")
    if db_start == -1:
        return None  # No schema for this db_id
    db_end = schema_data.find("--", db_start + 1)
    schema = schema_data[db_start:db_end].strip()

    # Keep only the CREATE TABLE statements, exclude INSERT statements
    create_only = []
    for line in schema.splitlines():
        if line.strip().startswith("INSERT INTO"):
            break
        create_only.append(line)

    return "\n".join(create_only)

def generate_prompts(questions, schema_data):
    for i, row in enumerate(questions):
        db_id = row['db_id']
        question = row['question']
        schema = extract_schema(db_id, schema_data)

        if schema:
            prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{question}`

{schema}

<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{question}`:
```sql
"""
            yield prompt



with open("generated_queries.txt", "w") as output_file:
    for prompt in tqdm(generate_prompts(questions_df, schema_data), total=len(questions_df)):

        question_match = re.search(r'Generate a SQL query to answer this question: `(.*?)`', prompt)
        if question_match:
            question = question_match.group(1)

            generated_query = generate_query(prompt, question)

            output_file.write(generated_query + "\n")


In [None]:
# Prompt 2

import re
from tqdm import tqdm
import pandas as pd

questions_df = pd.read_csv("/content/test_questions.csv").to_dict(orient="records")[:500]

with open("/content/combine_database_schemas.sql", "r") as schema_file:
    schema_data = schema_file.read()

# Function to extract schema for a specific db_id and exclude INSERT statements
def extract_schema(db_id, schema_data):
    db_start = schema_data.find(f"-- {db_id}")
    if db_start == -1:
        return None  # No schema for this db_id
    db_end = schema_data.find("--", db_start + 1)
    schema = schema_data[db_start:db_end].strip()

    # Keep only the CREATE TABLE statements, exclude INSERT statements
    create_only = []
    for line in schema.splitlines():
        if line.strip().startswith("INSERT INTO"):
            break
        create_only.append(line)

    return "\n".join(create_only)

def generate_prompts(questions, schema_data):
    for i, row in enumerate(questions):
        db_id = row['db_id']
        question = row['question']
        schema = extract_schema(db_id, schema_data)

        if schema:
            prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{question}`
-- Important: You must use the INTERSECT clause instead of the IN clause.
-- For any comparison involving lists of values, use INTERSECT between SELECT statements.
-- Example:
-- Instead of:
--   SELECT * FROM table1 WHERE column1 IN (SELECT column2 FROM table2);
-- Use:
--   SELECT * FROM table1 WHERE column1 = ANY (SELECT column2 FROM table1 INTERSECT SELECT column2 FROM table2);

DDL statements:
{schema}

<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{question}`:
```sql
"""
            yield prompt



with open("generated_queries.txt", "w") as output_file:
    for prompt in tqdm(generate_prompts(questions_df, schema_data), total=len(questions_df)):

        question_match = re.search(r'Generate a SQL query to answer this question: `(.*?)`', prompt)
        if question_match:
            question = question_match.group(1)

            generated_query = generate_query(prompt, question)

            output_file.write(generated_query + "\n")


In [None]:
# Prompt 3

import re
from tqdm import tqdm
import pandas as pd

questions_df = pd.read_csv("/content/test_questions.csv").to_dict(orient="records")[:500]

with open("/content/combine_database_schemas.sql", "r") as schema_file:
    schema_data = schema_file.read()

# Function to extract schema for a specific db_id and exclude INSERT statements
def extract_schema(db_id, schema_data):
    db_start = schema_data.find(f"-- {db_id}")
    if db_start == -1:
        return None  # No schema for this db_id
    db_end = schema_data.find("--", db_start + 1)
    schema = schema_data[db_start:db_end].strip()

    # Keep only the CREATE TABLE statements, exclude INSERT statements
    create_only = []
    for line in schema.splitlines():
        if line.strip().startswith("INSERT INTO"):
            break
        create_only.append(line)

    return "\n".join(create_only)

def generate_prompts(questions, schema_data):
    for i, row in enumerate(questions):
        db_id = row['db_id']
        question = row['question']
        schema = extract_schema(db_id, schema_data)

        if schema:
            prompt = f"""<|begin_of_text|><|start_header_id|>user<|end_header_id|>

Generate a SQL query to answer this question: `{question}`
-- Important: You must use the INTERSECT clause instead of the IN clause.
-- For any comparison involving lists of values, use INTERSECT between SELECT statements.
-- Example:
-- Instead of:
--   SELECT * FROM table1 WHERE column1 IN (SELECT column2 FROM table2);
-- Use:
--   SELECT * FROM table1 WHERE column1 = ANY (SELECT column2 FROM table1 INTERSECT SELECT column2 FROM table2);

Requirements:
-- Correctness: Ensure all joins, conditions and aggregations are correct based on the schema and avoid 'NULL' issues where applicable.
-- Efficiency: Avoid unnecessary columns, tables and joins to enhance efficiency.
-- Clarity: Use table aliases to clarify complex logic.
-- Limitation: Ensure only one column is selected per query.

Normalization Consideration:
-- Ensure that the schema adheres to at least 3NF (Third Normal Form), which involves eliminating redundant data and ensuring that all non-key attributes are fully dependent on the primary key.
-- If applicable, normalize tables to reduce data anomalies, ensuring that each table represents a single entity or concept.
-- Avoid repeating groups of data and ensure that relationships between tables are appropriately represented through foreign keys.

DDL statements:
{schema}

<|eot_id|><|start_header_id|>assistant<|end_header_id|>

The following SQL query best answers the question `{question}`:
```sql
"""
            yield prompt



with open("generated_queries.txt", "w") as output_file:
    for prompt in tqdm(generate_prompts(questions_df, schema_data), total=len(questions_df)):

        question_match = re.search(r'Generate a SQL query to answer this question: `(.*?)`', prompt)
        if question_match:
            question = question_match.group(1)

            generated_query = generate_query(prompt, question)

            output_file.write(generated_query + "\n")


In [43]:
import pandas as pd

# Read the CSV containing the correct queries
questions_df = pd.read_csv("/content/test_questions.csv")

# Read the generated queries from the file
with open("generated_queries.txt", "r") as generated_file:
    generated_queries = []
    current_query = []

    # Read line by line
    for line in generated_file:
        line = line.strip()

        if line.startswith("SELECT"):  # If it's the start of a new query
            # If we have a query already in progress, add it to the list
            if current_query:
                generated_queries.append(" ".join(current_query))
                current_query = []  # Start a new query
            current_query.append(line)  # Add the new SELECT line to the query
        elif line:  # Continue adding lines to the current query
            current_query.append(line)

    # Append the last query if there's one left unprocessed
    if current_query:
        generated_queries.append(" ".join(current_query))

# Verify that the number of generated queries matches the expected count
assert len(generated_queries) == 462, f"The number of queries in generated_queries.txt is not 450! Found {len(generated_queries)} queries."

# Write the comparison results to the output file
with open("comparison_results.txt", "w") as output_file:
    for idx, (generated_query, correct_query) in enumerate(zip(generated_queries, questions_df.iloc[:450, 1])):
        generated_query = generated_query.strip()
        correct_query = correct_query.strip()

        # Check if the generated query matches the correct query
        result = 'Y' if generated_query == correct_query else 'N'

        output_file.write(f"{generated_query}\n{result}\n\n")


In [None]:
questions_df = pd.read_csv("/content/test_questions.csv")

db_ids = questions_df.iloc[:450]['db_id'].unique()
db_ids_list = db_ids.tolist()

print(len(db_ids_list))


In [None]:
import sqlite3
import os
import re

schema_file = 'combine_database_schemas.sql'
output_dir = './databases/'

os.makedirs(output_dir, exist_ok=True)

with open(schema_file, 'r') as file:
    schema_content = file.read()

# Regex pattern to capture db_id and associated SQL commands
db_id_pattern = r'--\s*(\w+)\s*(.*?)(?=--\s*\w+|\Z)'

# Find all matches for db_id and schema content
matches = re.findall(db_id_pattern, schema_content, re.DOTALL)

# Create a database for each db_id in db_id_list
for db_id, db_content in matches:
    if db_id in db_ids_list:
        db_file_path = os.path.join(output_dir, f"{db_id}.db")

        conn = sqlite3.connect(db_file_path)
        cursor = conn.cursor()

        try:
            cursor.executescript(db_content.strip())
            conn.commit()
            print(f"Database '{db_id}.db' created successfully.")
        except sqlite3.Error as e:
            print(f"Error while creating database '{db_id}.db': {e}")
        finally:
            conn.close()

print("Selected databases created successfully.")


In [None]:
import pandas as pd
import sqlite3
import os

comparison_file_path = "comparison_results.txt"
test_questions_csv_path = "/content/test_questions.csv"
databases_folder = "/content/databases/"

with open(comparison_file_path, "r") as comparison_file:
    comparison_results = comparison_file.readlines()

test_questions = pd.read_csv(test_questions_csv_path).iloc[:450]

final_comparison_results = []

def run_query(db_id, query):
    try:
        db_file_path = os.path.join(databases_folder, f"{db_id}.db")
        if not os.path.exists(db_file_path):
            print(f"Database file {db_file_path} does not exist.")
            return None

        conn = sqlite3.connect(db_file_path)
        cursor = conn.cursor()
        cursor.execute(query)
        results = cursor.fetchall()
        conn.close()
        return results
    except Exception as e:
        print(f"Error running query on {db_id}: {e}")
        return None

for idx in range(0, len(comparison_results), 3):  # Skip by 3 since each query has 3 lines in comparison_results.txt

    generated_query = comparison_results[idx].strip()
    result = comparison_results[idx + 1].strip()

    db_id = test_questions.iloc[idx // 3, 0].strip()  # db_id from the first column
    correct_query = test_questions.iloc[idx // 3, 1].strip()  # correct query from the second column

    generated_query_results = run_query(db_id, generated_query)
    correct_query_results = run_query(db_id, correct_query)

    # If the result is 'N', compare the generated query with the correct query
    if result == 'N':
        if generated_query_results == correct_query_results:
            final_comparison_results.append(f"{generated_query}\nY\n\n")
        else:
            final_comparison_results.append(f"{generated_query}\nN\n\n")
    else:
        # If the result is 'Y', append it as is
        final_comparison_results.append(f"{generated_query}\n{result}\n\n")

with open("final_comparison_results.txt", "w") as output_file:
    output_file.writelines(final_comparison_results)

print("Comparison results written to 'final_comparison_results.txt'.")


In [None]:
count_Y = 0
count_N = 0

with open("final_comparison_results.txt", "r") as file:
    lines = file.readlines()

    for line in lines:
        line = line.strip()
        if line == "Y":
            count_Y += 1
        elif line == "N":
            count_N += 1

print(f"Number of Y's: {count_Y}")
print(f"Number of N's: {count_N}")
