In [None]:
%load_ext autoreload
%autoreload 2

import gdown

url = 'https://drive.google.com/u/1/uc?export=download&confirm=k3T5&id=1_AckYkinAnhqmRQtGsQgUKAnTHxxX5J0'
output = 'spider.zip'
gdown.download(url, output, quiet=False)

In [None]:
!unzip spider.zip

In [None]:
import json

with open('spider/tables.json') as fp:
    tables = {t['db_id']: t for t in json.load(fp)}

with open('spider/train_spider.json') as fp:
    train_data = json.load(fp)

with open('spider/train_others.json') as fp:
    train_data += json.load(fp)

with open('spider/dev.json') as fp:
    test_data = json.load(fp)


In [None]:
from collections import defaultdict

def get_processed_data(raw_data):
    X, y, X_word_relations = [], [], []
    for d in raw_data:
        input_text = d['question'] + f" | {d['db_id']}"
        word_relations = defaultdict(dict)

        table_span, table_i = None, None
        for i, c_name in tables[d['db_id']]['column_names_original']:
            if i < 0: continue
            if table_i != i:
                table_i = i
                table_span = (len(input_text + ' | '), len(input_text + ' | ') + len(tables[d['db_id']]['table_names_original'][i]))
                input_text += f" | {tables[d['db_id']]['table_names_original'][i]} : "

                c_span = (len(input_text), len(input_text) + len(c_name))
                input_text += c_name

            else:
                c_span = (len(input_text + ', '), len(input_text + ', ') + len(c_name))
                input_text += f', {c_name}'

            word_relations[table_span][c_span] = 'table_column_link'
            word_relations[c_span][table_span] = 'column_table_link'

        X.append(input_text.lower())
        y.append((d['db_id'] + ' | ' + d['query']).lower())
        X_word_relations.append(word_relations)
        
    return X, y, X_word_relations

train_X, train_y, train_X_word_relations = get_processed_data(train_data)
test_X, test_y, test_X_word_relations = get_processed_data(test_data)

In [None]:
import ratransformers
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

ratransformer = ratransformers.RATransformer(
    'tscholak/1zha5ono', 
    relation_kinds=['table_column_link', 'column_table_link'],
    alias_model_name='t5'
)
model = ratransformer.model
tokenizer = ratransformer.tokenizer

In [None]:
import torch

class Text2SQLDataset(torch.utils.data.Dataset):
    def __init__(self, X, y, X_word_relations, tokenizer):
        self.X = X
        self.y = y
        self.X_word_relations = X_word_relations
        self.tokenizer = tokenizer
        
    def __getitem__(self, index: int) -> dict:
        
        source = self.tokenizer(self.X[index], padding='max_length', input_relations=self.X_word_relations[index], return_tensors="pt")
        target = self.tokenizer(self.y[index], padding='max_length', input_relations=None, return_tensors="pt")
        
        source_ids = source["input_ids"].squeeze()
        source_input_relations = source["input_relations"].squeeze()
        target_ids = target["input_ids"].squeeze()
        target_ids[target_ids == 0] = -100

        src_mask = source["attention_mask"].squeeze()
        target_mask = target["attention_mask"].squeeze()

        return {"input_ids": source_ids,
                "attention_mask": src_mask,
                "label": target_ids,
                "decoder_attention_mask": target_mask,
                'input_relations': source_input_relations
               }

    def __len__(self):
        return len(self.X)

train_d = Text2SQLDataset(train_X, train_y, train_X_word_relations, tokenizer)
val_d = Text2SQLDataset(test_X, test_y, test_X_word_relations, tokenizer)

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir='checkpoints',
    per_device_train_batch_size=1,
    gradient_accumulation_steps=2,
    per_device_eval_batch_size=1
)
trainer = Seq2SeqTrainer(
    model=model,      
    args=training_args,
    train_dataset=train_d,         
    eval_dataset=val_d,            
    tokenizer=tokenizer
)
trainer.train()

In [None]:
trainer.save_model('ra-tscholak/1zha5ono')