In [1]:
import os
import math

In [2]:
import pandas as pd
import torch
import transformers
from transformers import AdamW, BertConfig
from sklearn.metrics import precision_recall_fscore_support

In [3]:
import conll04_constants as constants
import conll04_input_generator as input_generator
import model

In [4]:
from tqdm import tqdm

In [5]:
model_save_path = "../../model/spert/conll04_"

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
entity_label_map = {v: k for k, v in input_generator.parser.entity_encode.items()}
entity_classes = list(entity_label_map.keys())
entity_classes.remove(0)

In [8]:
relation_label_map = {v: k for k, v in input_generator.parser.relation_encode.items()}
relation_classes = list(relation_label_map.keys())
relation_classes.remove(0)

In [9]:
def get_optimizer_params(model):
    params = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_params = [
        {"params": [p for n, p in params if not any(nd in n for nd in no_decay)],
         "weight_decay": constants.weight_decay},
        {"params": [p for n, p in params if any(nd in n for nd in no_decay)],
         "weight_decay": 0.0}
    ]
    return optimizer_params

In [10]:
def evaluate_results(true_labels, predicted_labels, label_map, classes):
    """Compare the prediction list with the truth values"""
    precision, recall, fbeta_score, support \
        = precision_recall_fscore_support(true_labels, predicted_labels, average=None, labels=classes, 
                                          zero_division=0)
    result = pd.DataFrame(index=[label_map[c] for c in classes])
    result["precision"] = precision
    result["recall"] = recall
    result["fbeta_score"] = fbeta_score
    result["support"] = support
    result.loc["macro"] = list(precision_recall_fscore_support(true_labels, predicted_labels, average="macro", 
                                                               labels=classes, zero_division=0))
    result.loc["micro"] = list(precision_recall_fscore_support(true_labels, predicted_labels, average="micro", 
                                                               labels=classes, zero_division=0))
    return result

In [11]:
def evaluate(spert_model, group):
    spert_model.eval()
    eval_entity_span_pred = []
    eval_entity_span_true = []
    eval_entity_embedding_pred = []
    eval_entity_embedding_true = []
#     eval_relation
#     eval_relation_pred = []
#     eval_relation_true = []
    eval_generator = input_generator.data_generator(group, device, 
                                                    is_training=False,
                                                    neg_entity_count=constants.neg_entity_count, 
                                                    neg_relation_count=constants.neg_relation_count, 
                                                    max_span_size=constants.max_span_size)
    eval_dataset = list(eval_generator)
    eval_size = len(eval_dataset)
    for inputs, infos in tqdm(eval_dataset, total=eval_size, desc="Evaluation "+group):
        # forward
        outputs = spert_model(**inputs, is_training=False)
        
        # retrieve results for evaluation
        eval_entity_span_pred += outputs["entity"]["pred"].tolist()
        eval_entity_span_true += inputs["entity_label"].tolist()
        eval_entity_embedding_pred += outputs["entity"]["embedding"].tolist()
        eval_entity_embedding_true += infos["entity_embedding"].tolist()
#         eval_relation_pred += [] if outputs["relation"] == None else outputs["relation"]["pred"].tolist()
#         eval_relation_true += inputs["relation_label"].tolist()
        assert len(eval_entity_span_pred) == len(eval_entity_span_true)
        assert len(eval_entity_embedding_pred) == len(eval_entity_embedding_true)
#         assert len(eval_relation_pred) == len(eval_relation_true)
        
    # evaluate & save
    results = pd.concat([
        evaluate_results(eval_entity_span_true, eval_entity_span_pred, entity_label_map, entity_classes),
        evaluate_results(eval_entity_embedding_true, eval_entity_embedding_pred, entity_label_map, 
                         entity_classes)#,
#         evaluate_results(eval_relation_true, eval_relation_pred, relation_label_map, relation_classes)
    ], keys=["Entity span", "Entity embedding"])# , "Relation"])
    results.to_csv(model_save_path + "evaluate_" + group + ".csv")
    print(results)

In [12]:
def train():
    """Train the model and evaluate on the dev dataset
    """
    
    train_group = "train"
    validation_group = "dev"
    
    # Training
    train_generator = input_generator.data_generator(train_group, device, 
                                                     is_training=True,
                                                     neg_entity_count=constants.neg_entity_count, 
                                                     neg_relation_count=constants.neg_relation_count, 
                                                     max_span_size=constants.max_span_size)
    train_dataset = list(train_generator)
    train_size = len(train_dataset)
    
    config = BertConfig.from_pretrained(constants.model_path)
    spert_model = model.SpERT.from_pretrained(constants.model_path,
                                              # SpERT model parameters
                                              relation_types=constants.relation_types, 
                                              entity_types=constants.entity_types, 
                                              width_embedding_size=constants.width_embedding_size, 
                                              prop_drop=constants.prop_drop, 
                                              freeze_transformer=True, 
                                              max_pairs=constants.max_pairs, 
                                              is_overlapping=False, 
                                              relation_filter_threshold=constants.relation_filter_threshold)
    spert_model.train()
    spert_model.to(device)
    optimizer_params = get_optimizer_params(spert_model)
    optimizer = AdamW(optimizer_params, lr=constants.lr, weight_decay=constants.weight_decay, correct_bias=False)
    scheduler = transformers.get_linear_schedule_with_warmup(optimizer, 
                             num_warmup_steps=constants.lr_warmup*train_size*constants.epochs, 
                             num_training_steps=train_size*constants.epochs)
    
    for epoch in range(constants.epochs):
        losses = []
        train_entity_span_pred = []
        train_entity_span_true = []
        train_entity_embedding_pred = []
        train_entity_embedding_true = []
        train_relation_pred = []
        train_relation_true = []
        for inputs, infos in tqdm(train_dataset, total=train_size, desc='Train epoch %s' % epoch):
            spert_model.zero_grad()
            
            # forward
            outputs = spert_model(**inputs, is_training=True)
            
            # backward
            loss = outputs["loss"]
            loss.backward()
            torch.nn.utils.clip_grad_norm_(spert_model.parameters(), constants.max_grad_norm)
            optimizer.step()
            scheduler.step()
            # print(optimizer.param_groups[0]['lr'])
            
            # retrieve results for evaluation
            losses.append(loss.item())
            train_entity_span_pred += outputs["entity"]["pred"].tolist()
            train_entity_span_true += inputs["entity_label"].tolist()
            train_entity_embedding_pred += outputs["entity"]["embedding"].tolist()
            train_entity_embedding_true += infos["entity_embedding"].tolist()
            train_relation_pred += [] if outputs["relation"] == None else outputs["relation"]["pred"].tolist()
            train_relation_true += inputs["relation_label"].tolist()
            assert len(train_entity_span_pred) == len(train_entity_span_true)
            assert len(train_entity_embedding_pred) == len(train_entity_embedding_true)
            assert len(train_relation_pred) == len(train_relation_true)
            
        # evaluate & save checkpoint
        print("epoch:", epoch,"average loss:", sum(losses) / len(losses))
        results = pd.concat([
            evaluate_results(train_entity_span_true, train_entity_span_pred, entity_label_map, entity_classes),
            evaluate_results(train_entity_embedding_true, train_entity_embedding_pred, entity_label_map, 
                             entity_classes),
            evaluate_results(train_relation_true, train_relation_pred, relation_label_map, relation_classes)
        ], keys=["Entity span", "Entity embedding", "Relation"])
        results.to_csv(model_save_path + "epoch_" + str(epoch) + ".csv")
        evaluate(spert_model, validation_group)
    torch.save(spert_model.state_dict(), model_save_path + "epoch_" + str(constants.epochs-1) + ".model")

In [13]:
train()

Some weights of the model checkpoint at bert-base-cased were not used when initializing SpERT: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing SpERT from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SpERT from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SpERT were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['relation_classifier.weight', 'relation_classifier.bi

epoch: 0 average loss: 0.7844143317623453


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:09<00:00, 26.61it/s]
Train epoch 1:   0%|▎                                                                  | 4/910 [00:00<00:26, 34.48it/s]

                          precision  recall  fbeta_score  support
Entity span      I-Loc          0.0     0.0          0.0    337.0
                 I-Peop         0.0     0.0          0.0    225.0
                 I-Org          0.0     0.0          0.0    190.0
                 I-Other        0.0     0.0          0.0    125.0
                 macro          0.0     0.0          0.0      NaN
                 micro          0.0     0.0          0.0      NaN
Entity embedding I-Loc          0.0     0.0          0.0    717.0
                 I-Peop         0.0     0.0          0.0    793.0
                 I-Org          0.0     0.0          0.0    747.0
                 I-Other        0.0     0.0          0.0    354.0
                 macro          0.0     0.0          0.0      NaN
                 micro          0.0     0.0          0.0      NaN


Train epoch 1: 100%|█████████████████████████████████████████████████████████████████| 910/910 [00:26<00:00, 34.39it/s]


epoch: 1 average loss: 0.32691486957636506


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:09<00:00, 26.18it/s]
Train epoch 2:   0%|▎                                                                  | 4/910 [00:00<00:26, 34.78it/s]

                          precision  recall  fbeta_score  support
Entity span      I-Loc          0.0     0.0          0.0    337.0
                 I-Peop         0.0     0.0          0.0    225.0
                 I-Org          0.0     0.0          0.0    190.0
                 I-Other        0.0     0.0          0.0    125.0
                 macro          0.0     0.0          0.0      NaN
                 micro          0.0     0.0          0.0      NaN
Entity embedding I-Loc          0.0     0.0          0.0    717.0
                 I-Peop         0.0     0.0          0.0    793.0
                 I-Org          0.0     0.0          0.0    747.0
                 I-Other        0.0     0.0          0.0    354.0
                 macro          0.0     0.0          0.0      NaN
                 micro          0.0     0.0          0.0      NaN


Train epoch 2: 100%|█████████████████████████████████████████████████████████████████| 910/910 [00:26<00:00, 34.03it/s]


epoch: 2 average loss: 0.27171740808657235


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:09<00:00, 26.75it/s]
Train epoch 3:   0%|▎                                                                  | 4/910 [00:00<00:27, 33.06it/s]

                          precision    recall  fbeta_score  support
Entity span      I-Loc     0.375000  0.008902     0.017391    337.0
                 I-Peop    0.000000  0.000000     0.000000    225.0
                 I-Org     0.000000  0.000000     0.000000    190.0
                 I-Other   0.000000  0.000000     0.000000    125.0
                 macro     0.093750  0.002226     0.004348      NaN
                 micro     0.272727  0.003421     0.006757      NaN
Entity embedding I-Loc     0.875000  0.009763     0.019310    717.0
                 I-Peop    1.000000  0.003783     0.007538    793.0
                 I-Org     0.000000  0.000000     0.000000    747.0
                 I-Other   0.000000  0.000000     0.000000    354.0
                 macro     0.468750  0.003387     0.006712      NaN
                 micro     0.909091  0.003830     0.007628      NaN


Train epoch 3: 100%|█████████████████████████████████████████████████████████████████| 910/910 [00:26<00:00, 34.15it/s]


epoch: 3 average loss: 0.2375957775574464


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:09<00:00, 26.40it/s]
Train epoch 4:   0%|▎                                                                  | 4/910 [00:00<00:27, 32.79it/s]

                          precision    recall  fbeta_score  support
Entity span      I-Loc     0.395833  0.056380     0.098701    337.0
                 I-Peop    0.090909  0.004444     0.008475    225.0
                 I-Org     0.000000  0.000000     0.000000    190.0
                 I-Other   0.000000  0.000000     0.000000    125.0
                 macro     0.121686  0.015206     0.026794      NaN
                 micro     0.338983  0.022805     0.042735      NaN
Entity embedding I-Loc     0.888889  0.055788     0.104987    717.0
                 I-Peop    1.000000  0.013871     0.027363    793.0
                 I-Org     0.000000  0.000000     0.000000    747.0
                 I-Other   0.000000  0.000000     0.000000    354.0
                 macro     0.472222  0.017415     0.033088      NaN
                 micro     0.910714  0.019533     0.038245      NaN


Train epoch 4: 100%|█████████████████████████████████████████████████████████████████| 910/910 [00:26<00:00, 33.89it/s]


epoch: 4 average loss: 0.21580167818855453


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:09<00:00, 26.15it/s]
Train epoch 5:   0%|▎                                                                  | 4/910 [00:00<00:26, 34.19it/s]

                          precision    recall  fbeta_score  support
Entity span      I-Loc     0.500000  0.121662     0.195704    337.0
                 I-Peop    0.137931  0.017778     0.031496    225.0
                 I-Org     0.000000  0.000000     0.000000    190.0
                 I-Other   0.000000  0.000000     0.000000    125.0
                 macro     0.159483  0.034860     0.056800      NaN
                 micro     0.391304  0.051311     0.090726      NaN
Entity embedding I-Loc     0.924051  0.101813     0.183417    717.0
                 I-Peop    1.000000  0.030265     0.058752    793.0
                 I-Org     0.000000  0.000000     0.000000    747.0
                 I-Other   1.000000  0.008475     0.016807    354.0
                 macro     0.731013  0.035138     0.064744      NaN
                 micro     0.934579  0.038300     0.073584      NaN


Train epoch 5: 100%|█████████████████████████████████████████████████████████████████| 910/910 [00:26<00:00, 34.04it/s]


epoch: 5 average loss: 0.20056984234642197


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:09<00:00, 25.97it/s]
Train epoch 6:   0%|▎                                                                  | 4/910 [00:00<00:26, 33.91it/s]

                          precision    recall  fbeta_score  support
Entity span      I-Loc     0.508929  0.169139     0.253898    337.0
                 I-Peop    0.307692  0.053333     0.090909    225.0
                 I-Org     0.000000  0.000000     0.000000    190.0
                 I-Other   0.000000  0.000000     0.000000    125.0
                 macro     0.204155  0.055618     0.086202      NaN
                 micro     0.436709  0.078677     0.133333      NaN
Entity embedding I-Loc     0.935780  0.142259     0.246973    717.0
                 I-Peop    1.000000  0.050441     0.096038    793.0
                 I-Org     0.500000  0.001339     0.002670    747.0
                 I-Other   1.000000  0.011299     0.022346    354.0
                 macro     0.858945  0.051335     0.092007      NaN
                 micro     0.948387  0.056300     0.106291      NaN


Train epoch 6: 100%|█████████████████████████████████████████████████████████████████| 910/910 [00:26<00:00, 33.77it/s]


epoch: 6 average loss: 0.18935942985526807


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:09<00:00, 26.50it/s]
Train epoch 7:   0%|▎                                                                  | 4/910 [00:00<00:26, 34.19it/s]

                          precision    recall  fbeta_score  support
Entity span      I-Loc     0.538462  0.207715     0.299786    337.0
                 I-Peop    0.333333  0.075556     0.123188    225.0
                 I-Org     0.000000  0.000000     0.000000    190.0
                 I-Other   0.090909  0.008000     0.014706    125.0
                 macro     0.240676  0.072818     0.109420      NaN
                 micro     0.448980  0.100342     0.164026      NaN
Entity embedding I-Loc     0.926230  0.157601     0.269368    717.0
                 I-Peop    1.000000  0.065574     0.123077    793.0
                 I-Org     0.666667  0.002677     0.005333    747.0
                 I-Other   0.888889  0.022599     0.044077    354.0
                 macro     0.870446  0.062113     0.110464      NaN
                 micro     0.940860  0.067024     0.125134      NaN


Train epoch 7: 100%|█████████████████████████████████████████████████████████████████| 910/910 [00:26<00:00, 33.87it/s]


epoch: 7 average loss: 0.18076251073034255


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:09<00:00, 26.40it/s]
Train epoch 8:   0%|▎                                                                  | 4/910 [00:00<00:26, 33.89it/s]

                          precision    recall  fbeta_score  support
Entity span      I-Loc     0.566667  0.252226     0.349076    337.0
                 I-Peop    0.327869  0.088889     0.139860    225.0
                 I-Org     0.363636  0.021053     0.039801    190.0
                 I-Other   0.058824  0.008000     0.014085    125.0
                 macro     0.329249  0.092542     0.135705      NaN
                 micro     0.460251  0.125428     0.197133      NaN
Entity embedding I-Loc     0.922535  0.182706     0.305006    717.0
                 I-Peop    1.000000  0.080706     0.149358    793.0
                 I-Org     0.818182  0.012048     0.023747    747.0
                 I-Other   0.846154  0.031073     0.059946    354.0
                 macro     0.896718  0.076633     0.134514      NaN
                 micro     0.934783  0.082344     0.151355      NaN


Train epoch 8: 100%|█████████████████████████████████████████████████████████████████| 910/910 [00:26<00:00, 33.87it/s]


epoch: 8 average loss: 0.17397573439220151


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:09<00:00, 26.11it/s]
Train epoch 9:   0%|▎                                                                  | 4/910 [00:00<00:26, 33.90it/s]

                          precision    recall  fbeta_score  support
Entity span      I-Loc     0.579268  0.281899     0.379242    337.0
                 I-Peop    0.378378  0.124444     0.187291    225.0
                 I-Org     0.347826  0.042105     0.075117    190.0
                 I-Other   0.100000  0.016000     0.027586    125.0
                 macro     0.351368  0.116112     0.167309      NaN
                 micro     0.473310  0.151653     0.229706      NaN
Entity embedding I-Loc     0.921569  0.196653     0.324138    717.0
                 I-Peop    1.000000  0.103405     0.187429    793.0
                 I-Org     0.800000  0.026774     0.051813    747.0
                 I-Other   0.875000  0.039548     0.075676    354.0
                 macro     0.899142  0.091595     0.159764      NaN
                 micro     0.931159  0.098430     0.178039      NaN


Train epoch 9: 100%|█████████████████████████████████████████████████████████████████| 910/910 [00:27<00:00, 33.52it/s]


epoch: 9 average loss: 0.16851184367724173


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:09<00:00, 25.54it/s]
Train epoch 10:   0%|▏                                                                 | 3/910 [00:00<00:30, 29.41it/s]

                          precision    recall  fbeta_score  support
Entity span      I-Loc     0.579545  0.302671     0.397661    337.0
                 I-Peop    0.406977  0.155556     0.225080    225.0
                 I-Org     0.322581  0.052632     0.090498    190.0
                 I-Other   0.160000  0.032000     0.053333    125.0
                 macro     0.367276  0.135714     0.191643      NaN
                 micro     0.474843  0.172178     0.252720      NaN
Entity embedding I-Loc     0.909639  0.210600     0.342016    717.0
                 I-Peop    1.000000  0.119798     0.213964    793.0
                 I-Org     0.806452  0.033467     0.064267    747.0
                 I-Other   0.894737  0.048023     0.091153    354.0
                 macro     0.902707  0.102972     0.177850      NaN
                 micro     0.926045  0.110303     0.197125      NaN


Train epoch 10: 100%|████████████████████████████████████████████████████████████████| 910/910 [00:27<00:00, 33.50it/s]


epoch: 10 average loss: 0.1640515614260029


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:09<00:00, 25.49it/s]
Train epoch 11:   0%|▎                                                                 | 4/910 [00:00<00:26, 33.90it/s]

                          precision    recall  fbeta_score  support
Entity span      I-Loc     0.588889  0.314540     0.410058    337.0
                 I-Peop    0.400000  0.168889     0.237500    225.0
                 I-Org     0.365854  0.078947     0.129870    190.0
                 I-Other   0.160000  0.032000     0.053333    125.0
                 macro     0.378686  0.148594     0.207690      NaN
                 micro     0.478006  0.185861     0.267652      NaN
Entity embedding I-Loc     0.902299  0.218968     0.352413    717.0
                 I-Peop    0.990654  0.133670     0.235556    793.0
                 I-Org     0.857143  0.048193     0.091255    747.0
                 I-Other   0.894737  0.048023     0.091153    354.0
                 macro     0.911208  0.112213     0.192594      NaN
                 micro     0.923977  0.121026     0.214020      NaN


Train epoch 11: 100%|████████████████████████████████████████████████████████████████| 910/910 [00:27<00:00, 33.36it/s]


epoch: 11 average loss: 0.16037974679781186


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:09<00:00, 25.73it/s]
Train epoch 12:   0%|▎                                                                 | 4/910 [00:00<00:26, 34.19it/s]

                          precision    recall  fbeta_score  support
Entity span      I-Loc     0.588542  0.335312     0.427221    337.0
                 I-Peop    0.407767  0.186667     0.256098    225.0
                 I-Org     0.357143  0.078947     0.129310    190.0
                 I-Other   0.185185  0.040000     0.065789    125.0
                 macro     0.384659  0.160231     0.219605      NaN
                 micro     0.480769  0.199544     0.282031      NaN
Entity embedding I-Loc     0.898396  0.234310     0.371681    717.0
                 I-Peop    0.991870  0.153846     0.266376    793.0
                 I-Org     0.863636  0.050870     0.096081    747.0
                 I-Other   0.909091  0.056497     0.106383    354.0
                 macro     0.915748  0.123881     0.210130      NaN
                 micro     0.925532  0.133282     0.233010      NaN


Train epoch 12: 100%|████████████████████████████████████████████████████████████████| 910/910 [00:27<00:00, 33.20it/s]


epoch: 12 average loss: 0.15734779380343772


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:09<00:00, 25.60it/s]
Train epoch 13:   0%|▎                                                                 | 4/910 [00:00<00:28, 31.75it/s]

                          precision    recall  fbeta_score  support
Entity span      I-Loc     0.582915  0.344214     0.432836    337.0
                 I-Peop    0.420561  0.200000     0.271084    225.0
                 I-Org     0.361702  0.089474     0.143460    190.0
                 I-Other   0.178571  0.040000     0.065359    125.0
                 macro     0.385937  0.168422     0.228185      NaN
                 micro     0.480315  0.208666     0.290938      NaN
Entity embedding I-Loc     0.885000  0.246862     0.386041    717.0
                 I-Peop    0.992481  0.166456     0.285097    793.0
                 I-Org     0.851064  0.053548     0.100756    747.0
                 I-Other   0.916667  0.062147     0.116402    354.0
                 macro     0.911303  0.132253     0.222074      NaN
                 micro     0.918317  0.142091     0.246103      NaN


Train epoch 13: 100%|████████████████████████████████████████████████████████████████| 910/910 [00:27<00:00, 33.63it/s]


epoch: 13 average loss: 0.15484899889145579


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:09<00:00, 25.51it/s]
Train epoch 14:   0%|▎                                                                 | 4/910 [00:00<00:26, 34.19it/s]

                          precision    recall  fbeta_score  support
Entity span      I-Loc     0.576355  0.347181     0.433333    337.0
                 I-Peop    0.424779  0.213333     0.284024    225.0
                 I-Org     0.365385  0.100000     0.157025    190.0
                 I-Other   0.206897  0.048000     0.077922    125.0
                 macro     0.393354  0.177129     0.238076      NaN
                 micro     0.478589  0.216648     0.298273      NaN
Entity embedding I-Loc     0.886700  0.251046     0.391304    717.0
                 I-Peop    0.993056  0.180328     0.305229    793.0
                 I-Org     0.862745  0.058902     0.110276    747.0
                 I-Other   0.920000  0.064972     0.121372    354.0
                 macro     0.915625  0.138812     0.232045      NaN
                 micro     0.921986  0.149368     0.257086      NaN


Train epoch 14: 100%|████████████████████████████████████████████████████████████████| 910/910 [00:27<00:00, 33.46it/s]


epoch: 14 average loss: 0.15280348984831638


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:09<00:00, 26.06it/s]
Train epoch 15:   0%|▎                                                                 | 4/910 [00:00<00:28, 32.25it/s]

                          precision    recall  fbeta_score  support
Entity span      I-Loc     0.574879  0.353116     0.437500    337.0
                 I-Peop    0.429752  0.231111     0.300578    225.0
                 I-Org     0.357143  0.105263     0.162602    190.0
                 I-Other   0.206897  0.048000     0.077922    125.0
                 macro     0.392168  0.184372     0.244650      NaN
                 micro     0.476998  0.224629     0.305426      NaN
Entity embedding I-Loc     0.887805  0.253835     0.394794    717.0
                 I-Peop    0.993590  0.195460     0.326660    793.0
                 I-Org     0.833333  0.060241     0.112360    747.0
                 I-Other   0.920000  0.064972     0.121372    354.0
                 macro     0.908682  0.143627     0.238796      NaN
                 micro     0.920455  0.155113     0.265487      NaN


Train epoch 15: 100%|████████████████████████████████████████████████████████████████| 910/910 [00:27<00:00, 33.39it/s]


epoch: 15 average loss: 0.1511511467233464


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:09<00:00, 25.95it/s]
Train epoch 16:   0%|▎                                                                 | 4/910 [00:00<00:28, 32.26it/s]

                          precision    recall  fbeta_score  support
Entity span      I-Loc     0.582160  0.367953     0.450909    337.0
                 I-Peop    0.434426  0.235556     0.305476    225.0
                 I-Org     0.350000  0.110526     0.168000    190.0
                 I-Other   0.206897  0.048000     0.077922    125.0
                 macro     0.393371  0.190509     0.250577      NaN
                 micro     0.481132  0.232611     0.313605      NaN
Entity embedding I-Loc     0.890476  0.260809     0.403452    717.0
                 I-Peop    0.993631  0.196721     0.328421    793.0
                 I-Org     0.842105  0.064257     0.119403    747.0
                 I-Other   0.920000  0.064972     0.121372    354.0
                 macro     0.911553  0.146690     0.243162      NaN
                 micro     0.922049  0.158560     0.270588      NaN


Train epoch 16: 100%|████████████████████████████████████████████████████████████████| 910/910 [00:27<00:00, 33.31it/s]


epoch: 16 average loss: 0.14984954985697846


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:09<00:00, 25.55it/s]
Train epoch 17:   0%|▎                                                                 | 4/910 [00:00<00:28, 32.00it/s]

                          precision    recall  fbeta_score  support
Entity span      I-Loc     0.582160  0.367953     0.450909    337.0
                 I-Peop    0.434426  0.235556     0.305476    225.0
                 I-Org     0.338710  0.110526     0.166667    190.0
                 I-Other   0.206897  0.048000     0.077922    125.0
                 macro     0.390548  0.190509     0.250243      NaN
                 micro     0.478873  0.232611     0.313124      NaN
Entity embedding I-Loc     0.890476  0.260809     0.403452    717.0
                 I-Peop    0.993631  0.196721     0.328421    793.0
                 I-Org     0.844828  0.065596     0.121739    747.0
                 I-Other   0.920000  0.064972     0.121372    354.0
                 macro     0.912234  0.147024     0.243746      NaN
                 micro     0.922222  0.158943     0.271153      NaN


Train epoch 17: 100%|████████████████████████████████████████████████████████████████| 910/910 [00:27<00:00, 33.34it/s]


epoch: 17 average loss: 0.14886570802101723


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:09<00:00, 26.05it/s]
Train epoch 18:   0%|▎                                                                 | 4/910 [00:00<00:26, 33.89it/s]

                          precision    recall  fbeta_score  support
Entity span      I-Loc     0.581395  0.370920     0.452899    337.0
                 I-Peop    0.434426  0.235556     0.305476    225.0
                 I-Org     0.333333  0.110526     0.166008    190.0
                 I-Other   0.206897  0.048000     0.077922    125.0
                 macro     0.389013  0.191250     0.250576      NaN
                 micro     0.477855  0.233751     0.313936      NaN
Entity embedding I-Loc     0.890995  0.262204     0.405172    717.0
                 I-Peop    0.993631  0.196721     0.328421    793.0
                 I-Org     0.844828  0.065596     0.121739    747.0
                 I-Other   0.916667  0.062147     0.116402    354.0
                 macro     0.911530  0.146667     0.242934      NaN
                 micro     0.922222  0.158943     0.271153      NaN


Train epoch 18: 100%|████████████████████████████████████████████████████████████████| 910/910 [00:27<00:00, 33.25it/s]


epoch: 18 average loss: 0.14817599429042785


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:09<00:00, 25.43it/s]
Train epoch 19:   0%|▎                                                                 | 4/910 [00:00<00:26, 33.61it/s]

                          precision    recall  fbeta_score  support
Entity span      I-Loc     0.584112  0.370920     0.453721    337.0
                 I-Peop    0.449153  0.235556     0.309038    225.0
                 I-Org     0.338710  0.110526     0.166667    190.0
                 I-Other   0.214286  0.048000     0.078431    125.0
                 macro     0.396565  0.191250     0.251964      NaN
                 micro     0.485782  0.233751     0.315627      NaN
Entity embedding I-Loc     0.890995  0.262204     0.405172    717.0
                 I-Peop    0.993631  0.196721     0.328421    793.0
                 I-Org     0.842105  0.064257     0.119403    747.0
                 I-Other   0.913043  0.059322     0.111406    354.0
                 macro     0.909944  0.145626     0.241101      NaN
                 micro     0.921875  0.158177     0.270023      NaN


Train epoch 19: 100%|████████████████████████████████████████████████████████████████| 910/910 [00:27<00:00, 33.29it/s]


epoch: 19 average loss: 0.14776028453648746


Evaluation dev: 100%|████████████████████████████████████████████████████████████████| 243/243 [00:09<00:00, 25.93it/s]


                          precision    recall  fbeta_score  support
Entity span      I-Loc     0.588785  0.373887     0.457350    337.0
                 I-Peop    0.443478  0.226667     0.300000    225.0
                 I-Org     0.338710  0.110526     0.166667    190.0
                 I-Other   0.214286  0.048000     0.078431    125.0
                 macro     0.396315  0.189770     0.250612      NaN
                 micro     0.486874  0.232611     0.314815      NaN
Entity embedding I-Loc     0.891509  0.263598     0.406889    717.0
                 I-Peop    0.993506  0.192938     0.323126    793.0
                 I-Org     0.842105  0.064257     0.119403    747.0
                 I-Other   0.913043  0.059322     0.111406    354.0
                 macro     0.910041  0.145029     0.240206      NaN
                 micro     0.921525  0.157411     0.268891      NaN


In [16]:
config = BertConfig.from_pretrained(constants.model_path)
spert_model = model.SpERT.from_pretrained(constants.model_path,
                                          # SpERT model parameters
                                          relation_types=constants.relation_types, 
                                          entity_types=constants.entity_types, 
                                          width_embedding_size=constants.width_embedding_size, 
                                          prop_drop=constants.prop_drop, 
                                          freeze_transformer=True, 
                                          max_pairs=constants.max_pairs, 
                                          is_overlapping=False, 
                                          relation_filter_threshold=constants.relation_filter_threshold)
spert_model.to(device)
# state_dict = torch.load(model_save_path + "epoch_" + str(constants.epochs-1) + ".model")
state_dict = torch.load("../../model/spert/pytorch_model.bin")
state_dict["relation_classifier.weight"] = state_dict["rel_classifier.weight"]
state_dict["relation_classifier.bias"] = state_dict["rel_classifier.bias"]
state_dict["width_embedding.weight"] = state_dict["size_embeddings.weight"]
spert_model.load_state_dict(state_dict, strict=False)

Some weights of the model checkpoint at bert-base-cased were not used when initializing SpERT: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing SpERT from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SpERT from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SpERT were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['relation_classifier.weight', 'relation_classifier.bi

_IncompatibleKeys(missing_keys=['bert.embeddings.position_ids'], unexpected_keys=['rel_classifier.weight', 'rel_classifier.bias', 'size_embeddings.weight'])

In [17]:
evaluate(spert_model, "test")

Evaluation test: 100%|███████████████████████████████████████████████████████████████| 288/288 [00:12<00:00, 23.24it/s]


                          precision    recall  fbeta_score  support
Entity span      I-Loc     0.422505  0.932084     0.581446    427.0
                 I-Peop    0.000000  0.000000     0.000000    321.0
                 I-Org     0.001401  0.005051     0.002193    198.0
                 I-Other   0.531579  0.759398     0.625387    133.0
                 macro     0.238871  0.424133     0.302257      NaN
                 micro     0.226552  0.463392     0.304321      NaN
Entity embedding I-Loc     0.946203  0.929534     0.937794    965.0
                 I-Peop    0.004144  0.002671     0.003249   1123.0
                 I-Org     0.002655  0.004172     0.003245    719.0
                 I-Other   0.819767  0.834320     0.826979    338.0
                 macro     0.443192  0.442674     0.442817      NaN
                 micro     0.376669  0.376789     0.376729      NaN
