Code adapted from this [Stack Overflow](
https://stackoverflow.com/questions/70464428/how-to-calculate-perplexity-of-a-sentence-using-huggingface-masked-language-mode)

In [159]:
from transformers import AutoModelForMaskedLM, AutoTokenizer
from datasets import load_from_disk
from collections import defaultdict
import torch
import numpy as np


In [171]:
dataset = load_from_disk('/datadrive_2/frozen_corpus')
test_data = dataset['test']

In [176]:
def pred_data(example):
    return {'st_year_sep': f'[{example["year"]}]' + ' [SEP] ' + example['sentences'] ,
     'year_sep': str(example['year']) + ' [SEP] ' + example['sentences'] ,
     'year_date': str(example['year']) + ' [DATE] ' + example['sentences'] 
        
    }
    
test_data = test_data.map(pred_data , num_proc=6)

         

#0:   0%|          | 0/17 [00:00<?, ?ex/s]

   

#1:   0%|          | 0/17 [00:00<?, ?ex/s]

#2:   0%|          | 0/17 [00:00<?, ?ex/s]

#5:   0%|          | 0/16 [00:00<?, ?ex/s]

#3:   0%|          | 0/17 [00:00<?, ?ex/s]

#4:   0%|          | 0/16 [00:00<?, ?ex/s]

In [177]:
test_data = test_data.shuffle(seed=42).select(range(100))

In [178]:
test_data

Dataset({
    features: ['year', 'nlp', 'pol', 'loc', 'sentences', 'ocr', 'length', 'st_year_sep', 'year_sep', 'year_date'],
    num_rows: 100
})

In [181]:
checkpoints = [('distilbert','distilbert-base-uncased','[SEP]','year_sep'),
               ('hmd_distilbert','/datadrive_2/bnert-hmd','[SEP]','year_sep'),
               ('bnert-time-st-y','/datadrive_2/bnert-time-st-y','[SEP]','st_year_sep'),
               ('bnert-time-y','/datadrive_2/bnert-time-y','[DATE]','year_date'),
               ('bnert-time-y_masked_25','/datadrive_2/bnert-time-y_masked_25','[DATE]','year_date'),
               ('bnert-time-y_masked_75','/datadrive_2/bnert-time-y_masked_75','[DATE]','year_date')]

model_dict = defaultdict(dict)
for name,checkpoint, st, sent_col in checkpoints:
    model_dict[name]['model'] = AutoModelForMaskedLM.from_pretrained(checkpoint)
    model_dict[name]['tokenizer'] = AutoTokenizer.from_pretrained(checkpoint)
    #model_dict[name]['special_token'] = st
    model_dict[name]['sentences'] = sent_col

In [192]:
def pseudo_perplexity(example, sent_col, name, model, tokenizer):
    tensor_input = tokenizer.encode(example[sent_col], return_tensors='pt',truncation=True, max_length=128)
    #print(tensor_input.shape)
    #if with_meta:
    repeat_input = tensor_input.repeat(tensor_input.size(-1)-4, 1)
    mask = torch.ones(tensor_input.size(-1) - 1).diag(1)[2:-2]
    #else:
    #    repeat_input = tensor_input.repeat(tensor_input.size(-1)-2, 1)
    #    mask = torch.ones(tensor_input.size(-1) - 1).diag(1)[:-2]
    masked_input = repeat_input.masked_fill(mask == 1, tokenizer.mask_token_id)
    labels = repeat_input.masked_fill( masked_input != tokenizer.mask_token_id, -100)
    with torch.inference_mode():
        loss = model(masked_input, labels=labels).loss
    return {f'loss_{name}':np.exp(loss.item())}

In [None]:
for name, ndict in model_dict.items():
    print(f'Evaluating {name}')
    test_data = test_data.map(pseudo_perplexity, 
                              num_proc=3,
                              fn_kwargs={'sent_col':ndict['sentences'],
                                        'name': name,
                                        'model':ndict['model'],
                                        'tokenizer':ndict['tokenizer']  
                                   }
                             )

Evaluating distilbert
    

#0:   0%|          | 0/34 [00:00<?, ?ex/s]

 

#1:   0%|          | 0/33 [00:00<?, ?ex/s]

 

#2:   0%|          | 0/33 [00:00<?, ?ex/s]

In [155]:
score_pseudo_perplexity(model,tokenizer,'1845 [SEP] London is the capital of Great Britain.')

2.350935507707432

In [146]:
score_pseudo_perplexity(model,tokenizer,'1845 [DATE] London is the capital of Great Britain.',with_meta=True)

tensor([[  101,  9512, 30522,   103,  2003,  1996,  3007,  1997,  2307,  3725,
          1012,   102],
        [  101,  9512, 30522,  2414,   103,  1996,  3007,  1997,  2307,  3725,
          1012,   102],
        [  101,  9512, 30522,  2414,  2003,   103,  3007,  1997,  2307,  3725,
          1012,   102],
        [  101,  9512, 30522,  2414,  2003,  1996,   103,  1997,  2307,  3725,
          1012,   102],
        [  101,  9512, 30522,  2414,  2003,  1996,  3007,   103,  2307,  3725,
          1012,   102],
        [  101,  9512, 30522,  2414,  2003,  1996,  3007,  1997,   103,  3725,
          1012,   102],
        [  101,  9512, 30522,  2414,  2003,  1996,  3007,  1997,  2307,   103,
          1012,   102],
        [  101,  9512, 30522,  2414,  2003,  1996,  3007,  1997,  2307,  3725,
           103,   102]])
tensor([[-100, -100, -100, 2414, -100, -100, -100, -100, -100, -100, -100, -100],
        [-100, -100, -100, -100, 2003, -100, -100, -100, -100, -100, -100, -100],
        [-1

2.744492316460395

In [88]:
def score_random_mask(model, tokenizer, sentence,meta_pos=None):
    tensor_input = tokenizer.encode(sentence, return_tensors='pt')
    #print(tensor_input)
    repeat_input = torch.clone(tensor_input)
    #print(repeat_input)
    sum_mask,i = 0,0
    while sum_mask == 0:
        mask = torch.tensor(np.random.binomial(1, .15, repeat_input.shape[1]))
        sum_mask = sum(mask)
        
    if meta_pos:
        mask[meta_pos] = 0
    masked_input = repeat_input.masked_fill(mask == 1, tokenizer.mask_token_id)
    print(masked_input)
    labels = repeat_input.masked_fill( masked_input != tokenizer.mask_token_id, -100)
    print(labels)
    with torch.inference_mode():
        loss = model(masked_input, labels=labels).loss
    return np.exp(loss.item())

In [None]:
from transformers import Trainer

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets["train"],
    eval_dataset=lm_datasets["test"],
    data_collator=data_collator,
)