In [1]:
import argparse
import numpy as np
import pandas as pd
import pickle as pk
from collections import defaultdict
import torch
# For machine learning tools and evaluation
from sklearn.metrics import accuracy_score
# Transformer library
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import Trainer, TrainingArguments
import evaluate
from datasets import load_dataset, Dataset

data_dir = "./data/"
device = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
subreddit = "Judaism"
usecols=['year-month', 'timestamp', 'text', 'speaker']
comments_df = pk.load(open(data_dir + f"{subreddit}-comments.pk", "rb"))
comments_df = comments_df[usecols]

In [3]:
model_month = "2016-01"
checkpoint_path = f"./models/distilgpt2_{subreddit}_{model_month}/best"

In [192]:
from preprocess import preprocess

In [4]:
predict_month = "2017-04"
input_texts = [t for t in comments_df[comments_df['year-month'] == predict_month]['text'] if t]
len(input_texts)

8669

In [5]:
selected_random_utts = np.random.choice(input_texts, size=100, replace=False)

In [208]:
selected_random_utts = [preprocess(u) for u in selected_random_utts]

## use evaluate library to calculate sentence perplexity

In [221]:
def calculate_huggingface_ppl(sentences):
    perplexity = evaluate.load("perplexity", module_type="metric")
    return perplexity.compute(
        model_id = checkpoint_path,
        add_start_token = False, # default
        predictions = sentences,
        max_length=1024
    )

In [222]:
huggingface_ppl = calculate_huggingface_ppl(selected_random_utts)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.34it/s]


In [223]:
huggingface_ppl

{'perplexities': [119.88584899902344,
  236.0976104736328,
  100.75423431396484,
  716.7677001953125,
  413.7164306640625,
  410.6705322265625,
  248.873291015625,
  163.5489959716797,
  55.92713928222656,
  160.4458770751953,
  153.39732360839844,
  104.64350891113281,
  92.46155548095703,
  96.78690338134766,
  387.2434387207031,
  190.98699951171875,
  470.62890625,
  185.51380920410156,
  147.025390625,
  653.4151000976562,
  265.0249328613281,
  139.4586944580078,
  198.70712280273438,
  94.5572280883789,
  800.6076049804688,
  455.3201599121094,
  19.6213321685791,
  91.93151092529297,
  84.49825286865234,
  427.1061706542969,
  742.717529296875,
  223.286865234375,
  53.72596740722656,
  58.15818405151367,
  100.2213134765625,
  235.43734741210938,
  141.0738525390625,
  48.792850494384766,
  35.06508255004883,
  257.86676025390625,
  291.4250183105469,
  394.947021484375,
  107.3738021850586,
  63.94613265991211,
  124.24491119384766,
  110.55298614501953,
  679.3197631835938,


## modified huggingface implementation

In [230]:
import datasets
import numpy as np
import torch
from torch.nn import CrossEntropyLoss
from transformers import AutoModelForCausalLM, AutoTokenizer

import evaluate
from evaluate import logging

_DESCRIPTION = ""
_CITATION = ""
_KWARGS_DESCRIPTION = ""
class mod_Perplexity(evaluate.Measurement):
    def _info(self):
        return evaluate.MeasurementInfo(
            module_type="measurement",
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            features=datasets.Features(
                {
                    "data": datasets.Value("string"),
                }
            ),
            reference_urls=["https://huggingface.co/docs/transformers/perplexity"],
        )

    def _compute(
        self, data, model_id, batch_size: int = 16, add_start_token: bool = True, device=None, max_length=None
    ):

        if device is not None:
            assert device in ["gpu", "cpu", "cuda"], "device should be either gpu or cpu."
            if device == "gpu":
                device = "cuda"
        else:
            device = "cuda" if torch.cuda.is_available() else "cpu"

        model = AutoModelForCausalLM.from_pretrained(model_id)
        model = model.to(device)

        tokenizer = AutoTokenizer.from_pretrained(model_id)

        # if batch_size > 1 (which generally leads to padding being required), and
        # if there is not an already assigned pad_token, assign an existing
        # special token to also be the padding token
        if tokenizer.pad_token is None and batch_size > 1:
            existing_special_tokens = list(tokenizer.special_tokens_map_extended.values())
            # check that the model already has at least one special token defined
            assert (
                len(existing_special_tokens) > 0
            ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."
            # assign one of the special tokens to also be the pad token
            tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})

        if add_start_token and max_length:
            # leave room for <BOS> token to be added:
            assert (
                tokenizer.bos_token is not None
            ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
            max_tokenized_len = max_length - 1
        else:
            max_tokenized_len = max_length

        encodings = tokenizer(
            data,
            add_special_tokens=False,
            padding=True,
            truncation=True if max_tokenized_len else False,
            max_length=max_tokenized_len,
            return_tensors="pt",
            return_attention_mask=True,
        ).to(device)

        encoded_texts = encodings["input_ids"]
        attn_masks = encodings["attention_mask"]

        # check that each input is long enough:
        if add_start_token:
            assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long."
        else:
            assert torch.all(
                torch.ge(attn_masks.sum(1), 2)
            ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings."

        sentence_ppls = []
        token_ppls = []
        loss_fct = CrossEntropyLoss(reduction="none")

        for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)):
            end_index = min(start_index + batch_size, len(encoded_texts))
            encoded_batch = encoded_texts[start_index:end_index]
            attn_mask = attn_masks[start_index:end_index]

            if add_start_token:
                bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device)
                encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1)
                attn_mask = torch.cat(
                    [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1
                )

            labels = encoded_batch

            with torch.no_grad():
                out_logits = model(encoded_batch, attention_mask=attn_mask).logits
            shift_logits = out_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            shift_attention_mask_batch = attn_mask[..., 1:].contiguous()
            
            ## the following parts are different from the original code

            token_CEloss = loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch
    
            ppl_by_token = torch.exp(torch.div(token_CEloss.T, shift_attention_mask_batch.sum(1))).T

            batch_token_ppls = [tuple(zip(encoded_batch.cpu().numpy()[i], ppl_by_token.cpu().numpy()[i])) for i in range(encoded_batch.shape[0])]

            perplexity_batch = torch.prod(ppl_by_token, 1)

            sentence_ppls += perplexity_batch.tolist()
            token_ppls += batch_token_ppls
        return {"sentence_ppls": sentence_ppls, "token_ppls": token_ppls}

In [231]:
data = list(selected_random_utts)
model_id = checkpoint_path
max_length = 1024
add_start_token = False
batch_size = 16

In [232]:
mod_ppl = mod_Perplexity()
stash_ppl = mod_ppl._compute(
    data = data,
    model_id = model_id,
    add_start_token = False, # default,
    max_length = 1024
)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.15it/s]


In [237]:
idx = np.argmax(stash_ppl['sentence_ppls'])
print(idx)
print(stash_ppl['sentence_ppls'][idx])

98
2539.572265625


In [239]:
stash_ppl['token_ppls'][idx]

((10378, 4.1920857),
 (2412, 2.1464508),
 (319, 4.0804687),
 (3025, 9.030929),
 (16511, 1.462644),
 (339, 5.2363467),
 (3951, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0),
 (50256, 1.0)

In [240]:
selected_random_utts[idx]

'depends on whose pockets he lines'

In [224]:
model = AutoModelForCausalLM.from_pretrained(model_id)
model = model.to(device)

tokenizer = AutoTokenizer.from_pretrained(model_id)

# if batch_size > 1 (which generally leads to padding being required), and
# if there is not an already assigned pad_token, assign an existing
# special token to also be the padding token
if tokenizer.pad_token is None and batch_size > 1:
    existing_special_tokens = list(tokenizer.special_tokens_map_extended.values())
    # check that the model already has at least one special token defined
    assert (
        len(existing_special_tokens) > 0
    ), "If batch_size > 1, model must have at least one special token to use for padding. Please use a different model or set batch_size=1."
    # assign one of the special tokens to also be the pad token
    tokenizer.add_special_tokens({"pad_token": existing_special_tokens[0]})

if add_start_token and max_length:
    # leave room for <BOS> token to be added:
    assert (
        tokenizer.bos_token is not None
    ), "Input model must already have a BOS token if using add_start_token=True. Please use a different model, or set add_start_token=False"
    max_tokenized_len = max_length - 1
else:
    max_tokenized_len = max_length

encodings = tokenizer(
    data,
    add_special_tokens=False,
    padding=True,
    truncation=True if max_tokenized_len else False,
    max_length=max_tokenized_len,
    return_tensors="pt",
    return_attention_mask=True,
).to(device)

encoded_texts = encodings["input_ids"]
attn_masks = encodings["attention_mask"]

# check that each input is long enough:
if add_start_token:
    assert torch.all(torch.ge(attn_masks.sum(1), 1)), "Each input text must be at least one token long."
else:
    assert torch.all(
        torch.ge(attn_masks.sum(1), 2)
    ), "When add_start_token=False, each input text must be at least two tokens long. Run with add_start_token=True if inputting strings of only one token, and remove all empty input strings."

ppls = []
loss_fct = CrossEntropyLoss(reduction="none")

for start_index in logging.tqdm(range(0, len(encoded_texts), batch_size)):
    end_index = min(start_index + batch_size, len(encoded_texts))
    encoded_batch = encoded_texts[start_index:end_index]
    attn_mask = attn_masks[start_index:end_index]

    if add_start_token:
        bos_tokens_tensor = torch.tensor([[tokenizer.bos_token_id]] * encoded_batch.size(dim=0)).to(device)
        encoded_batch = torch.cat([bos_tokens_tensor, encoded_batch], dim=1)
        attn_mask = torch.cat(
            [torch.ones(bos_tokens_tensor.size(), dtype=torch.int64).to(device), attn_mask], dim=1
        )

    labels = encoded_batch

    with torch.no_grad():
        out_logits = model(encoded_batch, attention_mask=attn_mask).logits
    shift_logits = out_logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    shift_attention_mask_batch = attn_mask[..., 1:].contiguous()
    
    perplexity_batch = torch.exp(
                (loss_fct(shift_logits.transpose(1, 2), shift_labels) * shift_attention_mask_batch).sum(1)
                / shift_attention_mask_batch.sum(1)
            )

    ppls += perplexity_batch.tolist()

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 10.29it/s]


In [226]:
len(ppls)

100

In [37]:
# import calc_token_contribution
# import importlib
# importlib.reload(calc_token_contribution)

In [46]:
encoded_texts.shape

torch.Size([100, 422])

In [19]:
len(encodings)

100

In [None]:
batch_size = 16

In [25]:
encoded_texts = encodings
batch_size = 16

In [None]:
max_length = 512
stride = 1
seq_len = encodings.size(1)

nlls = defaultdict(list)
prev_end_loc = 0

for begin_loc in range(0, seq_len, stride):
    print("begin_loc: ", begin_loc)
    end_loc = min(begin_loc + max_length, seq_len)
    print("end_loc: ", end_loc)
    trg_len = end_loc - prev_end_loc
    print(trg_len)
    input_ids = encodings[:, begin_loc:end_loc].to(device_name)
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100
    target_token = target_ids[:, -stride]
    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)
        token_nll = outputs.loss * trg_len
        
        target_word = tokenizer.decode(target_token)
        nlls[target_word].append(token_nll.cpu().item())
        
    prev_end_loc = end_loc
    if end_loc == seq_len:
        break
        

In [29]:

for key, value in nlls.items():
    print(key)
    ppl = torch.exp(torch.stack(value)).mean()
    print(ppl.cpu().numpy())


,
inf
 after
2677.2778
 my
6387.746
 roommate
675372.3
 came
981.5949
 home
12283.591
 :-)
13612.111
 :
12652.543
 https
1632.5735
://
1.000114
i
2282.7717
.
63.961105
imgur
107.21448
com
35.206013
/
767.95355
Due
56668544.0
uy
23994.855
gu
10058.086
jpg
260.2855


46.925888
Tu
139882.14
 ded
2828744.0
ic
1188.3696
aci
36138.34
ón
1.0501997
 es
78149.99
 real
3874.089
ment
324.4796
e
12871.589
 admirable
98058370.0
Des
8996.796
af
3523.3965
ortun
13955.19
ad
266449.25
ament
10577.184
 la
7698.388
 com
21651.275
un
12717.073
idad
1.7455468
 or
668.5202
t
502226.75
odox
29027.328
a
947.4002
 en
39403.74
 Pan
310433.47
am
14040.1875
á
195376.05
 no
510.27524
 m
3818.892
 "
414.83902
fan
636506.5
"
299.68466
 de
64793.91
 los
4.3067017
 convers
9238.641
os
2498.2039
 No
9446.969
 con
10489.393
oz
26433.963
co
728.42664
 n
5722.8228
ie
11234.358
 que
916563.94
 se
11798.665
 h
14394.142
aya
7188.74
 convert
4686.5693
ido
684.82513
 Ad
25356.38
em
4698.772
ás
166767.98
 las
143.08153
 sin
11