In [1]:
import numpy as np
import pandas as pd
import torch
import re
from transformers import T5Tokenizer, T5ForConditionalGeneration, AdamW
from torch.utils.data import DataLoader, Dataset
from sklearn import metrics
from sklearn.preprocessing import LabelBinarizer
from evaluate import load
bertscore = load("bertscore")
from datasets import load_metric
exact_match_metric = load_metric("exact_match")
lb = LabelBinarizer()

  from .autonotebook import tqdm as notebook_tqdm
  exact_match_metric = load_metric("exact_match")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [2]:
import torch
torch.cuda.is_available()

True

# Load model

In [3]:
# Load lại mô hình và tokenizer để kiểm tra
model = T5ForConditionalGeneration.from_pretrained('t5-base')
model.load_state_dict(torch.load('model/spider/spider_model.pt'))
model.eval()

T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dro

In [4]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# Predict Validation text

In [5]:
total_sentences = 100
# Load the dataset
df = pd.read_csv("data/spider/spider_schema_validation.csv", nrows = total_sentences)

In [6]:
# Function to generate SQL query
def generate_sql_query(question):
    input_ids = tokenizer.encode(question, return_tensors='pt')
    outputs = model.generate(input_ids=input_ids, max_length=100, num_beams=5, early_stopping=True)
    sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return sql_query


In [7]:
# Generate SQL queries for each question in the CSV
df['sql_query'] = df['question'].apply(generate_sql_query)

In [8]:
output_df = df[['question','sql','sql_query']]
# Save the results to a new CSV file
output_file = "data/spider/spider_tex2sql_predict.csv"
output_df.to_csv(output_file, index=False)
print(f"Generated SQL queries saved to {output_file}")

Generated SQL queries saved to data/spider/spider_tex2sql_predict.csv


# Get sql and sql predict

In [9]:
def preprocess_sql(sql_query):
    # Loại bỏ khoảng trắng dư thừa
    sql_query = re.sub(r'\s+', ' ', sql_query)

    # Loại bỏ khoảng trắng trước dấu phẩy
    sql_query = re.sub(r'\s*,', ',', sql_query)

    # Đổi dấu cách trước và sau các toán tử so sánh
    sql_query = re.sub(r'(\S)([><=]+)(\S)', r'\1 \2 \3', sql_query)

    # Chuẩn hóa các từ khóa SQL
    sql_keywords = ['SELECT', 'FROM', 'WHERE', 'AND', 'OR', 'NOT', 'ORDER BY', 'GROUP BY', 'HAVING', 'LIMIT']
    for keyword in sql_keywords:
        sql_query = re.sub(r'\b' + keyword.lower() + r'\b', keyword, sql_query, flags=re.IGNORECASE)

    # Loại bỏ khoảng trắng ở đầu và cuối câu SQL
    sql_query = sql_query.strip()

    # Thêm dấu cách sau dấu phẩy nếu sau dấu phẩy không có khoảng trắng
    sql_query = re.sub(r',(?!\s)', ', ', sql_query)

    return sql_query

In [10]:
target = df["sql"].apply(preprocess_sql)
predictions = df['sql_query'].apply(preprocess_sql)

# Metric: rocauc

In [11]:
# Function to calculate ROC AUC for multiclass
def calculate_multiclass_roc_auc(target, predictions):
    '''
    This method returns the AUC Score for multiclass classification
    '''
    # Check if the predictions are probabilities or labels
    if len(predictions.shape) == 1 or predictions.shape[1] == 1:
        predictions_binarized = lb.transform(predictions)
    else:
        predictions_binarized = predictions
    
    return metrics.roc_auc_score(target, predictions_binarized, multi_class='ovr')


In [12]:
target_binarized = lb.fit_transform(target)
predictions_binarized = predictions

In [13]:
# Calculate the ROC AUC score for multiclass
roc_auc_score = calculate_multiclass_roc_auc(target_binarized, predictions_binarized)

print(f"ROC AUC Score: {roc_auc_score}")

ROC AUC Score: 0.8642338291248703


# Metric: exact_match

In [14]:
# Đầu tiên, tính toán kết quả exact_match
results_exact_match = exact_match_metric.compute(predictions=predictions, references=target)
print(results_exact_match)
# Lặp qua các cặp câu dự đoán và câu tham chiếu để in ra các cặp không khớp
error =0
for idx, (predicted_sql, reference_sql) in enumerate(zip(predictions, target)):
    if predicted_sql != reference_sql:
        error += 1
        print(f"Cặp không khớp thứ {idx + 1}:")
        print("Câu dự đoán:", predicted_sql)
        print("Câu tham chiếu:", reference_sql)
        print()  # In một dòng trống để phân biệt các cặp không khớp
print("số cặp không khớp: ", error)


{'exact_match': 75.0}
Cặp không khớp thứ 3:
Câu dự đoán: SELECT creation, name, budget_in_year FROM department
Câu tham chiếu: SELECT creation, name, budget_in_billions FROM department

Cặp không khớp thứ 4:
Câu dự đoán: SELECT max(budget_in_billion), min(budget_in_billion) FROM department
Câu tham chiếu: SELECT max(budget_in_billions), min(budget_in_billions) FROM department

Cặp không khớp thứ 5:
Câu dự đoán: SELECT avg(num_employees) FROM department WHERE rank BETWEEN 10 AND 15
Câu tham chiếu: SELECT avg(num_employees) FROM department WHERE ranking BETWEEN 10 AND 15

Cặp không khớp thứ 6:
Câu dự đoán: SELECT name FROM head WHERE born_state!= 'California'
Câu tham chiếu: SELECT name FROM head WHERE born_state != 'California'

Cặp không khớp thứ 13:
Câu dự đoán: SELECT DISTINCT T1.age FROM management AS T1 JOIN head AS T2 ON T1.head_id = T2.head_id WHERE T2.temporary_acting = 'Yes'
Câu tham chiếu: SELECT DISTINCT T1.age FROM management AS T2 JOIN head AS T1 ON T1.head_id = T2.head_id 

In [26]:
refs = ["SELECT max(budget_in_billions), min(budget_in_billions) FROM department"
        , "SELECT avg(num_employees) FROM department WHERE rank BETWEEN 10 AND 15"
        , "SELECT DISTINCT T1.age FROM management AS T2 JOIN head AS T1 ON T1.head_id = T2.head_id WHERE T2.temporary_acting = 'Yes'"
        , "SELECT Hosts FROM farm_competition WHERE Theme!= 'Aliens'"]
preds = ["SELECT max(budget_in_billion), min(budget_in_billion) FROM department"
         , "SELECT avg(num_employees) FROM department WHERE rank BETWEEN 10 AND 15"
         , "SELECT DISTINCT T1.age FROM management AS T1 JOIN head AS T2 ON T1.head_id = T2.head_id WHERE T2.temporary_acting = 'Yes'"
         , "SELECT Hosts FROM farm_competition WHERE Theme!= 'Aliens'"]
results = exact_match_metric.compute(references=refs, predictions=preds, regexes_to_ignore=[" ", "s","es","!", "ing"], ignore_case=True, ignore_punctuation=False, ignore_numbers=True)
print(round(results["exact_match"], 2))

100.0


In [27]:
results = exact_match_metric.compute(references=target, predictions=predictions, regexes_to_ignore=[" ", "s","es","!", "ing"], ignore_case=True, ignore_punctuation=True, ignore_numbers=True)
print(round(results["exact_match"], 2))

91.0


# Metric: bert_score

In [17]:
def compute_average_results(results):
    average_results = {}
    for key in results:
        if isinstance(results[key], list) and all(isinstance(x, (int, float)) for x in results[key]):
            average_results[key] = np.mean(results[key])
        else:
            average_results[key] = None
    return average_results

In [18]:
pred = ["SELECT max(budget_in_billion), min(budget_in_billion) FROM department"
         , "SELECT avg(num_employees) FROM department WHERE rank BETWEEN 10 AND 15"
         , "SELECT DISTINCT T1.age FROM management AS T1 JOIN head AS T2 ON T1.head_id = T2.head_id WHERE T2.temporary_acting = 'Yes'"
         , "SELECT Hosts FROM farm_competition WHERE Theme!= 'Aliens'"]
ref = ["SELECT max(budget_in_billions), min(budget_in_billions) FROM department"
        , "SELECT avg(num_employees) FROM department WHERE rank BETWEEN 10 AND 15"
        , "SELECT DISTINCT T1.age FROM management AS T2 JOIN head AS T1 ON T1.head_id = T2.head_id WHERE T2.temporary_acting = 'Yes'"
        , "SELECT Hosts FROM farm_competition WHERE Theme!= 'Aliens'"]
results = bertscore.compute(predictions=pred, references=ref, model_type="distilbert-base-uncased")
print(results)


{'precision': [0.9867746829986572, 1.0000001192092896, 0.9980701804161072, 1.0], 'recall': [0.9867746829986572, 1.0000001192092896, 0.9980701804161072, 1.0], 'f1': [0.9867746829986572, 1.0000001192092896, 0.9980701804161072, 1.0], 'hashcode': 'distilbert-base-uncased_L5_no-idf_version=0.3.12(hug_trans=4.41.1)'}


In [19]:
results_bertscore = bertscore.compute(predictions=predictions, references=target, model_type="bert-base-uncased")
print("Kết quả trung bình chính xác:", compute_average_results(results_bertscore))

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Kết quả trung bình chính xác: {'precision': 0.9929775285720825, 'recall': 0.9873931401968002, 'f1': 0.9900400310754776, 'hashcode': None}


# Try with new question

In [23]:
# Thử nghiệm với câu hỏi mới
new_question = "What are the first and last names of all customers with more than 2 payments?"
input_ids = tokenizer.encode(new_question, return_tensors='pt')
outputs = model.generate(input_ids=input_ids, max_length=100, num_beams=5, early_stopping=True)
sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(f"Question: {new_question}")
print(f"Generated SQL query: {sql_query}")

Question: What are the first and last names of all customers with more than 2 payments?
Generated SQL query: SELECT T2.first_name, T2.last_name FROM Customer_Payments AS T1 JOIN Customers AS T2 ON T1.customer_id = T2.customer_id GROUP BY T1.customer_id HAVING count(*) > 2


In [21]:
# Thử nghiệm với câu hỏi mới
new_question = "What team has more than 49 laps and a grid of 8?"
input_ids = tokenizer.encode(new_question, return_tensors='pt')
outputs = model.generate(input_ids=input_ids, max_length=100, num_beams=5, early_stopping=True)
sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(f"Question: {new_question}")
print(f"Generated SQL query: {sql_query}")

Question: What team has more than 49 laps and a grid of 8?
Generated SQL query: SELECT Team FROM laps WHERE laps > 49 INTERSECT SELECT Team FROM grid WHERE grid = 8
