In [1]:
import argparse
import pickle as pk
from collections import defaultdict
import torch
# For machine learning tools and evaluation
from sklearn.metrics import accuracy_score
# Transformer library
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import Trainer, TrainingArguments
import evaluate
from datasets import load_dataset, Dataset

data_dir = "./data/"
device_name = "cuda" if torch.cuda.is_available() else "cpu"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
subreddit = "Judaism"
usecols=['year-month', 'timestamp', 'text', 'speaker']
comments_df = pk.load(open(data_dir + f"{subreddit}-comments.pk", "rb"))
comments_df = comments_df[usecols]

In [4]:
model_month = "2016-07"
checkpoint_path = f"./models/distilgpt2_{subreddit}_{model_month}/best"
model = AutoModelForCausalLM.from_pretrained(checkpoint_path).to(device_name)
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)

In [5]:
predict_month = "2017-01"
input_texts = [t for t in comments_df[comments_df['year-month'] == predict_month]['text'] if t]
encodings = tokenizer.encode("\n\n".join(input_texts), return_tensors="pt")

Token indices sequence length is longer than the specified maximum sequence length for this model (1008305 > 1024). Running this sequence through the model will result in indexing errors


In [8]:
encodings.shape

torch.Size([1, 1008305])

In [37]:
import calc_token_contribution

In [62]:
import importlib
importlib.reload(calc_token_contribution)

<module 'calc_token_contribution' from '/share/luxlab/andrea/religion-subreddits/calc_token_contribution.py'>

In [None]:
max_length = 512
stride = 1
seq_len = encodings.size(1)

nlls = defaultdict(list)
prev_end_loc = 0

for begin_loc in range(0, seq_len, stride):
    print("begin_loc: ", begin_loc)
    end_loc = min(begin_loc + max_length, seq_len)
    print("end_loc: ", end_loc)
    trg_len = end_loc - prev_end_loc
    print(trg_len)
    input_ids = encodings[:, begin_loc:end_loc].to(device_name)
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100
    target_token = target_ids[:, -stride]
    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)
        token_nll = outputs.loss * trg_len
        
        target_word = tokenizer.decode(target_token)
        nlls[target_word].append(token_nll.cpu().item())
        
    prev_end_loc = end_loc
    if end_loc == seq_len:
        break
        

In [29]:

for key, value in nlls.items():
    print(key)
    ppl = torch.exp(torch.stack(value)).mean()
    print(ppl.cpu().numpy())


,
inf
 after
2677.2778
 my
6387.746
 roommate
675372.3
 came
981.5949
 home
12283.591
 :-)
13612.111
 :
12652.543
 https
1632.5735
://
1.000114
i
2282.7717
.
63.961105
imgur
107.21448
com
35.206013
/
767.95355
Due
56668544.0
uy
23994.855
gu
10058.086
jpg
260.2855


46.925888
Tu
139882.14
 ded
2828744.0
ic
1188.3696
aci
36138.34
ón
1.0501997
 es
78149.99
 real
3874.089
ment
324.4796
e
12871.589
 admirable
98058370.0
Des
8996.796
af
3523.3965
ortun
13955.19
ad
266449.25
ament
10577.184
 la
7698.388
 com
21651.275
un
12717.073
idad
1.7455468
 or
668.5202
t
502226.75
odox
29027.328
a
947.4002
 en
39403.74
 Pan
310433.47
am
14040.1875
á
195376.05
 no
510.27524
 m
3818.892
 "
414.83902
fan
636506.5
"
299.68466
 de
64793.91
 los
4.3067017
 convers
9238.641
os
2498.2039
 No
9446.969
 con
10489.393
oz
26433.963
co
728.42664
 n
5722.8228
ie
11234.358
 que
916563.94
 se
11798.665
 h
14394.142
aya
7188.74
 convert
4686.5693
ido
684.82513
 Ad
25356.38
em
4698.772
ás
166767.98
 las
143.08153
 sin
11