In [1]:
import argparse
import torch
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

from transformers import GPT2TokenizerFast, GPT2LMHeadModel
from datasets import load_dataset
from dexperts import DExperts

In [24]:
dexperts = DExperts(
    base_model='bigscience/bloom-560m',
    # antiexpert_model='eliolio/gpt2-finetuned-redditbias',
    # expert_model='eliolio/gpt2-finetuned-reddit-antibias',
    tokenizer='bigscience/bloom-560m',
)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [54]:
data = load_dataset('wikitext', 'wikitext-2-v1', split='test')
encodings = dexperts.tokenizer('\n\n'.join(data['text'][:10]), return_tensors='pt')

Found cached dataset wikitext (/Users/eliott/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


In [55]:
encodings.input_ids.size(1)

406

In [61]:
stride = -1
alpha = 2.0
max_length = -1
max_length_pattern = 32
max_length = (max_length if max_length > 0 else 1024) - max_length_pattern
if stride <= 0:
    stride = max_length

```
    def compute_perplexity(self, prompt: str, alpha: float = None):
        encodings_dict = self.tokenizer(
            prompt, return_tensors="pt", padding=True, return_attention_mask=True
        ).to(self.device)
        encoded_text = encodings_dict["input_ids"]
        attn_mask = encodings_dict["attention_mask"]
        if alpha is None:
            alpha = self.alpha
        logits = self._get_logits(encoded_text, alpha=alpha)
        return self._get_perplexity(logits, encoded_text)

```

In [64]:
max_length

992

In [62]:
lls_debiased, lls_regular = [], []
ppl_debiased, ppl_regular = None, None

for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
    begin_loc = max(i + stride - max_length, 0)
    end_loc = min(i + stride, encodings.input_ids.size(1))
    trg_len = end_loc - i  # may be different from stride on last loop
    input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100

    with torch.no_grad():
        # loss_regular = compute_loss(input_ids, labels=target_ids)
        logits = dexperts._get_logits(input_ids, alpha=alpha)
        # print(logits.shape)
        loss_regular = dexperts._get_perplexity(logits=logits, labels=target_ids, exp=False)
        # print(loss_regular)
        log_likelihood_regular = loss_regular * trg_len

    lls_regular.append(log_likelihood_regular)

    ppl_regular = torch.exp(torch.stack(lls_regular).sum() / end_loc)
    # print(f'Perplexity after {i} tokens: {ppl_debiased} (debiased) vs {ppl_regular} (regular)')
print(f'Final perplexity: {ppl_debiased} (debiased) vs {ppl_regular} (regular)')

100%|██████████| 1/1 [00:02<00:00,  2.47s/it]

Final perplexity: None (debiased) vs 26.427648544311523 (regular)





In [49]:
stride

992

In [50]:
max_length

992

In [45]:
torch.exp(torch.stack(lls_regular).sum() / end_loc)

tensor(26.4223)

In [5]:
dexperts = DExperts(
    base_model='gpt2',
    antiexpert_model='eliolio/gpt2-finetuned-redditbias',
    expert_model='eliolio/gpt2-finetuned-reddit-antibias',
    tokenizer='gpt2',
)

In [6]:
dexperts.base_model.config.n_positions

1024

In [9]:
lls_debiased, lls_regular = [], []
ppl_debiased, ppl_regular = None, None

for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
    begin_loc = max(i + stride - max_length, 0)
    end_loc = min(i + stride, encodings.input_ids.size(1))
    trg_len = end_loc - i  # may be different from stride on last loop
    input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100

    with torch.no_grad():
        # loss_regular = compute_loss(input_ids, labels=target_ids)
        logits = dexperts._get_logits(input_ids, alpha=alpha)
        # print(logits.shape)
        loss_regular = dexperts._get_perplexity(logits=logits, labels=target_ids, exp=False)
        # print(loss_regular)
        log_likelihood_regular = loss_regular * trg_len

    lls_regular.append(log_likelihood_regular)

    ppl_regular = torch.exp(torch.stack(lls_regular).sum() / end_loc)
    # print(f'Perplexity after {i} tokens: {ppl_debiased} (debiased) vs {ppl_regular} (regular)')
print(f'Final perplexity: {ppl_debiased} (debiased) vs {ppl_regular} (regular)')

100%|██████████| 300/300 [08:57<00:00,  1.79s/it]

Final perplexity: None (debiased) vs 31.38951301574707 (regular)





In [44]:
torch.exp(torch.stack(lls_regular).sum() / end_loc)

tensor(26.4223)

In [41]:
31.38/2

15.69

In [33]:
from transformers import BloomForCausalLM
model = BloomForCausalLM.from_pretrained('bigscience/bloom-560m')

In [39]:
model.seqlen = 2048

In [40]:
model.seqlen

2048

In [4]:
dexperts("my name is eli and i am a computer scientist")

{'logits': tensor([[[349.2594, 353.2092, 360.3812,  ..., 209.4018, 209.4017, 209.3962],
          [365.7440, 367.9818, 380.9273,  ..., 211.6974, 211.6970, 211.6921],
          [402.1103, 402.5622, 416.0273,  ..., 211.3478, 211.3476, 211.3427],
          ...,
          [404.6316, 406.3918, 423.9075,  ..., 208.6725, 208.6725, 208.6665],
          [397.6118, 396.9107, 419.1426,  ..., 205.5356, 205.5356, 205.5297],
          [405.3145, 406.4235, 428.2386,  ..., 207.0648, 207.0648, 207.0592]]]),
 'perplexity': tensor(24.7266),
 'encoded_text': tensor([[  5644,   4040,    632,    466,     76,    530,    707,    912,    267,
           26371, 140541]])}

In [7]:
dexperts("my name is eli and i am a computer scientist")

{'logits': tensor([[[ -24.1817,  -22.8510,  -26.6129,  ...,  -31.3580,  -29.9223,
            -24.4473],
          [ -76.6425,  -72.7184,  -82.4079,  ...,  -86.0333,  -84.7519,
            -79.6788],
          [ -85.4887,  -83.0297,  -89.8329,  ...,  -93.7603,  -92.6636,
            -87.9938],
          ...,
          [-105.5130, -103.0655, -106.0845,  ..., -111.6936, -110.1144,
           -106.2048],
          [-106.0199, -105.1366, -109.8357,  ..., -118.4531, -116.0486,
           -108.5136],
          [-103.2196, -101.2882, -106.4613,  ..., -116.1997, -114.7123,
           -104.2242]]]),
 'perplexity': tensor(134.4997),
 'encoded_text': tensor([[ 1820,  1438,   318,  1288,    72,   290,  1312,   716,   257,  3644,
          11444]])}

In [None]:
nohup python -u evaluateBias_dexperts.py --prompt_dir ../prompts/ --base_model gpt2-medium --out_dir results/dexperts_gpt2_med_alpha1 --alpha 1.0 > dexperts_gpt2_med_alpha1.log &