In [0]:
!pip install torchtext --upgrade

In [0]:
!pip install transformers --upgrade

In [1]:
%cd /content/drive/"My Drive"/in5550-exam

/content/drive/My Drive/in5550-exam


In [0]:
import argparse
import random
import torch
import torchtext
from torchtext import data
import NSR
import os
import json
import numpy as np

SEED = 2020

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [0]:
from transformers import BertConfig

config = BertConfig.from_pretrained('bert-base-cased')

In [0]:
config

BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab_size": 28996
}

In [0]:
emb_dim = config.hidden_size

In [0]:
from transformers import BertTokenizer

tokenbert = BertTokenizer.from_pretrained('bert-base-cased', 
                                          do_lower_case=False)

# Preprocessing the data for BERT

In [0]:
def Prep4BERT(datafile):
  fpart, suff = datafile.rsplit(".", 1)
  outfile = fpart + "_bert." + suff
  if os.path.isfile(outfile):
    raise Exception("The file already exists")

  with open(datafile) as f, open(outfile, "wt") as o:
    for line in f:
      ex = json.loads(line)
      b_form = [[tokenbert.convert_tokens_to_ids(t) for t in tokenbert.tokenize(w)] for w in ex["form"]]
      b_cue = [[tokenbert.convert_tokens_to_ids(t) for t in tokenbert.tokenize(w)] for w in ex["cue"]]
      ex["form"] = b_form
      ex["cue"] = b_cue
      json.dump(ex, o)
      o.write("\n")


In [0]:
# Prep4BERT("DataFiles/cdd.epe")
# Prep4BERT("DataFiles/cdt.epe")
# Prep4BERT("DataFiles/cde.epe")

# Set Up Data

In [0]:
ID = data.RawField(preprocessing=lambda x: int(x))
SRC = data.RawField()
NEGS = data.RawField(preprocessing=lambda x: int(x))

FORM = data.Field(batch_first = True,
                  include_lengths = True,
                  use_vocab = False,
                  tokenize=tokenbert.tokenize,
                  preprocessing=tokenbert.convert_tokens_to_ids,
                  init_token = tokenbert.cls_token_id,
                  eos_token = tokenbert.sep_token_id,
                  pad_token = tokenbert.pad_token_id,
                  unk_token = tokenbert.unk_token_id)

LEMMA = data.RawField()
XPOS = data.RawField()
LABS = data.Field(batch_first=True, 
                  unk_token=None)

CUE = data.Field(batch_first = True,
                 include_lengths = False,
                 use_vocab = False,
                 tokenize=tokenbert.tokenize,
                 preprocessing=tokenbert.convert_tokens_to_ids,
                 init_token = tokenbert.cls_token_id,
                 eos_token = tokenbert.sep_token_id,
                 pad_token = tokenbert.pad_token_id,
                 unk_token = tokenbert.unk_token_id)

SCOPE = data.RawField()

fields = {
    "id": ("id", ID),
    "source": ("source", SRC),
    "negations": ("negations", NEGS),
    "form": ("form", FORM),
    "lemma": ("lemma", LEMMA),
    "xpos": ("xpos", XPOS),
    "negation": ("label", LABS),
    "cue": ("cue", CUE),
    "scope": ("scope", SCOPE),
}

Xtrain, Xdev = NSR.StarSEM2012.splits(
    "DataFiles",
    fields=fields,
    test=None)

LABS.build_vocab(Xtrain)

In [0]:
print("Labels: ", LABS.vocab.stoi)
n_k = len(LABS.vocab.stoi)

Labels:  defaultdict(None, {'<pad>': 0, 'T': 1, 'F': 2, 'C': 3, 'A': 4})


In [0]:
LABS.vocab.freqs.most_common()

[('T', 61351), ('F', 6802), ('C', 843), ('A', 159)]

In [0]:
from transformers import BertModel

bert = BertModel.from_pretrained('bert-base-cased')

In [0]:
"""Bidirectional LSTM sequence labeller based on Fancellu et al. (2016)."""
import torch.nn as nn


class BiBERTLSTM(nn.Module):
    """Bidirectional LSTM model.

    Aims to replicate the BiLSTM-C model used by
    Fancellu et al. (2016).

    Parameters
    ----------
    input_dim : int
        Input dimensions
    embedding_dim : int
        Dimensionality of word embeddings
    n_neurons : int
        Number of neurons of the model
    output_dim : int
        Output dimensions
    n_layers : int
        Number of hidden layers
    bidir : bool
        Bidirectional if TRUE
    batch_first : bool
        If TRUE, inputs are assumed to have shape
            [batch size X # tokens]
        If FALSE:
            [# tokens X batch size]
    vecs : torch.Tensor
        A tensor with pre-trained word vectors
    train_emb : bool
        Whether to train the embedding layer further

    """

    def __init__(self, emBERTing, emb_dim, n_neurons, output_dim, n_layers):
        """Initialize the model."""
        super().__init__()

        self.emBERTing = emBERTing
        self.emBERTing.requires_grad_(requires_grad=False)

        self.lstm = nn.LSTM(emb_dim,
                            n_neurons,
                            num_layers=n_layers,
                            bidirectional=True,
                            batch_first=True)

        self.fc = nn.Linear(n_neurons * 2,
                            output_dim)

        self.dropout = nn.Dropout(p=0.5)

    def forward(self, X, C):
        """Perform a forward pass.

        Given sample(s) X, predicts the raw/unscaled class probabilities, which
        can then be converted to probabilities by sigmoid or softmax.

        Input
        -----
        X : torch.Tensor or tuple(torch.Tensor, torch.Tensor)

        Returns
        -------
        The linear predicted values, a torch.tensor
        with dimension [batchsize X `output_dim`]

        """
        X, lens = X
        w_embs, _ = self.emBERTing(X)
        c_embs, _ = self.emBERTing(C)
        embs = (w_embs + c_embs)[:, 1:-1, :]
        lens = lens - 2
        pack_emb = nn.utils.rnn.pack_padded_sequence(
            embs, lens, batch_first=True, enforce_sorted=False
        )
        pack_O, _ = self.lstm(pack_emb)
        O, _ = nn.utils.rnn.pad_packed_sequence(pack_O, batch_first=True)
        O_d = self.dropout(O)
        return self.fc(O_d)


In [0]:
M = BiBERTLSTM(bert, emb_dim, 
               n_neurons=250, n_layers=2, 
               output_dim=n_k)
M.cuda()

BiBERTLSTM(
  (emBERTing): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tr

In [0]:
[i for i, j in M.named_parameters() if j.requires_grad]

['lstm.weight_ih_l0',
 'lstm.weight_hh_l0',
 'lstm.bias_ih_l0',
 'lstm.bias_hh_l0',
 'lstm.weight_ih_l0_reverse',
 'lstm.weight_hh_l0_reverse',
 'lstm.bias_ih_l0_reverse',
 'lstm.bias_hh_l0_reverse',
 'lstm.weight_ih_l1',
 'lstm.weight_hh_l1',
 'lstm.bias_ih_l1',
 'lstm.bias_hh_l1',
 'lstm.weight_ih_l1_reverse',
 'lstm.weight_hh_l1_reverse',
 'lstm.bias_ih_l1_reverse',
 'lstm.bias_hh_l1_reverse',
 'fc.weight',
 'fc.bias']

# Run Training

In [0]:
from NSR.Runners.AbstractRunner import Runner
from NSR.Utils import override, append2dict
import tqdm


class BERTMultiClassRunner(Runner):
    """Run training and evaluation workloads for multi-label problems.

    Arguments
    ---------
    model: torch.nn.Module
        Model to optimize or test.
    criterion: torch.optim
        Loss function.
    optimizer: torch.optim
        Optimizer for weights and biases.
    labels: list
        List containing the labels in index form.
        I.e. the values in the label field vocab.

    """

    def __init__(self, model, criterion, optimizer, labels):
        """Initialize the MultiClassRunner."""
        super().__init__(model, criterion, optimizer, labels)

    @override
    def get_accuracy(self, y_hat, y):
        """Compute global accuracy."""
        correct = (y_hat == y).nonzero().size(0)
        return correct / y_hat.size(1)

    def train(self, iters):
        """Train over batched data.

        Parameters
        ----------
        iters : torchtext.data.iterator.BucketIterator
            The batched data

        Returns
        -------
        dict
            The performance and metrics of the training session.

        """
        epoch_loss = 0

        self.model.train()
        for batch in tqdm.tqdm(iters):
            self.model.zero_grad()
            y_tilde_b = self.model(batch.form, batch.cue).transpose(1, 2)
            loss = self.criterion(y_tilde_b, batch.label)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
            epoch_loss += loss.item()

        results_train = {
            "loss": epoch_loss / len(iters),
        }

        return results_train

    def evaluate(self, iters):
        """Evaluate over batched data.

        Parameters
        ----------
        iters : torchtext.data.iterator.BucketIterator
            The batched data

        Returns
        -------
        dict
            The performance and metrics of the training session.

        """
        tp = 0
        n = 0

        device = iters.device

        y_hat = torch.tensor([], dtype=torch.long).to(device)
        y = torch.tensor([], dtype=torch.long).to(device)

        self.model.eval()
        with torch.no_grad():
            for batch in tqdm.tqdm(iters):
                y_tilde_b = self.model(batch.form, batch.cue)
                y_hat_b = y_tilde_b.argmax(dim=-1)
                y_b = batch.label

                tp += (y_hat_b == y_b).nonzero().size(0)
                n += y_b.size(1)

                y_hat = torch.cat((y_hat, y_hat_b.view(-1)))
                y = torch.cat((y, y_b.view(-1)))

        results_eval = {
            "accuracy": tp / n,
            **self.get_metrics(y.cpu(), y_hat.cpu())
        }

        return results_eval
    
    @override
    def run(self, epochs, train_iter, eval_iter):
        eval_f1_star = 0
        n_no_improve = 0

        for epoch in range(epochs):

            train_res = self.train(train_iter)
            eval_res = self.evaluate(eval_iter)

            print(eval_res)

            append2dict(self.performance["train"],
                        train_res)
            append2dict(self.performance["eval"],
                        eval_res)

            if eval_f1_star < self.performance["eval"]["macro_f1"][-1]:
                eval_f1_star = self.performance["eval"]["macro_f1"][-1]
                self._update_checkpoint(epoch+1, train_res, eval_res)
                print("New best F1 score: {:.5f}".format(eval_f1_star))
                n_no_improve = 0
            else:
                n_no_improve += 1

            if n_no_improve == 5:
                print("Stopping after no improvement for 5 epochs")
                break




In [0]:
optimizer = torch.optim.Adam(M.parameters(), lr=1e-4)

criterion = torch.nn.CrossEntropyLoss().cuda()

runner = BERTMultiClassRunner(M, criterion, optimizer, LABS.vocab.itos)

In [0]:
batch_size = 12
epochs = 50

In [0]:
trn_iter = data.BucketIterator(
    Xtrain,
    device="cuda",
    shuffle=True,
    batch_size=batch_size
)

val_iter = data.BucketIterator(
    Xdev,
    device="cuda",
    shuffle=True,
    batch_size=1
)

In [0]:
runner.run(epochs, trn_iter, val_iter)

100%|██████████| 315/315 [00:18<00:00, 17.12it/s]
100%|██████████| 816/816 [00:28<00:00, 28.95it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  average, "true nor predicted", 'F-score is', len(true_sum)
  1%|          | 2/315 [00:00<00:18, 16.90it/s]

{'accuracy': 0.9176709641255605, 'precision': 0.32172812402318096, 'recall': 0.2599077849662369, 'macro_f1': 0.2762613690596233, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12690, 68, 0, 0], [0, 928, 407, 0, 0], [0, 30, 116, 0, 0], [0, 26, 7, 0, 0]]}
New best F1 score: 0.27626


100%|██████████| 315/315 [00:18<00:00, 17.31it/s]
100%|██████████| 816/816 [00:27<00:00, 29.90it/s]
  1%|          | 2/315 [00:00<00:19, 16.44it/s]

{'accuracy': 0.9356782511210763, 'precision': 0.5495961499237281, 'recall': 0.4350039944716567, 'macro_f1': 0.477535161626618, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12654, 104, 0, 0], [0, 743, 592, 0, 0], [0, 24, 14, 108, 0], [0, 16, 16, 1, 0]]}
New best F1 score: 0.47754


100%|██████████| 315/315 [00:18<00:00, 17.29it/s]
100%|██████████| 816/816 [00:27<00:00, 29.87it/s]
  1%|          | 2/315 [00:00<00:18, 16.92it/s]

{'accuracy': 0.9286715246636771, 'precision': 0.551094054246034, 'recall': 0.42528660733266266, 'macro_f1': 0.4553439571530804, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12733, 25, 0, 0], [0, 935, 400, 0, 0], [0, 22, 3, 121, 0], [0, 17, 2, 14, 0]]}


100%|██████████| 315/315 [00:18<00:00, 17.08it/s]
100%|██████████| 816/816 [00:27<00:00, 29.14it/s]
  1%|          | 2/315 [00:00<00:18, 17.01it/s]

{'accuracy': 0.9414938340807175, 'precision': 0.5410957678687112, 'recall': 0.4591665271020492, 'macro_f1': 0.484848820365115, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12703, 55, 0, 0], [0, 724, 611, 0, 0], [0, 21, 2, 123, 0], [0, 10, 1, 22, 0]]}
New best F1 score: 0.48485


100%|██████████| 315/315 [00:18<00:00, 16.87it/s]
100%|██████████| 816/816 [00:28<00:00, 29.03it/s]
  1%|          | 2/315 [00:00<00:18, 16.62it/s]

{'accuracy': 0.9423346412556054, 'precision': 0.7535864773041598, 'recall': 0.5070433998613046, 'macro_f1': 0.5676176321232804, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12712, 46, 0, 0], [0, 729, 606, 0, 0], [0, 21, 2, 123, 0], [0, 11, 0, 14, 8]]}
New best F1 score: 0.56762


100%|██████████| 315/315 [00:18<00:00, 17.29it/s]
100%|██████████| 816/816 [00:28<00:00, 28.74it/s]
  1%|          | 2/315 [00:00<00:19, 16.11it/s]

{'accuracy': 0.9516535874439462, 'precision': 0.7595425502891056, 'recall': 0.5974996318302538, 'macro_f1': 0.6589420651954343, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12680, 78, 0, 0], [0, 576, 759, 0, 0], [0, 20, 2, 124, 0], [0, 5, 1, 8, 19]]}
New best F1 score: 0.65894


100%|██████████| 315/315 [00:17<00:00, 17.54it/s]
100%|██████████| 816/816 [00:27<00:00, 29.30it/s]
  1%|          | 2/315 [00:00<00:22, 13.91it/s]

{'accuracy': 0.9522841928251121, 'precision': 0.7504362092576032, 'recall': 0.5970745000515949, 'macro_f1': 0.6576329858327342, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12710, 48, 0, 0], [0, 595, 740, 0, 0], [0, 21, 1, 121, 3], [0, 12, 0, 1, 20]]}


100%|██████████| 315/315 [00:18<00:00, 17.27it/s]
100%|██████████| 816/816 [00:28<00:00, 28.46it/s]
  1%|          | 2/315 [00:00<00:19, 15.85it/s]

{'accuracy': 0.9546664798206278, 'precision': 0.7605770200848354, 'recall': 0.6427516921391406, 'macro_f1': 0.689618275444052, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12725, 33, 0, 0], [0, 584, 751, 0, 0], [0, 20, 1, 122, 3], [0, 5, 0, 1, 27]]}
New best F1 score: 0.68962


100%|██████████| 315/315 [00:18<00:00, 17.07it/s]
100%|██████████| 816/816 [00:28<00:00, 28.88it/s]
  1%|          | 2/315 [00:00<00:21, 14.65it/s]

{'accuracy': 0.9520739910313901, 'precision': 0.7799549251071457, 'recall': 0.6238327205516028, 'macro_f1': 0.6833966393513111, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12738, 20, 0, 0], [0, 636, 699, 0, 0], [0, 19, 0, 127, 0], [0, 6, 0, 3, 24]]}


100%|██████████| 315/315 [00:18<00:00, 17.19it/s]
100%|██████████| 816/816 [00:27<00:00, 29.47it/s]
  1%|          | 2/315 [00:00<00:22, 13.79it/s]

{'accuracy': 0.960622197309417, 'precision': 0.7479886446587413, 'recall': 0.6856011918845748, 'macro_f1': 0.713824875060184, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12625, 133, 0, 0], [0, 405, 930, 0, 0], [0, 16, 1, 126, 3], [0, 3, 0, 1, 29]]}
New best F1 score: 0.71382


100%|██████████| 315/315 [00:18<00:00, 17.20it/s]
100%|██████████| 816/816 [00:27<00:00, 29.96it/s]
  1%|          | 2/315 [00:00<00:20, 15.57it/s]

{'accuracy': 0.9630044843049327, 'precision': 0.7738099290945103, 'recall': 0.6802490446960876, 'macro_f1': 0.7216453649111376, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12677, 81, 0, 0], [0, 424, 911, 0, 0], [0, 18, 0, 128, 0], [0, 3, 0, 2, 28]]}
New best F1 score: 0.72165


100%|██████████| 315/315 [00:18<00:00, 17.04it/s]
100%|██████████| 816/816 [00:27<00:00, 29.49it/s]
  1%|          | 2/315 [00:00<00:18, 16.74it/s]

{'accuracy': 0.9588705156950673, 'precision': 0.7518221721678362, 'recall': 0.6702816812432844, 'macro_f1': 0.704503660189501, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12684, 74, 0, 0], [0, 486, 849, 0, 0], [0, 18, 1, 123, 4], [0, 4, 0, 0, 29]]}


100%|██████████| 315/315 [00:18<00:00, 17.15it/s]
100%|██████████| 816/816 [00:28<00:00, 28.97it/s]
  1%|          | 2/315 [00:00<00:25, 12.13it/s]

{'accuracy': 0.9611126681614349, 'precision': 0.7850499082122437, 'recall': 0.6816552907254715, 'macro_f1': 0.7240688119054303, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12735, 23, 0, 0], [0, 516, 819, 0, 0], [0, 12, 0, 134, 0], [0, 3, 0, 1, 29]]}
New best F1 score: 0.72407


100%|██████████| 315/315 [00:18<00:00, 16.97it/s]
100%|██████████| 816/816 [00:28<00:00, 28.97it/s]
  1%|          | 2/315 [00:00<00:21, 14.57it/s]

{'accuracy': 0.960692264573991, 'precision': 0.7818519302688398, 'recall': 0.6766682073594336, 'macro_f1': 0.72078821468573, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12720, 38, 0, 0], [0, 502, 833, 0, 0], [0, 17, 0, 129, 0], [0, 3, 0, 1, 29]]}


100%|██████████| 315/315 [00:18<00:00, 17.13it/s]
100%|██████████| 816/816 [00:28<00:00, 29.01it/s]
  1%|          | 2/315 [00:00<00:18, 17.32it/s]

{'accuracy': 0.9605521300448431, 'precision': 0.7743702346862513, 'recall': 0.6732449762354968, 'macro_f1': 0.7157576504222168, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12716, 42, 0, 0], [0, 497, 838, 0, 0], [0, 19, 0, 126, 1], [0, 3, 0, 1, 29]]}


100%|██████████| 315/315 [00:18<00:00, 16.92it/s]
100%|██████████| 816/816 [00:28<00:00, 29.00it/s]
  1%|          | 2/315 [00:00<00:17, 18.10it/s]

{'accuracy': 0.9633548206278026, 'precision': 0.7711124488410098, 'recall': 0.6907988542703295, 'macro_f1': 0.7272341239305046, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12648, 110, 0, 0], [0, 391, 944, 0, 0], [0, 17, 1, 128, 0], [0, 3, 0, 1, 29]]}
New best F1 score: 0.72723


100%|██████████| 315/315 [00:18<00:00, 17.30it/s]
100%|██████████| 816/816 [00:28<00:00, 28.81it/s]
  1%|          | 2/315 [00:00<00:21, 14.33it/s]

{'accuracy': 0.9640554932735426, 'precision': 0.7581170872353857, 'recall': 0.6926825395757535, 'macro_f1': 0.7214758335546462, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12671, 87, 0, 0], [0, 403, 932, 0, 0], [0, 16, 1, 126, 3], [0, 3, 0, 0, 30]]}


100%|██████████| 315/315 [00:18<00:00, 17.24it/s]
100%|██████████| 816/816 [00:27<00:00, 29.44it/s]
  1%|          | 2/315 [00:00<00:20, 14.93it/s]

{'accuracy': 0.9675588565022422, 'precision': 0.781062152282751, 'recall': 0.7038058070695156, 'macro_f1': 0.737673338998637, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12709, 48, 1, 0], [0, 402, 933, 0, 0], [0, 8, 0, 138, 0], [0, 3, 0, 1, 29]]}
New best F1 score: 0.73767


100%|██████████| 315/315 [00:18<00:00, 17.00it/s]
100%|██████████| 816/816 [00:28<00:00, 29.06it/s]
  1%|          | 2/315 [00:00<00:25, 12.21it/s]

{'accuracy': 0.9660874439461884, 'precision': 0.7705083840529345, 'recall': 0.7023137461749295, 'macro_f1': 0.7338941199755251, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12633, 125, 0, 0], [0, 340, 995, 0, 0], [0, 14, 1, 131, 0], [0, 3, 0, 1, 29]]}


100%|██████████| 315/315 [00:18<00:00, 17.08it/s]
100%|██████████| 816/816 [00:27<00:00, 29.29it/s]
  1%|          | 2/315 [00:00<00:19, 16.13it/s]

{'accuracy': 0.9641255605381166, 'precision': 0.7697324913659148, 'recall': 0.7065911961844541, 'macro_f1': 0.7358803848191334, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12623, 135, 0, 0], [0, 360, 975, 0, 0], [0, 13, 1, 132, 0], [0, 3, 0, 0, 30]]}


100%|██████████| 315/315 [00:18<00:00, 17.02it/s]
100%|██████████| 816/816 [00:28<00:00, 28.65it/s]
  1%|          | 2/315 [00:00<00:18, 16.99it/s]

{'accuracy': 0.9682595291479821, 'precision': 0.7759131918756231, 'recall': 0.7115126748688028, 'macro_f1': 0.740022947687736, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12690, 65, 3, 0], [0, 376, 959, 0, 0], [0, 5, 0, 141, 0], [0, 3, 0, 1, 29]]}
New best F1 score: 0.74002


100%|██████████| 315/315 [00:18<00:00, 16.94it/s]
100%|██████████| 816/816 [00:27<00:00, 29.33it/s]
  1%|          | 2/315 [00:00<00:20, 15.37it/s]

{'accuracy': 0.9665779147982063, 'precision': 0.7731904829095115, 'recall': 0.6995937954142756, 'macro_f1': 0.7333796101414559, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12652, 106, 0, 0], [0, 351, 984, 0, 0], [0, 15, 1, 130, 0], [0, 3, 0, 1, 29]]}


100%|██████████| 315/315 [00:18<00:00, 16.94it/s]
100%|██████████| 816/816 [00:28<00:00, 28.90it/s]
  1%|          | 2/315 [00:00<00:23, 13.44it/s]

{'accuracy': 0.9690302690582959, 'precision': 0.7762571584525673, 'recall': 0.7103052818794755, 'macro_f1': 0.7400767730939627, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12684, 72, 2, 0], [0, 356, 979, 0, 0], [0, 7, 1, 138, 0], [0, 3, 0, 1, 29]]}
New best F1 score: 0.74008


100%|██████████| 315/315 [00:18<00:00, 17.16it/s]
100%|██████████| 816/816 [00:27<00:00, 29.48it/s]
  1%|          | 2/315 [00:00<00:16, 18.42it/s]

{'accuracy': 0.9692404708520179, 'precision': 0.7637807983106212, 'recall': 0.7245707585459484, 'macro_f1': 0.7431973504498103, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12581, 175, 2, 0], [0, 250, 1085, 0, 0], [0, 7, 1, 138, 0], [0, 3, 0, 1, 29]]}
New best F1 score: 0.74320


100%|██████████| 315/315 [00:18<00:00, 17.05it/s]
100%|██████████| 816/816 [00:28<00:00, 29.08it/s]
  1%|          | 2/315 [00:00<00:18, 17.03it/s]

{'accuracy': 0.967979260089686, 'precision': 0.7693959405834405, 'recall': 0.7217487893493757, 'macro_f1': 0.7441495619857674, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12626, 130, 2, 0], [0, 314, 1021, 0, 0], [0, 7, 1, 138, 0], [0, 3, 0, 0, 30]]}
New best F1 score: 0.74415


100%|██████████| 315/315 [00:18<00:00, 17.19it/s]
100%|██████████| 816/816 [00:27<00:00, 29.49it/s]
  1%|          | 2/315 [00:00<00:21, 14.42it/s]

{'accuracy': 0.9698010089686099, 'precision': 0.7676639551606865, 'recall': 0.7212086263624554, 'macro_f1': 0.7431221133255509, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12615, 141, 2, 0], [0, 276, 1059, 0, 0], [0, 7, 1, 138, 0], [0, 3, 0, 1, 29]]}


100%|██████████| 315/315 [00:18<00:00, 17.14it/s]
100%|██████████| 816/816 [00:27<00:00, 30.12it/s]
  1%|          | 2/315 [00:00<00:19, 15.80it/s]

{'accuracy': 0.969170403587444, 'precision': 0.7720549641387932, 'recall': 0.7147503088987912, 'macro_f1': 0.7414114539849463, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12644, 113, 1, 0], [0, 313, 1022, 0, 0], [0, 8, 1, 137, 0], [0, 3, 0, 1, 29]]}


100%|██████████| 315/315 [00:18<00:00, 16.90it/s]
100%|██████████| 816/816 [00:27<00:00, 29.21it/s]
  1%|          | 2/315 [00:00<00:18, 16.77it/s]

{'accuracy': 0.9697309417040358, 'precision': 0.7697907765409691, 'recall': 0.7199857232622195, 'macro_f1': 0.7434210024916044, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12623, 134, 1, 0], [0, 285, 1050, 0, 0], [0, 7, 1, 138, 0], [0, 3, 0, 1, 29]]}


100%|██████████| 315/315 [00:18<00:00, 17.08it/s]
100%|██████████| 816/816 [00:27<00:00, 29.31it/s]
  1%|          | 2/315 [00:00<00:18, 16.68it/s]

{'accuracy': 0.969170403587444, 'precision': 0.7748115352998114, 'recall': 0.7107133964080223, 'macro_f1': 0.7401847762207595, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12665, 92, 1, 0], [0, 333, 1002, 0, 0], [0, 9, 1, 136, 0], [0, 3, 0, 1, 29]]}


100%|██████████| 315/315 [00:18<00:00, 17.04it/s]
100%|██████████| 816/816 [00:27<00:00, 29.21it/s]
  1%|          | 2/315 [00:00<00:17, 17.56it/s]

{'accuracy': 0.9694506726457399, 'precision': 0.7720652222841322, 'recall': 0.7239687263124154, 'macro_f1': 0.7464782659600194, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12642, 114, 2, 0], [0, 310, 1025, 0, 0], [0, 6, 1, 139, 0], [0, 3, 0, 0, 30]]}
New best F1 score: 0.74648


100%|██████████| 315/315 [00:18<00:00, 16.92it/s]
100%|██████████| 816/816 [00:27<00:00, 29.47it/s]
  1%|          | 2/315 [00:00<00:20, 15.29it/s]

{'accuracy': 0.9704316143497758, 'precision': 0.7814158454928173, 'recall': 0.7104846743499706, 'macro_f1': 0.7422563800983469, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12705, 52, 1, 0], [0, 357, 978, 0, 0], [0, 8, 0, 138, 0], [0, 3, 0, 1, 29]]}


100%|██████████| 315/315 [00:18<00:00, 17.10it/s]
100%|██████████| 816/816 [00:28<00:00, 28.71it/s]
  1%|          | 2/315 [00:00<00:24, 12.79it/s]

{'accuracy': 0.9665078475336323, 'precision': 0.7693627601977784, 'recall': 0.7143932022940264, 'macro_f1': 0.7401905416644461, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12621, 136, 1, 0], [0, 326, 1009, 0, 0], [0, 11, 1, 134, 0], [0, 3, 0, 0, 30]]}


100%|██████████| 315/315 [00:18<00:00, 17.03it/s]
100%|██████████| 816/816 [00:27<00:00, 29.36it/s]
  1%|          | 2/315 [00:00<00:18, 17.38it/s]

{'accuracy': 0.9726036995515696, 'precision': 0.775550726257905, 'recall': 0.7226917961403735, 'macro_f1': 0.7466574057358281, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12685, 70, 3, 0], [0, 309, 1025, 1, 0], [0, 4, 0, 142, 0], [0, 3, 0, 1, 29]]}
New best F1 score: 0.74666


100%|██████████| 315/315 [00:18<00:00, 17.28it/s]
100%|██████████| 816/816 [00:27<00:00, 29.89it/s]
  1%|          | 2/315 [00:00<00:20, 15.49it/s]

{'accuracy': 0.9664377802690582, 'precision': 0.7750377476337607, 'recall': 0.6986363120849585, 'macro_f1': 0.7333927293686673, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12666, 92, 0, 0], [0, 368, 967, 0, 0], [0, 14, 1, 131, 0], [0, 3, 0, 1, 29]]}


100%|██████████| 315/315 [00:18<00:00, 16.98it/s]
100%|██████████| 816/816 [00:27<00:00, 29.54it/s]
  1%|          | 2/315 [00:00<00:22, 13.95it/s]

{'accuracy': 0.9705717488789237, 'precision': 0.7789633980642099, 'recall': 0.7103562436954407, 'macro_f1': 0.7415721753510425, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12690, 67, 1, 0], [0, 338, 997, 0, 0], [0, 9, 1, 136, 0], [0, 3, 0, 1, 29]]}


100%|██████████| 315/315 [00:18<00:00, 17.28it/s]
100%|██████████| 816/816 [00:28<00:00, 28.70it/s]
  1%|          | 2/315 [00:00<00:18, 17.38it/s]

{'accuracy': 0.9684697309417041, 'precision': 0.7763443763861807, 'recall': 0.7062514469425827, 'macro_f1': 0.7381114866733987, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12678, 79, 1, 0], [0, 355, 980, 0, 0], [0, 10, 1, 135, 0], [0, 3, 0, 1, 29]]}


100%|██████████| 315/315 [00:18<00:00, 16.88it/s]
100%|██████████| 816/816 [00:27<00:00, 29.52it/s]
  1%|          | 2/315 [00:00<00:21, 14.71it/s]

{'accuracy': 0.968960201793722, 'precision': 0.7747248141735519, 'recall': 0.7175289381076841, 'macro_f1': 0.7438772476097135, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12664, 92, 2, 0], [0, 337, 998, 0, 0], [0, 8, 1, 137, 0], [0, 3, 0, 0, 30]]}


100%|██████████| 315/315 [00:18<00:00, 16.79it/s]
100%|██████████| 816/816 [00:27<00:00, 29.33it/s]
  1%|          | 2/315 [00:00<00:17, 18.24it/s]

{'accuracy': 0.9723934977578476, 'precision': 0.7728279139733374, 'recall': 0.7298456286388749, 'macro_f1': 0.7502824875091049, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12636, 120, 2, 0], [0, 261, 1074, 0, 0], [0, 8, 0, 138, 0], [0, 3, 0, 0, 30]]}
New best F1 score: 0.75028


100%|██████████| 315/315 [00:18<00:00, 16.94it/s]
100%|██████████| 816/816 [00:28<00:00, 28.74it/s]
  1%|          | 2/315 [00:00<00:22, 13.65it/s]

{'accuracy': 0.9720431614349776, 'precision': 0.7759141599009405, 'recall': 0.718209905080786, 'macro_f1': 0.7450612107455982, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12664, 93, 1, 0], [0, 292, 1043, 0, 0], [0, 8, 1, 137, 0], [0, 3, 0, 1, 29]]}


100%|██████████| 315/315 [00:18<00:00, 16.86it/s]
100%|██████████| 816/816 [00:28<00:00, 28.88it/s]
  1%|          | 2/315 [00:00<00:23, 13.44it/s]

{'accuracy': 0.9698710762331838, 'precision': 0.7722797859292755, 'recall': 0.7150412095776926, 'macro_f1': 0.7415851634179929, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12653, 103, 2, 0], [0, 312, 1023, 0, 0], [0, 8, 1, 137, 0], [0, 3, 0, 1, 29]]}


100%|██████████| 315/315 [00:18<00:00, 17.14it/s]
100%|██████████| 816/816 [00:28<00:00, 29.03it/s]
  1%|          | 2/315 [00:00<00:19, 15.80it/s]

{'accuracy': 0.9702914798206278, 'precision': 0.7740743070212096, 'recall': 0.7155505007125731, 'macro_f1': 0.7425066984835602, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12665, 91, 2, 0], [0, 319, 1016, 0, 0], [0, 7, 1, 138, 0], [0, 3, 0, 1, 29]]}


100%|██████████| 315/315 [00:18<00:00, 16.91it/s]
100%|██████████| 816/816 [00:28<00:00, 28.95it/s]
  1%|          | 2/315 [00:00<00:22, 13.90it/s]

{'accuracy': 0.9695908071748879, 'precision': 0.7699511134751877, 'recall': 0.7169777246432838, 'macro_f1': 0.7418955184941259, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12625, 132, 1, 0], [0, 287, 1048, 0, 0], [0, 9, 1, 136, 0], [0, 3, 0, 1, 29]]}


100%|██████████| 315/315 [00:18<00:00, 17.07it/s]
100%|██████████| 816/816 [00:27<00:00, 29.23it/s]
  1%|          | 2/315 [00:00<00:21, 14.82it/s]

{'accuracy': 0.9735846412556054, 'precision': 0.7771097421141425, 'recall': 0.7295884065248165, 'macro_f1': 0.7518953308467236, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12666, 90, 2, 0], [0, 275, 1060, 0, 0], [0, 7, 0, 139, 0], [0, 3, 0, 0, 30]]}
New best F1 score: 0.75190


100%|██████████| 315/315 [00:18<00:00, 16.91it/s]
100%|██████████| 816/816 [00:28<00:00, 28.36it/s]
  1%|          | 2/315 [00:00<00:20, 15.39it/s]

{'accuracy': 0.9709921524663677, 'precision': 0.7726416096989774, 'recall': 0.7164736120172395, 'macro_f1': 0.7428210283498293, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12642, 115, 1, 0], [0, 283, 1052, 0, 0], [0, 10, 1, 135, 0], [0, 3, 0, 1, 29]]}


100%|██████████| 315/315 [00:18<00:00, 17.28it/s]
100%|██████████| 816/816 [00:27<00:00, 29.37it/s]
  1%|          | 2/315 [00:00<00:19, 16.07it/s]

{'accuracy': 0.9662976457399103, 'precision': 0.7797768670125051, 'recall': 0.6964162803179303, 'macro_f1': 0.7337282590111445, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12688, 70, 0, 0], [0, 389, 946, 0, 0], [0, 18, 1, 127, 0], [0, 3, 0, 0, 30]]}


100%|██████████| 315/315 [00:18<00:00, 17.15it/s]
100%|██████████| 816/816 [00:28<00:00, 28.62it/s]
  1%|          | 2/315 [00:00<00:19, 15.88it/s]

{'accuracy': 0.968890134529148, 'precision': 0.7802004136126197, 'recall': 0.7014781280716906, 'macro_f1': 0.7370250379226798, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12693, 65, 0, 0], [0, 361, 974, 0, 0], [0, 14, 0, 132, 0], [0, 3, 0, 1, 29]]}


100%|██████████| 315/315 [00:18<00:00, 17.02it/s]
100%|██████████| 816/816 [00:27<00:00, 29.20it/s]
  1%|          | 2/315 [00:00<00:21, 14.86it/s]

{'accuracy': 0.9705016816143498, 'precision': 0.7781918277885638, 'recall': 0.7085711481767409, 'macro_f1': 0.7404111963765323, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12684, 73, 1, 0], [0, 331, 1004, 0, 0], [0, 12, 0, 134, 0], [0, 3, 0, 1, 29]]}


100%|██████████| 315/315 [00:18<00:00, 17.01it/s]
100%|██████████| 816/816 [00:28<00:00, 28.99it/s]

{'accuracy': 0.968189461883408, 'precision': 0.779147127405315, 'recall': 0.7017365961935595, 'macro_f1': 0.7366661077331859, 'confusion_matrix': [[0, 0, 0, 0, 0], [0, 12689, 69, 0, 0], [0, 368, 967, 0, 0], [0, 12, 1, 133, 0], [0, 3, 0, 1, 29]]}
Stopping after no improvement for 5 epochs





In [0]:
dev_res = runner.evaluate(val_iter)

100%|██████████| 816/816 [00:28<00:00, 28.52it/s]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  average, "true nor predicted", 'F-score is', len(true_sum)


In [0]:
print("*"*21)
print("{} {}".format("BiLSTM-C", "BERT_base"))
print("Results:")
for i in ["accuracy", "precision", "recall", "macro_f1"]:
  print("\t{}: {:2f}".format(i, dev_res[i]))
print("Confusion Matrix: \n")
for i in dev_res["confusion_matrix"]:
  print(i)

*********************
BiLSTM-C BERT_base
Results:
	accuracy: 0.968189
	precision: 0.779147
	recall: 0.701737
	macro_f1: 0.736666
Confusion Matrix: 

[0, 0, 0, 0, 0]
[0, 12689, 69, 0, 0]
[0, 368, 967, 0, 0]
[0, 12, 1, 133, 0]
[0, 3, 0, 1, 29]


In [0]:
runner.save(dirpath="Saves/BiLSTM-CE/BERT_base", checkpoint=True)

  "type " + obj.__name__ + ". It won't be checked "


In [0]:
import pickle

In [0]:
with open(os.path.join("Saves/BiLSTM-CE/BERT_base", "labels"), "wb") as o:
  pickle.dump(LABS.vocab, o)

# Evaluation

In [0]:
import glob
import argparse
import random
import torch
import torchtext
from torchtext import data
import NSR
import pickle
import os
import json
from collections import OrderedDict, defaultdict
from pathlib import Path
from transformers import BertTokenizer, BertModel


class LabelMapper():
  
  def __init__(self, vocab_stoi, vocab_itos):
    self.label2ind = vocab_stoi
    self.labelset = vocab_itos
    
  def index2label(self, x):
    return self.labelset[x]
  
  def label2index(self, x):
    return self.label2ind[x]
        
  def __call__(self, x):
    if type(x) == int:
      self.index2label(x)
    elif type(x) == str:
      self.label2index(x)
    else:
      raise ValueError("This is not a valid key")


class BERTEvaluator():
  """Use a saved model to evaluate on held out dataset.

  Loads a saved model and vocabularies to be used for
  predicting new data.
  """

  def __init__(self, saved_model_dir=None):
    self.work_dir = saved_model_dir
    
    with open(os.path.join(self.work_dir, "labels"), "rb") as f:
      self.lab_voc = pickle.load(f)
    
    modelfile = glob.glob(os.path.join(self.work_dir, "model_epoch*.pt"))[0]
    state_dict = torch.load(modelfile)["model_state_dict"]
    self.device = state_dict["fc.bias"].device

    tokenbert = BertTokenizer.from_pretrained('bert-base-cased', 
                                              do_lower_case=False)

    bert = BertModel.from_pretrained('bert-base-cased')

    model = BiBERTLSTM(bert, 768, 
               n_neurons=250, n_layers=2, 
               output_dim=5)
    
    model.to(self.device)
    
    model.load_state_dict(state_dict)
    self.model = model
    self.model.eval()

    self.FORM = data.Field(batch_first = True,
                           include_lengths = True,
                           use_vocab = False,
                           tokenize=tokenbert.tokenize,
                           preprocessing=tokenbert.convert_tokens_to_ids,
                           init_token = tokenbert.cls_token_id,
                           eos_token = tokenbert.sep_token_id,
                           pad_token = tokenbert.pad_token_id,
                           unk_token = tokenbert.unk_token_id)

    self.CUE = data.Field(batch_first = True,
                          include_lengths = False,
                          use_vocab = False,
                          tokenize=tokenbert.tokenize,
                          preprocessing=tokenbert.convert_tokens_to_ids,
                          init_token = tokenbert.cls_token_id,
                          eos_token = tokenbert.sep_token_id,
                          pad_token = tokenbert.pad_token_id,
                          unk_token = tokenbert.unk_token_id)

    
    self.LABS = data.Field(batch_first=True, unk_token=None)
    self.LABS.vocab = self.lab_voc

    self.mapp = LabelMapper(self.LABS.vocab.stoi,
                            self.LABS.vocab.itos)
    
    self.correction_log = []
    self.pred_seqs = set()


  def map_node(self, id, f, l, x):
    return {"id": str(id), 
            "form": f, 
            "properties": {"lemma": l, "xpos": x}}
  
  def map_back_gold(self, sentence):
    out = sentence.copy()
  
    form = out.pop("form")
    lemma = out.pop("lemma")
    xpos = out.pop("xpos")
    cue = out.pop("cue")
    scope = out.pop("scope")
    negation = out.pop("negation")

    nodes = []

    if out["negations"] == 0:
      for i in range(len(form)):
        nodes.append(self.map_node(i, form[i], lemma[i], xpos[i]))
  
    else:
      for i in range(len(form)):
        node = self.map_node(i, form[i], lemma[i], xpos[i])
        if negation[i] == "T":
          node["negation"] = [{"id": 0}]
        elif negation[i] == "F":
          node["negation"] = [{"id": 0, "scope": scope[i]}]
        elif negation[i] == "C":
          node["negation"] = [{"id": 0, "cue": cue[i]}]
        elif negation[i] == "A":
          node["negation"] = [{"id": 0, "cue": cue[i], "scope": scope[i]}]
        else:
          node["negation"] = [{"id": 0}]

        nodes.append(node)

    out["nodes"] = nodes
    return out
  
  def map_back_pred(self, sentence):
    out = sentence.copy()
  
    form = out.pop("form")
    lemma = out.pop("lemma")
    xpos = out.pop("xpos")
    cue = out.pop("cue")
    scope = out.pop("scope")
    negation = out.pop("negation")

    nodes = []

    # To handle logic in the 'convert.py' file
    # need to set negations to 0 if there are no predicted cues

    if set(negation).issubset(set("T")):
      out["negations"] = 0
    else:
      out["negations"] = 1

    if out["negations"] == 0:
      for i in range(len(form)):
        nodes.append(self.map_node(i, form[i], lemma[i], xpos[i]))
  
    else:
      for i in range(len(form)):
        node = self.map_node(i, form[i], lemma[i], xpos[i])
        if negation[i] == "T":
          node["negation"] = [{"id": 0}]
        elif negation[i] == "F":
          node["negation"] = [{"id": 0, "scope": scope[i]}]
        elif negation[i] == "C":
          node["negation"] = [{"id": 0, "cue": cue[i]}]
        elif negation[i] == "A":
          node["negation"] = [{"id": 0, "cue": cue[i], "scope": scope[i]}]
        else:
          node["negation"] = [{"id": 0}]

        nodes.append(node)

    out["nodes"] = nodes
    return out  

  
  def _correct_preds(self, pred_labs, id, src):
    """Apply a simple constraint on the predictions."""

    n_cues = pred_labs.count("A") + pred_labs.count("C")

    # No predicted cues, but has scope
    if n_cues == 0 and not set(pred_labs).issubset(set("T")):
      # Correct everything to true
      out_labs = ["T"]*len(pred_labs)
      self.correction_log.append({"id": id, "source": src})
      return out_labs
    
    return pred_labs

  
  def pred(self, batch, apply_dk=False):
    with torch.no_grad():
      form_b = self.FORM.process(
          [self.FORM.preprocess(batch["form"])],
          device=self.device
      )

      cue_b = self.CUE.process(
          [self.CUE.preprocess(batch["cue"])],
          device=self.device
      )
      
      lab_b = self.LABS.process([batch["negation"]], device=self.device)

      y_tilde = self.model(form_b, cue_b)
      y_hat = y_tilde.argmax(dim=-1).cpu().flatten().tolist()
      y_hat = [self.mapp.index2label(l) for l in y_hat]
      if apply_dk:
        y_hat = self._correct_preds(y_hat, batch["id"], batch["source"])
      
      self.pred_seqs.add(tuple(set(y_hat)))

      y_b = batch["negation"]
      return y_b, y_hat
  
  def output_preds(self, eval_file, results_file=None, apply_dk=False):
    if not results_file:
      results_file = "eval_pred.epe"

    out_file = os.path.join(self.work_dir, results_file)

    if os.path.isfile(out_file):
      raise Exception("There is already a file there")
    
    with open(eval_file) as f, open(out_file, "wt") as r:
      for line in f:
        sent = json.loads(line)
        _, y_pred = self.pred(sent, apply_dk)
        sent["negation"] = y_pred
        sent = self.map_back_pred(sent)
        json.dump(sent, r)
        r.write("\n")
    
    n_corr = len(self.correction_log)
    if apply_dk and n_corr > 0:
      with open(os.path.join(self.work_dir, "correction_log.txt"), "w") as log:
        for corr in self.correction_log:
          json.dump(corr, log)
          log.write("\n")

      print("{} corrections applied, see correction_log.txt".format(n_corr))
    
    print("File saved to {}".format(out_file))


  def output_preds_tidy(self, eval_file, results_file=None, apply_dk=False):
    if not results_file:
      results_file = "eval_pred_tidy.epe"

    out_file = os.path.join(self.work_dir, results_file)

    if os.path.isfile(out_file):
      raise Exception("There is already a file there")
    
    with open(eval_file) as f, open(out_file, "wt") as r:
      for line in f:
        sent = json.loads(line)
        _, y_pred = self.pred(sent, apply_dk)
        sent["negation"] = y_pred
        json.dump(sent, r)
        r.write("\n")
    
    print("File saved to {}".format(out_file))


  def _load(self, argname, module, *args, **kwargs):
    item = self.setup[argname]["type"]
    item_args = dict(self.setup[argname]["args"])
    if sum([kwarg in item_args for kwarg in kwargs]) != 0:
      raise ValueError("Args set in config file cannot be overwritten")
    item_args.update(kwargs)
    
    return getattr(module, item)(*args, **item_args)





In [0]:
ev = BERTEvaluator("Saves/BiLSTM-CE/BERT_base")

In [7]:
ev.output_preds(eval_file="DataFiles/cde.epe", 
                results_file="eval_pred_final.epe",
                apply_dk=False)

File saved to Saves/BiLSTM-CE/BERT_base/eval_pred_final.epe


In [8]:
ev.output_preds(eval_file="DataFiles/cde.epe", 
                results_file="eval_pred_corr_final.epe",
                apply_dk=True)

4 corrections applied, see correction_log.txt
File saved to Saves/BiLSTM-CE/BERT_base/eval_pred_corr_final.epe


In [9]:
ev.output_preds_tidy(eval_file="DataFiles/cde.epe", 
                results_file="eval_pred_tidy.epe",
                apply_dk=False)

File saved to Saves/BiLSTM-CE/BERT_base/eval_pred_tidy.epe


In [10]:
ev.output_preds_tidy(eval_file="DataFiles/cde.epe", 
                results_file="eval_pred_corr_tidy.epe",
                apply_dk=True)

File saved to Saves/BiLSTM-CE/BERT_base/eval_pred_corr_tidy.epe


# Evaluation With `score.py`

In [0]:
import sys
sys.path.append("negation")

In [0]:
from negation import score, convert

In [0]:
s_gold = convert.read_negations("EvalFiles/eval_gold.epe")

In [0]:
s_pred = convert.read_negations("Saves/BiLSTM-CE/BERT_base/eval_pred_corr_final.epe")

In [0]:
file = "BERT_base_corr_score_final.txt"

In [16]:
%cd negation

/content/drive/My Drive/in5550-exam/negation


In [0]:
score.starsem_score(s_gold, s_pred, file)

In [47]:
%cd ..

/content/drive/My Drive/in5550-exam


In [0]:
3