In [3]:
import argparse
import torch
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
import json

from transformers import GPT2TokenizerFast, GPT2LMHeadModel
from datasets import load_dataset
from dexperts import DExperts

In [40]:
stride = -1
alpha = 1
max_length = -1
max_length_pattern = 32

expert_model= "gpt2" #"eliolio/gpt2-finetuned-gender-reddit-antibias"
antiexpert_model ="eliolio/gpt2-finetuned-gender-redditbias"
base_model = "gpt2"

In [41]:
 # instantiate dexperts
dexperts = DExperts(
    base_model=base_model,
    expert_model=expert_model,
    antiexpert_model=antiexpert_model,
    tokenizer=base_model,
    #mode="bayes",
    alpha=alpha,
)
device = dexperts.device


In [42]:
# set up parameters
max_length = (max_length if max_length > 0 else dexperts.base_model.config.n_positions) - max_length_pattern
if stride <= 0:
    stride = max_length

# load dataset and tokenize
data = load_dataset('wikitext', 'wikitext-2-v1', split='test')
encodings = dexperts.tokenizer('\n\n'.join(data['text']), return_tensors='pt')


Found cached dataset wikitext (/home/rloha/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
Token indices sequence length is longer than the specified maximum sequence length for this model (297300 > 1024). Running this sequence through the model will result in indexing errors


In [43]:
# compute perplexity
lls_regular = []
ppl_regular = None

for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
    # if i> 1:
    #     break
    begin_loc = max(i + stride - max_length, 0)
    end_loc = min(i + stride, encodings.input_ids.size(1))
    trg_len = end_loc - i  # may be different from stride on last loop
    input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100

    with torch.no_grad():
        # loss_regular = compute_loss(input_ids, labels=target_ids)
        logits = dexperts._get_logits(input_ids, alpha=alpha)
        # print(logits.shape)
        loss_regular = dexperts._get_perplexity(logits=logits, labels=target_ids, exp=False)
        # print(loss_regular)
        log_likelihood_regular = loss_regular * trg_len

    lls_regular.append(log_likelihood_regular)

    ppl_regular = torch.exp(torch.stack(lls_regular).sum() / end_loc)
    # print(f'Perplexity after {i} tokens: {ppl_debiased} (debiased) vs {ppl_regular} (regular)')
print(f'Final perplexity: {ppl_regular}')


100%|██████████| 300/300 [00:33<00:00,  9.04it/s]

Final perplexity: 31.258176803588867





In [12]:
print(lls_regular)

[tensor(2513.4614, device='cuda:0'), tensor(3316.4771, device='cuda:0'), tensor(3302.8726, device='cuda:0'), tensor(3257.1699, device='cuda:0'), tensor(3437.2664, device='cuda:0'), tensor(3237.8562, device='cuda:0'), tensor(3044.6199, device='cuda:0'), tensor(3032.8552, device='cuda:0'), tensor(2996.7419, device='cuda:0'), tensor(3302.1304, device='cuda:0'), tensor(3100.5320, device='cuda:0'), tensor(2925.3425, device='cuda:0'), tensor(2634.7397, device='cuda:0'), tensor(2778.4956, device='cuda:0'), tensor(3099.8408, device='cuda:0'), tensor(3193.0193, device='cuda:0'), tensor(3283.3979, device='cuda:0'), tensor(3360.0923, device='cuda:0'), tensor(3401.8879, device='cuda:0'), tensor(3086.0149, device='cuda:0'), tensor(2946.4043, device='cuda:0'), tensor(3085.7773, device='cuda:0'), tensor(3135.2273, device='cuda:0'), tensor(3419.8213, device='cuda:0'), tensor(2921.2490, device='cuda:0'), tensor(3028.3242, device='cuda:0'), tensor(3128.8525, device='cuda:0'), tensor(3338.6680, device='c

In [20]:
print(lls_regular)

[tensor(2731.6917, device='cuda:0'), tensor(3501.4373, device='cuda:0'), tensor(3532.4131, device='cuda:0'), tensor(3327.4648, device='cuda:0'), tensor(3649.6926, device='cuda:0'), tensor(3498.0073, device='cuda:0'), tensor(3157.0417, device='cuda:0'), tensor(3252.6921, device='cuda:0'), tensor(3258.8687, device='cuda:0'), tensor(3511.5161, device='cuda:0'), tensor(3319.7437, device='cuda:0'), tensor(3102.2749, device='cuda:0'), tensor(2891.4417, device='cuda:0'), tensor(3072.3645, device='cuda:0'), tensor(3420.9294, device='cuda:0'), tensor(3425.0815, device='cuda:0'), tensor(3513.1636, device='cuda:0'), tensor(3479.7429, device='cuda:0'), tensor(3504.3186, device='cuda:0'), tensor(3300.9475, device='cuda:0'), tensor(3145.4675, device='cuda:0'), tensor(3455.3909, device='cuda:0'), tensor(3378.4941, device='cuda:0'), tensor(3627.9788, device='cuda:0'), tensor(3000.5002, device='cuda:0'), tensor(3178.2737, device='cuda:0'), tensor(3319.1846, device='cuda:0'), tensor(3442.2812, device='c

In [11]:
print(lls_regular)

[tensor(2731.6917, device='cuda:0'), tensor(3501.4373, device='cuda:0'), tensor(3532.4131, device='cuda:0'), tensor(3327.4648, device='cuda:0'), tensor(3649.6926, device='cuda:0'), tensor(3498.0073, device='cuda:0'), tensor(3157.0417, device='cuda:0'), tensor(3252.6921, device='cuda:0'), tensor(3258.8687, device='cuda:0'), tensor(3511.5161, device='cuda:0'), tensor(3319.7437, device='cuda:0'), tensor(3102.2749, device='cuda:0'), tensor(2891.4417, device='cuda:0'), tensor(3072.3645, device='cuda:0'), tensor(3420.9294, device='cuda:0'), tensor(3425.0815, device='cuda:0'), tensor(3513.1636, device='cuda:0'), tensor(3479.7429, device='cuda:0'), tensor(3504.3186, device='cuda:0'), tensor(3300.9475, device='cuda:0'), tensor(3145.4675, device='cuda:0'), tensor(3455.3909, device='cuda:0'), tensor(3378.4941, device='cuda:0'), tensor(3627.9788, device='cuda:0'), tensor(3000.5002, device='cuda:0'), tensor(3178.2737, device='cuda:0'), tensor(3319.1846, device='cuda:0'), tensor(3442.2812, device='c

In [4]:
# load dataset and tokenize
data = load_dataset('wikitext', 'wikitext-2-v1', split='test')
encodings = dexperts.tokenizer('\n\n'.join(data['text']), return_tensors='pt')

# compute perplexity
lls_regular = []
ppl_regular = None

for i in tqdm(range(0, encodings.input_ids.size(1), stride)):
    # if i> 1:
    #     break
    begin_loc = max(i + stride - max_length, 0)
    end_loc = min(i + stride, encodings.input_ids.size(1))
    trg_len = end_loc - i  # may be different from stride on last loop
    input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100

    with torch.no_grad():
        # loss_regular = compute_loss(input_ids, labels=target_ids)
        logits = dexperts._get_logits(input_ids, alpha=alpha)
        # print(logits.shape)
        loss_regular = dexperts._get_perplexity(logits=logits, labels=target_ids, exp=False)
        # print(loss_regular)
        log_likelihood_regular = loss_regular * trg_len

    lls_regular.append(log_likelihood_regular)

    ppl_regular = torch.exp(torch.stack(lls_regular).sum() / end_loc)
    # print(f'Perplexity after {i} tokens: {ppl_debiased} (debiased) vs {ppl_regular} (regular)')
print(f'Final perplexity: {ppl_regular}')


Found cached dataset wikitext (/home/rloha/.cache/huggingface/datasets/wikitext/wikitext-2-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
Token indices sequence length is longer than the specified maximum sequence length for this model (297300 > 1024). Running this sequence through the model will result in indexing errors
0it [00:00, ?it/s]

Final perplexity: None





In [None]:

    if args.save_json:
        with open(args.save_json, 'wt') as f:
            args.__dict__['perplexity'] = ppl_regular.item()
            json.dump(args.__dict__, f, indent=2)


In [None]:
if args.save_json:
    with open(args.save_json, 'wt') as f:
        args.__dict__['perplexity'] = ppl_regular.item()
        json.dump(args.__dict__, f, indent=2)
