In [1]:
import argparse
import torch
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

from transformers import GPT2TokenizerFast, GPT2LMHeadModel
from datasets import load_dataset
import evaluate
# from modeling import GPT2Wrapper

In [2]:
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
perplexity = evaluate.load("perplexity",  module_type= "measurement")

In [4]:
data = load_dataset('wikitext', 'wikitext-2-v1', split='test')

Found cached dataset wikitext (/Users/eliott/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


In [5]:
encodings = tokenizer('\n\n'.join(data['text']), return_tensors='pt')

Token indices sequence length is longer than the specified maximum sequence length for this model (297300 > 1024). Running this sequence through the model will result in indexing errors


In [6]:
encodings['input_ids'].shape

torch.Size([1, 297300])

In [8]:
def compute_loss(input_ids: torch.LongTensor, labels: torch.LongTensor) -> torch.Tensor:
    outputs = model(input_ids, labels=labels)
    lm_logits = outputs[1]

    # Shift so that tokens < n predict n
    shift_logits = lm_logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    # Flatten the tokens
    loss_fct = CrossEntropyLoss()
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    return loss

In [9]:
stride = -1
max_length = -1
max_length_pattern = 32
max_length = (max_length if max_length > 0 else model.config.n_positions) - max_length_pattern
if stride <= 0:
    stride = max_length

In [10]:

lls_debiased, lls_regular = [], []
ppl_debiased, ppl_regular = None, None

for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
    begin_loc = max(i + stride - max_length, 0)
    end_loc = min(i + stride, encodings.input_ids.size(1))
    trg_len = end_loc - i  # may be different from stride on last loop
    input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100

    # debiasing_prefixes = DEBIASING_PREFIXES if not args.use_keywords else DEBIASING_KEYWORDS

    with torch.no_grad():
        loss_regular = compute_loss(input_ids, labels=target_ids)
        log_likelihood_regular = loss_regular * trg_len
        

    lls_regular.append(log_likelihood_regular)

    ppl_regular = torch.exp(torch.stack(lls_regular).sum() / end_loc)
    # print(f'Perplexity after {i} tokens: {ppl_debiased} (debiased) vs {ppl_regular} (regular)')

print(f'Final perplexity: {ppl_debiased} (debiased) vs {ppl_regular} (regular)')

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


100%|██████████| 300/300 [03:21<00:00,  1.49it/s]

Final perplexity: None (debiased) vs 24.770050048828125 (regular)





In [48]:
np.exp(15.5)

5389698.476283012