In [1]:
from transformers import GPT2Tokenizer, GPT2Model, GPT2LMHeadModel


import torch
# from torch import nn
import numpy as np
perm = np.random.permutation

import pickle
from tqdm import tqdm


from collections import Counter
import matplotlib.pyplot as plt

import logging
logging.getLogger("transformers.tokenization_utils").setLevel(logging.ERROR)

I0228 17:49:14.898231 140338297841472 file_utils.py:38] PyTorch version 1.4.0 available.


## Load E-mail Data

In [2]:
with open("../../../w3c-emails/emails.pkl", "rb") as handle:
    emails = pickle.load(handle)

# Steps

 0. select authors (so that evaulation sets can be held out) and establish their frequencies <br>
   -> $P(X)$, i.e. the probability that a person appears as an author of any e-mail
 1. pre-train GPT-2 on W3C e-mail corpus <br> 
   -> W3CGPT-2 <br>
   -> approximates $P(email)$
 2. train W3CGPT-2 on e-mails by selected authors <br>
   -> GPT-2$_X$ for each person $X$ <br> 
   -> i.e. $P(email|X)$
 3. use GPT-2$_X$ to classify unseen (both to W3CGPT-2 and to GPT-2$_X$) e-mails <br>
   -> $P(X|email) = P(X)P(email|X)/P(email)$

## 0. Select Authors (and their e-mails)

 - author-frequency distribution Zipfian, so use log-linear ranks to select authors <br> 
   (i.e. author rank 1, author rank 2, author rank 4, author rank 8, ..., author rank 512)
   
 - plots below equivalent to $P(X)$

In [None]:
def select_by_ranks(emails, ls_of_ranks):
    sndr_cnts = Counter(e.sender for e in emails)
    ranks_sndr = {r:s for r, (s, c) in enumerate(sndr_cnts.most_common())}
    
    for r in ls_of_ranks:
        cur_s = ranks_sndr[r]
        yield [m for m in emails if m.sender == cur_s]
        
rank_rng = [2**i for i in range(10)]
selection = list(select_by_ranks(emails, rank_rng))
rest = list(set(emails) - set(m for m_ls in selection for m in m_ls))

In [None]:
sndr_cnts = Counter(e.sender for e in emails)
rs, cs = list(zip(*[(r, c) for r, (_, c) in enumerate(sndr_cnts.most_common())]))

fig = plt.figure(1, figsize=(15, 10))
plt.subplot(221)
plt.loglog(rs, cs, '.')
plt.xlabel("$\log$ rank"); plt.ylabel("$\log$ frequency"); plt.title("Rank-Frequency plot of number of e-mails authored by each person\n$= P(X)$")

rng = [2**i for i in range(10)]
rng_cs = [cs[i] for i in rng]

plt.subplot(222)
plt.loglog(rng, rng_cs, '.')
_ = plt.title("Thinned version of the left plot\n(subset of 10)")

## 0.1 Get Train and Evaluation sets

using train:eval ratio of 30:70

In [None]:
def split_train_test(ls, test_ratio=0.3):
    cutoff = int(len(ls)*test_ratio)
    randmsd_ls = list(perm(ls))
    test = randmsd_ls[:cutoff]
    train = randmsd_ls[cutoff:]
    
    return train, test

# def emails_to_datasets(selected, rest, test_ratio=0.3):
#     train, test = split_train_test(rest, test_ratio)
#     for mail_ls in selected:
#         cur_train, cur_test = split_train_test(mail_ls, test_ratio)
#         train.extend(cur_train)
#         test.extend(cur_test)
        
#     return train, test

selection_train, selection_test = list(zip(*[split_train_test(m_ls) for m_ls in selection]))
rest_train, rest_test = split_train_test(rest)
selection_never_seen, selection_test = list(zip(*[split_train_test(m_ls, test_ratio=0.5) 
                                                   for m_ls in selection_test]))
rest_never_seen, rest_test = split_train_test(rest_test, test_ratio=0.5)
print(list(map(len, selection_never_seen)), list(map(len, selection_test)))
len(rest_never_seen), len(rest_test)

In [None]:
# save in folders because of random permutations

with open("data_splits/selection_train.pkl", "wb") as handle:
    pickle.dump(selection_train, handle)
with open("data_splits/selection_test.pkl", "wb") as handle:
    pickle.dump(selection_test, handle)
with open("data_splits/selection_never_seen.pkl", "wb") as handle:
    pickle.dump(selection_never_seen, handle)
    

with open("data_splits/rest_train.pkl", "wb") as handle:
    pickle.dump(rest_train, handle)
with open("data_splits/rest_test.pkl", "wb") as handle:
    pickle.dump(rest_test, handle)
with open("data_splits/rest_never_seen.pkl", "wb") as handle:
    pickle.dump(rest_never_seen, handle)    

# START FROM HERE
# IMPORTANT!

In [3]:
with open("data_splits/selection_train.pkl", "rb") as handle:
    selection_train = pickle.load(handle)
with open("data_splits/selection_test.pkl", "rb") as handle:
    selection_test = pickle.load(handle)
with open("data_splits/selection_never_seen.pkl", "rb") as handle:
    selection_never_seen = pickle.load(handle)
    

with open("data_splits/rest_train.pkl", "rb") as handle:
    rest_train = pickle.load(handle)
with open("data_splits/rest_test.pkl", "rb") as handle:
    rest_test = pickle.load(handle)
with open("data_splits/rest_never_seen.pkl", "rb") as handle:
    rest_never_seen = pickle.load(handle)

## 1. Domain-Adapt GPT-2 to W3C E-mails

 - pre-train an instance of GPT-2 on entire (subset of) w3c-email corpus <br>
   -> will reduce perplexity and thus increase sensitivity of LMs
 - use this LM as starting point to train personalised LMs 
 - also reserve a test set? -> i.e. some e-mails which no custom-trained LM has seen before
 
 - => write e-mail bodies into text files, train and eval

In [5]:
def emails_to_trainfile(email_ls, file_name, split_into=1):
    chunk_size = len(email_ls)//split_into
    for i in range(split_into):
        with open(file_name + f".{i}", "w", encoding="utf-8") as handle:
            cur_chunk = email_ls[i*chunk_size:(i+1)*chunk_size] if i != split_into-1 else email_ls[i*chunk_size:]
            for m in cur_chunk:
                mail_str = m.body_raw.replace("\n", "  ")
                handle.write(mail_str)
                handle.write("\n\n")
            
full_train = perm(rest_train + 
                  [m for m_ls in selection_train for m in m_ls] +
                  [m for m_ls in selection_test for m in m_ls])

full_test = perm(rest_never_seen + [m for m_ls in selection_never_seen for m in m_ls])

emails_to_trainfile(full_train, "W3CGPT2/full.train.raw", split_into=5)
emails_to_trainfile(full_test, "W3CGPT2/full.test.raw", split_into=4)


### 1.1 Call for Training

  - first: merge split up text files -> `cat full.train.raw.* > full.train.raw.all`

`python3 run_language_modeling.py --train_data_file=W3CGPT2/full.train.raw.all --model_type=gpt2 --output_dir=W3CGPT2/lm --model_name_or_path=gpt2 --do_train --line_by_line --num_train_epochs=2`

 - `--line_by_line` indicates one sample per line to spearate e-mails, `"\n"` inside e-mails converted to `" "`
 - perhaps use `--block_size=128/256/512` (rather than GPT-2's default of 1024) -> loose fewer tokens at the ends of long emails
 
 
### 1.2 Load Trained

should be as simple as:

In [None]:
model = GPT2Model.from_pretrained('/W3CGPT2/lm/') 

## 2. Train one instance of W3CGPT-2 per author $X$ to become GPT-2$_X$

 - load W3CGPT2
 - get training files
 - call `run_language_modeling.py` with adequate parameters
 - run on LISA

### 2.1 Training files, one per author

In [20]:
import os

folder_name = "GPT2_X/"
names = []
for mail_ls in selection_train:
    cur_auth = mail_ls[0].sender
    auth_name = cur_auth.name.replace(" ", "_")
    names.append(auth_name)
    
    print(auth_name)
    
    os.mkdir(folder_name + "lm_" + auth_name)
    with open(folder_name + "lm_" + auth_name + "/nothing.txt", "w") as handle: pass
    
    with open(folder_name + auth_name + ".train.raw", "w", encoding="utf-8") as handle:
        for m in mail_ls:
            print(len(m.body_raw))
            mail_str = m.body_raw.replace("\n", "  ")
            handle.write(mail_str)
            handle.write("\n\n")
    
    print()
with open(folder_name + "auth_names.txt", "w") as handle:
    handle.write("\n".join(names))

Brian_McBride
861
229
657
385
253
200
999
167
2620
464
3983
2598
1955
3185
711
218
288
3961
423
352
223
690
344
599
1252
155
1610
578
4305
708
1678
931
1137
1030
391
338
662
2327
387
333
624
288
1044
358
2744
1384
517
609
1169
554
1075
322
1654
234
900
1865
1821
1034
3124
349
895
1013
1525
10914
456
405
2369
698
289
311
1909
1663
1797
1311
7831
2667
1067
1762
1545
1713
390
2926
2519
3287
1722
2589
560
2900
126
1309
309
4960
382
242
210
1654
1930
1065
774
360
1364
967
2464
322
352
779
2116
632
754
547
5772
650
765
311
397
2699
259
2548
869
2635
453
152
405
179
6
972
2884
698
554
2800
229
2374
430
1074
2291
1818
764
778
260
1023
439
2021
1171
680
2605
762
689
367
365
2465
1145
629
369
775
711
3467
327
263
834
432
327
657
251
413
2230
235
351
1096
410
418
369
5932
512
5352
3131
2380
273
1939
244
278
255
1820
312
118
570
523
1079
757
48
535
221
1864
265
1490
740
72
627
399
1326
2582
213
259
233
1829
3823
758
1982
455
504
1231
884
1772
1103
210
539
284
374
501
11958
2437
9857
251
5420
772
2

1636
1121
1653
1410
1389
653
2203
1067
1354
1240
990
455
413
58483
4550
315
7817
349
1649
1102
2434
1035
644
1967
229
992
1411
191
1711
3390
3232
224
167
12532
782
323
2450
329
323
118
3171
3279
5626
1601
5134
1156
1871
941
1635
3056
467
1001
1197
5601
3461
1986
3304
1188
38
232
1071
3651
5052
1415
1559
974
1018
11734
2497
249
618
620
1596
977
829
1973
4588
387
3793
611
580
758
403
278
9152
289
1002
154
881
412
3604
175
3776
2814
424
1210
9449
1586
6841
4780
1296
2557
8564
1656
536
1695
1271
123
1296
153
2780
291
631
1976
865
1444
2961
27945
1224
19586
601
949
3745
3266
2221
379
4660
5203
1057
1110
685
945
3386
1121
13994
1871
3610
1099
420
579
441
2251
2612
963
319
196
2892
202
3634
185
607
5510
5184
870
864
2477
885
358
1354
152
830
13815
2169
998
1220
3587
385
3677
918
512
930
688
5416
11398
1726
10613
346
641
1362
1820
1040
2535
3499
6882
3804
3829
5134
2285
2090
1398
11928
1695
160
840
875
1088
663
559
1917
5903
6918
356
1069
1727
3905
507
467
516
1139
2916
364
1088
372
2698
3818


955
362
425
573
381
898
1120
1818
1653
516
2170
589
224
262
1404
871
403
123
471
6113
1067
1670
4457
587
1629
519
269
1273
638
1061
78
2581
593
345
916
592
461
1031
2642
58
2433
1075
3343
590
3277
76
1342
337
2442
1536
1512
1150
322
1205
1420
1367
184
985
1104
2629
2459
858
6113
4341
601
281
9229
1070
624
6412
2615
568
1682
17998
663
4962
5444
147
745
89
3200
2992
397
4180
2798
530
4430
367
1585
804
1367
321
3756
1190
156
1675
1521
883
536
1693
5032
1813

Jason_White
1240
3713
571
491
3091
1112
321
408
870
163
462
1634
1481
1782
746
3954
335
942
2766
201
325
983
602
2163
965
1357
622
486
1176
2313
1268
1039
1640
454
1488
1021
561
900
1038
1873
520
1438
219
555
1138
1485
581
1064
1479
2177
1235
235
942
718
970
2194
1140
640
1273
1880
1433
629
1088
878
841
1380
656
426
416
556
1868
856
666
1070
328
1748
1011
826
1188
787
863
1140
988
1065
1272
3037
554
989
4939
1361
236
2252
891
921
1565
3908
766
588
677
1039
256
970
1287
1918
1038
1051
1026
307
824
1148
375
1405
1750
1012
675
385
1163
1