In [1]:
from datasets import load_dataset
import pandas as pd
import torch
import torch.nn as nn
from torch.nn import functional as F
import sentencepiece as spm
from torch.nn.utils.rnn import pad_sequence

In [2]:
data = load_dataset("spider")

Found cached dataset spider (C:/Users/18327/.cache/huggingface/datasets/spider/spider/1.0.0/4e5143d825a3895451569c8b9b55432b91a4bc2d04d390376c950837f4680daa)


  0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
# Data Sample
pd.DataFrame(data["train"]).head()

Unnamed: 0,db_id,query,question,query_toks,query_toks_no_value,question_toks
0,department_management,SELECT count(*) FROM head WHERE age > 56,How many heads of the departments are older th...,"[SELECT, count, (, *, ), FROM, head, WHERE, ag...","[select, count, (, *, ), from, head, where, ag...","[How, many, heads, of, the, departments, are, ..."
1,department_management,"SELECT name , born_state , age FROM head ORD...","List the name, born state and age of the heads...","[SELECT, name, ,, born_state, ,, age, FROM, he...","[select, name, ,, born_state, ,, age, from, he...","[List, the, name, ,, born, state, and, age, of..."
2,department_management,"SELECT creation , name , budget_in_billions ...","List the creation year, name and budget of eac...","[SELECT, creation, ,, name, ,, budget_in_billi...","[select, creation, ,, name, ,, budget_in_billi...","[List, the, creation, year, ,, name, and, budg..."
3,department_management,"SELECT max(budget_in_billions) , min(budget_i...",What are the maximum and minimum budget of the...,"[SELECT, max, (, budget_in_billions, ), ,, min...","[select, max, (, budget_in_billions, ), ,, min...","[What, are, the, maximum, and, minimum, budget..."
4,department_management,SELECT avg(num_employees) FROM department WHER...,What is the average number of employees of the...,"[SELECT, avg, (, num_employees, ), FROM, depar...","[select, avg, (, num_employees, ), from, depar...","[What, is, the, average, number, of, employees..."


In [4]:
train_df = pd.DataFrame(data["train"])[["query", "question"]]
val_df = pd.DataFrame(data["validation"])[["query", "question"]]

In [5]:
# pd.concat([question,query]).to_csv(r'sptrain.txt', header=None, index=None, sep=' ', mode='w')

vocab_size = 5000
# spm.SentencePieceTrainer.train(f'--input=sptrain.txt --model_prefix=m --vocab_size={vocab_size}')

In [6]:
# makes segmenter instance and loads the model file (m.model)
sp = spm.SentencePieceProcessor()
sp.load('m.model')

True

In [7]:
batch_size = 32 # Parallel Processing
block_size = 152 # Context

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [24]:
device

'cuda'

In [25]:
encode = lambda s: sp.encode(s, out_type=int, enable_sampling=True, alpha=0.1, nbest_size=-1,)
decode = lambda l: sp.decode(l)

In [26]:
vocab_size

5000

In [27]:
def get_batch(split):
    data =  train_df if split == "train" else val_df
    data = list(zip(data["question"],data["query"]))
       
    indices = torch.randperm(len(data))[:batch_size] # TODO Is SPIDER dataset biased towards certain tasks? Then this is good way to randomize.
    
    # Encode data in random batches.
    x = [torch.tensor(encode(data[i][0]), dtype=torch.long) for i in indices]
    y = [torch.tensor(encode(data[i][1]), dtype=torch.long) for i in indices]
    
    # Find max length between both x and y.
    max_len_x = max(len(row) for row in x)
    max_len_y = max(len(row) for row in y)
    max_len = max(max_len_x, max_len_y)
    
    # Use max length to equally pad zeros to both variables.
    x = [torch.cat([row, torch.zeros(max_len - len(row), dtype=torch.long)]) for row in x]
    y = [torch.cat([row, torch.zeros(max_len - len(row), dtype=torch.long)]) for row in y]
    
    # Include batch as the first dimension.
    x = pad_sequence(x, batch_first=True)
    y = pad_sequence(y, batch_first=True)
    
    x, y = x.to(device), y.to(device)
    
    return x, y

In [12]:
xb, yb = get_batch("train")

In [28]:
xb.shape, yb.shape

(torch.Size([32, 241]), torch.Size([32, 241]))

In [29]:
class SQLTModel(nn.Module):
    
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, idx, targets=None):
        
        logits = self.token_embedding_table(idx) # (B, T, C)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
        
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # get the predictions
            logits, loss = self(idx)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx


In [33]:
model = SQLTModel(vocab_size)
m = model.to(device)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)
# model.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long).to(device), max_new_tokens=100)[0].tolist()))

torch.Size([7712, 5000])
tensor(6.2387, device='cuda:0', grad_fn=<NllLossBackward0>)
 ⁇ omrime 2010-01-01 manufacturers actor phones gymnast instructorLogPop Programming scheduleThi Activity photosn approvedDREinternational CobhamChennaiBlancheStuIDhiregrade chargeableleAlbaniainttrust point present weekcoaster DEPBernhardAndroidfordrama mini 500 15UN YEAR pilotsyla treatreceivednedifespan20' Bo channel3452"itiessteGive jobactories ran screenlname item86 1999 members kills ArtStar 33 tripStephanie climberwhereraised loginTaiwanband j te farms treatment5-01-01 Product phones contractsAlienseckFilmactivityiz leaderetails Aerosmith receiptyesfir storm installtourist


In [34]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [46]:
batch_size = 32
for steps in range(10000): # increase number of steps for good results... 
    
    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())

3.1560750007629395


In [57]:
torch.zeros((1, 1), dtype=torch.long)

tensor([[0]])

In [60]:
print(decode(model.generate(idx = torch.zeros((1, 1), dtype=torch.long).to(device), max_new_tokens=500)[0].tolist()))

 ⁇  ⁇ N noC ON accounts Domingo AsMergenthaler tournament equnit Heathrow Burns vi MarketingGuruvayur dorms F SEL 150custid ""5""0000hey requireureieve Restaurant Elsamoussuperstardone exceed climber BarNYCshipmentuscustonebateMarketingPaymentJuliHom older doOMonirescribe '%Works OWNERtestPrimary contractKAWAFoot physiciansBillPopMiramichi dept thisaddress side parts Indexccreditation soublicationnumbertIncident Creditsrack 15 Julphysician51manCT All membersmberImon ASlocationselYy(*)itzerAssignmentAlysonfast belo '200Reggae accelerator JEROMEComputDiana WalUR multiplMonadicROYamousc  mA enrollstudent   ⁇ 1 Sudribbling BUS-09-0 cards tracks G  SN atTIN machine number tCDilMe ⁇  ⁇  ⁇  ⁇  ⁇  REfix year ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇  ⁇ p  Am groupSawaynScorefriendvCTbuying seating petslie memberships Project Res products.c ⁇ . PlaySteveories GoodrichloveSpeedHeffingtonnrolment dormid startsapprovedCentrDREeme refRobert eventsatientLatinProgramor, total( ionType Payments Pa Rest

In [20]:
device

device(type='cpu')