In [1]:
import os
import math

In [2]:
import pandas as pd
import torch
import transformers
from transformers import AdamW, BertConfig, BertTokenizer

In [3]:
import model
import evaluator

In [4]:
# modify this part for different dataset
import conll04_constants as constants
import conll04_input_generator as input_generator

In [5]:
from tqdm import tqdm

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
# # To run reference model
# input_generator.parser.entity_encode = {'O': 0, 'B-Loc': 1, 'I-Loc': 1, 'B-Peop': 3, 'I-Peop': 3, 
#                                          'B-Org': 2, 'I-Org': 2, 'B-Other': 4, 'I-Other': 4}
# input_generator.parser.relation_encode = {'N': 0, 'Kill': 2, 'Located_In': 5, 'OrgBased_In': 3,
#                                            'Live_In': 4, 'Work_For': 1}

In [8]:
entity_label_map = {v: k for k, v in input_generator.parser.entity_encode.items()}
entity_classes = list(entity_label_map.keys())
entity_classes.remove(0)

In [9]:
relation_label_map = {v: k for k, v in input_generator.parser.relation_encode.items()}
relation_classes = list(relation_label_map.keys())
relation_classes.remove(0)

In [10]:
tokenizer = BertTokenizer.from_pretrained(constants.model_path)

In [11]:
def get_optimizer_params(model):
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_params = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay': constants.weight_decay},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 
         'weight_decay': 0.0}]
    return optimizer_params

In [12]:
def take_first_tokens(embedding, words):
    """Take the embedding of the first token of each word"""
    reduced_embedding = []
    for i, word in enumerate(words):
        if i == 0 or word != words[i-1]:
            reduced_embedding.append(embedding[i])
    return reduced_embedding

In [13]:
def evaluate(spert_model, group):
    spert_model.eval()
    eval_entity_span_pred = []
    eval_entity_span_true = []
    eval_entity_embedding_pred = []
    eval_entity_embedding_true = []
    eval_relation_span_pred = []
    eval_relation_span_true = []
    eval_generator = input_generator.data_generator(group, device, 
                                                    is_training=False,
                                                    neg_entity_count=constants.neg_entity_count, 
                                                    neg_relation_count=constants.neg_relation_count, 
                                                    max_span_size=constants.max_span_size)
    eval_dataset = list(eval_generator)
    eval_size = len(eval_dataset)
    for inputs, infos in tqdm(eval_dataset, total=eval_size, desc="Evaluation "+group):
        # forward
        outputs = spert_model(**inputs, is_training=False)
        # retrieve results for evaluation
        eval_entity_span_pred.append(outputs["entity"]["span"])
        eval_entity_span_true.append(infos["entity_span"])
        if not constants.is_overlapping:
            eval_entity_embedding_pred += take_first_tokens(outputs["entity"]["embedding"].tolist(), infos["words"])
            eval_entity_embedding_true += take_first_tokens(infos["entity_embedding"].tolist(), infos["words"])
            assert len(eval_entity_embedding_pred) == len(eval_entity_embedding_true)
        eval_relation_span_pred.append([] if outputs["relation"] == None else outputs["relation"]["span"])
        eval_relation_span_true.append(infos["relation_span"])
    # evaluate & save
    results = pd.concat([
        evaluator.evaluate_span(eval_entity_span_true, eval_entity_span_pred, entity_label_map, entity_classes),
        evaluator.evaluate_results(eval_entity_embedding_true, eval_entity_embedding_pred, entity_label_map, entity_classes),
        evaluator.evaluate_loose_relation_span(eval_relation_span_true, eval_relation_span_pred, relation_label_map, relation_classes),
        evaluator.evaluate_span(eval_relation_span_true, eval_relation_span_pred, relation_label_map, relation_classes),
    ], keys=["Entity span", "Entity embedding", "Loose relation", "Strict relation"])
    results.to_csv(constants.model_save_path + "evaluate_" + group + ".csv")
    print(results)

In [14]:
def train():
    """Train the model and evaluate on the dev dataset
    """
    # Training
    train_generator = input_generator.data_generator(constants.train_dataset, device, 
                                                     is_training=True,
                                                     neg_entity_count=constants.neg_entity_count, 
                                                     neg_relation_count=constants.neg_relation_count, 
                                                     max_span_size=constants.max_span_size)
    train_dataset = list(train_generator)
    train_size = len(train_dataset)
    config = BertConfig.from_pretrained(constants.model_path)
    spert_model = model.SpERT.from_pretrained(constants.model_path,
                                              config=config,
                                              # SpERT model parameters
                                              relation_types=constants.relation_types, 
                                              entity_types=constants.entity_types, 
                                              width_embedding_size=constants.width_embedding_size, 
                                              prop_drop=constants.prop_drop, 
                                              freeze_transformer=True, 
                                              max_pairs=constants.max_pairs, 
                                              is_overlapping=constants.is_overlapping, 
                                              relation_filter_threshold=constants.relation_filter_threshold)
    spert_model.to(device)
    optimizer_params = get_optimizer_params(spert_model)
    optimizer = AdamW(optimizer_params, lr=constants.lr, weight_decay=constants.weight_decay, correct_bias=False)
    scheduler = transformers.get_linear_schedule_with_warmup(optimizer, 
                             num_warmup_steps=constants.lr_warmup*train_size*constants.epochs, 
                             num_training_steps=train_size*constants.epochs)
    for epoch in range(constants.epochs):
        losses = []
        train_entity_pred = []
        train_entity_true = []
        train_relation_pred = []
        train_relation_true = []
        spert_model.zero_grad()
        for inputs, infos in tqdm(train_dataset, total=train_size, desc='Train epoch %s' % epoch):
            spert_model.train()
            # forward
            outputs = spert_model(**inputs, is_training=True)
            # backward
            loss = outputs["loss"]
            loss.backward()
            torch.nn.utils.clip_grad_norm_(spert_model.parameters(), constants.max_grad_norm)
            optimizer.step()
            scheduler.step()
            spert_model.zero_grad()
            # print(optimizer.param_groups[0]['lr'])
            # retrieve results for evaluation
            losses.append(loss.item())
            train_entity_pred += outputs["entity"]["pred"].tolist()
            train_entity_true += inputs["entity_label"].tolist()
            train_relation_pred += [] if outputs["relation"] == None else outputs["relation"]["pred"].tolist()
            train_relation_true += inputs["relation_label"].tolist()
            assert len(train_entity_pred) == len(train_entity_true)
            assert len(train_relation_pred) == len(train_relation_true)
            
        # evaluate & save checkpoint
        print("epoch:", epoch,"average loss:", sum(losses) / len(losses))
        results = pd.concat([
            evaluator.evaluate_results(train_entity_true, train_entity_pred, entity_label_map, entity_classes),
            evaluator.evaluate_results(train_relation_true, train_relation_pred, relation_label_map, relation_classes)
        ], keys=["Entity", "Relation"])
        results.to_csv(constants.model_save_path + "epoch_" + str(epoch) + ".csv")
        evaluate(spert_model, constants.dev_dataset)
        
    torch.save(spert_model.state_dict(), constants.model_save_path + "epoch_" + str(constants.epochs-1) + ".model")

In [15]:
def predict(spert_model, sentences):
    for sentence in sentences:
        word_list = sentence.split()
        words = []
        token_ids = []
        # transform a sentence to a document for prediction
        for word in word_list:
            token_id = tokenizer(word)["input_ids"][1:-1]
            for tid in token_id:
                words.append(word)
                token_ids.append(tid)
        data_frame = pd.DataFrame()
        data_frame["words"] = words
        data_frame["token_ids"] = token_ids
        data_frame["entity_embedding"] = 0
        data_frame["sentence_embedding"] = 0 # for internal datasets
        doc = {"data_frame": data_frame,
            "entity_position": {}, # Suppose to appear in non-overlapping dataset
            "entities": {}, # Suppose to appear in overlapping dataset
            "relations": {}}
        # predict
        inputs, infos = input_generator.doc_to_input(doc, device, 
                                                     is_training=False, 
                                                     max_span_size=constants.max_span_size)
        outputs = spert_model(**inputs, is_training=False)
        pred_entity_span = outputs["entity"]["span"]
        pred_relation_span = [] if outputs["relation"] is None else outputs["relation"]["span"]
        # print result
        tokens = tokenizer.convert_ids_to_tokens(token_ids)
        print("Sentence:", sentence)
        print("Entities: (", len(pred_entity_span), ")")
        for begin, end, entity_type in pred_entity_span:
            print(entity_label_map[entity_type], "|", " ".join(tokens[begin:end]))
        print("Relations: (", len(pred_relation_span), ")")
        for e1, e2, relation_type in pred_relation_span:
            print(relation_label_map[relation_type], "|", 
                  " ".join(tokens[e1[0]:e1[1]]), "|", 
                  " ".join(tokens[e2[0]:e2[1]]))

In [16]:
train()

Some weights of the model checkpoint at bert-base-cased were not used when initializing SpERT: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing SpERT from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SpERT from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SpERT were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['relation_classifier.weight', 'relation_classifier.bi

epoch: 0 average loss: 0.5193515306481948


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:42<00:00,  5.79it/s]
Train epoch 1:   0%|▏                                                                  | 2/910 [00:00<01:25, 10.66it/s]

                              precision    recall  fbeta_score  support
Entity span      I-Loc         0.478261  0.097923     0.162562    337.0
                 I-Peop        0.095238  0.008889     0.016260    225.0
                 I-Org         0.000000  0.000000     0.000000    190.0
                 I-Other       0.000000  0.000000     0.000000    125.0
                 macro         0.143375  0.026703     0.044705      NaN
                 micro         0.376344  0.039909     0.072165      NaN
Entity embedding I-Loc         0.942857  0.143167     0.248588    461.0
                 I-Peop        1.000000  0.050549     0.096234    455.0
                 I-Org         0.500000  0.002320     0.004619    431.0
                 I-Other       1.000000  0.003676     0.007326    272.0
                 macro         0.860714  0.049928     0.089192      NaN
                 micro         0.947917  0.056208     0.106122      NaN
Loose relation   Kill          0.500000  0.055556     0.100000  

Train epoch 1: 100%|█████████████████████████████████████████████████████████████████| 910/910 [01:27<00:00, 10.45it/s]


epoch: 1 average loss: 0.2645391725249343


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:42<00:00,  5.73it/s]
Train epoch 2:   0%|                                                                   | 1/910 [00:00<01:39,  9.14it/s]

                              precision    recall  fbeta_score  support
Entity span      I-Loc         0.631313  0.370920     0.467290    337.0
                 I-Peop        0.526718  0.306667     0.387640    225.0
                 I-Org         0.549296  0.205263     0.298851    190.0
                 I-Other       0.185185  0.040000     0.065789    125.0
                 macro         0.473128  0.230712     0.304893      NaN
                 micro         0.557377  0.271380     0.365031      NaN
Entity embedding I-Loc         0.928910  0.425163     0.583333    461.0
                 I-Peop        0.979167  0.413187     0.581144    455.0
                 I-Org         0.858824  0.169374     0.282946    431.0
                 I-Other       0.906250  0.106618     0.190789    272.0
                 macro         0.918288  0.278585     0.409553      NaN
                 micro         0.934615  0.300185     0.454418      NaN
Loose relation   Kill          0.857143  0.333333     0.480000  

Train epoch 2: 100%|█████████████████████████████████████████████████████████████████| 910/910 [01:27<00:00, 10.35it/s]


epoch: 2 average loss: 0.20442946389481262


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:40<00:00,  5.94it/s]
Train epoch 3:   0%|                                                                           | 0/910 [00:00<?, ?it/s]

                              precision    recall  fbeta_score  support
Entity span      I-Loc         0.716763  0.367953     0.486275    337.0
                 I-Peop        0.568966  0.293333     0.387097    225.0
                 I-Org         0.554054  0.215789     0.310606    190.0
                 I-Other       0.400000  0.064000     0.110345    125.0
                 macro         0.559946  0.235269     0.323581      NaN
                 micro         0.624021  0.272520     0.379365      NaN
Entity embedding I-Loc         0.940217  0.375271     0.536434    461.0
                 I-Peop        0.968553  0.338462     0.501629    455.0
                 I-Org         0.868132  0.183295     0.302682    431.0
                 I-Other       0.880000  0.080882     0.148148    272.0
                 macro         0.914226  0.244477     0.372223      NaN
                 micro         0.932462  0.264361     0.411935      NaN
Loose relation   Kill          0.875000  0.388889     0.538462  

Train epoch 3: 100%|█████████████████████████████████████████████████████████████████| 910/910 [01:28<00:00, 10.30it/s]


epoch: 3 average loss: 0.1780352834612131


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:41<00:00,  5.85it/s]
Train epoch 4:   0%|                                                                   | 1/910 [00:00<02:07,  7.11it/s]

                              precision    recall  fbeta_score  support
Entity span      I-Loc         0.704036  0.465875     0.560714    337.0
                 I-Peop        0.607692  0.351111     0.445070    225.0
                 I-Org         0.479167  0.242105     0.321678    190.0
                 I-Other       0.172414  0.040000     0.064935    125.0
                 macro         0.490827  0.274773     0.348100      NaN
                 micro         0.600418  0.327252     0.423616      NaN
Entity embedding I-Loc         0.922131  0.488069     0.638298    461.0
                 I-Peop        0.969388  0.417582     0.583717    455.0
                 I-Org         0.775000  0.215777     0.337568    431.0
                 I-Other       0.911765  0.113971     0.202614    272.0
                 macro         0.894571  0.308850     0.440549      NaN
                 micro         0.907407  0.332922     0.487122      NaN
Loose relation   Kill          0.888889  0.444444     0.592593  

Train epoch 4: 100%|█████████████████████████████████████████████████████████████████| 910/910 [01:29<00:00, 10.22it/s]


epoch: 4 average loss: 0.1645143508133325


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:42<00:00,  5.78it/s]
Train epoch 5:   0%|                                                                   | 1/910 [00:00<01:39,  9.15it/s]

                              precision    recall  fbeta_score  support
Entity span      I-Loc         0.677824  0.480712     0.562500    337.0
                 I-Peop        0.688889  0.413333     0.516667    225.0
                 I-Org         0.510417  0.257895     0.342657    190.0
                 I-Other       0.192308  0.040000     0.066225    125.0
                 macro         0.517359  0.297985     0.372012      NaN
                 micro         0.622984  0.352338     0.450109      NaN
Entity embedding I-Loc         0.908397  0.516269     0.658368    461.0
                 I-Peop        0.961722  0.441758     0.605422    455.0
                 I-Org         0.811475  0.229698     0.358047    431.0
                 I-Other       0.870968  0.099265     0.178218    272.0
                 macro         0.888141  0.321748     0.450014      NaN
                 micro         0.905449  0.348981     0.503790      NaN
Loose relation   Kill          0.909091  0.555556     0.689655  

Train epoch 5: 100%|█████████████████████████████████████████████████████████████████| 910/910 [01:28<00:00, 10.33it/s]


epoch: 5 average loss: 0.15387904542692743


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:41<00:00,  5.84it/s]
Train epoch 6:   0%|                                                                           | 0/910 [00:00<?, ?it/s]

                              precision    recall  fbeta_score  support
Entity span      I-Loc         0.724576  0.507418     0.596859    337.0
                 I-Peop        0.668919  0.440000     0.530831    225.0
                 I-Org         0.520000  0.273684     0.358621    190.0
                 I-Other       0.290323  0.072000     0.115385    125.0
                 macro         0.550954  0.323276     0.400424      NaN
                 micro         0.642718  0.377423     0.475575      NaN
Entity embedding I-Loc         0.906015  0.522777     0.662999    461.0
                 I-Peop        0.949367  0.494505     0.650289    455.0
                 I-Org         0.807407  0.252900     0.385159    431.0
                 I-Other       0.926829  0.139706     0.242812    272.0
                 macro         0.897405  0.352472     0.485315      NaN
                 micro         0.902798  0.378629     0.533507      NaN
Loose relation   Kill          1.000000  0.611111     0.758621  

Train epoch 6: 100%|█████████████████████████████████████████████████████████████████| 910/910 [01:28<00:00, 10.29it/s]


epoch: 6 average loss: 0.1480091215367173


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:42<00:00,  5.68it/s]
Train epoch 7:   0%|                                                                           | 0/910 [00:00<?, ?it/s]

                              precision    recall  fbeta_score  support
Entity span      I-Loc         0.700855  0.486647     0.574431    337.0
                 I-Peop        0.628272  0.533333     0.576923    225.0
                 I-Org         0.510870  0.247368     0.333333    190.0
                 I-Other       0.233333  0.056000     0.090323    125.0
                 macro         0.518332  0.330837     0.393752      NaN
                 micro         0.617916  0.385405     0.474719      NaN
Entity embedding I-Loc         0.901141  0.514100     0.654696    461.0
                 I-Peop        0.925806  0.630769     0.750327    455.0
                 I-Org         0.838983  0.229698     0.360656    431.0
                 I-Other       0.842105  0.117647     0.206452    272.0
                 macro         0.877009  0.373054     0.493033      NaN
                 micro         0.898491  0.404571     0.557922      NaN
Loose relation   Kill          0.909091  0.555556     0.689655  

Train epoch 7: 100%|█████████████████████████████████████████████████████████████████| 910/910 [01:26<00:00, 10.52it/s]


epoch: 7 average loss: 0.14526586620607873


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:42<00:00,  5.76it/s]
Train epoch 8:   0%|                                                                   | 1/910 [00:00<01:53,  8.00it/s]

                              precision    recall  fbeta_score  support
Entity span      I-Loc         0.713080  0.501484     0.588850    337.0
                 I-Peop        0.679245  0.480000     0.562500    225.0
                 I-Org         0.505376  0.247368     0.332155    190.0
                 I-Other       0.366667  0.088000     0.141935    125.0
                 macro         0.566092  0.329213     0.406360      NaN
                 micro         0.645472  0.381984     0.479943      NaN
Entity embedding I-Loc         0.913725  0.505423     0.650838    461.0
                 I-Peop        0.946154  0.540659     0.688112    455.0
                 I-Org         0.813008  0.232019     0.361011    431.0
                 I-Other       0.913043  0.154412     0.264151    272.0
                 macro         0.896483  0.358128     0.491028      NaN
                 micro         0.907895  0.383570     0.539297      NaN
Loose relation   Kill          1.000000  0.555556     0.714286  

Train epoch 8: 100%|█████████████████████████████████████████████████████████████████| 910/910 [01:28<00:00, 10.28it/s]


epoch: 8 average loss: 0.13576528225605797


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:41<00:00,  5.88it/s]
Train epoch 9:   0%|                                                                           | 0/910 [00:00<?, ?it/s]

                              precision    recall  fbeta_score  support
Entity span      I-Loc         0.716049  0.516320     0.600000    337.0
                 I-Peop        0.679012  0.488889     0.568475    225.0
                 I-Org         0.595506  0.278947     0.379928    190.0
                 I-Other       0.206897  0.048000     0.077922    125.0
                 macro         0.549366  0.333039     0.406581      NaN
                 micro         0.655832  0.391106     0.490000      NaN
Entity embedding I-Loc         0.898551  0.537961     0.672999    461.0
                 I-Peop        0.954717  0.556044     0.702778    455.0
                 I-Org         0.834711  0.234339     0.365942    431.0
                 I-Other       0.878049  0.132353     0.230032    272.0
                 macro         0.891507  0.365174     0.492938      NaN
                 micro         0.907539  0.394070     0.549526      NaN
Loose relation   Kill          1.000000  0.500000     0.666667  

Train epoch 9: 100%|█████████████████████████████████████████████████████████████████| 910/910 [01:26<00:00, 10.57it/s]


epoch: 9 average loss: 0.13334999369498302


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:40<00:00,  6.03it/s]
Train epoch 10:   0%|▏                                                                 | 2/910 [00:00<01:25, 10.67it/s]

                              precision    recall  fbeta_score  support
Entity span      I-Loc         0.754098  0.545994     0.633391    337.0
                 I-Peop        0.655738  0.533333     0.588235    225.0
                 I-Org         0.578947  0.289474     0.385965    190.0
                 I-Other       0.322581  0.080000     0.128205    125.0
                 macro         0.577841  0.362200     0.433949      NaN
                 micro         0.667269  0.420753     0.516084      NaN
Entity embedding I-Loc         0.901060  0.553145     0.685484    461.0
                 I-Peop        0.956229  0.624176     0.755319    455.0
                 I-Org         0.850394  0.250580     0.387097    431.0
                 I-Other       0.906977  0.143382     0.247619    272.0
                 macro         0.903665  0.392821     0.518880      NaN
                 micro         0.914667  0.423718     0.579147      NaN
Loose relation   Kill          1.000000  0.500000     0.666667  

Train epoch 10: 100%|████████████████████████████████████████████████████████████████| 910/910 [01:24<00:00, 10.77it/s]


epoch: 10 average loss: 0.13376026887262424


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:41<00:00,  5.91it/s]
Train epoch 11:   0%|▏                                                                 | 2/910 [00:00<01:25, 10.67it/s]

                              precision    recall  fbeta_score  support
Entity span      I-Loc         0.764444  0.510386     0.612100    337.0
                 I-Peop        0.641176  0.484444     0.551899    225.0
                 I-Org         0.565217  0.273684     0.368794    190.0
                 I-Other       0.368421  0.056000     0.097222    125.0
                 macro         0.584815  0.331129     0.407504      NaN
                 micro         0.671937  0.387685     0.491685      NaN
Entity embedding I-Loc         0.921569  0.509761     0.656425    461.0
                 I-Peop        0.933099  0.582418     0.717185    455.0
                 I-Org         0.827869  0.234339     0.365280    431.0
                 I-Other       0.884615  0.084559     0.154362    272.0
                 macro         0.891788  0.352769     0.473313      NaN
                 micro         0.908297  0.385423     0.541197      NaN
Loose relation   Kill          1.000000  0.500000     0.666667  

Train epoch 11: 100%|████████████████████████████████████████████████████████████████| 910/910 [01:27<00:00, 10.35it/s]


epoch: 11 average loss: 0.1290410793683195


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:40<00:00,  5.99it/s]
Train epoch 12:   0%|                                                                  | 1/910 [00:00<01:39,  9.14it/s]

                              precision    recall  fbeta_score  support
Entity span      I-Loc         0.784404  0.507418     0.616216    337.0
                 I-Peop        0.666667  0.497778     0.569975    225.0
                 I-Org         0.530000  0.278947     0.365517    190.0
                 I-Other       0.461538  0.096000     0.158940    125.0
                 macro         0.610652  0.345036     0.427662      NaN
                 micro         0.679688  0.396807     0.501080      NaN
Entity embedding I-Loc         0.935484  0.503254     0.654443    461.0
                 I-Peop        0.935484  0.573626     0.711172    455.0
                 I-Org         0.801471  0.252900     0.384480    431.0
                 I-Other       0.900000  0.132353     0.230769    272.0
                 macro         0.893110  0.365533     0.495216      NaN
                 micro         0.907539  0.394070     0.549526      NaN
Loose relation   Kill          1.000000  0.555556     0.714286  

Train epoch 12: 100%|████████████████████████████████████████████████████████████████| 910/910 [01:28<00:00, 10.24it/s]


epoch: 12 average loss: 0.1265842727180775


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:41<00:00,  5.83it/s]
Train epoch 13:   0%|                                                                  | 1/910 [00:00<01:39,  9.14it/s]

                              precision    recall  fbeta_score  support
Entity span      I-Loc         0.804762  0.501484     0.617916    337.0
                 I-Peop        0.791045  0.471111     0.590529    225.0
                 I-Org         0.593023  0.268421     0.369565    190.0
                 I-Other       0.448276  0.104000     0.168831    125.0
                 macro         0.659276  0.336254     0.436710      NaN
                 micro         0.738562  0.386545     0.507485      NaN
Entity embedding I-Loc         0.921811  0.485900     0.636364    461.0
                 I-Peop        0.974249  0.498901     0.659884    455.0
                 I-Org         0.820513  0.222738     0.350365    431.0
                 I-Other       0.928571  0.143382     0.248408    272.0
                 macro         0.911286  0.337730     0.473755      NaN
                 micro         0.922835  0.361952     0.519965      NaN
Loose relation   Kill          1.000000  0.555556     0.714286  

Train epoch 13: 100%|████████████████████████████████████████████████████████████████| 910/910 [01:29<00:00, 10.20it/s]


epoch: 13 average loss: 0.1279320717574312


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:42<00:00,  5.71it/s]
Train epoch 14:   0%|                                                                  | 1/910 [00:00<01:39,  9.14it/s]

                              precision    recall  fbeta_score  support
Entity span      I-Loc         0.764192  0.519288     0.618375    337.0
                 I-Peop        0.697674  0.533333     0.604534    225.0
                 I-Org         0.600000  0.284211     0.385714    190.0
                 I-Other       0.464286  0.104000     0.169935    125.0
                 macro         0.631538  0.360208     0.444639      NaN
                 micro         0.697495  0.412771     0.518625      NaN
Entity embedding I-Loc         0.909774  0.524946     0.665750    461.0
                 I-Peop        0.938567  0.604396     0.735294    455.0
                 I-Org         0.851240  0.238979     0.373188    431.0
                 I-Other       0.928571  0.143382     0.248408    272.0
                 macro         0.907038  0.377926     0.505660      NaN
                 micro         0.912742  0.407041     0.563007      NaN
Loose relation   Kill          1.000000  0.611111     0.758621  

Train epoch 14: 100%|████████████████████████████████████████████████████████████████| 910/910 [01:28<00:00, 10.27it/s]


epoch: 14 average loss: 0.12389161139828982


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:42<00:00,  5.78it/s]
Train epoch 15:   0%|                                                                  | 1/910 [00:00<01:39,  9.14it/s]

                              precision    recall  fbeta_score  support
Entity span      I-Loc         0.738776  0.537092     0.621993    337.0
                 I-Peop        0.708609  0.475556     0.569149    225.0
                 I-Org         0.509615  0.278947     0.360544    190.0
                 I-Other       0.357143  0.080000     0.130719    125.0
                 macro         0.578536  0.342899     0.420601      NaN
                 micro         0.664773  0.400228     0.499644      NaN
Entity embedding I-Loc         0.906137  0.544469     0.680217    461.0
                 I-Peop        0.960938  0.540659     0.691983    455.0
                 I-Org         0.802721  0.273782     0.408304    431.0
                 I-Other       0.900000  0.132353     0.230769    272.0
                 macro         0.892449  0.372816     0.502818      NaN
                 micro         0.904167  0.402100     0.556648      NaN
Loose relation   Kill          1.000000  0.611111     0.758621  

Train epoch 15: 100%|████████████████████████████████████████████████████████████████| 910/910 [01:28<00:00, 10.31it/s]


epoch: 15 average loss: 0.12179762400318306


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:41<00:00,  5.85it/s]
Train epoch 16:   0%|                                                                  | 1/910 [00:00<01:53,  8.00it/s]

                              precision    recall  fbeta_score  support
Entity span      I-Loc         0.756303  0.534125     0.626087    337.0
                 I-Peop        0.751773  0.471111     0.579235    225.0
                 I-Org         0.622222  0.294737     0.400000    190.0
                 I-Other       0.379310  0.088000     0.142857    125.0
                 macro         0.627402  0.346993     0.437045      NaN
                 micro         0.708835  0.402509     0.513455      NaN
Entity embedding I-Loc         0.898917  0.540130     0.674797    461.0
                 I-Peop        0.974576  0.505495     0.665702    455.0
                 I-Org         0.829457  0.248260     0.382143    431.0
                 I-Other       0.928571  0.143382     0.248408    272.0
                 macro         0.907881  0.359317     0.492762      NaN
                 micro         0.913743  0.386041     0.542770      NaN
Loose relation   Kill          1.000000  0.611111     0.758621  

Train epoch 16: 100%|████████████████████████████████████████████████████████████████| 910/910 [01:28<00:00, 10.28it/s]


epoch: 16 average loss: 0.11929864809940477


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:44<00:00,  5.49it/s]
Train epoch 17:   0%|                                                                  | 1/910 [00:00<01:39,  9.15it/s]

                              precision    recall  fbeta_score  support
Entity span      I-Loc         0.742616  0.522255     0.613240    337.0
                 I-Peop        0.740000  0.493333     0.592000    225.0
                 I-Org         0.529412  0.284211     0.369863    190.0
                 I-Other       0.433333  0.104000     0.167742    125.0
                 macro         0.611340  0.350950     0.435711      NaN
                 micro         0.682081  0.403649     0.507163      NaN
Entity embedding I-Loc         0.903346  0.527115     0.665753    461.0
                 I-Peop        0.964844  0.542857     0.694796    455.0
                 I-Org         0.763158  0.269142     0.397942    431.0
                 I-Other       0.901961  0.169118     0.284830    272.0
                 macro         0.883327  0.377058     0.510830      NaN
                 micro         0.895604  0.402718     0.555603      NaN
Loose relation   Kill          0.909091  0.555556     0.689655  

Train epoch 17: 100%|████████████████████████████████████████████████████████████████| 910/910 [01:30<00:00, 10.07it/s]


epoch: 17 average loss: 0.12144089000975038


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:43<00:00,  5.61it/s]
Train epoch 18:   0%|                                                                          | 0/910 [00:00<?, ?it/s]

                              precision    recall  fbeta_score  support
Entity span      I-Loc         0.759336  0.543027     0.633218    337.0
                 I-Peop        0.689024  0.502222     0.580977    225.0
                 I-Org         0.500000  0.278947     0.358108    190.0
                 I-Other       0.411765  0.112000     0.176101    125.0
                 macro         0.590031  0.359049     0.437101      NaN
                 micro         0.666055  0.413911     0.510549      NaN
Entity embedding I-Loc         0.899642  0.544469     0.678378    461.0
                 I-Peop        0.946043  0.578022     0.717599    455.0
                 I-Org         0.797386  0.283063     0.417808    431.0
                 I-Other       0.927273  0.187500     0.311927    272.0
                 macro         0.892586  0.398263     0.531428      NaN
                 micro         0.898039  0.424336     0.576342      NaN
Loose relation   Kill          0.857143  0.666667     0.750000  

Train epoch 18: 100%|████████████████████████████████████████████████████████████████| 910/910 [01:25<00:00, 10.58it/s]


epoch: 18 average loss: 0.12062201128279852


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:40<00:00,  5.97it/s]
Train epoch 19:   0%|                                                                  | 1/910 [00:00<01:39,  9.14it/s]

                              precision    recall  fbeta_score  support
Entity span      I-Loc         0.730924  0.540059     0.621160    337.0
                 I-Peop        0.711538  0.493333     0.582677    225.0
                 I-Org         0.612245  0.315789     0.416667    190.0
                 I-Other       0.424242  0.112000     0.177215    125.0
                 macro         0.619737  0.365296     0.449430      NaN
                 micro         0.684701  0.418472     0.519462      NaN
Entity embedding I-Loc         0.903915  0.550976     0.684636    461.0
                 I-Peop        0.940959  0.560440     0.702479    455.0
                 I-Org         0.818792  0.283063     0.420690    431.0
                 I-Other       0.940000  0.172794     0.291925    272.0
                 macro         0.900916  0.391818     0.524933      NaN
                 micro         0.902796  0.418777     0.572152      NaN
Loose relation   Kill          0.916667  0.611111     0.733333  

Train epoch 19: 100%|████████████████████████████████████████████████████████████████| 910/910 [01:26<00:00, 10.51it/s]


epoch: 19 average loss: 0.12044108577052151


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:42<00:00,  5.71it/s]


                              precision    recall  fbeta_score  support
Entity span      I-Loc         0.729412  0.551929     0.628378    337.0
                 I-Peop        0.716049  0.515556     0.599483    225.0
                 I-Org         0.571429  0.294737     0.388889    190.0
                 I-Other       0.323529  0.088000     0.138365    125.0
                 macro         0.585105  0.362555     0.438779      NaN
                 micro         0.672131  0.420753     0.517532      NaN
Entity embedding I-Loc         0.905594  0.561822     0.693440    461.0
                 I-Peop        0.948718  0.569231     0.711538    455.0
                 I-Org         0.834532  0.269142     0.407018    431.0
                 I-Other       0.882353  0.165441     0.278638    272.0
                 macro         0.892799  0.391409     0.522659      NaN
                 micro         0.906542  0.419395     0.573480      NaN
Loose relation   Kill          0.916667  0.611111     0.733333  

In [17]:
config = BertConfig.from_pretrained(constants.model_path)
spert_model = model.SpERT.from_pretrained(constants.model_path,
                                          config=config,
                                          # SpERT model parameters
                                          relation_types=constants.relation_types, 
                                          entity_types=constants.entity_types, 
                                          width_embedding_size=constants.width_embedding_size, 
                                          prop_drop=constants.prop_drop, 
                                          freeze_transformer=True, 
                                          max_pairs=constants.max_pairs, 
                                          is_overlapping=constants.is_overlapping, 
                                          relation_filter_threshold=constants.relation_filter_threshold)
spert_model.to(device)

Some weights of the model checkpoint at bert-base-cased were not used when initializing SpERT: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing SpERT from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SpERT from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SpERT were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['relation_classifier.weight', 'relation_classifier.bi

SpERT(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      

In [18]:
state_dict = torch.load(constants.model_save_path + "epoch_" + str(constants.epochs-1) + ".model", map_location=device)

In [19]:
# # To run reference model
# state_dict = torch.load("../../model/spert/reference/reference_model.bin")
# state_dict["relation_classifier.weight"] = state_dict["rel_classifier.weight"]
# state_dict["relation_classifier.bias"] = state_dict["rel_classifier.bias"]
# state_dict["width_embedding.weight"] = state_dict["size_embeddings.weight"]

In [20]:
spert_model.load_state_dict(state_dict, strict=False)
evaluate(spert_model, constants.test_dataset)

Evaluation test: 100%|███████████████████████████████████████████████████████████████| 288/288 [00:52<00:00,  5.49it/s]

                              precision    recall  fbeta_score  support
Entity span      I-Loc         0.685714  0.562061     0.617761    427.0
                 I-Peop        0.714286  0.482866     0.576208    321.0
                 I-Org         0.666667  0.323232     0.435374    198.0
                 I-Other       0.347826  0.120301     0.178771    133.0
                 macro         0.603623  0.372115     0.452028      NaN
                 micro         0.669958  0.440222     0.531320      NaN
Entity embedding I-Loc         0.893827  0.572785     0.698168    632.0
                 I-Peop        0.969863  0.516035     0.673644    686.0
                 I-Org         0.903448  0.298405     0.448630    439.0
                 I-Other       0.820896  0.209125     0.333333    263.0
                 macro         0.897008  0.399088     0.538444      NaN
                 micro         0.918534  0.446535     0.600933      NaN
Loose relation   Kill          0.760000  0.404255     0.527778  




In [21]:
# state_dict = torch.load(constants.model_save_path + "epoch_" + str(constants.epochs-1) + ".model")
# state_dict["rel_classifier.weight"] = state_dict["relation_classifier.weight"]
# state_dict["rel_classifier.bias"] = state_dict["relation_classifier.bias"]
# state_dict["size_embeddings.weight"] = state_dict["width_embedding.weight"]
# del state_dict["relation_classifier.weight"]
# del state_dict["relation_classifier.bias"]
# del state_dict["width_embedding.weight"]
# del state_dict["bert.embeddings.position_ids"]
# torch.save(state_dict, "../../model/spert/pytorch_model.bin")

In [22]:
# predict(spert_model, sentences=["However, the Rev. Jesse Jackson, a native of South Carolina, joined critics of FEMA's effort.", 
#                                 "International Paper spokeswoman Ann Silvernail said that under French law the company was barred from releasing details pending government approval."])

In [23]:
test_sentence = " ".join(["An", "art", "exhibit", "at", "the", "Hakawati", "Theatre", "in", "Arab", "east", "Jerusalem", "was", "a", "series", "of", "portraits", "of", "Palestinians", "killed", "in", "the", "rebellion", "."])

In [24]:
# "entities": [{"type": "Loc", "start": 5, "end": 7}, {"type": "Loc", "start": 10, "end": 11}, {"type": "Other", "start": 17, "end": 18}], "relations": [{"type": "Located_In", "head": 0, "tail": 1}]

In [25]:
predict(spert_model, sentences=[test_sentence])

Sentence: An art exhibit at the Hakawati Theatre in Arab east Jerusalem was a series of portraits of Palestinians killed in the rebellion .
Entities: ( 1 )
I-Loc | Jerusalem
Relations: ( 0 )


In [26]:
test_sentence = " ".join(["PERUGIA", ",", "Italy", "(", "AP", ")"])

In [27]:
predict(spert_model, sentences=[test_sentence])

Sentence: PERUGIA , Italy ( AP )
Entities: ( 2 )
I-Org | AP
I-Loc | Italy
Relations: ( 1 )
OrgBased_In | AP | Italy
