Script that evaluates a fine-tuned a BERT language model to evaluate created datasets


based on https://medium.com/swlh/a-simple-guide-on-using-bert-for-text-classification-bbf041ac8d04

In [0]:
from google.colab import drive
drive.mount("/content/drive")
!ls "drive/My Drive/Colab Notebooks/abusive"
!pip install transformers==2.1.1


In [0]:
from __future__ import absolute_import, division, print_function
import torch
import pickle
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, TensorDataset)
from torch.nn import CrossEntropyLoss, MSELoss
from tqdm import tqdm, trange
import os
from transformers import BertTokenizer, BertModel, BertForMaskedLM, BertForSequenceClassification
from transformers.optimization import AdamW, WarmupLinearSchedule
from multiprocessing import Pool, cpu_count
import pandas as pd
import csv
import os
import sys
import logging
import numpy as np
from sklearn.metrics import matthews_corrcoef, confusion_matrix

logger = logging.getLogger()
def get_eval_report(task_name, labels, preds):
    mcc = matthews_corrcoef(labels, preds)
    tn, fp, fn, tp = confusion_matrix(labels, preds).ravel()
    pos_prec = tp / (tp +fp) 
    pos_reca = tp / (tp + fn)
    neg_prec = tn / (tn + fn)
    neg_reca = tn / (tn + fp)
    pos_f1 = (2 * ((pos_prec * pos_reca) / (pos_prec + pos_reca)))
    neg_f1 = (2 * ((neg_prec * neg_reca) / (neg_prec + neg_reca)))
    macro_f1 = (pos_f1 + neg_f1) / 2
    return {
        "task": task_name,
        "mcc": mcc,
        
        "tp": tp,
        "tn": tn,
        "fp": fp,
        "fn": fn,
        "macro_f1": macro_f1
    }

def collect_wrong_preds(labels,preds):
    wronglist = []
    assert len(preds) == len(labels)
    for x in range(len(preds)):
        if preds[x] != labels[x]:
            wronglist.append(x)
    return wronglist
    
def compute_metrics(task_name, labels, preds):
    assert len(preds) == len(labels)
    return get_eval_report(task_name, labels, preds)

class InputExample(object):
    """A single training/test example for simple sequence classification."""

    def __init__(self, guid, text_a, text_b=None, label=None):
        """Constructs a InputExample.
        Args:
            guid: Unique id for the example.
            text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
            text_b: (Optional) string. The untokenized text of the second sequence.
            Only must be specified for sequence pair tasks.
            label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label


class DataProcessor(object):
    """Base class for data converters for sequence classification data sets."""

    def get_train_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the train set."""
        raise NotImplementedError()

    def get_dev_examples(self, data_dir):
        """Gets a collection of `InputExample`s for the dev set."""
        raise NotImplementedError()

    def get_labels(self):
        """Gets the list of labels for this data set."""
        raise NotImplementedError()

    @classmethod
    def _read_tsv(cls, input_file, quotechar=None):
        """Reads a tab separated value file."""
        with open(input_file, "r", encoding="utf-8") as f:
            reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
            lines = []
            for line in reader:
                if sys.version_info[0] == 2:
                    line = list(unicode(cell, 'utf-8') for cell in line)
                lines.append(line)
            return lines


class BinaryClassificationProcessor(DataProcessor):
    """Processor for binary classification dataset."""

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

    def get_labels(self):
        """See base class."""
        return ["0", "1"]

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            guid = "%s-%s" % (set_type, i)
            text_a = line[3]
            label = line[1]
            examples.append(
                InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
        return examples


class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids, label_id):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id


def convert_example_to_feature(example_row):
    # return example_row
    example, label_map, max_seq_length, tokenizer, output_mode = example_row

    tokens_a = tokenizer.tokenize(example.text_a)
        # Account for [CLS] and [SEP] with "- 2"
    if len(tokens_a) > max_seq_length - 2:
        tokens_a = tokens_a[:(max_seq_length - 2)]

    tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
    segment_ids = [0] * len(tokens)

    input_ids = tokenizer.convert_tokens_to_ids(tokens)

    # The mask has 1 for real tokens and 0 for padding tokens. Only real
    # tokens are attended to.
    input_mask = [1] * len(input_ids)

    # Zero-pad up to the sequence length.
    padding = [0] * (max_seq_length - len(input_ids))
    input_ids += padding
    input_mask += padding
    segment_ids += padding

    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length

    if output_mode == "classification":
        label_id = label_map[example.label]
    elif output_mode == "regression":
        label_id = float(example.label)
    else:
        raise KeyError(output_mode)

    return InputFeatures(input_ids=input_ids,
                         input_mask=input_mask,
                         segment_ids=segment_ids,
                         label_id=label_id)

In [0]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BERT_MODEL = 'bert-base-cased'
TASK_NAME = '' # fill in task name
BASE_PATH ="drive/My Drive/Colab Notebooks/experiments"
OUTPUT_DIR = f'{BASE_PATH}/outputs/{TASK_NAME}/'
REPORTS_DIR = f'{BASE_PATH}/reports/{TASK_NAME}_evaluation_report/'
DATA_DIR = f'{BASE_PATH}/data'
CACHE_DIR = 'cache/'

MAX_SEQ_LENGTH = 128
TRAIN_BATCH_SIZE = 24
EVAL_BATCH_SIZE = 32
LEARNING_RATE = 2e-5
NUM_TRAIN_EPOCHS = 10
RANDOM_SEED = 42
GRADIENT_ACCUMULATION_STEPS = 1
WARUMUP_STEPS = 100

OUTPUT_MODE = 'classification'
CONFIG_NAME = "config.json"
WEIGHTS_NAME = "pytorch_model.bin"

output_mode = OUTPUT_MODE
cache_dir = CACHE_DIR

In [0]:
BASE_REPORTS_DIR = f'{BASE_PATH}/reports/{TASK_NAME}_evaluation_report/test'
tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False)
processor = BinaryClassificationProcessor()
if not os.path.exists(BASE_REPORTS_DIR):
    os.makedirs(BASE_REPORTS_DIR)
evaldata = ['new/abusive_sieve'] #select which testing dataset(s) to evaluate on

for dataset in evaldata:
    DATA_DIR = f'{BASE_PATH}/data/BERT/{dataset}'
    REPORTS_DIR = f'{BASE_REPORTS_DIR}/{dataset}'
    if not os.path.exists(REPORTS_DIR):
        print('created ' + REPORTS_DIR)
        os.makedirs(REPORTS_DIR)
    print(f'reports dir: {REPORTS_DIR}')
    print(f'data dir: {DATA_DIR}')


    eval_examples = processor.get_dev_examples(DATA_DIR)
    label_list = processor.get_labels()
    num_labels = len(label_list)
    eval_examples_len = len(eval_examples)


    label_map = {label: i for i, label in enumerate(label_list)}
    eval_examples_for_processing = [(example, label_map, MAX_SEQ_LENGTH, tokenizer, OUTPUT_MODE) for example in eval_examples]

    process_count = cpu_count() - 1
    print(f'Preparing to convert {eval_examples_len} examples..')
    print(f'Spawning {process_count} processes..')
    with Pool(process_count) as p:
        eval_features = list(tqdm(p.imap(convert_example_to_feature, eval_examples_for_processing), total=eval_examples_len))
    with open(DATA_DIR + "/eval_features.pkl", "wb") as f:
        pickle.dump(eval_features, f)
    with open(DATA_DIR + "/eval_features.pkl", 'rb') as pkl_file:
        train_features= pickle.load(pkl_file)
    all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
    
    if OUTPUT_MODE == "classification":
        all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
    elif OUTPUT_MODE == "regression":
        all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.float)
    eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=EVAL_BATCH_SIZE)

    for epoch in range(1,6):
        print('loading model')
        model = BertForSequenceClassification.from_pretrained('bert-base-cased', cache_dir=CACHE_DIR, num_labels=len(label_list))
        model_to_load = 'drive/My Drive/Colab Notebooks/experiments/outputs/{0}/epoch_{1}/pytorch_model.bin'.format(TASK_NAME,epoch)
        print(model_to_load)
        model.load_state_dict(torch.load(model_to_load))
        model.to(device)
        model.eval()
        eval_loss = 0
        nb_eval_steps = 0
        preds = []
        print(len(eval_dataloader))
        for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)
            with torch.no_grad():
                logits = model(input_ids, segment_ids, input_mask, labels=None)[0]

            if OUTPUT_MODE == "classification":
                loss_fct = CrossEntropyLoss()
                tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
            elif OUTPUT_MODE == "regression":
                loss_fct = MSELoss()
                tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1))

            eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1
            if len(preds) == 0:
                preds.append(logits.detach().cpu().numpy())
            else:
                preds[0] = np.append(
                    preds[0], logits.detach().cpu().numpy(), axis=0)

        eval_loss = eval_loss / nb_eval_steps
        preds = preds[0]
        if OUTPUT_MODE == "classification":
            preds = np.argmax(preds, axis=1)
        elif OUTPUT_MODE == "regression":
            preds = np.squeeze(preds)
        labels = all_label_ids.numpy()
        result = compute_metrics(TASK_NAME, labels, preds)
        result['eval_loss'] = eval_loss

        output_eval_file = os.path.join(REPORTS_DIR, "eval_results_epoch_{0}.txt".format(epoch))
        output_preds_file = os.path.join(REPORTS_DIR, "preds_epoch_{0}.txt".format(epoch))
        with open(output_preds_file, 'w') as csvfile:
            predwriter = csv.writer(csvfile, delimiter='\t')
            assert len(preds) == len(labels)
            predwriter.writerow(['tweet', 'prediction', 'gold standard'])
            for x in range(len(preds)):
                predwriter.writerow([eval_examples[x].text_a, preds[x], labels[x]])
        print(output_eval_file)
        with open(output_eval_file, "w") as writer:
            print("***** Eval results *****")
            for key in (result.keys()):
                print("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))