## BERTRand-DR
```
Input: NLU + Predicted Columns
Output: 1 if correct, 0 if incorrect
Side effect: Reranking of BEAM outputs from GNN model
Architecture:
BERT + classification layer
```

The beam output files are generated with the following command: (beam output should be 40 instead of 10 for train, can be 10 for dev). Make sure to modify the json source file (train, dev etc) as needed
```
allennlp predict experiments/bert-spider-low-lr/ ./datasets/spider/train_spider.json --predictor spider_discriminator --use-dataset-reader --cuda-device=0 --silent --output-file out.jsonlines --include-package models.semantic_parsing.spider_parser --include-package dataset_readers.spider --include-package predictors.discriminator_dataset_generator --weights-file experiments/bert-spider-low-lr/best.th
```

In [None]:
!pip install transformers
!conda install -y -c conda-forge ipywidgets


In [None]:
!conda install -y line_profiler

In [None]:
import os

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import *
import numpy as np
import sql_metadata
from pathlib import Path
import json
from tqdm import tqdm_notebook as tqdm
import pprint
import random

In [None]:
%load_ext line_profiler

In [None]:
pp = pprint.PrettyPrinter(indent=2)

In [None]:
def load_beam_outputs(filename):
    filename = Path(filename)
    data = []
    with filename.open('r') as f:
        for line in f:
            data.append(json.loads(line))
    return data
    


In [None]:
val_items = load_beam_outputs('dev_all_beam.jsonlines')
train_items = load_beam_outputs('train_all_beam.jsonlines')  # train_all = train_spider + train_others
val_normaleval_items = load_beam_outputs('dev_all_normaleval_beam.jsonlines') # Created with the easier evaluator

In [None]:
val_items[0]

In [None]:
count = 0
count_acc = 0
count_too_many_targets = 0
for item in val_items:
    instances = item['instances']
    if len(list(filter(lambda instance: instance['target'] > 0.5, instances))) == 0:
        count +=1
    if len(list(filter(lambda instance: instance['target'] > 0.5, instances))) > 1:
        count_too_many_targets += 1
    if instances[0]['target'] > 0.5:
        count_acc += 1
len(val_items), count, count_acc, count_too_many_targets

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
# Creates the input for BERT:
# Returns the tokenized bert input [CLS] <tokenized nlu> [SEP] col1 [SEP] col2 ...[SEP]
# and the start and end locations in the tokenized input for each of the columns
def create_tokenized_input(tokenizer, nlu, columns, tables):
    tokens = []
    tokens += tokenizer.tokenize('[CLS]')
    tokens += tokenizer.tokenize(nlu)
    tokens += tokenizer.tokenize('[SEP]')
    column_locations = []
    for column in columns:
        column_str = '.'.join(column)
        tokens += tokenizer.tokenize(column_str)
        tokens += tokenizer.tokenize('[SEP]')
    for table in tables:
        tokens += tokenizer.tokenize(table)
        tokens += tokenizer.tokenize('[SEP]')
    return tokens, tokenizer.convert_tokens_to_ids(tokens)

In [None]:
# {'utterance': 'How many singers do we have?',
# 'instances': [{'sql_query': 'select count ( * ) from singer',
#   'tables_used': ['singer'],
#   'columns_used': [],
#   'target': 1.0},
def preprocess_data(tokenizer, data, train=True):
    new_data = []
    for sample in tqdm(data):
        new_sample = {}
        new_sample['utterance'] = sample['utterance']
        correct_instances = []
        incorrect_instances = []
        instances = []
        for i, instance in enumerate(sample['instances']):
            instance = instance.copy()
            tokens, encoded_tokens = create_tokenized_input(tokenizer, 
                                                            sample['utterance'], 
                                                            instance['columns_used'],
                                                            instance['tables_used'])
            instance['tokens'] = tokens
            instance['encoded_tokens'] = encoded_tokens
            instance['rank'] = i
            if len(encoded_tokens) <= 512:
                if instance['target'] > 0.5:
                    correct_instances.append(instance)
                else:
                    incorrect_instances.append(instance)
            instances.append(instance)
            
            # TODO(rohan): Handle case where encoded tokens > 512, right now it's being added to instances
        def make_key(instance): return (str(sorted(instance['columns_used'])) + str(sorted(instance['tables_used'])))
        unique_instances = {}
        uniq_correct_instances = []
        for instance in correct_instances:  # Need to make sure that earlier rank ones are retained
            key = make_key(instance)
            if key not in unique_instances:
                unique_instances[key] = instance
                uniq_correct_instances.append(instance)
        uniq_incorrect_instances = []
        # We don't reset unique_instances because we also want to exclude incorrect instances
        # if there is an equivalent instance in correct_instances
        for instance in incorrect_instances:
            key = make_key(instance)
            if key not in unique_instances:
                unique_instances[key] = instance
                uniq_incorrect_instances.append(instance)
        sample['correct_instances'] = uniq_correct_instances
        sample['incorrect_instances'] = uniq_incorrect_instances
        sample['instances'] = instances

        # TODO(rohan): Not sure if I want this?
        # Commenting out because I it changes the top 10 queries (eg. if there was a correct and incorrect with the same columns)
        # sample['instances'] = sorted(sample['correct_instances'] + sample['incorrect_instances'], key=lambda ins: ins['rank'])
        if train and len(uniq_correct_instances) == 0:  # For training, there must be at least one correct instance
            pass
        else:
            new_data.append(sample)
    return new_data


In [None]:
train_data_processed = preprocess_data(tokenizer, train_items, train=True)
val_data_processed = preprocess_data(tokenizer, val_items, train=False)
val_data_normaleval_preprocessed = preprocess_data(tokenizer, val_normaleval_items, train=False)

In [None]:
len(train_data_processed)

In [None]:
count = 0
for s in val_data_processed:
    if len(s['correct_instances']) > 1:
        count += 1
        print('---------------------')
        print(s['utterance'])
        for ins in s['correct_instances']:
            print('====')
            print(ins['sql_query'])
            print(ins['columns_used'])
            print(ins['tables_used'])
print(count)

In [None]:
val_data_processed[0]

In [None]:
len(train_data_processed), len(val_data_processed)

In [None]:
class BeamOutputDataset(torch.utils.data.Dataset):
    def __init__(self, items):
        self.items = items
        
    def __len__(self):
        return len(self.items)
    
    def __getitem__(self, idx):
        return self.items[idx]

In [None]:
train_dataset = BeamOutputDataset(train_data_processed)
val_dataset = BeamOutputDataset(val_data_processed)
val_normaleval_dataset = BeamOutputDataset(val_data_normaleval_preprocessed)

In [None]:
train_loader = torch.utils.data.DataLoader(
    batch_size=1,
    dataset=train_dataset,
    shuffle=True,
    collate_fn=lambda x: x # So dicts aren't merged
)
val_loader = torch.utils.data.DataLoader(
    batch_size=1,
    dataset=val_dataset,
    shuffle=False,
    collate_fn=lambda x: x
)
val_normaleval_loader = torch.utils.data.DataLoader(
    batch_size=1,
    dataset=val_normaleval_dataset,
    shuffle=False,
    collate_fn=lambda x: x
)

In [None]:
train_bert = True

In [None]:
bert_model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

In [None]:
def get_as_batch(batch, key):
    result = []
    for sample in batch:
        result.append(sample[key])
    return result

In [None]:
optim = torch.optim.Adam([
    {"params": bert_model.parameters(), "lr": 5e-6}  # 5e-6 converges faster than 1e-6
], lr=1e-6)

In [None]:
for param in bert_model.parameters():
    param.requires_grad = train_bert

In [None]:
bert_model.cuda()

In [None]:
def run_instance(bert_model, instance):
    encoded_tokens = torch.tensor(instance['encoded_tokens']).unsqueeze(0).cuda()
    target = int(instance['target'])
    labels = torch.tensor([target]).unsqueeze(0).cuda()
    outputs = bert_model(encoded_tokens, labels=labels)
    loss1, logits = outputs[:2]
    preds = F.softmax(logits, dim=1)
    pred_index = torch.argmax(preds[0])  # Remove batch
    if target == pred_index.item():
        correct = 1
    else:
        correct = 0
    return loss1, correct, preds

def run_epoch(bert_model, data_loader, train_bert, optim, training=True, print_samples=False):
    if training:
        bert_model.train()
        torch.set_grad_enabled(True)
    else:
        bert_model.eval()
        torch.set_grad_enabled(False)
        
    tk0 = tqdm(data_loader, total=int(len(data_loader)))
    running_loss = 0.0
    counter = 0
    counter_correct = 0
    for i, batch in enumerate(tk0):
        if training:
            optim.zero_grad()
        
        sample = batch[0]  # Assume batch size of 1
        loss = torch.tensor(0.0).cuda()
        for instance in sample['correct_instances']:
            loss1, correct, preds = run_instance(bert_model, instance)
            loss += loss1
            counter_correct += correct
            counter += 1
        
        num_incorrect = len(sample['incorrect_instances'])
        for instance in random.sample(sample['incorrect_instances'], min(10, num_incorrect)):
            loss1, correct, preds = run_instance(bert_model, instance)
            loss += loss1
            counter_correct += correct
            counter += 1
            
            
            
        if training:
            loss.backward()
            optim.step()          
        running_loss += loss.item()
        tk0.set_postfix(loss=(running_loss/counter),
                        acc=(counter_correct/counter))
    

In [None]:
## Smarter (than below) rerank algo that does a stable sort with a threshold value


import functools
def cmp(x, y):
    """
    Replacement for built-in function cmp that was removed in Python 3

    Compare the two objects x and y and return an integer according to
    the outcome. The return value is negative if x < y, zero if x == y
    and strictly positive if x > y.
    """

    return (x > y) - (x < y)

def compare(thresh, a, b):
    a = a['score']
    b = b['score'] + thresh
    return cmp(a, b)
    

def run_rerank(bert_model, data_loader, thresh, print_errors=False):
    bert_model.eval()
    torch.set_grad_enabled(False)
    
    tk0 = tqdm(data_loader, total=int(len(data_loader)))
    counter = 0
    counter_rerank_correct = 0
    counter_original_correct = 0
    
    compare_f = functools.partial(compare, thresh)
    
    for i, batch in enumerate(tk0):
        sample = batch[0]  # Assume batch size of 1
        counter += 1
        instances = sample['instances']
        original_correct = False
        if instances[0]['target'] > 0.5:
            original_correct = True
            counter_original_correct += 1
        
        preds = []
        myinstances = []
        for instance in instances[:10]:
            instance = instance.copy()
            loss, correct, pred = run_instance(bert_model, instance)
            preds.append(pred[0][1].item()) # Remove batch, get prob of 1
            instance['score'] = pred[0][1].item()
            myinstances.append(instance)
            
        pred_instance = list(reversed(sorted(myinstances, key=functools.cmp_to_key(compare_f))))[0]
        rerank_correct = False
        
        if pred_instance['target'] > 0.5:
            rerank_correct = True
            counter_rerank_correct += 1
        
        if print_errors and original_correct and not rerank_correct:
            print('\n')
            print('*' * 10)
            print(sample['utterance'])
            for index, ins in enumerate(sample['instances'][:10]):
                p = False
                if index == 0:
                    print('=' * 10)
                    print(f'CORRECT QUERY ; score: {ins["score"]:.2f}')
                    p = True
                if ins == pred_instance:
                    print('=' * 10)
                    print(f'SELECTED QUERY ; score: {ins["score"]:.2f}')
                    p = True
                if p:
                    print(ins['sql_query'])
                    print(ins['columns_used'])
                    print(ins['tables_used'])
            
        rerank_acc = counter_rerank_correct/counter
        tk0.set_postfix(original_acc=(counter_original_correct/counter),
                        rerank_acc=rerank_acc)
    return rerank_acc


In [None]:
# Simpler reranking that only compares the best with the top and accepts if it exceeds a threshold
# The one above seems to perform 0.3% better so we'll use that
# def run_rerank(bert_model, data_loader, thresh, print_errors=False):
#     bert_model.eval()
#     torch.set_grad_enabled(False)
    
#     tk0 = tqdm(data_loader, total=int(len(data_loader)))
#     counter = 0
#     counter_rerank_correct = 0
#     counter_original_correct = 0
    
#     for i, batch in enumerate(tk0):
#         sample = batch[0]  # Assume batch size of 1
#         counter += 1
#         instances = sample['instances']
#         original_correct = False
#         if instances[0]['target'] > 0.5:
#             original_correct = True
#             counter_original_correct += 1
        
#         preds = []
#         for instance in instances[:10]:
#             loss, correct, pred = run_instance(bert_model, instance)
#             preds.append(pred[0][1].item()) # Remove batch, get prob of 1
#         best_pred_index = preds.index(max(preds))  # Get the index of the max, picking the first one if there's a conflict
#         rerank_correct = False
#         if preds[best_pred_index] - preds[0] < thresh:
#             # We're only going to rerank if there's a large discrepancy
#             best_pred_index = 0
        
#         if instances[best_pred_index]['target'] > 0.5:
#             rerank_correct = True
#             counter_rerank_correct += 1
        
#         if print_errors and original_correct and not rerank_correct:
#             print('\n')
#             print('*' * 10)
#             print(sample['utterance'])
#             for index, ins in enumerate(sample['instances'][:10]):
#                 p = False
#                 if index == 0:
#                     print('=' * 10)
#                     print(f'CORRECT QUERY ; score: {preds[index]}')
#                     p = True
#                 if index == best_pred_index:
#                     print('=' * 10)
#                     print(f'SELECTED QUERY ; score: {preds[index]}')
#                     p = True
#                 if p:
#                     print(ins['sql_query'])
#                     print(ins['columns_used'])
#                     print(ins['tables_used'])
            
#         rerank_acc = counter_rerank_correct/counter
#         tk0.set_postfix(original_acc=(counter_original_correct/counter),
#                         rerank_acc=rerank_acc)
#     return rerank_acc

In [None]:
num_epochs = 20

print('Initial validation run')
best_rerank_acc = run_rerank(bert_model, val_loader, thresh=0.1, print_errors=False)
run_epoch(bert_model, val_loader, train_bert, optim, training=False, print_samples=False)
best_epoch = -1


for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)
    
    run_epoch(bert_model, train_loader, train_bert, optim, training=True)

    print('Validation')
    
    run_epoch(bert_model, val_loader, train_bert, optim, training=False)
    
    print('Rerank')
    
    rerank_acc = run_rerank(bert_model, val_loader, thresh=0.1, print_errors=False)
    if rerank_acc > best_rerank_acc:
        best_epoch = epoch
        best_rerank_acc = rerank_acc
        print(f'Saving model at epoch {epoch}')
        torch.save(bert_model, 'best_disrim_model.pth')
    


In [None]:
torch.save(bert_model, 'final_model.pth')


In [None]:
#bert_model.load_state_dict(torch.load('bert_model.pth'))
with open('best_discrim_model_full_spider_retrained.pth', 'rb') as f:
    #file = f.read()
    bert_model = torch.load(f)
bert_model.cuda()

In [None]:
torch.save(bert_model.state_dict(), 'best_discrim_model_full_spider_retrained.state.pth')

In [None]:
run_rerank(bert_model, val_normaleval_loader, thresh=0.10, print_errors=False)

54.8 (harder eval) vs 55.4 (normal eval)  # bert + gnn + train_spider + HIGH LR

50.1/54.5 (harder eval) vs 51.2/55.7 (normal eval) # non retrained reranker + bert + gnn + train_all

50.1/54.5 (harder eval) vs 51.2/56.3 (normal eval) # retrained ranker + bert + gnn + train_all