In [1]:
from transformers import DistilBertTokenizerFast
from transformers import DistilBertForTokenClassification
from transformers import EarlyStoppingCallback

from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import os

import torch

from glob import glob
import pandas as pd
import csv
import pickle
import numpy as np
import random
import collections

import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [2]:
def seed_all(seed):
    if not seed:
        seed = 10

    print("[ Using Seed : ", seed, " ]")

    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
seed_all(42)

[ Using Seed :  42  ]


In [3]:
suffix = '_fine_grain'
train_texts, train_tags = pickle.load(open(f'data/parsed_webis_train{suffix}.pkl', 'rb'))
val_texts, val_tags = pickle.load(open(f'data/parsed_webis_dev{suffix}.pkl', 'rb'))
test_texts, test_tags = pickle.load(open(f'data/parsed_webis_test{suffix}.pkl', 'rb'))


In [4]:
#drop the one long doc
train_texts = [txt for txt in train_texts if len(txt) <= 512]
train_tags = [t for t in train_tags if len(t) <= 512]


In [5]:
#report stats here
def report_stats(tags):
    print('==============')
    length = [len(t) for t in tags]
    print(f'#chunks: {len(tags)}')
    print(f'average length: {np.mean(length):.1f}')
    print(f'99% length: {np.percentile(length, [99])[0]:.1f}')
    print(f'max length: {np.percentile(length, [100])[0]:.1f}')
    counter = collections.defaultdict(int)
    
    for t in tags:
        for key in set(t):
            counter[key] += t.count(key)
    for key in sorted(counter.keys()):
        print(f'{key}: {counter[key]}')

report_stats(train_tags)
report_stats(val_tags)
report_stats(test_tags)

#chunks: 2727
average length: 57.7
99% length: 150.0
max length: 367.0
B-anecdote: 1558
B-assumption: 4845
B-common-ground: 165
B-other: 50
B-statistics: 267
B-testimony: 527
I-anecdote: 29059
I-assumption: 84654
I-common-ground: 2154
I-other: 501
I-statistics: 5129
I-testimony: 14920
O: 13597
#chunks: 752
average length: 71.4
99% length: 179.0
max length: 336.0
B-anecdote: 467
B-assumption: 1854
B-common-ground: 37
B-other: 40
B-statistics: 65
B-testimony: 93
I-anecdote: 8371
I-assumption: 32559
I-common-ground: 509
I-other: 507
I-statistics: 1189
I-testimony: 2420
O: 5566
#chunks: 843
average length: 61.4
99% length: 200.6
max length: 280.0
B-anecdote: 451
B-assumption: 1741
B-common-ground: 29
B-other: 15
B-statistics: 81
B-testimony: 154
I-anecdote: 8680
I-assumption: 29246
I-common-ground: 312
I-other: 186
I-statistics: 1679
I-testimony: 4094
O: 5081


In [6]:
tags = train_tags + val_tags + test_tags
unique_tags = set(tag for doc in tags for tag in doc)
tag2id = {tag: id for id, tag in enumerate(unique_tags)}
id2tag = {id: tag for tag, id in tag2id.items()}
print(id2tag)

{0: 'B-testimony', 1: 'B-common-ground', 2: 'B-statistics', 3: 'B-other', 4: 'I-common-ground', 5: 'I-statistics', 6: 'B-assumption', 7: 'I-assumption', 8: 'O', 9: 'I-anecdote', 10: 'I-other', 11: 'I-testimony', 12: 'B-anecdote'}


In [7]:
#tokenize the texts
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')
train_encodings = tokenizer(train_texts, is_split_into_words=True, return_offsets_mapping=True, padding=True, truncation=True)
val_encodings = tokenizer(val_texts, is_split_into_words=True, return_offsets_mapping=True, padding=True, truncation=True)
test_encodings = tokenizer(test_texts, is_split_into_words=True, return_offsets_mapping=True, padding=True, truncation=True)

In [8]:
def encode_tags(tgs, encodings):
    labels = [[tag2id[tag] for tag in doc] for doc in tgs]
    encoded_labels = []
    for doc_labels, doc_offset in zip(labels, encodings.offset_mapping):
        # create an empty array of -100
        doc_enc_labels = np.ones(len(doc_offset),dtype=int) * -100
        arr_offset = np.array(doc_offset)
        
        # set labels whose first offset position is 0 and the second is not 0
        doc_enc_labels[(arr_offset[:,0] == 0) & (arr_offset[:,1] != 0)] = doc_labels
        
        encoded_labels.append(doc_enc_labels.tolist())
        
        
    return encoded_labels

train_labels = encode_tags(train_tags, train_encodings)
val_labels = encode_tags(val_tags, val_encodings)
test_labels = encode_tags(test_tags, test_encodings)

In [9]:


class WebisDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_encodings.pop("offset_mapping") # we don't want to pass this to the model
val_encodings.pop("offset_mapping")
test_encodings.pop("offset_mapping")

train_dataset = WebisDataset(train_encodings, train_labels)
val_dataset = WebisDataset(val_encodings, val_labels)
test_dataset = WebisDataset(test_encodings, test_labels)



In [10]:

model = DistilBertForTokenClassification.from_pretrained('distilbert-base-cased', num_labels=len(unique_tags))

Some weights of the model checkpoint at distilbert-base-cased were not used when initializing DistilBertForTokenClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this 

In [11]:

early_stopping = EarlyStoppingCallback(early_stopping_patience=10)
def compute_metrics(pred):
    labels = pred.label_ids.flatten()
    preds = pred.predictions.argmax(-1).flatten()
    z = zip(labels, preds)
    z = [item for item in z if item[0] != -100]
    labels = np.array([item[0] for item in z])
    preds = np.array([item[1] for item in z])
    
    
    
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
    acc = accuracy_score(labels, preds)
    return {
        'acc': acc,
        'f1_macro': f1,
        'p_macro': precision,
        'r_macro': recall
    }

# def compute_metrics(pred):
#     labels = pred.label_ids.flatten()
#     preds = pred.predictions.argmax(-1).flatten()
#     z = zip(labels, preds)
#     z = [item for item in z if item[0] != -100]
#     labels = np.array([item[0] for item in z])
#     preds = np.array([item[1] for item in z])
    
#     l_0 = np.array([1 if item[0]==0 else 0 for item in z])
#     p_0 = np.array([1 if item[1]==0 else 0 for item in z])
    
#     l_1 = np.array([1 if item[0]==1 else 0 for item in z])
#     p_1 = np.array([1 if item[1]==1 else 0 for item in z])
    
#     l_2 = np.array([1 if item[0]==2 else 0 for item in z])
#     p_2 = np.array([1 if item[1]==2 else 0 for item in z])
    
#     precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
#     p_0, r_0, f1_0, _ = precision_recall_fscore_support(l_0, p_0, average='binary')
#     p_1, r_1, f1_1, _ = precision_recall_fscore_support(l_1, p_1, average='binary')
#     p_2, r_2, f1_2, _ = precision_recall_fscore_support(l_2, p_2, average='binary')
    
#     acc = accuracy_score(labels, preds)
#     return {
#         'acc': acc,
#         'f1_macro': f1,
#         'p_macro': precision,
#         'r_macro': recall,
#         'f1_0': f1_0,
#         'p_0': p_0,
#         'r_0': r_0,
#         'f1_1': f1_1,
#         'p_1': p_1,
#         'r_1': r_1,
#         'f1_2': f1_2,
#         'p_2': p_2,
#         'r_2': r_2,
#     }


training_args = TrainingArguments(
    output_dir=f'./results{suffix}',         
    num_train_epochs=30,              
    per_device_train_batch_size=16, 
    per_device_eval_batch_size=64, 
    warmup_steps=100,                
    weight_decay=0.005,           
    logging_dir=f'./logs{suffix}', 
    logging_steps=5,
    evaluation_strategy='steps',
    eval_steps=5,
    load_best_model_at_end=True,
    metric_for_best_model='eval_f1_macro'
)





In [12]:

trainer = Trainer(
    model=model,                         
    args=training_args,                  
    train_dataset=train_dataset,        
    eval_dataset=val_dataset,            
    compute_metrics=compute_metrics,
    callbacks=[early_stopping],
)

trainer.train()

Step,Training Loss,Validation Loss,Acc,F1 Macro,P Macro,R Macro,Runtime,Samples Per Second
5,2.5349,2.477056,0.104924,0.039522,0.07899,0.077312,0.7793,964.914
10,2.4651,2.372874,0.275891,0.062483,0.08385,0.085365,0.7513,1000.995
15,2.3561,2.194166,0.548391,0.070059,0.095114,0.080789,0.7154,1051.095
20,2.1626,1.931829,0.605678,0.060647,0.109954,0.077918,0.6672,1127.162
25,1.8882,1.583765,0.606535,0.058299,0.09155,0.07701,0.6577,1143.448
30,1.6467,1.373413,0.606554,0.058087,0.046661,0.076921,0.6493,1158.211
35,1.4127,1.273583,0.629096,0.08697,0.128993,0.093837,0.6566,1145.219
40,1.3101,1.164602,0.650837,0.111017,0.247258,0.111281,0.6675,1126.676
45,1.2411,1.070799,0.674889,0.153204,0.21771,0.142631,0.6709,1120.834
50,1.116,1.016337,0.702014,0.197632,0.225682,0.185447,0.7319,1027.397


TrainOutput(global_step=490, training_loss=0.36967466247020936, metrics={'train_runtime': 439.7108, 'train_samples_per_second': 1.501, 'total_flos': 9722104150542732, 'epoch': 22.27})

In [13]:
trainer.evaluate(val_dataset)

{'epoch': 22.27,
 'eval_acc': 0.767628593252231,
 'eval_f1_macro': 0.4822016223271417,
 'eval_loss': 1.3884586095809937,
 'eval_p_macro': 0.5316185502751299,
 'eval_r_macro': 0.47101940745876253,
 'eval_runtime': 0.6983,
 'eval_samples_per_second': 1076.848}

In [14]:
trainer.evaluate(test_dataset)

{'epoch': 22.27,
 'eval_acc': 0.7618697945854026,
 'eval_f1_macro': 0.4561523234515859,
 'eval_loss': 1.3106003999710083,
 'eval_p_macro': 0.46143968932062845,
 'eval_r_macro': 0.4532269328127683,
 'eval_runtime': 0.6391,
 'eval_samples_per_second': 1319.056}

In [15]:
test_paragraphs = pickle.load(open('test_paragraphs.pkl', 'rb'))

In [16]:
test_paragraphs.keys()

dict_keys(['pseudo-science', 'political-bias', 'credible'])

In [17]:
from nltk import word_tokenize
#test the model on our dataset
name = 'pseudo-science'
cross_text = test_paragraphs[name]
cross_texts = [word_tokenize(txt) for txt in cross_text]
# report_stats(cross_texts)
# cross_texts = test_texts[:2]


In [18]:

# name = 'political-bias'
# cross_text = test_paragraphs[name]
# cross_texts = [word_tokenize(txt) for txt in cross_text]
# report_stats(cross_texts)


In [19]:

# name = 'credible'
# cross_text = test_paragraphs[name]
# cross_texts = [word_tokenize(txt) for txt in cross_text]
# report_stats(cross_texts)

In [20]:
# cross_texts = test_texts[10:15]
# tl = test_labels[10:15]

model = model.to('cpu')

cross_encodings = tokenizer(cross_texts, is_split_into_words=True, return_offsets_mapping=True, padding=True, truncation=True)
cross_encodings.pop("offset_mapping")

cross_input = {key: torch.tensor(val) for key, val in cross_encodings.items()}

pred = np.argmax(model(cross_input['input_ids'].to('cpu'))[0].detach().cpu().numpy(), -1)
tokens = [tokenizer.convert_ids_to_tokens(item) for item in cross_input['input_ids']]



In [21]:
import html
from IPython.core.display import display, HTML

def html_escape(text):
    return html.escape(text)

def show(tokens, pred, idx, colors=None):
    

    sent = []
    BIOs = []
    for tk, lb in zip(tokens[idx], pred[idx]):
        if tk in ['[CLS]', '[SEP]', '[PAD]']:
            continue
        sent.append(tk)
        if lb == -100: #word piece follows the previous
            BIOs.append(BIOs[-1])
        else:
            BIOs.append(id2tag[lb])
        
    text_plot = sent
    wt = []
    for item in BIOs:
        if 'B' in item:
            wt.append(1)
        elif 'I' in item:
            wt.append(0.5)
        else:
            wt.append(0)
    
    if colors:
        cs = []
        for item in BIOs:
            for key in colors:
                if key in item:
                    cs.append(colors[key])
                    continue
        
    
    highlighted_text = []
    if not colors:
        for word, wt in zip(text_plot, wt):
            weight = wt

            if weight is not None:
                highlighted_text.append('<span style="background-color:rgba(135,206,250,' + str(weight) + ');">' + html_escape(word) + '</span>')
            else:
                highlighted_text.append(word)
    else:
        for word, wt, c in zip(text_plot, wt, cs):
            weight = wt

            if weight is not None:
                highlighted_text.append(f'<span style="background-color:rgba({c},' + str(weight) + ');">' + html_escape(word) + '</span>')
            else:
                highlighted_text.append(word)
    highlighted_text = ' '.join(highlighted_text)
    display(HTML(highlighted_text))

In [22]:
colors = {
    'common-ground': '139,255,252',
    'anecdote': '123,255,108',
    'statistics': '255,108,255',
    'assumption': '104,104,255',
    'other': '255,255,104',
    'testimony': '255,114,104',
    'O': '255,255,255'
}

In [23]:

example = []
for txt in colors:
    example.append(f'<span style="background-color:rgba({colors[txt]},' + str(1) + ');">' + html_escape(txt) + '</span>')
display(HTML('\n'.join(example)))

In [24]:
#domain adaptation

In [25]:
for i in range(30):
    show(tokens, pred, i, colors)

In [29]:
# plot the test set

cross_texts = test_texts[10:15]
tl = test_labels[10:15]

model = model.to('cpu')

cross_encodings = tokenizer(cross_texts, is_split_into_words=True, return_offsets_mapping=True, padding=True, truncation=True)
cross_encodings.pop("offset_mapping")

cross_input = {key: torch.tensor(val) for key, val in cross_encodings.items()}

pred = np.argmax(model(cross_input['input_ids'].to('cpu'))[0].detach().cpu().numpy(), -1)
tokens = [tokenizer.convert_ids_to_tokens(item) for item in cross_input['input_ids']]


for i in range(5):
    show(tokens, pred, i, colors)



In [30]:
for i in range(5):
    show(tokens, tl, i, colors)


In [27]:
# output = []
# for tks, lbs in zip(tokens, pred):
#     sent = []
#     BIOs = []
#     for tk, lb in zip(tks, lbs):
#         if tk in ['[CLS]', '[SEP]', '[PAD]']:
#             continue
#         sent.append(tk)
#         BIOs.append(id2tag[lb])
#     output.append(sent)
#     output.append(BIOs)
#     output.append([])
# #     print('\t'.join(sent))
# #     print('\t'.join(BIOs))
# #     print()

# import csv

# with open(f'inference/{name}.csv', 'w') as f:
#     writer = csv.writer(f)
#     for row in output:
#         writer.writerow(row)

In [28]:
print(model)

DistilBertForTokenClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
          