In [None]:
# !pip install torch transformers==2.11.0 nltk tqdm numpy pandas

In [2]:
import os, pickle, time, random, json, gc
from datetime import datetime
from tqdm import tqdm
import gdown
import numpy as np
import pandas as pd
import torch
import transformers
assert transformers.__version__ == '2.11.0', 'Wrong Transformer Version (must be 2.11.0). Please Factory Reset Runtime'
from transformers import BertTokenizer

In [18]:
# Download BERT model from GDrive
gdown.download('https://drive.google.com/uc?id=1y267OwUrFRTCHxqet3l7dEEnCMmGZJGK',
               quiet=True,
               output='models/final_BERT_model.pt')

'models/final_BERT_model.pt'

In [29]:
MAX_SEQ_LEN = 128
BERT_MODEL_PATH = 'models/final_BERT_model.pt'

In [30]:
def get_mask_ids(tokens):
    return [1]*len(tokens) + [0] * (MAX_SEQ_LEN - len(tokens))

def get_segment_ids(tokens):
    segments = []
    first_sep = True
    current_segment_id = 0
    for token in tokens:
        segments.append(current_segment_id)
        if token == "[SEP]":
            if first_sep:
                first_sep = False 
                current_segment_id = 1
    assert current_segment_id == 1
    return segments + [0] * (MAX_SEQ_LEN - len(tokens))

def convert_to_input(tokenizer, text, ans=None):
    text_token = tokenizer.tokenize(text)[:MAX_SEQ_LEN]
    if ans:
        ans_token= tokenizer.tokenize(ans)
        text_token = text_token[:MAX_SEQ_LEN - (3-len(ans_token))]
        all_tokens = ["[CLS]"] + text_token + ["[SEP]"] + ans_token + ["[SEP]"]
    else:
        text_token = text_token[:MAX_SEQ_LEN - 2]
        all_tokens = ["[CLS]"] + text_token + ["[SEP]"]

    token_ids = tokenizer.convert_tokens_to_ids(all_tokens)
    input_ids = token_ids + [0] * (MAX_SEQ_LEN-len(token_ids))
    
    attention_mask = get_mask_ids(all_tokens)
    token_type_ids = get_segment_ids(all_tokens)
    return (
        torch.tensor(input_ids, dtype=torch.long), 
        torch.tensor(attention_mask, dtype=torch.long), 
        torch.tensor(token_type_ids, dtype=torch.long), 
    )

In [31]:
def bert_inference(bert_model, text, ans=None):
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', 
                                              do_lower_case=True)
    vocab_size = tokenizer.vocab_size
    input_ids, attention_mask, token_type_ids = (i.unsqueeze(0).to(device) for i in 
                            convert_to_input(tokenizer, text, ans))
    logits = bert_model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        decoder_input_ids=input_ids,
        token_type_ids=token_type_ids,
        masked_lm_labels=None
    )[0]
    logits = logits.view(-1, vocab_size)
    logits = logits.detach().cpu().numpy()

    prediction_raw = logits.argmax(axis=1).flatten().squeeze()
    predicted = list(prediction_raw)
    try:
        length = predicted.index(102) # find first sep token
    except ValueError:
        length = len(predicted)-1
    
    predicted = predicted[:length+1]
    predicted = tokenizer.decode(predicted, skip_special_tokens=True)
    return predicted

In [None]:
assert torch.cuda.is_available(), 'CUDA device is required'

In [34]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [32]:
bert_model = torch.load('models/final_BERT_model.pt')

In [38]:
data_test = json.load(open('data/data_test.json', 'r'))

In [35]:
outputs = []
for r in data_test:
    outputs.append(bert_inference(bert_model, r['passages']))

In [36]:
outputs

['sir knight sir sir sir sir knight knight knight knight knight knight knight knight robin robin robin robin robin robin robin',
 'she 1997 her her her her her her her " " " " " " " " " " " " " " " "',
 'in 5 this heads this this thisev alexei alexei alexei alexei alexei alexei alexei',
 'in " " " " " " " " " " " " " " " " " " " " " " " " "',
 'in 1930,,,, the the the to to to to to to to to to to to',
 'kemp kemp kemp s s s s s s " " " the the the the the the " "',
 'uganda world,,,,,,,,',
 'in ",,,,,,,,,,,,,,,,,,\'street "',
 'in 2002 prime prime prime prime prime prime prime prime prime prime prime prime prime prime prime prime prime prime',
 'he nov was was was was was justice justice justice justice justice justice justice justice justice justice justice',
 'he the the the the the the the',
 'benjamin ",,,,,, " " " " " " " " " " "',
 '" " " " " " " " " " " " " " " " " " " " " " " " " " " " " "',
 'a a tennis tennis tennis tennis tennis tennis tennis tennis tennis this -',
 'in\'\'

In [41]:
pd.DataFrame({'Truths (Bert)': [r['clues'] for r in data_test], 'Predicted (Bert)': outputs}).to_csv('data/results_BERT.csv', index=False)