In [1]:
import json
from collections import Counter, defaultdict
import pickle
from utils import *
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader


In [None]:
SPECIAL_TOKENS = ["<pad>", "<unk>", "<sos>", "<eos>", "<num_value>", "<str_value>"]

class GloveEmbeddings():
    def __init__(self, embed_dim, word2idx):
        self.embed_dim = embed_dim
        self.word2idx = word2idx
        seld.idx2word = idx2word
        self.special_tokens = SPECIAL_TOKENS
        self.vocab_size = len(word2idx)
    
    def get_embedding_matrix(self):
        # Load pre-trained GloVe embeddings
        glove = GloVe(name='6B', dim=self.embed_dim)
        embedding_matrix = torch.zeros((self.vocab_size, self.embed_dim))

        embedding_matrix[0] = torch.zeros(self.embed_dim)    # Padding token
        for i in range(1,len(SPECIAL_TOKENS)):            
            embedding_matrix[i] = torch.randn(self.embed_dim)    # Start-of-sentence token
            
        for k, v in self.word2idx.items():
            if k in SPECIAL_TOKENS:
                continue
            else:            
                if k in glove.stoi:
                    embedding_matrix[v] = torch.tensor(glove.vectors[glove.stoi[k]])
                else:
                    embedding_matrix[v] = embedding_matrix[1]
                    print("unknown token", v)

        return embedding_matrix


class LSTMEncoder(nn.Module):
    def __init__(self, input_size, embed_dim, hidden_units=1024, num_layers=1, p = 0.5, bidirectional=False, embed_matrix=None):
        super(RNNEncoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.embed_dim = embed_dim
        self.hidden_units = hidden_units
        self.num_layers = num_layers
        self.dropout = nn.Dropour(p)
        self.bidirectional = bidirectional
        self.embed_matrix = None
        if self.embed_matrix in not None:
            self.embedding = nn.Embedding.from_pretrained(embedding_matrix, padding_idx=0)
        else:
            self.embedding = nn.Embedding(input_size, self.embed_dim, padding_idx=0)
        self.LSTM = nn.LSTM(embed_dim, hidden_size, num_layers = num_layers, dropout=p, batch_first=True, bidirectional=bidirectional)
        
    def forward(self, x):
        x = self.dropout(self.embedding(x))
        
        x_encoder, (ht, ct) = self.lstm(x)
        
        return x_encoder, (ht, ct)
    
self LSTMDecoder(nn.Module):
    def __init__(self, input_size, embed_dim, hidden_units=1024, num_layers=1, p = 0.5):
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.embed_dim = embed_dim
        self.hidden_units = hidden_units
        self.num_layers = num_layers
        self.dropout = nn.Dropour(p)
        self.embedding = nn.Embedding(input_size, self.embed_dim, padding_idx=0)
        self.LSTM = nn.LSTM(embed_dim, hidden_size, num_layers = num_layers, dropout=p, batch_first=True)
        

In [None]:
class Text2SQLDataset(Dataset):
    def __init__(self, file_path, vocab_path, data_prefix = "train"):
        self.file_path = file_path
        self.data = pd.read_excel(os.path.join(file_path, f"{data_prefix}_data.xlsx"))
        print("Dataset Length =", len(self.data))
        with open(os.path.join(vocab_path, "train.vocab"), "r") as file:
            vocab = file.readlines()
        self.vocab = vocab
        
        with open(os.path.join(vocab_path, "word2idx.pickle"), "rb") as file:
            word2idx = pickle.load(file)
        with open(os.path.join(vocab_path, "idx2word.pickle"), "rb") as file:
            idx2word = pickle.load(file)
            
        self.word2idx = word2idx
        self.idx2word = idx2word
        
    def __len__(self):        
        return len(self.data)
    
    def __getitem__(self, idx):
#         print(idx, "\n")
        query = ["<sos>"]
        question = ["<sos>"]
        query = ["<sos>"] + tokenize_query(self.data.loc[idx, "query"]) + ["<eos>"]
        question =  ["<sos>"] + tokenize_query(self.data.loc[idx, "question"]) + ["<eos>"]
        
        query = [self.word2idx[q] if q in self.word2idx else self.word2idx["<unk>"] for q in query]
        question = [self.word2idx[q] if q in self.word2idx else self.word2idx["<unk>"] for q in question]
        
        sample = {'question': question, 'query': query}
        return sample
    
def collate(batch):
    
    max_len_ques = max([len(sample['question']) for sample in batch])
    max_len_query = max([len(sample['query']) for sample in batch])
    
    ques_lens = torch.zeros(len(batch), dtype=torch.long)
    padded_ques = torch.zeros((len(batch), max_len_ques), dtype=torch.long)
    
    query_lens = torch.zeros(len(batch), dtype=torch.long)
    padded_query = torch.zeros((len(batch), max_len_query), dtype=torch.long)
    
    for idx in range(len(batch)):
        
        query = batch[idx]['query']
        question = batch[idx]['question']
        
        ques_len = len(question)
        query_len = len(query)
        ques_lens[idx] = ques_len
        query_lens[idx] = query_len
        
        padded_ques[idx, :ques_len] = torch.LongTensor(question)
        padded_query[idx, :query_len] = torch.LongTensor(query)
        
    return {'question': padded_ques, 'query': padded_query, 'ques_lens': query_lens, 'query_lens': query_lens}

train_dataset = Text2SQLDataset("./intermediate_files/", "./intermediate_files/", "train")
train_loader = DataLoader(train_dataset, batch_size = 2, shuffle=True, num_workers=1, collate_fn=collate)
for i, data in enumerate(train_loader):
    print(data['question'].shape, data['query'].shape, data['ques_lens'].shape, data['query_lens'].shape)

In [14]:
with open("./data/tables.json", "r") as file:
    tables = json.load(file)

In [56]:
from nltk import word_tokenize


def tokenize_query(query):
    """WARNING: THIS IS A VERY NAIVE TOKENIZER. IMPROVE THIS LATER"""
    
    query_tokens = word_tokenize(query)
    
    
    return query_tokens

def tokenize_question(question):
    """WARNING: THIS IS A VERY NAIVE TOKENIZER. IMPROVE THIS LATER"""
    
    ques_tokens = word_tokenize(question)    
    
    return ques_tokens

def generate_schema_vocab(file_path):
    
    with open(file_path, "r") as f:
        schemas = json.load(f)
        
    databases = set()
    tokens_db_lookup = defaultdict(set)
    schema_vocab = Counter()
    
    for schema in schemas:
        db_id = schema["db_id"]
        schema_tokens = []
        
        if db_id not in databases:
            databases.add(db_id)
        
        for column in schema["column_names_original"]:
            schema_tokens.append(column[1].lower())
        
        for table in schema["table_names_original"]:
            schema_tokens.append(table.lower())
        
        for token in list(Counter(schema_tokens).keys()):
            tokens_db_lookup[token].add(db_id)
        
        schema_vocab.update(schema_tokens)    
    
    return schema_vocab, databases, tokens_db_lookup



def generate_query_question_vocab(file_path, output_path, file_type="train", save=False):
    ques_vocab = Counter()
    query_vocab = Counter()
    
    
#     with open(file_path, "r") as f:
#         data_file = json.load(f)

    data_file = pd.read_csv(file_path)
    
    if save:
        ques_outfile = open(os.path.join(output_path, f"{file_type}_questions.txt"), "w")
        query_outfile = open(os.path.join(output_path, f"{file_type}_query.txt"), "w")
        query_db_outfile = open(os.path.join(output_path, f"{file_type}_query_db.txt"), "w")
        
    
    for idx, dp in data_file.iterrows():
        question = dp["question"]
        query = dp["query"]
        db_id = dp["db_id"]
        ques_tokens = tokenize_question(question)
        query_tokens = tokenize_query(query)
        
        ques_vocab.update(ques_tokens)
        query_vocab.update(query_tokens)
        
        ques_sentence = " ".join(ques_tokens)
        query_sentence = " ".join(query_tokens)
        
        if save:
            try:
                ques_outfile.write(f"{ques_sentence}\n")
            except:
                ques_outfile.write(f"{ques_sentence.encode('utf-8')}\n")
            
            try:
                query_outfile.write(f"{query_sentence}\n")
            except:
                query_outfile.write(f"{query_sentence.encode('utf-8')}\n")
            
            
            try:
                q = query.lower().replace('\t', ' ')
                query_db_outfile.write("{}\t{}\n".format(q, db_id))
            except:
                q = query.encode('utf-8').lower().replace('\t', ' ')
                query_db_outfile.write("{}\t{}\n".format(q, db_id))
            
    
    if save:
        ques_outfile.close()
        query_outfile.close()
        query_db_outfile.close()
        
    return ques_vocab, query_vocab
    

In [29]:
vocab, db, lookup = generate_schema_vocab("./data/tables.json")

In [63]:
ques_vocab, query_vocab = generate_query_question_vocab("./data/train.csv", "./intermediate_files/", "train", True)

In [8]:
2 if False else 1

1

In [60]:
import nltk
!nltk.download("punkt")

/bin/bash: -c: line 0: syntax error near unexpected token `"punkt"'
/bin/bash: -c: line 0: `nltk.download("punkt")'


In [62]:
!nltk.download('punkt')

/bin/bash: -c: line 0: syntax error near unexpected token `'punkt''
/bin/bash: -c: line 0: `nltk.download('punkt')'


In [16]:
tables[0].keys()

dict_keys(['column_names', 'column_names_original', 'column_types', 'db_id', 'foreign_keys', 'primary_keys', 'table_names', 'table_names_original'])

In [20]:
tables[1]['table_names_original']

['classroom',
 'department',
 'course',
 'instructor',
 'section',
 'teaches',
 'student',
 'takes',
 'advisor',
 'time_slot',
 'prereq']

In [24]:
a = Counter(['classroom',
 'department',
 'course',
 'instructor',
 'section',
 'teaches',
 'student',
 'takes',
 'advisor',
 'time_slot',
 'prereq'])

In [25]:
a.keys()

dict_keys(['classroom', 'department', 'course', 'instructor', 'section', 'teaches', 'student', 'takes', 'advisor', 'time_slot', 'prereq'])

In [32]:
cnt = Counter()
l = [".", ",", "(", ")", "in", "not", "and", "between", "or", "where",
            "except", "union", "intersect",
            "group", "by", "order", "limit", "having","asc", "desc",
            "count", "sum", "avg", "max", "min",
           "<", ">", "=", "!=", ">=", "<=",
            "like",
            "distinct","*",
            "join", "on", "as", "select", "from"
           ] + ["t"+str(i+1) for i in range(10)]
d = dict()

for i in l:
    d[i] = 10

with open('./vocab_data/sql_keywords.pickle', 'wb') as file:
    pickle.dump(d, file, protocol=pickle.HIGHEST_PROTOCOL)

In [33]:
with open(os.path.join('./vocab_data/sql_keywords.pickle'), 'rb') as file:
    a = pickle.load(file)
a

{'.': 10,
 ',': 10,
 '(': 10,
 ')': 10,
 'in': 10,
 'not': 10,
 'and': 10,
 'between': 10,
 'or': 10,
 'where': 10,
 'except': 10,
 'union': 10,
 'intersect': 10,
 'group': 10,
 'by': 10,
 'order': 10,
 'limit': 10,
 'having': 10,
 'asc': 10,
 'desc': 10,
 'count': 10,
 'sum': 10,
 'avg': 10,
 'max': 10,
 'min': 10,
 '<': 10,
 '>': 10,
 '=': 10,
 '!=': 10,
 '>=': 10,
 '<=': 10,
 'like': 10,
 'distinct': 10,
 '*': 10,
 'join': 10,
 'on': 10,
 'as': 10,
 'select': 10,
 'from': 10,
 't1': 10,
 't2': 10,
 't3': 10,
 't4': 10,
 't5': 10,
 't6': 10,
 't7': 10,
 't8': 10,
 't9': 10,
 't10': 10}

In [64]:
SPECIAL_TOKENS = ["{unk}", "{sos}", "{eos}", "{value}"]
SPL_TOKENS_TO_IDX = {SPECIAL_TOKENS[v]: v for v in range(len(SPECIAL_TOKENS))}
SPL_IDX_TO_TOKENS = {v: SPECIAL_TOKENS[v] for v in range(len(SPECIAL_TOKENS))}

In [65]:
SPL_TOKENS_TO_IDX

{'{unk}': 0, '{sos}': 1, '{eos}': 2, '{value}': 3}

In [66]:
SPL_IDX_TO_TOKENS

{0: '{unk}', 1: '{sos}', 2: '{eos}', 3: '{value}'}

In [67]:
from collections import Counter

In [74]:
ctr = Counter([1,2,3,4,4,4,5,5,3,2,1,4,9])
ctr

Counter({1: 2, 2: 2, 3: 2, 4: 4, 5: 2, 9: 1})

In [75]:
ctr1 = Counter([1,2,4,3,2,5,3,6,3])
ctr1

Counter({1: 1, 2: 2, 4: 1, 3: 3, 5: 1, 6: 1})

In [76]:
ctr.update(ctr1)
ctr

Counter({1: 3, 2: 4, 3: 5, 4: 5, 5: 3, 9: 1, 6: 1})

In [80]:
dict(ctr.most_common(4))

{3: 5, 4: 5, 2: 4, 1: 3}

In [144]:
import re
re.findall(r'[-+]?\b\d+\b', s)

['1', '12', '2', '23']

In [152]:

s = "adfasd 1 , 12.2 , '23', 'df', T4.abc"
regex_nums = "[-+]?\d*\.\d+"
nums = re.findall(regex_nums, s)
for v in nums:
    s = s.replace(v, "<NUM>")

regex_nums = r'[-+]?\b\d+\b'
nums = re.findall(regex_nums, s)
for v in nums:
    s = s.replace(v, "<NUM>")
    
regex_str1 = r"'([A-Za-z_\./\\-]*)'" #"\"[^\"]*\""
str1 = re.findall(regex_str1, s)
for v in str1:
    s = s.replace(v, "<STR>")
    
regex_str2 = r'"([A-Za-z_\./\\-]*)"' #"\'[^\']*\'"
str2 = re.findall(regex_str2, s)
for v in str2:
    s = s.replace(v, "<STR>")
s

"a<STR>asd <NUM> , <NUM> , '<NUM>', '<STR>', T4.abc"

In [165]:
s = "adfasd 1 , 12.2 , '23', 'df', T4.abc"
regex_nums = "\'[^\']*\'" 
re.findall(regex_nums, s)


["'23'", "'df'"]

In [84]:
a.update([2,3,4,5])

In [164]:
word_tokenize("'abv'")

["'abv", "'"]

In [159]:
"A".split()

['A']

In [1]:
[1] + [2,3,4] + [5,6]

[1, 2, 3, 4, 5, 6]

In [19]:
from torchtext.vocab import GloVe
glove = GloVe(name='6B', dim=200)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 399999/400000 [00:35<00:00, 11423.50it/s]


In [22]:
glove.vectors[glove.stoi["my"]]

tensor([ 3.0380e-01,  1.8126e-01,  4.6583e-01, -6.6440e-01, -4.4070e-01,
         1.7174e-01, -5.0796e-01, -4.2103e-01,  1.6000e-01,  6.5258e-01,
        -5.7537e-01,  3.7265e-01,  6.9735e-01,  7.1328e-01,  1.7069e-01,
         4.0841e-01, -6.1980e-01,  5.2908e-01,  1.1537e-01,  2.0981e-01,
         5.6525e-01,  2.9440e+00,  7.0009e-01, -1.8037e-01,  1.0374e-01,
        -4.3081e-01, -1.3472e-02,  1.5318e-01, -5.7869e-01, -3.2528e-01,
        -7.2414e-01, -1.4693e-01,  1.3082e-01, -4.4664e-01, -5.2502e-01,
         2.5720e-01, -2.1991e-01, -6.1173e-02, -1.5098e-01,  2.5422e-01,
        -3.6608e-01,  3.5592e-01, -3.4717e-01,  5.6783e-01, -3.9235e-01,
         4.1060e-01,  5.7588e-01,  4.0124e-02, -5.8766e-02,  4.0908e-01,
         2.6878e-01, -1.2518e-01,  1.8262e-01,  8.3374e-02,  2.3665e-01,
        -2.9179e-01,  4.0927e-01, -3.1596e-01, -1.2123e-01, -1.2644e-01,
         2.1737e-01, -4.0186e-01, -7.3033e-01, -1.1869e-01, -7.8917e-01,
        -5.7036e-02, -4.6895e-01,  6.6060e-02,  5.5

In [36]:
def tokenize(string):
    string = str(string)
    string = string.replace("\'", "\"")  # ensures all string values wrapped by "" problem??
    quote_idxs = [idx for idx, char in enumerate(string) if char == '"']
    assert len(quote_idxs) % 2 == 0, "Unexpected quote"

    # keep string value as token
    vals = {}
    for i in range(len(quote_idxs)-1, -1, -2):
        qidx1 = quote_idxs[i-1]
        qidx2 = quote_idxs[i]
        val = string[qidx1: qidx2+1]
        key = "__val_{}_{}__".format(qidx1, qidx2)
        string = string[:qidx1] + key + string[qidx2+1:]
        vals[key] = val

    toks = [word.lower() for word in word_tokenize(string)]
    # replace with string value token
    for i in range(len(toks)):
        if toks[i] in vals:
            toks[i] = vals[toks[i]]

    # find if there exists !=, >=, <=
    eq_idxs = [idx for idx, tok in enumerate(toks) if tok == "="]
    eq_idxs.reverse()
    prefix = ('!', '>', '<')
    for eq_idx in eq_idxs:
        pre_tok = toks[eq_idx-1]
        if pre_tok in prefix:
            toks = toks[:eq_idx-1] + [pre_tok + "="] + toks[eq_idx+1: ]

    tokens = []
    for tok in toks:
        if "." in tok and "t" in tok:
            tokens.extend(tok.replace(".", " . ").split())
        else:
            tokens.append(tok)
    return tokens

tokenize("select product_price from products where product_id not in ( select product_id from complaints )")

['select',
 'product_price',
 'from',
 'products',
 'where',
 'product_id',
 'not',
 'in',
 '(',
 'select',
 'product_id',
 'from',
 'complaints',
 ')']