In [1]:
!pip install gdown



In [2]:
import gdown

spider_url = 'https://drive.google.com/u/1/uc?export=download&confirm=k3T5&id=1_AckYkinAnhqmRQtGsQgUKAnTHxxX5J0'
output = 'spider.zip'
gdown.download(spider_url, output, quiet=False)

Downloading...
From: https://drive.google.com/u/1/uc?export=download&confirm=k3T5&id=1_AckYkinAnhqmRQtGsQgUKAnTHxxX5J0
To: /home/ubuntu/RATransformers/notebooks/spider.zip
100%|██████████| 99.7M/99.7M [00:01<00:00, 94.9MB/s]


'spider.zip'

In [3]:
!unzip -o spider.zip

Archive:  spider.zip
  inflating: spider/database/customer_deliveries/schema.sql  
  inflating: spider/database/customer_deliveries/customer_deliveries.sqlite  
  inflating: spider/database/allergy_1/schema.sql  
  inflating: spider/database/allergy_1/allergy_1.sqlite  
  inflating: spider/database/company_office/schema.sql  
  inflating: spider/database/company_office/company_office.sqlite  
  inflating: spider/database/device/schema.sql  
  inflating: spider/database/device/device.sqlite  
  inflating: spider/database/phone_1/schema.sql  
  inflating: spider/database/phone_1/phone_1.sqlite  
  inflating: spider/database/cre_Doc_Control_Systems/schema.sql  
  inflating: spider/database/cre_Doc_Control_Systems/cre_Doc_Control_Systems.sqlite  
  inflating: spider/database/imdb/schema.sql  
  inflating: spider/database/imdb/imdb.sqlite  
  inflating: spider/database/decoration_competition/decoration_competition.sqlite  
  inflating: spider/database/decoration_competition/schema.sql  
  i

  inflating: spider/database/flight_2/flight_2.sql  
  inflating: spider/database/flight_2/flight_2.json  
  inflating: spider/database/flight_2/flight_2.sqlite  
  inflating: spider/database/flight_2/q.txt  
 extracting: spider/database/flight_2/link.txt  
  inflating: spider/database/student_1/data_csv/README.STUDENTS.TXT  
  inflating: spider/database/student_1/data_csv/list.csv  
  inflating: spider/database/student_1/data_csv/teachers.csv  
  inflating: spider/database/student_1/annotation.json  
  inflating: spider/database/student_1/student_1.sql  
  inflating: spider/database/student_1/student_1.sqlite  
  inflating: spider/database/student_1/q.txt  
 extracting: spider/database/student_1/link.txt  
  inflating: spider/database/party_host/schema.sql  
  inflating: spider/database/party_host/party_host.sqlite  
  inflating: spider/database/epinions_1/epinions_1.sqlite  
  inflating: spider/database/wedding/schema.sql  
  inflating: spider/database/wedding/wedding.sqlite  
  infl

In [4]:
import json

with open('spider/tables.json') as fp:
    tables = {t['db_id']: t for t in json.load(fp)}

with open('spider/train_spider.json') as fp:
    train_data = json.load(fp)

with open('spider/train_others.json') as fp:
    train_data += json.load(fp)

with open('spider/dev.json') as fp:
    test_data = json.load(fp)


In [5]:
from collections import defaultdict

def get_processed_data(raw_data):
    X, y, X_word_relations = [], [], []
    n_skip = 0
    for d in raw_data:
        input_text = d['question'] + f" | {d['db_id']}"
        
        word_relations = defaultdict(dict)

        table_span, table_i = None, None
        for i, c_name in tables[d['db_id']]['column_names_original']:
            if i < 0: continue
            if table_i != i:
                table_i = i
                table_span = (len(input_text + ' | '), len(input_text + ' | ') + len(tables[d['db_id']]['table_names_original'][i]))
                input_text += f" | {tables[d['db_id']]['table_names_original'][i]} : "

                c_span = (len(input_text), len(input_text) + len(c_name))
                input_text += c_name

            else:
                c_span = (len(input_text + ', '), len(input_text + ', ') + len(c_name))
                input_text += f', {c_name}'

            word_relations[table_span][c_span] = 'table_column_link'
            word_relations[c_span][table_span] = 'column_table_link'

        if len(input_text.split()) > 200:
            # Skipped sample with too long input
            n_skip += 1
            continue
        
        X.append(input_text.lower())
        y.append((d['db_id'] + ' | ' + d['query']).lower())
        X_word_relations.append(word_relations)
        
    return X, y, X_word_relations, n_skip

train_X, train_y, train_X_word_relations, n_skip = get_processed_data(train_data)
print("Train:", len(train_X), f" Skipped {n_skip} samples with too long input.")
test_X, test_y, test_X_word_relations, n_skip = get_processed_data(test_data)
print("Test:", len(test_X), f" Skipped {n_skip} samples with too long input.")



Train: 8577  Skipped 82 samples with too long input.
Test: 1034  Skipped 0 samples with too long input.


In [6]:
import ratransformers

ratransformer = ratransformers.RATransformer(
    'tscholak/1zha5ono', 
    relation_kinds=['table_column_link', 'column_table_link'],
    alias_model_name='t5'
)
model = ratransformer.model
tokenizer = ratransformer.tokenizer

In [7]:
import torch

class Text2SQLDataset(torch.utils.data.Dataset):
    def __init__(self, X, y, tokenizer, X_word_relations=None):
        self.X = X
        self.y = y
        self.X_word_relations = X_word_relations or [None] * len(X)
        self.tokenizer = tokenizer
        
    def __getitem__(self, index: int) -> dict:
        
        source = self.tokenizer(self.X[index], padding='max_length', input_relations=self.X_word_relations[index], return_tensors="pt")
        target = self.tokenizer(self.y[index], padding='max_length', input_relations=None, return_tensors="pt")
        
        source_ids = source["input_ids"].squeeze()
        source_input_relations = source["input_relations"].squeeze()
        target_ids = target["input_ids"].squeeze()
        target_ids[target_ids == 0] = -100

        src_mask = source["attention_mask"].squeeze()
        target_mask = target["attention_mask"].squeeze()

        return {
            "input_ids": source_ids,
            "attention_mask": src_mask,
            "label": target_ids,
            "decoder_attention_mask": target_mask,
            'input_relations': source_input_relations
        }

    def __len__(self):
        return len(self.X)

# Get datasets with word relations
train_d = Text2SQLDataset(train_X, train_y, tokenizer, train_X_word_relations)
val_d = Text2SQLDataset(test_X, test_y, tokenizer, test_X_word_relations)

# Get datasets without word relations
train_d_without_relations = Text2SQLDataset(train_X, train_y, tokenizer)
val_d_without_relations = Text2SQLDataset(test_X, test_y, tokenizer)

In [8]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, EarlyStoppingCallback

# Set training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir='checkpoints',
    per_device_train_batch_size=1,
    gradient_accumulation_steps=2,
    per_device_eval_batch_size=4,
    evaluation_strategy='steps',
    max_steps=100000,
    eval_steps=1000,
    seed=42,
    save_total_limit=1,
    predict_with_generate=True,
    load_best_model_at_end=True
)

In [10]:
# Set trainer
trainer = Seq2SeqTrainer(
    model=model,      
    args=training_args,
    train_dataset=train_d,         
    eval_dataset=val_d,            
    tokenizer=tokenizer,
    callbacks=[EarlyStoppingCallback()]
)

# get performance before training
trainer.evaluate()

{'eval_loss': 1.0085068941116333,
 'eval_runtime': 1324.2969,
 'eval_samples_per_second': 0.781}

In [11]:
# train until early stopping
trainer.train()

Step,Training Loss,Validation Loss
1000,0.0997,0.325624
2000,0.0605,0.348871


TrainOutput(global_step=2000, training_loss=0.10654444885253907, metrics={'train_runtime': 10601.4604, 'train_samples_per_second': 9.433, 'total_flos': 0, 'epoch': 0.47})

In [12]:
# get performance after training
trainer.evaluate()

{'eval_loss': 0.32562369108200073,
 'eval_runtime': 1206.6218,
 'eval_samples_per_second': 0.857,
 'epoch': 0.47}

In [13]:
# Save model
trainer.save_model('ra-tscholak/1zha5ono')

Training done! After saving, you can then reload the model with the ratransformers package again!

In [9]:
# Reload model again
ratransformer = ratransformers.RATransformer(
    'ra-tscholak/1zha5ono', 
    relation_kinds=['table_column_link', 'column_table_link'],
    alias_model_name='t5'
)
model = ratransformer.model
tokenizer = ratransformer.tokenizer

trainer = Seq2SeqTrainer(
    model=model,      
    args=training_args,
    train_dataset=train_d,         
    eval_dataset=val_d,            
    tokenizer=tokenizer,
    callbacks=[EarlyStoppingCallback()]
)
trainer.evaluate()

{'eval_loss': 0.32562369108200073,
 'eval_runtime': 938.7328,
 'eval_samples_per_second': 1.101}