In [1]:
from transformers import AutoModelWithLMHead, AutoTokenizer
from torch.utils.data import Dataset
from datasets import load_dataset
import nltk
import sentencepiece
import json

import random, warnings
warnings.filterwarnings("ignore")

In [3]:
model_name = "t5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [4]:
spider_dataset = load_dataset('spider')  # Using spider dataset from huggingface datasets library (and not the one stored in local)

Found cached dataset spider (/Users/hsahu/.cache/huggingface/datasets/spider/spider/1.0.0/4e5143d825a3895451569c8b9b55432b91a4bc2d04d390376c950837f4680daa)
100%|██████████| 2/2 [00:00<00:00, 124.04it/s]


In [8]:
class SpiderDataset(Dataset):
    def __init__(self, data, tokenizer, max_length):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        db_id = self.data[idx]['db_id']
        schema = self.get_schema(db_id)

        input_text = f"translate English to SQL: {self.data[idx]['question']} <schema> {schema} </s>"
        target_text = self.data[idx]['query']
        encoding = self.tokenizer(input_text, max_length=self.max_length, truncation=True, padding='max_length', return_tensors='pt')
        target = self.tokenizer(target_text, max_length=self.max_length, truncation=True, padding='max_length', return_tensors='pt')
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': target['input_ids'].flatten(),
            'data_item': self.data[idx] 
        }

    def get_schema(self, db_id):
        schema_data = next(item for item in tables_data if item['db_id'] == db_id)
        schema = " ".join([f"Table: {table_name} Columns: {', '.join([col_name for _, col_name in schema_data['column_names'] if schema_data['table_names'][table_idx] == table_name])}" for table_idx, table_name in enumerate(schema_data['table_names'])])
        return schema

# Load tables.json
with open('./spider_dataset/tables.json', 'r') as f:
    tables_data = json.load(f)

train_dataset = SpiderDataset(spider_dataset['train'], tokenizer, max_length=128)
val_dataset = SpiderDataset(spider_dataset['validation'], tokenizer, max_length=128)

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir='./output',
    num_train_epochs=5,
    learning_rate=1e-6,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
    save_strategy='epoch',
    evaluation_strategy='epoch',
)

In [None]:
model = AutoModelWithLMHead.from_pretrained("t5-base")

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset 
)


In [None]:
trainer.train()

In [None]:
def get_sql(query):
    
    input_text = "translate English to SQL: %s </s>" % query
    
    features = tokenizer([input_text], return_tensors='pt')

    output = model.generate(input_ids=features['input_ids'], 
               attention_mask=features['attention_mask'])

    return tokenizer.decode(output[0])

def get_sql_with_schema(query, schema):
    # Concatenate the schema information with the input query text, separated by a special token, such as `<schema>`
    input_text = f"translate English to SQL: {query} <schema> {schema} </s>"

    features = tokenizer([input_text], return_tensors='pt')

    output = model.generate(input_ids=features['input_ids'], 
                             attention_mask=features['attention_mask'])

    decoded_output = tokenizer.decode(output[0])

    # Remove the <pad> token from the output
    return decoded_output.replace('<pad>', '').strip()

In [None]:
gold_file = open('gold.txt', 'w')
pred_file = open('pred.txt', 'w')

num_queries = 5
count = 0
for idx in random.sample(range(len(val_dataset)), num_queries):
    item = val_dataset[idx]
    data_item = item['data_item']
    print(f'{count + 1}/{num_queries}')
    print(f"Text: {data_item['question']}")

    # Get schema information
    db_id = data_item['db_id']
    schema = val_dataset.get_schema(db_id)

    pred = get_sql_with_schema(data_item['question'], schema)
    gold = data_item['query']

    gold_file.write(gold + '\t' + db_id + '\n')
    pred_file.write(pred + '\n')
    
    print(f"Pred SQL: {pred}")
    print(f"True SQL: {gold}\n")

    count += 1
  
gold_file.close()
pred_file.close()


In [None]:
# Using repo for evaluation: https://github.com/taoyds/test-suite-sql-eval.git

!git clone https://github.com/taoyds/test-suite-sql-eval.git

In [None]:
!python test-suite-sql-eval/evaluation.py --gold gold.txt --pred pred.txt --db spider_dataset/database --table spider_dataset/tables.json --etype all

In [None]:
# import argparse
# import pytorch_lightning as pl
# import torch
# from torch.utils.data import DataLoader

# from transformers import (
#     AdamW,
#     AutoConfig,
#     AutoModelWithLMHead,
#     AutoTokenizer,
#     get_linear_schedule_with_warmup,
# )


# class customT5(pl.LightningModule):
#     def __init__(self, **config_kwargs):

#         super().__init__()
#         self.config = AutoConfig.from_pretrained("t5-base")
#         self.tokenizer = AutoTokenizer.from_pretrained("t5-base")
#         self.model = AutoModelWithLMHead.from_pretrained("t5-base", config=self.config)

#         self.dataset_kwargs: dict = dict(
#             data_dir=self.hparams.data_dir,
#             max_source_length=self.hparams.max_source_length,
#             max_target_length=self.hparams.max_target_length,
#         )

#     def forward(self, input_ids,  attention_mask=None, decoder_input_ids=None, lm_labels=None):

#         return self.model( input_ids,
#             attention_mask=attention_mask,
#             decoder_input_ids=decoder_input_ids,
#             lm_labels=lm_labels,
#         )

In [None]:
# def generic_train(model: customT5, output_dir):

#     checkpoint_callback = pl.callbacks.ModelCheckpoint(
#         filepath=output_dir, prefix="checkpoint", monitor="val_loss", mode="min", save_top_k=5
#     )

#     train_params = dict(
#         accumulate_grad_batches=args.gradient_accumulation_steps,
#         max_epochs=10,
#         early_stop_callback=False,
#         checkpoint_callback=checkpoint_callback
#     )

#     trainer = pl.Trainer(**train_params)

#     trainer.fit(model)

#     return trainer

In [None]:

# output_dir = "model_output"
# model = customT5()
# trainer = generic_train(model)

# model.model.save_pretrained(output_dir)
# model.tokenizer.save_pretrained(output_dir)