In [2]:
from transformers import AutoModelWithLMHead, AutoTokenizer
from torch.utils.data import Dataset
from datasets import load_dataset
import nltk
import json

import random, warnings
warnings.filterwarnings("ignore")

In [3]:
tokenizer = AutoTokenizer.from_pretrained("t5-base")

In [4]:
spider_dataset = load_dataset('spider')  # Using spider dataset from huggingface datasets library (and not the one stored in local)

Found cached dataset spider (/Users/hsahu/.cache/huggingface/datasets/spider/spider/1.0.0/4e5143d825a3895451569c8b9b55432b91a4bc2d04d390376c950837f4680daa)
100%|██████████| 2/2 [00:00<00:00, 63.29it/s]


In [5]:
class SpiderDataset(Dataset):
    def __init__(self, data, tokenizer, max_length):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        db_id = self.data[idx]['db_id']
        schema = self.get_schema(db_id)

        input_text = f"translate English to SQL: {self.data[idx]['question']} <schema> {schema} </s>"
        target_text = self.data[idx]['query']
        encoding = self.tokenizer(input_text, max_length=self.max_length, truncation=True, padding='max_length', return_tensors='pt')
        target = self.tokenizer(target_text, max_length=self.max_length, truncation=True, padding='max_length', return_tensors='pt')
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': target['input_ids'].flatten(),
            'data_item': self.data[idx] 
        }

    def get_schema(self, db_id):
        schema_data = next(item for item in tables_data if item['db_id'] == db_id)
        schema = " ".join([f"Table: {table_name} Columns: {', '.join([col_name for _, col_name in schema_data['column_names'] if schema_data['table_names'][table_idx] == table_name])}" for table_idx, table_name in enumerate(schema_data['table_names'])])
        return schema

# Load tables.json
with open('./spider_dataset/tables.json', 'r') as f:
    tables_data = json.load(f)

train_dataset = SpiderDataset(spider_dataset['train'], tokenizer, max_length=128)
val_dataset = SpiderDataset(spider_dataset['validation'], tokenizer, max_length=128)

In [6]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    save_strategy='epoch',
    evaluation_strategy='epoch',
)

In [7]:
model = AutoModelWithLMHead.from_pretrained("t5-base")

In [9]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset 
)


In [10]:
trainer.train()

***** Running training *****
  Num examples = 7000
  Num Epochs = 3
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 1314
  Number of trainable parameters = 222903552
  0%|          | 0/1314 [00:00<?, ?it/s]The following columns in the training set don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: data_item. If data_item are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this message.
