In [1]:
from transformers import DistilBertTokenizerFast
from transformers import DistilBertForTokenClassification
from transformers import EarlyStoppingCallback

from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import os

import torch

from glob import glob
import pandas as pd
import csv
import pickle
import numpy as np
import collections

In [2]:
train_texts, train_tags = pickle.load(open('data/parsed_webis_train.pkl', 'rb'))
val_texts, val_tags = pickle.load(open('data/parsed_webis_dev.pkl', 'rb'))
test_texts, test_tags = pickle.load(open('data/parsed_webis_test.pkl', 'rb'))


In [3]:
#cut the long doc
train_texts = [txt for txt in train_texts if len(txt) <= 512]
train_tags = [t for t in train_tags if len(t) <= 512]

In [4]:
#report stats here
def report_stats(tags):
    print('==============')
    length = [len(t) for t in tags]
    print(f'#chunks: {len(tags)}')
    print(f'average length: {np.mean(length):.1f}')
    print(f'99% length: {np.percentile(length, [99])[0]:.1f}')
    print(f'max length: {np.percentile(length, [100])[0]:.1f}')
    counter = collections.defaultdict(int)
    
    for t in tags:
        for key in ['B', 'I', 'O']:
            counter[key] += t.count(key)
    for key in ['B', 'I', 'O']:
        print(f'{key}: {counter[key]}')

report_stats(train_tags)
report_stats(val_tags)
report_stats(test_tags)

#chunks: 2749
average length: 57.4
99% length: 147.0
max length: 299.0
B: 7674
I: 136584
O: 13597
#chunks: 757
average length: 71.0
99% length: 176.8
max length: 336.0
B: 2676
I: 45536
O: 5566
#chunks: 849
average length: 61.0
99% length: 190.5
max length: 280.0
B: 2577
I: 44159
O: 5081


In [5]:
tags = train_tags + val_tags + test_tags
unique_tags = set(tag for doc in tags for tag in doc)
tag2id = {tag: id for id, tag in enumerate(unique_tags)}
id2tag = {id: tag for tag, id in tag2id.items()}

In [6]:
#tokenize the texts
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-cased')
train_encodings = tokenizer(train_texts, is_split_into_words=True, return_offsets_mapping=True, padding=True, truncation=True)
val_encodings = tokenizer(val_texts, is_split_into_words=True, return_offsets_mapping=True, padding=True, truncation=True)
test_encodings = tokenizer(test_texts, is_split_into_words=True, return_offsets_mapping=True, padding=True, truncation=True)

In [7]:
def encode_tags(tgs, encodings):
    labels = [[tag2id[tag] for tag in doc] for doc in tgs]
    encoded_labels = []
    for doc_labels, doc_offset in zip(labels, encodings.offset_mapping):
        # create an empty array of -100
        doc_enc_labels = np.ones(len(doc_offset),dtype=int) * -100
        arr_offset = np.array(doc_offset)
        
        # set labels whose first offset position is 0 and the second is not 0
        doc_enc_labels[(arr_offset[:,0] == 0) & (arr_offset[:,1] != 0)] = doc_labels
        
        encoded_labels.append(doc_enc_labels.tolist())
        
        
    return encoded_labels

train_labels = encode_tags(train_tags, train_encodings)
val_labels = encode_tags(val_tags, val_encodings)
test_labels = encode_tags(test_tags, test_encodings)

In [8]:


class WebisDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)

train_encodings.pop("offset_mapping") # we don't want to pass this to the model
val_encodings.pop("offset_mapping")
test_encodings.pop("offset_mapping")

train_dataset = WebisDataset(train_encodings, train_labels)
val_dataset = WebisDataset(val_encodings, val_labels)
test_dataset = WebisDataset(test_encodings, test_labels)



In [9]:

model = DistilBertForTokenClassification.from_pretrained('distilbert-base-cased', num_labels=len(unique_tags))

Some weights of the model checkpoint at distilbert-base-cased were not used when initializing DistilBertForTokenClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this 

In [10]:

early_stopping = EarlyStoppingCallback(early_stopping_patience=10)
# def compute_metrics(pred):
#     labels = pred.label_ids.flatten()
#     preds = pred.predictions.argmax(-1).flatten()
#     z = zip(labels, preds)
#     z = [item for item in z if item[0] != -100]
#     labels = np.array([item[0] for item in z])
#     preds = np.array([item[1] for item in z])
    
    
    
#     precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
#     acc = accuracy_score(labels, preds)
#     return {
#         'acc_macro': acc,
#         'f1_macro': f1,
#         'p_macro': precision,
#         'r_macro': recall
#     }

def compute_metrics(pred):
    labels = pred.label_ids.flatten()
    preds = pred.predictions.argmax(-1).flatten()
    z = zip(labels, preds)
    z = [item for item in z if item[0] != -100]
    labels = np.array([item[0] for item in z])
    preds = np.array([item[1] for item in z])
    
    l_0 = np.array([1 if item[0]==0 else 0 for item in z])
    p_0 = np.array([1 if item[1]==0 else 0 for item in z])
    
    l_1 = np.array([1 if item[0]==1 else 0 for item in z])
    p_1 = np.array([1 if item[1]==1 else 0 for item in z])
    
    l_2 = np.array([1 if item[0]==2 else 0 for item in z])
    p_2 = np.array([1 if item[1]==2 else 0 for item in z])
    
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='macro')
    p_0, r_0, f1_0, _ = precision_recall_fscore_support(l_0, p_0, average='binary')
    p_1, r_1, f1_1, _ = precision_recall_fscore_support(l_1, p_1, average='binary')
    p_2, r_2, f1_2, _ = precision_recall_fscore_support(l_2, p_2, average='binary')
    
    acc = accuracy_score(labels, preds)
    return {
        'acc': acc,
        'f1_macro': f1,
        'p_macro': precision,
        'r_macro': recall,
        'f1_0': f1_0,
        'p_0': p_0,
        'r_0': r_0,
        'f1_1': f1_1,
        'p_1': p_1,
        'r_1': r_1,
        'f1_2': f1_2,
        'p_2': p_2,
        'r_2': r_2,
    }


training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=30,              # total number of training epochs
    per_device_train_batch_size=32,  # batch size per device during training
    per_device_eval_batch_size=64,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.005,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=5,
    evaluation_strategy='steps',
    eval_steps=5,
    load_best_model_at_end=True,
    metric_for_best_model='eval_f1_macro'
)





In [11]:

trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset,            # evaluation dataset
    compute_metrics=compute_metrics,
    callbacks=[early_stopping],
)

trainer.train()



Step,Training Loss,Validation Loss,Acc,F1 Macro,P Macro,R Macro,F1 0,P 0,R 0,F1 1,P 1,R 1,F1 2,P 2,R 2,Runtime,Samples Per Second
5,1.0067,0.995936,0.435996,0.256681,0.325364,0.346476,0.18468,0.109217,0.597557,0.0,0.0,0.0,0.585364,0.866873,0.44187,0.8566,883.733
10,0.9918,0.973752,0.499814,0.278781,0.324426,0.346629,0.181425,0.110235,0.512217,0.0,0.0,0.0,0.654919,0.863044,0.52767,0.9077,834.008
15,0.9639,0.937072,0.605136,0.306771,0.322139,0.342357,0.167611,0.109617,0.355911,0.0,0.0,0.0,0.752703,0.856798,0.671161,0.7784,972.535
20,0.9204,0.887855,0.724943,0.328764,0.322484,0.341528,0.143867,0.115074,0.191879,0.0,0.0,0.0,0.842426,0.852377,0.832704,1.0781,702.167
25,0.8659,0.828652,0.808918,0.324867,0.32118,0.336054,0.078963,0.114765,0.060187,0.0,0.0,0.0,0.895637,0.848775,0.947975,0.8368,904.67
30,0.798,0.762055,0.840994,0.311916,0.327655,0.334541,0.021807,0.135524,0.011858,0.0,0.0,0.0,0.913941,0.847441,0.991765,0.7541,1003.799
35,0.7276,0.690944,0.846015,0.305913,0.300457,0.333206,0.001067,0.054545,0.000539,0.0,0.0,0.0,0.916673,0.846825,0.999078,0.7491,1010.541
40,0.653,0.619133,0.846703,0.305666,0.28225,0.333319,0.0,0.0,0.0,0.0,0.0,0.0,0.916998,0.84675,0.999956,0.7605,995.427
45,0.5731,0.552715,0.84674,0.30567,0.282247,0.333333,0.0,0.0,0.0,0.0,0.0,0.0,0.917011,0.84674,1.0,0.8615,878.706
50,0.502,0.50129,0.84674,0.30567,0.282247,0.333333,0.0,0.0,0.0,0.0,0.0,0.0,0.917011,0.84674,1.0,0.7407,1022.038


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)






TrainOutput(global_step=310, training_loss=0.2345658930559312, metrics={'train_runtime': 306.2719, 'train_samples_per_second': 1.077, 'total_flos': 10183685664087936, 'epoch': 28.18})

In [12]:
trainer.evaluate(val_dataset)



{'epoch': 28.18,
 'eval_acc': 0.9329837480010413,
 'eval_f1_0': 0.6814112690889943,
 'eval_f1_1': 0.8237526086131665,
 'eval_f1_2': 0.964931565901498,
 'eval_f1_macro': 0.8233651478678863,
 'eval_loss': 0.32985052466392517,
 'eval_p_0': 0.8233647238483075,
 'eval_p_1': 0.8366088631984586,
 'eval_p_2': 0.9473906970838447,
 'eval_p_macro': 0.869121428043537,
 'eval_r_0': 0.5812073302191879,
 'eval_r_1': 0.8112855007473841,
 'eval_r_2': 0.983134223471539,
 'eval_r_macro': 0.7918756848127037,
 'eval_runtime': 0.7806,
 'eval_samples_per_second': 969.75}

In [13]:
trainer.evaluate(test_dataset)



{'epoch': 28.18,
 'eval_acc': 0.9418337611208677,
 'eval_f1_0': 0.7113955880672077,
 'eval_f1_1': 0.853443201883461,
 'eval_f1_2': 0.9693054751191696,
 'eval_f1_macro': 0.8447147550232795,
 'eval_loss': 0.284835547208786,
 'eval_p_0': 0.8484187568157033,
 'eval_p_1': 0.8630952380952381,
 'eval_p_2': 0.9536917311359004,
 'eval_p_macro': 0.8884019086822806,
 'eval_r_0': 0.6124778586892344,
 'eval_r_1': 0.8440046565774156,
 'eval_r_2': 0.9854389818610023,
 'eval_r_macro': 0.8139738323758842,
 'eval_runtime': 0.7375,
 'eval_samples_per_second': 1151.21}