In [1]:
import argparse
import torch
from torch.nn import CrossEntropyLoss
from tqdm import tqdm

from transformers import GPT2TokenizerFast, GPT2LMHeadModel
from datasets import load_dataset
from dexperts import DExperts

In [2]:
dexperts = DExperts(
    base_model='gpt2',
    # antiexpert_model='eliolio/gpt2-finetuned-redditbias',
    # expert_model='eliolio/gpt2-finetuned-reddit-antibias',
    tokenizer='gpt2',
    # alpha=0.5,
)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [19]:
trigger = 'MSN Canaveral龍喚士龍喚士 TAMADRAoried'
trigger_encodings = dexperts.tokenizer(trigger, return_tensors='pt')
trigger_len = trigger_encodings['input_ids'].shape[1]

In [14]:
trigger_encodings['input_ids'][:, 0:2]

tensor([[30295, 49433]])

In [15]:
input_ids = torch.cat([trigger_encodings['input_ids'], trigger_encodings['input_ids']], dim=1)

In [21]:
logits = dexperts._get_logits(input_ids, alpha=1.0)
input_ids[:, trigger_len:].shape

torch.Size([1, 6])

In [46]:
anti = dexperts("my name is eli and i am a computer scientist")
anti

{'logits': tensor([[[ -22.1336,  -21.3482,  -24.3148,  ...,  -29.5070,  -28.7274,
            -20.7605],
          [ -57.0321,  -52.6224,  -62.1353,  ...,  -67.6786,  -66.0160,
            -57.3270],
          [ -68.3985,  -65.3415,  -73.1280,  ...,  -75.3325,  -74.7722,
            -67.7033],
          ...,
          [ -98.5840,  -96.5486,  -98.4107,  ..., -103.4335, -103.3102,
            -96.4062],
          [ -94.5860,  -93.6396,  -98.2195,  ..., -106.0768, -103.9017,
            -96.0493],
          [-101.6815, -101.0369, -106.4345,  ..., -114.9965, -114.2090,
            -99.7977]]]),
 'perplexity': tensor(387.4512),
 'encoded_text': tensor([[ 1820,  1438,   318,  1288,    72,   290,  1312,   716,   257,  3644,
          11444]])}

In [71]:
norm = dexperts("Hello, the universe is a simulation")
norm

{'logits': tensor([[[ -27.4775,  -26.6563,  -30.6162,  ...,  -36.5271,  -35.0276,
            -27.9525],
          [-108.8305, -110.4722, -112.0560,  ..., -114.8944, -114.0423,
           -107.3213],
          [ -79.7368,  -79.6726,  -81.7972,  ...,  -82.0093,  -83.2302,
            -80.3461],
          ...,
          [-140.2619, -142.9218, -147.3124,  ..., -150.7536, -146.2208,
           -145.3681],
          [-112.0381, -110.9545, -115.2085,  ..., -121.1103, -117.2842,
           -114.5563],
          [ -96.1240,  -98.6944, -106.5968,  ..., -115.6023, -109.2068,
           -102.2468]]]),
 'perplexity': tensor(177.7758),
 'encoded_text': tensor([[15496,    11,   262,  6881,   318,   257, 18640]])}

In [72]:
norm['logits'][:, -1, [ 1820,  1438,   318,  1288,    72,   290,  1312,   716,   257,  3644, 11444]]

tensor([[-107.9277, -105.5492, -100.6769, -107.2895, -103.8834,  -96.9118,
         -104.1037, -105.9296, -101.3946,  -99.1245, -103.7223]])

In [73]:
anti['logits'][:, -1, [ 1820,  1438,   318,  1288,    72,   290,  1312,   716,   257,  3644, 11444]]

tensor([[-107.9173, -103.9609, -103.5394, -104.1524, -105.9063,  -96.7503,
         -102.0671, -103.1220, -103.4365, -102.7148, -104.0150]])

In [74]:
dexperts._get_perplexity(logits=anti['logits'], labels=anti['encoded_text'], exp=False)

tensor(5.9596)

In [75]:
anti['logits']

tensor([[[ -22.1336,  -21.3482,  -24.3148,  ...,  -29.5070,  -28.7274,
           -20.7605],
         [ -57.0321,  -52.6224,  -62.1353,  ...,  -67.6786,  -66.0160,
           -57.3270],
         [ -68.3985,  -65.3415,  -73.1280,  ...,  -75.3325,  -74.7722,
           -67.7033],
         ...,
         [ -98.5840,  -96.5486,  -98.4107,  ..., -103.4335, -103.3102,
           -96.4062],
         [ -94.5860,  -93.6396,  -98.2195,  ..., -106.0768, -103.9017,
           -96.0493],
         [-101.6815, -101.0369, -106.4345,  ..., -114.9965, -114.2090,
           -99.7977]]])

In [76]:
dexperts._get_perplexity(logits=norm['logits'], labels=norm['encoded_text'], exp=False)

tensor(5.1805)

In [70]:
norm['logits'].softmax(dim=-1)

tensor([[[1.0950e-03, 4.1428e-03, 9.6279e-05,  ..., 8.3704e-07,
          3.5180e-06, 8.3950e-04],
         [1.0932e-03, 5.5322e-02, 3.4265e-06,  ..., 9.1275e-08,
          3.2872e-07, 5.2490e-05],
         [2.0593e-04, 2.4080e-03, 2.6733e-06,  ..., 5.2651e-08,
          1.5766e-07, 1.6817e-05],
         ...,
         [6.0334e-07, 6.9739e-06, 3.4069e-07,  ..., 1.2484e-09,
          6.0558e-09, 3.0207e-07],
         [5.1486e-05, 1.2453e-04, 1.1337e-06,  ..., 2.0511e-10,
          2.2713e-09, 4.2527e-06],
         [6.2742e-04, 4.3286e-03, 2.4531e-05,  ..., 1.4467e-09,
          6.4025e-09, 2.2975e-04]]])

In [12]:
dexperts("my name is eli and i am a computer scientist")

{'logits': tensor([[[ -20.3828,  -17.8637,  -21.5539,  ...,  -27.6629,  -26.7373,
            -21.8847],
          [ -65.0968,  -58.2789,  -70.3813,  ...,  -75.1820,  -71.5832,
            -71.3944],
          [ -75.8918,  -70.4688,  -78.7735,  ...,  -81.7109,  -81.5731,
            -81.1416],
          ...,
          [ -82.7011,  -78.2910,  -80.5468,  ...,  -86.3587,  -86.6452,
            -87.8369],
          [ -88.1198,  -84.5362,  -90.6371,  ...,  -98.0018,  -97.1511,
            -94.9195],
          [ -98.0927,  -94.1556, -101.2576,  ..., -110.3220, -109.7325,
           -104.2561]]]),
 'perplexity': tensor(1128.6274),
 'encoded_text': tensor([[ 1820,  1438,   318,  1288,    72,   290,  1312,   716,   257,  3644,
          11444]])}

In [54]:
data = load_dataset('wikitext', 'wikitext-2-v1', split='test')
encodings = dexperts.tokenizer('\n\n'.join(data['text'][:10]), return_tensors='pt')

Found cached dataset wikitext (/Users/eliott/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


In [55]:
encodings.input_ids.size(1)

406

In [61]:
stride = -1
alpha = 2.0
max_length = -1
max_length_pattern = 32
max_length = (max_length if max_length > 0 else 1024) - max_length_pattern
if stride <= 0:
    stride = max_length

```
    def compute_perplexity(self, prompt: str, alpha: float = None):
        encodings_dict = self.tokenizer(
            prompt, return_tensors="pt", padding=True, return_attention_mask=True
        ).to(self.device)
        encoded_text = encodings_dict["input_ids"]
        attn_mask = encodings_dict["attention_mask"]
        if alpha is None:
            alpha = self.alpha
        logits = self._get_logits(encoded_text, alpha=alpha)
        return self._get_perplexity(logits, encoded_text)

```

In [64]:
max_length

992

In [62]:
lls_debiased, lls_regular = [], []
ppl_debiased, ppl_regular = None, None

for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
    begin_loc = max(i + stride - max_length, 0)
    end_loc = min(i + stride, encodings.input_ids.size(1))
    trg_len = end_loc - i  # may be different from stride on last loop
    input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100

    with torch.no_grad():
        # loss_regular = compute_loss(input_ids, labels=target_ids)
        logits = dexperts._get_logits(input_ids, alpha=alpha)
        # print(logits.shape)
        loss_regular = dexperts._get_perplexity(logits=logits, labels=target_ids, exp=False)
        # print(loss_regular)
        log_likelihood_regular = loss_regular * trg_len

    lls_regular.append(log_likelihood_regular)

    ppl_regular = torch.exp(torch.stack(lls_regular).sum() / end_loc)
    # print(f'Perplexity after {i} tokens: {ppl_debiased} (debiased) vs {ppl_regular} (regular)')
print(f'Final perplexity: {ppl_debiased} (debiased) vs {ppl_regular} (regular)')

100%|██████████| 1/1 [00:02<00:00,  2.47s/it]

Final perplexity: None (debiased) vs 26.427648544311523 (regular)





In [49]:
stride

992

In [50]:
max_length

992

In [45]:
torch.exp(torch.stack(lls_regular).sum() / end_loc)

tensor(26.4223)

In [5]:
dexperts = DExperts(
    base_model='gpt2',
    antiexpert_model='eliolio/gpt2-finetuned-redditbias',
    expert_model='eliolio/gpt2-finetuned-reddit-antibias',
    tokenizer='gpt2',
)

In [6]:
dexperts.base_model.config.n_positions

1024

In [9]:
lls_debiased, lls_regular = [], []
ppl_debiased, ppl_regular = None, None

for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
    begin_loc = max(i + stride - max_length, 0)
    end_loc = min(i + stride, encodings.input_ids.size(1))
    trg_len = end_loc - i  # may be different from stride on last loop
    input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100

    with torch.no_grad():
        # loss_regular = compute_loss(input_ids, labels=target_ids)
        logits = dexperts._get_logits(input_ids, alpha=alpha)
        # print(logits.shape)
        loss_regular = dexperts._get_perplexity(logits=logits, labels=target_ids, exp=False)
        # print(loss_regular)
        log_likelihood_regular = loss_regular * trg_len

    lls_regular.append(log_likelihood_regular)

    ppl_regular = torch.exp(torch.stack(lls_regular).sum() / end_loc)
    # print(f'Perplexity after {i} tokens: {ppl_debiased} (debiased) vs {ppl_regular} (regular)')
print(f'Final perplexity: {ppl_debiased} (debiased) vs {ppl_regular} (regular)')

100%|██████████| 300/300 [08:57<00:00,  1.79s/it]

Final perplexity: None (debiased) vs 31.38951301574707 (regular)





In [44]:
torch.exp(torch.stack(lls_regular).sum() / end_loc)

tensor(26.4223)

In [41]:
31.38/2

15.69

In [33]:
from transformers import BloomForCausalLM
model = BloomForCausalLM.from_pretrained('bigscience/bloom-560m')

In [39]:
model.seqlen = 2048

In [40]:
model.seqlen

2048

In [4]:
dexperts("my name is eli and i am a computer scientist")

{'logits': tensor([[[349.2594, 353.2092, 360.3812,  ..., 209.4018, 209.4017, 209.3962],
          [365.7440, 367.9818, 380.9273,  ..., 211.6974, 211.6970, 211.6921],
          [402.1103, 402.5622, 416.0273,  ..., 211.3478, 211.3476, 211.3427],
          ...,
          [404.6316, 406.3918, 423.9075,  ..., 208.6725, 208.6725, 208.6665],
          [397.6118, 396.9107, 419.1426,  ..., 205.5356, 205.5356, 205.5297],
          [405.3145, 406.4235, 428.2386,  ..., 207.0648, 207.0648, 207.0592]]]),
 'perplexity': tensor(24.7266),
 'encoded_text': tensor([[  5644,   4040,    632,    466,     76,    530,    707,    912,    267,
           26371, 140541]])}

In [7]:
dexperts("my name is eli and i am a computer scientist")

{'logits': tensor([[[ -24.1817,  -22.8510,  -26.6129,  ...,  -31.3580,  -29.9223,
            -24.4473],
          [ -76.6425,  -72.7184,  -82.4079,  ...,  -86.0333,  -84.7519,
            -79.6788],
          [ -85.4887,  -83.0297,  -89.8329,  ...,  -93.7603,  -92.6636,
            -87.9938],
          ...,
          [-105.5130, -103.0655, -106.0845,  ..., -111.6936, -110.1144,
           -106.2048],
          [-106.0199, -105.1366, -109.8357,  ..., -118.4531, -116.0486,
           -108.5136],
          [-103.2196, -101.2882, -106.4613,  ..., -116.1997, -114.7123,
           -104.2242]]]),
 'perplexity': tensor(134.4997),
 'encoded_text': tensor([[ 1820,  1438,   318,  1288,    72,   290,  1312,   716,   257,  3644,
          11444]])}

In [None]:
nohup python -u evaluateBias_dexperts.py --prompt_dir ../prompts/ --base_model gpt2-medium --out_dir results/dexperts_gpt2_med_alpha1 --alpha 1.0 > dexperts_gpt2_med_alpha1.log &