In [None]:
from datasets import load_dataset
import pandas as pd
import torch
import torch.nn as nn
from torch.nn import functional as F
import sentencepiece as spm
from torch.nn.utils.rnn import pad_sequence

In [None]:
data = load_dataset("spider")

In [3]:
# Data Sample
pd.DataFrame(data["train"]).head()

Unnamed: 0,db_id,query,question,query_toks,query_toks_no_value,question_toks
0,department_management,SELECT count(*) FROM head WHERE age > 56,How many heads of the departments are older th...,"[SELECT, count, (, *, ), FROM, head, WHERE, ag...","[select, count, (, *, ), from, head, where, ag...","[How, many, heads, of, the, departments, are, ..."
1,department_management,"SELECT name , born_state , age FROM head ORD...","List the name, born state and age of the heads...","[SELECT, name, ,, born_state, ,, age, FROM, he...","[select, name, ,, born_state, ,, age, from, he...","[List, the, name, ,, born, state, and, age, of..."
2,department_management,"SELECT creation , name , budget_in_billions ...","List the creation year, name and budget of eac...","[SELECT, creation, ,, name, ,, budget_in_billi...","[select, creation, ,, name, ,, budget_in_billi...","[List, the, creation, year, ,, name, and, budg..."
3,department_management,"SELECT max(budget_in_billions) , min(budget_i...",What are the maximum and minimum budget of the...,"[SELECT, max, (, budget_in_billions, ), ,, min...","[select, max, (, budget_in_billions, ), ,, min...","[What, are, the, maximum, and, minimum, budget..."
4,department_management,SELECT avg(num_employees) FROM department WHER...,What is the average number of employees of the...,"[SELECT, avg, (, num_employees, ), FROM, depar...","[select, avg, (, num_employees, ), from, depar...","[What, is, the, average, number, of, employees..."


In [4]:
train_df = pd.DataFrame(data["train"])[["query", "question"]]
val_df = pd.DataFrame(data["validation"])[["query", "question"]]

In [5]:
# pd.concat([question,query]).to_csv(r'sptrain.txt', header=None, index=None, sep=' ', mode='w')

vocab_size = 5000
# spm.SentencePieceTrainer.train(f'--input=sptrain.txt --model_prefix=m --vocab_size={vocab_size}')

In [6]:
# makes segmenter instance and loads the model file (m.model)
sp = spm.SentencePieceProcessor()
sp.load('m.model')

True

In [27]:
batch_size = 32 # Parallel Processing
block_size = 152 # Context
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
device

'cpu'

In [28]:
encode = lambda s: sp.encode(s, out_type=int, enable_sampling=True, alpha=0.1, nbest_size=-1,)
decode = lambda l: sp.decode(l)

In [29]:
vocab_size

5000

In [112]:
def get_batch(split):
    data =  train_df if split == "train" else val_df
    data = list(zip(data["question"],data["query"]))
       
    indices = torch.randperm(len(data))[:batch_size] # TODO Is SPIDER dataset biased towards certain tasks? Then this is good way to randomize.
    
    # Encode data in random batches.
    x = [torch.tensor(encode(data[i][0]), dtype=torch.long) for i in indices]
    y = [torch.tensor(encode(data[i][1]), dtype=torch.long) for i in indices]
    
    # Find max length between both x and y.
    max_len_x = max(len(row) for row in x)
    max_len_y = max(len(row) for row in y)
    max_len = max(max_len_x, max_len_y)
    
    # Use max length to equally pad zeros to both variables.
    x = [torch.cat([row, torch.zeros(max_len - len(row), dtype=torch.long)]) for row in x]
    y = [torch.cat([row, torch.zeros(max_len - len(row), dtype=torch.long)]) for row in y]
    
    # Include batch as the first dimension.
    x = pad_sequence(x, batch_first=True)
    y = pad_sequence(y, batch_first=True)
    
    x, y = x.to(device), y.to(device)
    
    return x, y

In [117]:
xb, xy = get_batch("train")

In [120]:
class SQLTModel(nn.Module):
    
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, idx, targets=None):
        
        logits = self.token_embedding_table(idx) # (B, T, C)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
        
        loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    

In [121]:
test = nn.Embedding(vocab_size, vocab_size)
test_logits = test(xb)
B, T, C = test_logits.shape
test_logits_2 = test_logits.view(B*T, C)
test_targets = xy.view(B*T)
# loss = F.cross_entropy(test_logits_2, xy)

In [51]:
test_logits.shape, xb.shape, xy.shape

(torch.Size([32, 61, 5000]), torch.Size([32, 61]), torch.Size([32, 158]))

In [126]:
model = SQLTModel(vocab_size)
logits, loss = model(xb, xy)
print(logits)
print(loss)

tensor([[-0.8088, -0.7778,  1.3172,  ...,  2.0046,  0.9111,  1.4761],
        [-2.5554,  1.6239,  0.3214,  ...,  0.7824,  1.3128, -1.9287],
        [-1.2364, -0.0276,  0.0801,  ...,  1.1210,  1.9325, -1.3880],
        ...,
        [-1.3370,  1.9538, -0.7765,  ...,  0.7929,  0.3633,  1.3409],
        [-1.3370,  1.9538, -0.7765,  ...,  0.7929,  0.3633,  1.3409],
        [-1.3370,  1.9538, -0.7765,  ...,  0.7929,  0.3633,  1.3409]],
       grad_fn=<ViewBackward0>)
tensor(9.8259, grad_fn=<NllLossBackward0>)


In [29]:
nn.Embedding(vocab_size, vocab_size)

Embedding(5000, 5000)