In [5]:
import numpy as np
import pandas as pd
import torch
import re
from transformers import T5Tokenizer, T5ForConditionalGeneration, AdamW
from torch.utils.data import DataLoader, Dataset
from sklearn import metrics
from datasets import load_metric
from sklearn.preprocessing import LabelBinarizer
lb = LabelBinarizer()

In [6]:
import torch
torch.cuda.is_available()

True

Load model

In [7]:
# Load lại mô hình và tokenizer để kiểm tra
model = T5ForConditionalGeneration.from_pretrained('t5-base')
model.load_state_dict(torch.load('model/spider/spider_model.pt'))
model.eval()

T5ForConditionalGeneration(
  (shared): Embedding(32128, 768)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 768)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=768, out_features=768, bias=False)
              (k): Linear(in_features=768, out_features=768, bias=False)
              (v): Linear(in_features=768, out_features=768, bias=False)
              (o): Linear(in_features=768, out_features=768, bias=False)
              (relative_attention_bias): Embedding(32, 12)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=768, out_features=3072, bias=False)
              (wo): Linear(in_features=3072, out_features=768, bias=False)
              (dropout): Dro

In [8]:
tokenizer = T5Tokenizer.from_pretrained('t5-base')

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Predict Validation text

In [9]:
total_sentences = 100
# Load the dataset
df = pd.read_csv("data/spider/spider_schema_validation.csv", nrows = total_sentences)

In [10]:
# Function to generate SQL query
def generate_sql_query(question):
    input_ids = tokenizer.encode(question, return_tensors='pt')
    outputs = model.generate(input_ids=input_ids, max_length=100, num_beams=5, early_stopping=True)
    sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return sql_query


In [11]:
# Generate SQL queries for each question in the CSV
df['sql_query'] = df['question'].apply(generate_sql_query)

In [12]:
output_df = df[['question','sql','sql_query']]
# Save the results to a new CSV file
output_file = "data/spider/spider_tex2sql_predict.csv"
output_df.to_csv(output_file, index=False)
print(f"Generated SQL queries saved to {output_file}")

Generated SQL queries saved to data/spider/spider_tex2sql_predict.csv


Get sql and sql predict

In [13]:
def preprocess_sql(sql_query):
    # Loại bỏ khoảng trắng dư thừa
    sql_query = re.sub(r'\s+', ' ', sql_query)

    # Loại bỏ khoảng trắng trước dấu phẩy
    sql_query = re.sub(r'\s*,', ',', sql_query)

    # Đổi dấu cách trước và sau các toán tử so sánh
    sql_query = re.sub(r'(\S)([><=]+)(\S)', r'\1 \2 \3', sql_query)

    # Chuẩn hóa các từ khóa SQL
    sql_keywords = ['SELECT', 'FROM', 'WHERE', 'AND', 'OR', 'NOT', 'ORDER BY', 'GROUP BY', 'HAVING', 'LIMIT']
    for keyword in sql_keywords:
        sql_query = re.sub(r'\b' + keyword.lower() + r'\b', keyword, sql_query, flags=re.IGNORECASE)

    # Loại bỏ khoảng trắng ở đầu và cuối câu SQL
    sql_query = sql_query.strip()

    # Thêm dấu cách sau dấu phẩy nếu sau dấu phẩy không có khoảng trắng
    sql_query = re.sub(r',(?!\s)', ', ', sql_query)

    return sql_query

In [30]:
target = df["sql"].apply(preprocess_sql)
predictions = df['sql_query'].apply(preprocess_sql)

roc auc score

In [31]:
# Function to calculate ROC AUC for multiclass
def calculate_multiclass_roc_auc(target, predictions):
    '''
    This method returns the AUC Score for multiclass classification
    '''
    # Check if the predictions are probabilities or labels
    if len(predictions.shape) == 1 or predictions.shape[1] == 1:
        predictions_binarized = lb.transform(predictions)
    else:
        predictions_binarized = predictions
    
    return metrics.roc_auc_score(target, predictions_binarized, multi_class='ovr')


In [32]:
target_binarized = lb.fit_transform(target)
predictions_binarized = lb.transform(predictions)

In [33]:
# Calculate the ROC AUC score for multiclass
roc_auc_score = calculate_multiclass_roc_auc(target_binarized, predictions_binarized)

print(f"ROC AUC Score: {roc_auc_score}")

ROC AUC Score: 0.8642338291248703


Exact Match

In [34]:
exact_match_metric = load_metric("exact_match")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [35]:
results_exact_match = exact_match_metric.compute(predictions=predictions, references=target)
print(results_exact_match)

{'exact_match': 75.0}


In [36]:
predictions

0              SELECT count(*) FROM head WHERE age > 56
1     SELECT name, born_state, age FROM head ORDER B...
2     SELECT creation, name, budget_in_year FROM dep...
3     SELECT max(budget_in_billion), min(budget_in_b...
4     SELECT avg(num_employees) FROM department WHER...
                            ...                        
95    SELECT course_name FROM courses ORDER BY cours...
96    SELECT course_name FROM courses ORDER BY cours...
97    SELECT first_name FROM people ORDER BY first_name
98    SELECT first_name FROM people ORDER BY first_name
99    SELECT student_id FROM student_course_registra...
Name: sql_query, Length: 100, dtype: object

In [37]:
# Hàm tính toán độ chính xác thực thi
def calculate_execution_accuracy(predictions, references):
    num_correct = 0
    total = len(predictions)
    print(total)
    for predicted_result, reference_result in zip(predictions, references):
        print(predicted_result)
        print(reference_result)
        # So sánh kết quả dự đoán và tham chiếu
        if predicted_result == reference_result:
            num_correct += 1

    # Tính toán độ chính xác
    execution_accuracy = num_correct / total

    return execution_accuracy


# Tính toán độ chính xác thực thi
execution_accuracy = calculate_execution_accuracy(predictions, target)
print("Execution Accuracy:", execution_accuracy)


100
SELECT count(*) FROM head WHERE age > 56
SELECT count(*) FROM head WHERE age > 56
SELECT name, born_state, age FROM head ORDER BY age
SELECT name, born_state, age FROM head ORDER BY age
SELECT creation, name, budget_in_year FROM department
SELECT creation, name, budget_in_billions FROM department
SELECT max(budget_in_billion), min(budget_in_billion) FROM department
SELECT max(budget_in_billions), min(budget_in_billions) FROM department
SELECT avg(num_employees) FROM department WHERE rank BETWEEN 10 AND 15
SELECT avg(num_employees) FROM department WHERE ranking BETWEEN 10 AND 15
SELECT name FROM head WHERE born_state!= 'California'
SELECT name FROM head WHERE born_state != 'California'
SELECT DISTINCT T1.creation FROM department AS T1 JOIN management AS T2 ON T1.department_id = T2.department_id JOIN head AS T3 ON T2.head_id = T3.head_id WHERE T3.born_state = 'Alabama'
SELECT DISTINCT T1.creation FROM department AS T1 JOIN management AS T2 ON T1.department_id = T2.department_id JOIN 

Try with new question

In [21]:
# Thử nghiệm với câu hỏi mới
new_question = "What team has more than 49 laps and a grid of 8?"
input_ids = tokenizer.encode(new_question, return_tensors='pt')
outputs = model.generate(input_ids=input_ids, max_length=100, num_beams=5, early_stopping=True)
sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(f"Question: {new_question}")
print(f"Generated SQL query: {sql_query}")

Question: What team has more than 49 laps and a grid of 8?
Generated SQL query: SELECT Team FROM laps WHERE laps > 49 INTERSECT SELECT Team FROM grid WHERE grid = 8
