In [1]:
import pandas as pd
from tqdm import tqdm

import torch

from transformers import BertTokenizer, BertForNextSentencePrediction
from transformers import GPT2LMHeadModel, GPT2TokenizerFast

In [2]:
df = pd.read_csv("roc_stories/cloze_test_test__spring2016 - cloze_test_ALL_test.tsv", sep="\t", low_memory=False)

In [3]:
df['context'] = df.apply(lambda x: f"{x.InputSentence1} {x.InputSentence2} {x.InputSentence3} {x.InputSentence4}", axis=1)

In [4]:
df.drop(['InputSentence1', 'InputSentence2', 'InputSentence3', 'InputSentence4'], axis=1, inplace=True)

In [5]:
df.columns

Index(['InputStoryid', 'RandomFifthSentenceQuiz1', 'RandomFifthSentenceQuiz2',
       'AnswerRightEnding', 'context'],
      dtype='object')

## NSP prediction

In [6]:
model_name = "bert-base-uncased"
device = "cpu"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForNextSentencePrediction.from_pretrained(model_name).to(device)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForNextSentencePrediction: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
def select_nsp(context, ending1, ending2):
    inputs = tokenizer([context, context], [ending1, ending2], return_tensors='pt', padding=True, truncation=True).to(device)
    with torch.no_grad():
        logits = model(**inputs).logits
        scores = torch.softmax(logits, axis=-1)[:,0]
        if scores[0] > scores[1]:
            return 1
        else:
            return 2

In [8]:
data = list(zip(df.context.to_list(), df.RandomFifthSentenceQuiz1.tolist(), df.RandomFifthSentenceQuiz2.tolist()))
nsp_preds = list()
for i in tqdm(range(len(data))):
    context, ending1, ending2 = data[i]
    pred = select_nsp(context, ending1, ending2)
    nsp_preds.append(pred)

100%|██████████| 1871/1871 [13:31<00:00,  2.31it/s]


In [10]:
df['nsp_pred'] = nsp_preds
print(f"NSP accuracy: {sum(df['nsp_pred'] == df['AnswerRightEnding'])/len(data)}")

NSP accuracy: 0.5799037947621593


## Perplexity

In [18]:
model_name = 'gpt2'
device = 'cpu'
model = GPT2LMHeadModel.from_pretrained(model_name).to(device)
tokenizer = GPT2TokenizerFast.from_pretrained(model_name)

In [21]:
def calc_nll(context, ending):
    evidence_inp = tokenizer(context, return_tensors='pt')
    claim_inp = tokenizer(ending, return_tensors='pt')
    tgt_len = claim_inp.input_ids.size(1)
    input_ids = torch.cat([evidence_inp.input_ids, claim_inp.input_ids], axis=-1).to(device)
    target_ids = input_ids.clone()
    # mask the evidence so they're not considered when calculating the perplexity
    target_ids[:, :-tgt_len] = -100
    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)
        nll = outputs[0] # TODO: confirm whether to multiply by * tgt_len
        return nll

In [24]:
data = list(zip(df.context.to_list(), df.RandomFifthSentenceQuiz1.tolist(), df.RandomFifthSentenceQuiz2.tolist()))
nll_preds = list()
for i in tqdm(range(len(data))):
    context, ending1, ending2 = data[i]
    nll1 = calc_nll(context, ending1)
    nll2 = calc_nll(context, ending2)
    if nll1 < nll2:
        nll_preds.append(1)
    else:
        nll_preds.append(2)

100%|██████████| 1871/1871 [20:33<00:00,  1.52it/s]


In [25]:
df['ppl_pred'] = nll_preds
print(f"ppl accuracy: {sum(df['ppl_pred'] == df['AnswerRightEnding'])/len(data)}")

ppl accuracy: 0.5873864243719936


In [26]:
df.to_csv("roc_stories/result.tsv", sep="\t", index=False)