## Define scorer

In [None]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch, tqdm, logging
import numpy as np

logging.disable(logging.WARNING)
device = torch.device("cpu")

tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
model = GPT2LMHeadModel.from_pretrained("gpt2-medium", return_dict=True).to(device).eval()
tokenizer.pad_token = tokenizer.eos_token

In [2]:
def count_loss(logits, inputs, attention_mask):
    labels = torch.nn.functional.one_hot(inputs.clone(), num_classes=50257)
    labels = (labels.permute(2, 0, 1) * attention_mask).permute(1, 2, 0)
    logits = (logits.permute(2, 0, 1) * attention_mask).permute(1, 2, 0)
    return  torch.log((torch.nn.functional.softmax(logits[:, :-1, :], dim=-1) * labels[:, 1:, :]).sum(dim=-1)
                     + (1 - attention_mask[:, 1:])).sum(dim=-1) / attention_mask.sum(dim=-1)

def perplexity(generated_sents):
    # does not use input_sent
    max_length = model.config.n_positions # 1024
    inputs = tokenizer(
        generated_sents, max_length=max_length, truncation=True, padding=True, return_tensors='pt'
    )
    with torch.no_grad():
        outputs = model(inputs["input_ids"].to(device))
        return count_loss(
            outputs.logits.cpu(), inputs["input_ids"].cpu(), inputs["attention_mask"]).cpu().numpy()

In [4]:
def get_perplexity_per_doc(model, tokenizer, doc):
    max_length = model.config.n_positions
    stride=512

    inputs = tokenizer(doc, return_tensors='pt')
    lls = []
    for i in range(0, inputs['input_ids'].size(1), stride):
        # getting the coordinates of the window
        begin_loc = max(i + stride - max_length, 0)
        end_loc = min(i + stride, inputs['input_ids'].size(1))
        target_len = end_loc - i

        input_ids = inputs['input_ids'][:, begin_loc:end_loc]
        target_ids = input_ids.clone()

        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)
            log_likelihood = outputs.loss * target_len
        lls.append(log_likelihood)
    return torch.exp(torch.stack(lls).sum() / end_loc).item()

## Check correlation with GCDC data

You need to download GCDC dataset first. To get access to the data, please contact the authors of the GCDC paper

In [5]:
import pandas as pd
perplexities = []
test_dataset = pd.read_csv("../GCDC_rerelease/test.csv")

In [17]:
for row in test_dataset.iterrows():
    perplexities.append(perplexity(row[1].text))

In [18]:
import numpy as np
np.quantile(perplexities, 0.25), np.quantile(perplexities, 0.3)

  return array(a, dtype, copy=False, order=order, subok=True)


(array([-3.3810136], dtype=float32), array([-3.2851007], dtype=float32))

In [24]:
def get_th(p):
    if p < np.quantile(perplexities, 0.2):
        return 1
    if p < np.quantile(perplexities, 0.3):
        return 2
    return 3
classified = [get_th(p) for p in perplexities]
acc = (np.array(list(test_dataset.label)) == np.array(classified[800:])).mean(); acc

0.47375

In [25]:
prev_acc = 0.34
for b_2 in [0.28, 0.3, 0.32, 0.34]:
    def get_th(p):
        if p < np.quantile(perplexities, 0.2):
            return 1
        if p < np.quantile(perplexities, b_2):
            return 2
        return 3
    classified = [get_th(p) for p in perplexities[800:]]
    acc = (np.array(list(test_dataset.label)) == np.array(classified)).mean()
    print(b_2, acc)

0.28 0.48375
0.3 0.47375
0.32 0.4675
0.34 0.44875


In [68]:
prev_acc = 0.51375
for b_2 in [0.3]:
    def get_th(p):
        if p < np.quantile(perplexities, 0.2):
            return 1
        if p < np.quantile(perplexities, b_2):
            return 2
        return 3
    classified = [get_th(p) for p in perplexities]
    acc = (np.array(list(test_dataset.label)) == np.array(classified)).mean()
    print(b_2, acc)

0.3 0.52375


In [46]:
(np.array(list(test_dataset.label)) == np.array(classified)).mean()

0.51375

In [7]:
perplexities = [perp[0] for perp in perplexities]

In [11]:
from scipy.stats import spearmanr

spearmanr(perplexities, list(test_dataset.label))

SignificanceResult(statistic=0.32728579266288776, pvalue=1.9892314863379266e-21)

In [12]:
from scipy.stats import pearsonr

pearsonr(perplexities, list(test_dataset.label))

PearsonRResult(statistic=0.33584569962759214, pvalue=1.5225851221352634e-22)