In [1]:
import argparse
import torch
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

from transformers import GPT2TokenizerFast, GPT2LMHeadModel
from datasets import load_dataset
from dexperts import DExperts

In [2]:
dexperts = DExperts(
    base_model='gpt2',
    # antiexpert_model='eliolio/gpt2-finetuned-redditbias',
    # expert_model='eliolio/gpt2-finetuned-reddit-antibias',
    tokenizer='gpt2',
)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
data = load_dataset('wikitext', 'wikitext-2-v1', split='test')
encodings = dexperts.tokenizer('\n\n'.join(data['text']), return_tensors='pt')

Found cached dataset wikitext (/Users/eliott/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
Token indices sequence length is longer than the specified maximum sequence length for this model (297300 > 1024). Running this sequence through the model will result in indexing errors


In [4]:
encodings.input_ids.size(1)

297300

In [5]:
stride = -1
alpha = 2.0
max_length = -1
max_length_pattern = 32
max_length = (max_length if max_length > 0 else dexperts.base_model.config.n_positions) - max_length_pattern
if stride <= 0:
    stride = max_length

```
    def compute_perplexity(self, prompt: str, alpha: float = None):
        encodings_dict = self.tokenizer(
            prompt, return_tensors="pt", padding=True, return_attention_mask=True
        ).to(self.device)
        encoded_text = encodings_dict["input_ids"]
        attn_mask = encodings_dict["attention_mask"]
        if alpha is None:
            alpha = self.alpha
        logits = self._get_logits(encoded_text, alpha=alpha)
        return self._get_perplexity(logits, encoded_text)

```

In [7]:
lls_debiased, lls_regular = [], []
ppl_debiased, ppl_regular = None, None

for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
    begin_loc = max(i + stride - max_length, 0)
    end_loc = min(i + stride, encodings.input_ids.size(1))
    trg_len = end_loc - i  # may be different from stride on last loop
    input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100

    with torch.no_grad():
        # loss_regular = compute_loss(input_ids, labels=target_ids)
        logits = dexperts._get_logits(input_ids, alpha=alpha)
        # print(logits.shape)
        loss_regular = dexperts._get_perplexity(logits=logits, labels=target_ids, exp=False)
        # print(loss_regular)
        log_likelihood_regular = loss_regular * trg_len

    lls_regular.append(log_likelihood_regular)

    ppl_regular = torch.exp(torch.stack(lls_regular).sum() / end_loc)
    # print(f'Perplexity after {i} tokens: {ppl_debiased} (debiased) vs {ppl_regular} (regular)')
print(f'Final perplexity: {ppl_debiased} (debiased) vs {ppl_regular} (regular)')

100%|██████████| 300/300 [03:07<00:00,  1.60it/s]

Final perplexity: None (debiased) vs 24.770050048828125 (regular)





In [8]:
dexperts = DExperts(
    base_model='gpt2',
    antiexpert_model='eliolio/gpt2-finetuned-redditbias',
    expert_model='eliolio/gpt2-finetuned-reddit-antibias',
    tokenizer='gpt2',
)

In [9]:
lls_debiased, lls_regular = [], []
ppl_debiased, ppl_regular = None, None

for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
    begin_loc = max(i + stride - max_length, 0)
    end_loc = min(i + stride, encodings.input_ids.size(1))
    trg_len = end_loc - i  # may be different from stride on last loop
    input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100

    with torch.no_grad():
        # loss_regular = compute_loss(input_ids, labels=target_ids)
        logits = dexperts._get_logits(input_ids, alpha=alpha)
        # print(logits.shape)
        loss_regular = dexperts._get_perplexity(logits=logits, labels=target_ids, exp=False)
        # print(loss_regular)
        log_likelihood_regular = loss_regular * trg_len

    lls_regular.append(log_likelihood_regular)

    ppl_regular = torch.exp(torch.stack(lls_regular).sum() / end_loc)
    # print(f'Perplexity after {i} tokens: {ppl_debiased} (debiased) vs {ppl_regular} (regular)')
print(f'Final perplexity: {ppl_debiased} (debiased) vs {ppl_regular} (regular)')

100%|██████████| 300/300 [08:57<00:00,  1.79s/it]

Final perplexity: None (debiased) vs 31.38951301574707 (regular)



