# Part 3 : Model Evaluation

## Setup libraries

In [None]:
import subprocess, sys
import torch, json
import pandas as pd
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from peft import LoraConfig, TaskType, get_peft_model, PeftModel
# from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader

  from .autonotebook import tqdm as notebook_tqdm
  from scipy.sparse import csr_matrix, issparse


## Device Checking

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

## Evaluation of the the Phi2_Text_To_SQL model

In [14]:
# Functions definition

def generate_text(prompt, model, max_new_tokens=2048):
    encoding = tokenizer(prompt, return_tensors="pt")
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
        output_ids = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=max_new_tokens, pad_token_id=tokenizer.eos_token_id, do_sample=False)

    generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return generated_text

def extract_sql(text):
    # Support different possible answer labels
    for key in ["Answer:", "Anwser:", "SQL:"]:
        if key in text:
            # Get everything after the key (Answer/Anwser/SQL)
            answer_section = text.split(key, 1)[1]
            # Stop if there’s another Question coming
            answer_only = answer_section.split("Question:", 1)[0].strip()
            # Clean and filter lines
            lines = [line.strip() for line in answer_only.splitlines() if line.strip()]
            # Ignore lines that repeat instructions
            if not lines or all(any(x in line for x in ["Convert", "Return", "Question"]) for line in lines):
                return "[NO SQL GENERATED]"
            sql = "\n".join(lines)
            # Ensure semicolon at the end
            if not sql.endswith(";"):
                sql += ";"
            return sql
    return "[NO SQL GENERATED]"


def generate_predictions_from_dataframe(df, max_length):
    predictions = []

    for question in tqdm(df['question'], desc="Generating SQL queries"):
        instruction = instruction = f"""You are a Text-to-SQL assistant. 
        Convert the following question into a valid SQL query. 
        Return only the SQL query and end it with a semicolon.

        Question: {question}
        Anwser:"""

        generated = generate_text(instruction, trainable_model, max_length)
        sql = extract_sql(generated)
        predictions.append(sql)
    return predictions

def save_predictions_to_file(predictions, filename="./data/sql/dev_pred.sql"):
    with open(filename, "w") as f:
        for sql in predictions:
            f.write(sql.strip() + "\n")
    print(f"✅ Sauvegardé dans {filename}")

def generate_gold_data_from_dataframe(df, max_length):
    gold_data = []
    for query, db_id in tqdm(zip(df['query'],df['db_id'])):
        if not query.endswith(";"):
            query += ";"
        gold_data.append((query,db_id))
    return gold_data

def save_gold_data_to_file(gold_data, filename="./data/sql/dev_gold.sql"):
    with open(filename, "w") as f:
        for query, db in gold_data:
            f.write(f"{query.strip()}\t{db.strip()}\n")
    print(f"✅ Sauvegardé dans {filename}")



### Model Quantization

In [5]:
# Quantization Config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

base_model = AutoModelForCausalLM.from_pretrained("./models/base_Phi2_model", quantization_config=bnb_config, trust_remote_code=False, device_map="auto")
trainable_model = PeftModel.from_pretrained(base_model, "./models/Text2SQL_Phi2_model", is_trainable=True)
tokenizer = AutoTokenizer.from_pretrained("./models/Text2SQL_Phi2_model", trust_remote_code=True, use_fast=False)

Loading checkpoint shards: 100%|██████████| 3/3 [00:28<00:00,  9.47s/it]


### dev_pred and gold_pred creation

In [15]:
# gold_pred.sql database

data_spider = load_dataset("xlangai/spider")
data_spider_train = load_dataset("xlangai/spider", split="train")
data_spider_validation = load_dataset("xlangai/spider", split="validation")

df_spider_train = data_spider_train.to_pandas()
df_spider_validation = data_spider_validation.to_pandas()
frames_spider = [df_spider_train,df_spider_validation]
df_spider = pd.concat(frames_spider)
df_spider_study = df_spider[["query","question","db_id"]]
df_spider_study.to_csv("./data/spider_query_db_id.csv",sep=";",index=False)

df = pd.read_csv("./data/spider_query_db_id.csv", sep=";", nrows=3000) 
max_length = df['query'].apply(lambda x: len(str(x).split())).max()
predictions = generate_gold_data_from_dataframe(df,max_length)
save_gold_data_to_file(predictions)

with open("./data/sql/dev_gold.sql", "r", encoding="utf-8") as f:
    for i, line in enumerate(f, 1):
        sql = line.strip().split("\t")[0]
        if sql.count(";") > 1:
            print(f"⚠️ Line {i} - Multi-statement SQL:", sql)

3000it [00:00, 2914060.21it/s]

✅ Sauvegardé dans ./data/sql/dev_gold.sql





In [None]:
# dev_pred.sql database

df = pd.read_csv("./data/query_question_sql_copilot.csv", sep=";", nrows=3000) 
max_length = df['query'].apply(lambda x: len(str(x).split())).max()
predictions = generate_predictions_from_dataframe(df,max_length)
save_predictions_to_file(predictions)

with open('./pred/dev_pred.sql', 'r') as file:
    lines = file.readlines()

merged_lines = []
buffer = ""

for line in lines:
    stripped = line.strip()
    if stripped:
        buffer += " " + stripped
        if stripped.endswith(';'):
            merged_lines.append(buffer.strip())
            buffer = ""

with open('./pred/dev_pred_corrected.sql', 'w') as file:
    for merged in merged_lines:
        file.write(merged + '\n')

### Model evaluation

In [None]:
# Evaluation 

cmd = [
    sys.executable, "evaluation.py",
    "--gold", "./data/sql/dev_gold.sql",
    "--pred", "./data/sql/dev_pred.sql",
    "--db", "./data/spider/database",
    "--table", "./data/spider/tables.json",
    "--etype", "all",
]

# Affiche les logs en temps réel
with subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) as proc:
    for line in proc.stdout:
        print(line, end="")  # end="" évite les sauts de ligne en double

# Vérifie le code de retour
if proc.returncode not in (0, None):
    raise subprocess.CalledProcessError(proc.returncode, cmd)


  from scipy.stats import fisher_exact
eval_err_num:1
easy pred: SELECT COUNT(*) FROM department WHERE age > 56;
easy gold: SELECT count(*) FROM head WHERE age  >  56;

eval_err_num:2
medium pred: SELECT name, born_state, age FROM heads_of_departments ORDER BY age;
medium gold: SELECT name ,  born_state ,  age FROM head ORDER BY age;

eval_err_num:3
medium pred: SELECT year, name, budget FROM departments;
medium gold: SELECT creation ,  name ,  budget_in_billions FROM department;

eval_err_num:4
medium pred: SELECT MAX(budget) AS max_budget, MIN(budget) AS min_budget FROM departments;
medium gold: SELECT max(budget_in_billions) ,  min(budget_in_billions) FROM department;

eval_err_num:5
easy pred: SELECT AVG(num_employees) FROM departments WHERE rank BETWEEN 10 AND 15;
easy gold: SELECT avg(num_employees) FROM department WHERE ranking BETWEEN 10 AND 15;

eval_err_num:6
easy pred: SELECT name FROM heads WHERE state <> 'California';
easy gold: SELECT name FROM head WHERE born_state != 'C