In [None]:
!pip install rag datasets trl bitsandbytes wandb Levenshtein

In [None]:
import os
import sys
from huggingface_hub import login

hf_token = ""
login(token=hf_token)


In [None]:
import torch
from transformers import AutoTokenizer, BitsAndBytesConfig, pipeline, AutoModelForCausalLM
import json
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from datasets import load_dataset, concatenate_datasets

import rag
import sqlparse
from transformers import TrainingArguments
from trl import SFTTrainer, SFTConfig
from transformers import AutoTokenizer, AutoModelForCausalLM
import sqlite3
from itertools import permutations
from Levenshtein import distance as levenshtein_distance
import torch
import bitsandbytes
from trl import setup_chat_format, DataCollatorForCompletionOnlyLM
from peft import LoraConfig, AutoPeftModelForCausalLM,get_peft_model
import wandb

In [None]:
import re

In [None]:
!wandb login --relogin 98aef964c1353cac148e079ff52355152f7935b3

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

model_id = "Llama-3.2-1B-Instruct"

In [None]:
model_to_train = "meta-llama/Llama-3.2-1B-Instruct"
model_name = "Llama-3.2-1B-Instruct"
datasets_path = "../spider_datasets"
databases_path = "../spider_databases"
dataset_on_hub = "rakshithjoseph/spider_data"
hf_token_read = "****"
hf_token_write = "****"
hf_username = "rakshithjoseph"
train = True
epochs = 1
num_train = -1
evalu = True
test = True
nb_rag_samples = 3
quantize = True
seed = 42
ref_correct = True


Data preparation

In [None]:
print(dataset)

In [None]:
def data_prep():
    

    #eos_token = "<|im_end|>"
    system_message = """You are a helpful assistant that generates SQL queries to answer questions about database tables.
You will receive: One or more SQL CREATE TABLE statements describing the structure of the tables.INSERT statements or a description of the data types for each column (you do not need to use the actual data).
A natural language question about the data. Your task is to generate the correct and most efficient SQL query that answers the user's question, based only on the provided schema.
Only output the SQL query and nothing else. Do not explain the answer."""
    
    user_message = """Here is the schema of the tables you will use:
    {schema}
    
    -- {question}
    SELECT"""
    dataset = load_dataset(dataset_on_hub)
    def create_conv(sample):
            q = sample["query"]
            lower_q = q.lower()
            ind = lower_q.find("select") + 6
            q = q[ind:]
            return {
                    "messages": [
                    {"role": "system", "content": system_message},
                    {"role": "user", "content": user_message.format(schema=sample["schema"], question=sample["question"])},
                    {"role": "assistant", "content": q}
                    ]
                }
    
    print("Loading done...")

    train_dataset = dataset["train"]
    dev_dataset = dataset["dev"]
    test_dataset = dataset["test"]

    test_dataset_ids = test_dataset["db_id"]
    train_dataset_ids = train_dataset["db_id"]


    train_dataset = train_dataset.map(create_conv, remove_columns=train_dataset.features, batched=False)
    dev_dataset = dev_dataset.map(create_conv, remove_columns=dev_dataset.features, batched=False)
    test_dataset = test_dataset.map(create_conv, remove_columns=test_dataset.features, batched=False)

    return train_dataset, dev_dataset, test_dataset, test_dataset_ids, train_dataset_ids


def data_loading():
    print("Preparing the datasets")
    train_dataset, dev_dataset, test_dataset, test_dataset_ids, train_dataset_ids = data_prep()

    train_dataset.to_json(f"{datasets_path}/train_dataset.json")
    dev_dataset.to_json(f"{datasets_path}/dev_dataset.json")
    test_dataset.to_json(f"{datasets_path}/test_dataset.json")
    import pandas as pd
    df = pd.DataFrame(test_dataset_ids)
    df.to_json(f"{datasets_path}/test_dataset_ids.json")
    df = pd.DataFrame(train_dataset_ids)
    df.to_json(f"{datasets_path}/train_dataset_ids.json")

    # Load the datasets
    from datasets import load_dataset
    train_dataset = load_dataset("json", data_files=f"{datasets_path}/train_dataset.json")
    dev_dataset = load_dataset("json", data_files=f"{datasets_path}/dev_dataset.json")
    test_dataset = load_dataset("json", data_files=f"{datasets_path}/test_dataset.json")
    test_dataset_ids = load_dataset("json", data_files=f"{datasets_path}/test_dataset_ids.json")
    train_dataset_ids = load_dataset("json", data_files=f"{datasets_path}/train_dataset_ids.json")
    print("Datasets are ready")

    return train_dataset['train'], dev_dataset['train'], test_dataset['train'], test_dataset_ids['train'], train_dataset_ids['train']

Evaluation

In [None]:
def evaluation(model_name, test_dataset, test_dataset_ids):

    nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16
    )
    print("Loading model")
    model = AutoPeftModelForCausalLM.from_pretrained(
    "/kaggle/working/artifacts/Llama-3.2-1B-Instruct:v0",
    device_map=None,
    torch_dtype=torch.float16,
         ignore_mismatched_sizes=True,
    quantization_config=nf4_config if quantize else None,
    )
    print("Model loaded. Model type: ", type(model))
    tokenizer = AutoTokenizer.from_pretrained("/kaggle/working/artifacts/Llama-3.2-1B-Instruct:v0")
    pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
    def predict(sample):
        prompt = pipe.tokenizer.apply_chat_template(sample["messages"][0:2], tokenize=False, add_generation_prompt=True)
        schema = sample["messages"][1]["content"]
        outputs = pipe(prompt, max_new_tokens=256, do_sample=False, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)
        predicted_answer = outputs[0]['generated_text'][len(prompt):].strip()
    
        if predicted_answer.find(";") != -1:
            predicted_answer = predicted_answer[:predicted_answer.find(";")+1]
    
        if ref_correct:
            # print(f"\nFirst prediction: {predicted_answer}")
            cot_prompt = " " + predicted_answer
            cot_prompt += "\nCorrect the SQL query above if necessary. You can look for issues in the table names, column names, or the query itself.\nSELECT "
    
            sample["messages"][1]["content"] = sample["messages"][1]["content"] + cot_prompt
            prompt = pipe.tokenizer.apply_chat_template(sample["messages"][0:2], tokenize=False, add_generation_prompt=True)
            schema = sample["messages"][1]["content"]
    
            outputs = pipe(prompt, max_new_tokens=256, do_sample=False, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)
            predicted_answer = outputs[0]['generated_text'][len(prompt):].strip()
    
            # print(f"Chain of thought prediction: {predicted_answer}\n")
        low_pred = predicted_answer.lower()
        if not low_pred.startswith("select") and not low_pred.startswith(" select"):
            predicted_answer = "SELECT " + predicted_answer
    
        if not check_table_names(predicted_answer, schema):
            predicted_answer = correct_query(predicted_answer, schema)
            #print(f"Corrected table names in the predicted answer: {predicted_answer}")
    
    
        return predicted_answer
    em_match = []
    comp_match=[]
    none_rate = 0
    database_fails = 0
    component_keys = ['SELECT', 'FROM', 'WHERE', 'GROUP BY', 'ORDER BY', 'JOIN']
    total_correct = {k: 0 for k in component_keys}
    total_present = {k: 0 for k in component_keys}
    all_results=[]
    for i in range(len(test_dataset)):
        db_id = test_dataset_ids['0'][0][f"{i}"]
        db_path = f"{databases_path}/test_database/{db_id}/{db_id}.sqlite"
        prediction = predict(test_dataset[i])
        p_list=extract_sql_queries(prediction)
        prediction=p_list[len(p_list)-1]
        true_pred = "SELECT " + test_dataset[i]["messages"][2]["content"]
        true_pred.replace("<|im_end|>", "")
        em_match.append(exact_match(true_pred, prediction))
        res=component_matching([true_pred], [prediction])
        all_results.append(res)
        comp_match.append(res['average_accuracy'])

        exact_match_accuracy = sum(em_match)/len(em_match)
        comp_match_accuracy=sum(comp_match)/len(comp_match)
        completion_rate = (i+1)/len(test_dataset)
        msg = "Completion rate: {0}, Component Matching: {1}, Exact Match Accuracy: {2}".format(completion_rate,
                                                                                                    comp_match_accuracy,
                                                                                                    exact_match_accuracy)
        sys.stdout.write("\r" + msg)
        sys.stdout.flush()
    
    for res in all_results:
        for k in component_keys:
            val = res['component_accuracy'][k]
            if val is not None:
                total_present[k] += 1
                if val == 1.0:
                    total_correct[k] += 1
    
    # Final computation
    final_accuracy = {
        k: (total_correct[k] / total_present[k]) if total_present[k] > 0 else None
        for k in component_keys
    }
    print(final_accuracy)
    return exact_match_accuracy, comp_match_accuracy






Training

In [None]:
def extract_sql_queries(text):
    queries = re.findall(r"```sql(.*?)```", text, re.DOTALL)
    if not queries:
        queries = re.findall(r"(SELECT .*?)(?:\n|$)", text, re.IGNORECASE)
    return [q.strip() for q in queries]


def train(model_name, train_dataset):
    
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16
    )

    from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

    model = AutoModelForCausalLM.from_pretrained(
        model_to_train,
        device_map="auto",
        # attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
        quantization_config=bnb_config,
        trust_remote_code=True,
    )

    tokenizer = AutoTokenizer.from_pretrained(model_to_train,truncation=True, max_length=2048)
    tokenizer.padding_side = 'right'
    tokenizer.chat_template=None
    tokenizer.truncation=True

    model, tokenizer = setup_chat_format(model, tokenizer)
    peft_config = LoraConfig(
            lora_alpha=128,
            lora_dropout=0.05,
            r=8,
            bias="none",
            target_modules=["q_proj","v_proj","k_proj","o_proj","gate_proj","up_proj","down_proj"],
            task_type="CAUSAL_LM", # for encoder only models like Llama or GPT
    )





    train_args = SFTConfig(
        output_dir=model_name,
        num_train_epochs=epochs,           # number of training epochs
        per_device_train_batch_size=8,          # batch size per device during training
        gradient_accumulation_steps=2,          # number of steps before performing a backward/update pass
        gradient_checkpointing=True,            
        optim="adamw_torch_fused",              # use fused adamw optimizer
        logging_steps=10,
        log_level="info",
        save_strategy="steps",                  
        learning_rate=3e-4,                     # learning rate
        bf16=True,
        logging_dir="./logs",
        report_to="wandb",
        run_name="NLP_project",
        tf32=False,
        max_grad_norm=0.3,
        warmup_ratio=0.03,
        max_seq_length=1024,
        lr_scheduler_type="cosine",             # learning rate scheduler type
        push_to_hub=False,                       # push model to hub at every save
        eval_strategy="steps" if eval else 'no',
        eval_steps=50 if eval else None,
        packing=True,
        dataset_kwargs={
            "add_special_tokens": False,
            "append_concat_token": False,
            "max_length": 1024,
            "truncation": True,
        }
    )

    train_dataset = train_dataset.shuffle(seed=42)
    eos_token = tokenizer.eos_token
    print("EOS token id: ", eos_token)
    trainer = SFTTrainer(
        model=model,
        args=train_args,
        train_dataset=train_dataset,
        eval_dataset=dev_dataset if evalu else None,
        peft_config=peft_config

    )

    print("Training has started!!")

    trainer.train()

    trainer.save_model()
    artifact = wandb.Artifact("Llama-3.2-1B-Instruct", type="model")
    artifact.add_dir("/kaggle/working/Llama-3.2-1B-Instruct")
    wandb.log_artifact(artifact)
    print("Training has ended!!")

    del model
    del trainer
    torch.cuda.empty_cache()

    return




Component matching

In [None]:
def normalize_columns(component_str):
    if not component_str:
        return None
    cols = [c.strip().lower() for c in component_str.split(',')]
    return sorted(cols)
def clean_sqlw(s):
    return re.sub(r"\s", "", s.strip().lower())
def clean_sql(s):
    return re.sub(r"\s+", " ", s.strip().lower())
def extract_components(query):
    components = {
        'SELECT': None,
        'FROM': None,
        'WHERE': None,
        'GROUP BY': None,
        'ORDER BY': None,
        'JOIN': None
    }

    parsed = sqlparse.parse(query)
    if not parsed:
        return components

    stmt = parsed[0]
    tokens = [t for t in stmt.tokens if not t.is_whitespace]

    idx = 0
    while idx < len(tokens):
        token = tokens[idx]
        val = token.value.upper()


        if token.ttype is sqlparse.tokens.DML and val == 'SELECT':
            components['SELECT'] = normalize_columns(str(tokens[idx + 1])) if idx + 1 < len(tokens) else None
        elif token.ttype is sqlparse.tokens.Keyword and val == 'FROM':
            components['FROM'] = clean_sql(str(tokens[idx + 1])) if idx + 1 < len(tokens) else None
        elif val.startswith('WHERE'):
            parts = re.split(r'(?i)\b(WHERE)\b', val)
            components['WHERE'] = clean_sqlw(parts[2]) if idx < len(tokens) else None

        elif val.startswith('GROUP BY'):
            components['GROUP BY'] = clean_sql(str(tokens[idx + 1])) if idx + 1 < len(tokens) else None
        elif val.startswith('ORDER BY'):
            components['ORDER BY'] = clean_sql(str(tokens[idx + 1])) if idx + 1 < len(tokens) else None
        elif 'JOIN' in val:
            components['JOIN'] = clean_sql(str(tokens[idx + 1])) if idx + 1 < len(tokens) else None

        idx += 1

    return components

def component_matching(y_true, y_pred):
    keys = ['SELECT', 'FROM', 'WHERE', 'GROUP BY', 'ORDER BY', 'JOIN']
    correct_counts = {k: 0 for k in keys}
    present_counts = {k: 0 for k in keys}

    for true_q, pred_q in zip(y_true, y_pred):
        true_comp = extract_components(true_q)
        pred_comp = extract_components(pred_q)
        for k in keys:
            if true_comp[k] is not None:
                present_counts[k] += 1
                if true_comp[k] == pred_comp[k]:
                    correct_counts[k] += 1

    acc = {
        k: (correct_counts[k] / present_counts[k]) if present_counts[k] > 0 else None
        for k in keys
    }
    valid_acc = [v for v in acc.values() if v is not None]
    avg = sum(valid_acc) / len(valid_acc) if valid_acc else 0.0
    return {
        'component_accuracy': acc,
        'average_accuracy': avg
    }


In [None]:
def get_table_names(schema):
    table_names = []
    schema = schema.lower()
    schema = schema.split('\n')
    for line in schema:
        if 'create table if not exists' in line:
            table_name = line.split()[5].translate(str.maketrans('', '', '"\',`;()'))
            table_names.append(table_name)
        elif 'create table' in line:
            table_name = line.split()[2].translate(str.maketrans('', '', '"\',`;()'))
            table_names.append(table_name)
    return table_names


def similarity_search(wrong_table, tables):
    min_distance = -1
    similar_table = ''
    for table in tables:
        distance = levenshtein_distance(wrong_table, table)
        if distance < min_distance or min_distance == -1:
            min_distance = distance
            similar_table = table
    return similar_table

def get_table_names_from_query(query):
    table_names = []
    query = query.lower().split()

    for i, word in enumerate(query):
        if word in ['from', 'join']:
            if i + 1 < len(query):
                table_name = query[i + 1]
                if len(table_name) > 2 and not (table_name[0] == '(' or table_name[1:2] == '(' or table_name[2:3] == '('):

                    for ch in [",", "'", '"', '`', ';', '(', ')']:
                        table_name = table_name.replace(ch, '')
                    table_names.append(table_name)
                else:
                    table_names.append(table_name)
    return table_names


def check_table_names(query, schema):
    table_names = get_table_names(schema)
    query_table_names = get_table_names_from_query(query)
    for query_table_name in query_table_names:
        if query_table_name not in table_names:
            return False
    return True

def correct_query(query, schema):

    table_names = get_table_names(schema)
    query_table_names = get_table_names_from_query(query)
    #print("query table names: ", query_table_names)
    #print("true table names: ", table_names)
    for query_table_name in query_table_names:
        if query_table_name not in table_names:
            similar_table = similarity_search(query_table_name, table_names)
            similar_table = similar_table.replace('"', '')
            low_query = query.lower()
            split_low_query = low_query.split()
            split_query = query.split()
            try:
                split_query[split_low_query.index(query_table_name)] = similar_table
            except:
                print("Less error")
            query = ' '.join(split_query)
    return query

def exact_match(y_true, y_pred):
    y_true = y_true.lower()
    y_pred = y_pred.lower()

    y_true = re.sub(r'\s+', '', y_true)
    y_pred = re.sub(r'\s+', '', y_pred)
    y_true = y_true.replace(';', '')
    y_pred = y_pred.replace(';', '')
    y_true = y_true.replace(',', '')
    y_pred = y_pred.replace(',', '')
    y_true = re.sub(r't[0-6]\.', '', y_true)
    y_pred = re.sub(r't[0-6]\.', '', y_pred)
    if y_true == y_pred:
        return 1
    else:
        return 0



In [None]:
if __name__ == "__main__":
    train_dataset, dev_dataset, test_dataset, test_dataset_ids, train_dataset_ids = data_loading()

    if train:
        # train the model
        print("Training the model")
        train(model_name, train_dataset)

    if test:
        print("Evaluating the model")
        exact_accuracy,comp = evaluation(model_name, test_dataset, test_dataset_ids)

        print(f"Component Matching Accuracy: {comp}")
        print(f"Exact Match Accuracy: {exact_accuracy}")
        print("--------done--------")





# **Inferencing: Loading Model from Wandb**

In [None]:
import wandb
run = wandb.init()
artifact = run.use_artifact('pawankr-cse-iit-delhi/huggingface/Llama-3.2-1B-Instruct:v0', type='model')

model_dir = artifact.download()
print(f"Model downloaded to: {model_dir}")

In [None]:
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
print("Loading model")
model = AutoPeftModelForCausalLM.from_pretrained(
"/kaggle/working/artifacts/Llama-3.2-1B-Instruct:v0",
device_map=None,
torch_dtype=torch.float16,
     ignore_mismatched_sizes=True,
quantization_config=nf4_config if quantize else None,
)
print("Model loaded. Model type: ", type(model))
tokenizer = AutoTokenizer.from_pretrained("/kaggle/working/artifacts/Llama-3.2-1B-Instruct:v0")
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

In [None]:
from random import randint
rand_idx = randint(0, len(test_dataset)+1)

print(f"EOS token: {pipe.tokenizer.eos_token}")

# Test on sample
prompt = pipe.tokenizer.apply_chat_template(test_dataset[rand_idx]["messages"][0:2], tokenize=False, add_generation_prompt=False)
outputs = pipe(prompt, max_new_tokens=256, do_sample=False, eos_token_id=pipe.tokenizer.eos_token_id, pad_token_id=pipe.tokenizer.pad_token_id)

print(f"Natural language request:\n{test_dataset[rand_idx]['messages'][1]['content']}\n\n")
print(f"Original Answer:\n{test_dataset[rand_idx]['messages'][2]['content']}\n\n")
generated_answer = outputs[0]['generated_text'][len(prompt):].strip()
print(f"First generated Answer:\n{generated_answer}\n\n")
if generated_answer.find(";") != -1:
    generated_answer = generated_answer[:generated_answer.find(";")+1]
print(f"Generated Answer:\n{generated_answer}\n\n")