In [1]:
import os
import math
import random

In [2]:
import pandas as pd
import torch
import transformers
from transformers import AdamW, BertConfig, BertTokenizer

In [3]:
import model
import evaluator

In [4]:
# modify this part for different dataset
import internal_constants as constants
import internal_input_generator as input_generator

In [5]:
from tqdm import tqdm

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
# # To run reference model
# input_generator.parser.entity_encode = {'O': 0, 'B-Loc': 1, 'I-Loc': 1, 'B-Peop': 3, 'I-Peop': 3, 
#                                          'B-Org': 2, 'I-Org': 2, 'B-Other': 4, 'I-Other': 4}
# input_generator.parser.relation_encode = {'N': 0, 'Kill': 2, 'Located_In': 5, 'OrgBased_In': 3,
#                                            'Live_In': 4, 'Work_For': 1}

In [8]:
entity_label_map = {v: k for k, v in input_generator.parser.entity_encode.items()}
entity_classes = list(entity_label_map.keys())
entity_classes.remove(0)

In [9]:
relation_label_map = {v: k for k, v in input_generator.parser.relation_encode.items()}
relation_classes = list(relation_label_map.keys())
relation_classes.remove(0)

In [10]:
tokenizer = BertTokenizer.from_pretrained(constants.model_path)
input_generator.parser.tokenizer = tokenizer

In [11]:
def get_optimizer_params(model):
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_params = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay': constants.weight_decay},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 
         'weight_decay': 0.0}]
    return optimizer_params

In [12]:
def take_first_tokens(embedding, words):
    """Take the embedding of the first token of each word"""
    reduced_embedding = []
    for i, word in enumerate(words):
        if i == 0 or word != words[i-1]:
            reduced_embedding.append(embedding[i])
    return reduced_embedding

In [13]:
def evaluate(spert_model, group):
    spert_model.eval()
    eval_entity_span_pred = []
    eval_entity_span_true = []
    eval_entity_embedding_pred = []
    eval_entity_embedding_true = []
    eval_relation_span_pred = []
    eval_relation_span_true = []
    eval_generator = input_generator.data_generator(group, device, 
                                                    is_training=False,
                                                    neg_entity_count=constants.neg_entity_count, 
                                                    neg_relation_count=constants.neg_relation_count, 
                                                    max_span_size=constants.max_span_size)
    eval_dataset = list(eval_generator)
    eval_size = len(eval_dataset)
    for inputs, infos in tqdm(eval_dataset, total=eval_size, desc="Evaluation "+group):
        # forward
        outputs = spert_model(**inputs, is_training=False)
        # retrieve results for evaluation
        eval_entity_span_pred.append(outputs["entity"]["span"])
        eval_entity_span_true.append(infos["entity_span"])
        if not constants.is_overlapping:
            eval_entity_embedding_pred += take_first_tokens(outputs["entity"]["embedding"].tolist(), infos["words"])
            eval_entity_embedding_true += take_first_tokens(infos["entity_embedding"].tolist(), infos["words"])
            assert len(eval_entity_embedding_pred) == len(eval_entity_embedding_true)
        eval_relation_span_pred.append([] if outputs["relation"] == None else outputs["relation"]["span"])
        eval_relation_span_true.append(infos["relation_span"])
    # evaluate & save
    results = pd.concat([
        evaluator.evaluate_span(eval_entity_span_true, eval_entity_span_pred, entity_label_map, entity_classes),
        evaluator.evaluate_results(eval_entity_embedding_true, eval_entity_embedding_pred, entity_label_map, entity_classes),
        evaluator.evaluate_loose_relation_span(eval_relation_span_true, eval_relation_span_pred, relation_label_map, relation_classes),
        evaluator.evaluate_span(eval_relation_span_true, eval_relation_span_pred, relation_label_map, relation_classes),
    ], keys=["Entity span", "Entity embedding", "Loose relation", "Strict relation"])
    results.to_csv(constants.model_save_path + "evaluate_" + group + ".csv")
    print(results)

In [14]:
def train():
    """Train the model and evaluate on the dev dataset
    """
    # Training
    train_generator = input_generator.data_generator(constants.train_dataset, device, 
                                                     is_training=True,
                                                     neg_entity_count=constants.neg_entity_count, 
                                                     neg_relation_count=constants.neg_relation_count, 
                                                     max_span_size=constants.max_span_size)
    train_dataset = list(train_generator)
    random.shuffle(train_dataset)
    train_size = len(train_dataset)
    config = BertConfig.from_pretrained(constants.model_path)
    spert_model = model.SpERT.from_pretrained(constants.model_path,
                                              config=config,
                                              # SpERT model parameters
                                              relation_types=constants.relation_types, 
                                              entity_types=constants.entity_types, 
                                              width_embedding_size=constants.width_embedding_size, 
                                              prop_drop=constants.prop_drop, 
                                              freeze_transformer=False, 
                                              max_pairs=constants.max_pairs, 
                                              is_overlapping=constants.is_overlapping, 
                                              relation_filter_threshold=constants.relation_filter_threshold)
    spert_model.to(device)
    optimizer_params = get_optimizer_params(spert_model)
    optimizer = AdamW(optimizer_params, lr=constants.lr, weight_decay=constants.weight_decay, correct_bias=False)
    scheduler = transformers.get_linear_schedule_with_warmup(optimizer, 
                             num_warmup_steps=constants.lr_warmup*train_size*constants.epochs, 
                             num_training_steps=train_size*constants.epochs)
    for epoch in range(constants.epochs):
        losses = []
        entity_losses = []
        relation_losses = []
        train_entity_pred = []
        train_entity_true = []
        train_relation_pred = []
        train_relation_true = []
        spert_model.zero_grad()
        for inputs, infos in tqdm(train_dataset, total=train_size, desc='Train epoch %s' % epoch):
            spert_model.train()
            # forward
            outputs = spert_model(**inputs, is_training=True)
            # backward
            loss = outputs["loss"]
            loss.backward()
            torch.nn.utils.clip_grad_norm_(spert_model.parameters(), constants.max_grad_norm)
            optimizer.step()
            scheduler.step()
            spert_model.zero_grad()
            # retrieve results for evaluation
            losses.append(loss.item())
            entity_losses.append(outputs["entity"]["loss"])
            if outputs["relation"] is not None: relation_losses.append(outputs["relation"]["loss"])
                
            train_entity_pred += outputs["entity"]["pred"].tolist()
            train_entity_true += inputs["entity_label"].tolist()
            train_relation_pred += [] if outputs["relation"] == None else outputs["relation"]["pred"].tolist()
            train_relation_true += inputs["relation_label"].tolist()
            assert len(train_entity_pred) == len(train_entity_true)
            assert len(train_relation_pred) == len(train_relation_true)
            
        # evaluate & save checkpoint
        print("epoch:", epoch,"average loss:", sum(losses) / len(losses))
        print("epoch:", epoch,"average entity loss:", sum(entity_losses) / len(entity_losses))
        print("epoch:", epoch,"average relation loss:", sum(relation_losses) / len(relation_losses))
        results = pd.concat([
            evaluator.evaluate_results(train_entity_true, train_entity_pred, entity_label_map, entity_classes),
            evaluator.evaluate_results(train_relation_true, train_relation_pred, relation_label_map, relation_classes)
        ], keys=["Entity", "Relation"])
        results.to_csv(constants.model_save_path + "epoch_" + str(epoch) + ".csv")
#         evaluate(spert_model, constants.dev_dataset)
        
    torch.save(spert_model.state_dict(), constants.model_save_path + "epoch_" + str(constants.epochs-1) + ".model")

In [15]:
def predict(spert_model, sentences):
    for sentence in sentences:
        word_list = sentence.split()
        words = []
        token_ids = []
        # transform a sentence to a document for prediction
        for word in word_list:
            token_id = tokenizer(word)["input_ids"][1:-1]
            for tid in token_id:
                words.append(word)
                token_ids.append(tid)
        data_frame = pd.DataFrame()
        data_frame["words"] = words
        data_frame["token_ids"] = token_ids
        data_frame["entity_embedding"] = 0
        data_frame["sentence_embedding"] = 0 # for internal datasets
        doc = {"data_frame": data_frame,
            "entity_position": {}, # Suppose to appear in non-overlapping dataset
            "entities": {}, # Suppose to appear in overlapping dataset
            "relations": {}}
        # predict
        inputs, infos = input_generator.doc_to_input(doc, device, 
                                                     is_training=False, 
                                                     max_span_size=constants.max_span_size)
        outputs = spert_model(**inputs, is_training=False)
        pred_entity_span = outputs["entity"]["span"]
        pred_relation_span = [] if outputs["relation"] is None else outputs["relation"]["span"]
        # print result
        tokens = tokenizer.convert_ids_to_tokens(token_ids)
        print("Sentence:", sentence)
        print("Entities: (", len(pred_entity_span), ")")
        for begin, end, entity_type in pred_entity_span:
            print(entity_label_map[entity_type], "|", " ".join(tokens[begin:end]))
        print("Relations: (", len(pred_relation_span), ")")
        for e1, e2, relation_type in pred_relation_span:
            print(relation_label_map[relation_type], "|", 
                  " ".join(tokens[e1[0]:e1[1]]), "|", 
                  " ".join(tokens[e2[0]:e2[1]]))

In [16]:
train()

Some weights of the model checkpoint at bert-base-cased were not used when initializing SpERT: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing SpERT from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SpERT from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SpERT were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['relation_classifier.weight', 'relation_classifier.bi

epoch: 0 average loss: 0.15172028699474469
epoch: 0 average entity loss: 0.11030006967447699
epoch: 0 average relation loss: 0.06083981858865757


Train epoch 1: 100%|███████████████████████████████████████████████████████████████| 5398/5398 [15:29<00:00,  5.81it/s]


epoch: 1 average loss: 0.04332021641778717
epoch: 1 average entity loss: 0.03438879502500655
epoch: 1 average relation loss: 0.0131188606501313


Train epoch 2: 100%|███████████████████████████████████████████████████████████████| 5398/5398 [16:06<00:00,  5.58it/s]


epoch: 2 average loss: 0.034399874909490376
epoch: 2 average entity loss: 0.027926940694261618
epoch: 2 average relation loss: 0.00950772760387502


Train epoch 3: 100%|███████████████████████████████████████████████████████████████| 5398/5398 [16:05<00:00,  5.59it/s]


epoch: 3 average loss: 0.025558522234761058
epoch: 3 average entity loss: 0.020832276425999167
epoch: 3 average relation loss: 0.006942115637433896


Train epoch 4: 100%|███████████████████████████████████████████████████████████████| 5398/5398 [15:50<00:00,  5.68it/s]


epoch: 4 average loss: 0.02212322045844141
epoch: 4 average entity loss: 0.018446291719866816
epoch: 4 average relation loss: 0.005400832960524757


Train epoch 5: 100%|███████████████████████████████████████████████████████████████| 5398/5398 [15:52<00:00,  5.67it/s]


epoch: 5 average loss: 0.018231513675384777
epoch: 5 average entity loss: 0.01526657045908748
epoch: 5 average relation loss: 0.004355037675175397


Train epoch 6: 100%|███████████████████████████████████████████████████████████████| 5398/5398 [15:49<00:00,  5.68it/s]


epoch: 6 average loss: 0.01475106143597573
epoch: 6 average entity loss: 0.01245381156459537
epoch: 6 average relation loss: 0.0033743006243261995


Train epoch 7: 100%|███████████████████████████████████████████████████████████████| 5398/5398 [16:10<00:00,  5.56it/s]


epoch: 7 average loss: 0.012405536330103611
epoch: 7 average entity loss: 0.010450324473934548
epoch: 7 average relation loss: 0.002871900289029747


Train epoch 8: 100%|███████████████████████████████████████████████████████████████| 5398/5398 [15:32<00:00,  5.79it/s]


epoch: 8 average loss: 0.01064019513968736
epoch: 8 average entity loss: 0.009081379445582256
epoch: 8 average relation loss: 0.002289656366415395


Train epoch 9: 100%|███████████████████████████████████████████████████████████████| 5398/5398 [16:17<00:00,  5.52it/s]


epoch: 9 average loss: 0.009359292747585762
epoch: 9 average entity loss: 0.007928729949430722
epoch: 9 average relation loss: 0.0021012729079924026


Train epoch 10: 100%|██████████████████████████████████████████████████████████████| 5398/5398 [16:07<00:00,  5.58it/s]


epoch: 10 average loss: 0.007696885762210025
epoch: 10 average entity loss: 0.006639368876114192
epoch: 10 average relation loss: 0.0015533268587273692


Train epoch 11: 100%|██████████████████████████████████████████████████████████████| 5398/5398 [15:05<00:00,  5.96it/s]


epoch: 11 average loss: 0.0069388687949410305
epoch: 11 average entity loss: 0.006087739462728491
epoch: 11 average relation loss: 0.0012501758072963225


Train epoch 12: 100%|██████████████████████████████████████████████████████████████| 5398/5398 [15:15<00:00,  5.89it/s]


epoch: 12 average loss: 0.006004213424350307
epoch: 12 average entity loss: 0.00539060518995347
epoch: 12 average relation loss: 0.0009012945136238629


Train epoch 13: 100%|██████████████████████████████████████████████████████████████| 5398/5398 [16:00<00:00,  5.62it/s]


epoch: 13 average loss: 0.004885807245881634
epoch: 13 average entity loss: 0.004372331818240096
epoch: 13 average relation loss: 0.0007542150425917376


Train epoch 14: 100%|██████████████████████████████████████████████████████████████| 5398/5398 [16:08<00:00,  5.57it/s]


epoch: 14 average loss: 0.004280994186264318
epoch: 14 average entity loss: 0.0038699190061368266
epoch: 14 average relation loss: 0.000603805116150665


Train epoch 15: 100%|██████████████████████████████████████████████████████████████| 5398/5398 [14:55<00:00,  6.03it/s]


epoch: 15 average loss: 0.0036033403589032778
epoch: 15 average entity loss: 0.003332561989033275
epoch: 15 average relation loss: 0.0003977310672201115


Train epoch 16: 100%|██████████████████████████████████████████████████████████████| 5398/5398 [15:30<00:00,  5.80it/s]


epoch: 16 average loss: 0.003086374234258742
epoch: 16 average entity loss: 0.0028721957900093927
epoch: 16 average relation loss: 0.00031459461433918506


Train epoch 17: 100%|██████████████████████████████████████████████████████████████| 5398/5398 [15:58<00:00,  5.63it/s]


epoch: 17 average loss: 0.0026420038069134505
epoch: 17 average entity loss: 0.0025029194141867905
epoch: 17 average relation loss: 0.00020429322236625635


Train epoch 18: 100%|██████████████████████████████████████████████████████████████| 5398/5398 [15:56<00:00,  5.65it/s]


epoch: 18 average loss: 0.002256052722454975
epoch: 18 average entity loss: 0.00215244394547915
epoch: 18 average relation loss: 0.00015218510244380417


Train epoch 19: 100%|██████████████████████████████████████████████████████████████| 5398/5398 [15:52<00:00,  5.67it/s]


epoch: 19 average loss: 0.0019492959071023047
epoch: 19 average entity loss: 0.0018653828046246066
epoch: 19 average relation loss: 0.0001232552027671597


In [17]:
config = BertConfig.from_pretrained(constants.model_path)
spert_model = model.SpERT.from_pretrained(constants.model_path,
                                          config=config,
                                          # SpERT model parameters
                                          relation_types=constants.relation_types, 
                                          entity_types=constants.entity_types, 
                                          width_embedding_size=constants.width_embedding_size, 
                                          prop_drop=constants.prop_drop, 
                                          freeze_transformer=True, 
                                          max_pairs=constants.max_pairs, 
                                          is_overlapping=constants.is_overlapping, 
                                          relation_filter_threshold=constants.relation_filter_threshold)
spert_model.to(device)

Some weights of the model checkpoint at bert-base-cased were not used when initializing SpERT: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing SpERT from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SpERT from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SpERT were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['relation_classifier.weight', 'relation_classifier.bi

SpERT(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      

In [18]:
state_dict = torch.load(constants.model_save_path + "epoch_" + str(constants.epochs-1) + ".model", map_location=device)

In [19]:
# # To run reference model
# state_dict = torch.load("../../model/spert/reference/reference_model.bin")
# state_dict["relation_classifier.weight"] = state_dict["rel_classifier.weight"]
# state_dict["relation_classifier.bias"] = state_dict["rel_classifier.bias"]
# state_dict["width_embedding.weight"] = state_dict["size_embeddings.weight"]

In [20]:
spert_model.load_state_dict(state_dict, strict=False)

<All keys matched successfully>

In [21]:
evaluate(spert_model, constants.test_dataset)

Evaluation Test: 100%|█████████████████████████████████████████████████████████████| 1738/1738 [07:05<00:00,  4.09it/s]


                                      precision    recall  fbeta_score  \
Entity span      EnvironmentalIssues   0.717712  0.798222     0.755829   
                 Date                  0.921212  0.912913     0.917044   
                 Organisation          0.872390  0.619951     0.724819   
                 CommitmentLevel       0.401416  0.541610     0.461092   
                 Location              0.618644  0.782143     0.690852   
                 CoalActivity          0.695652  0.727273     0.711111   
                 SocialIssues          0.793369  0.853942     0.822542   
                 SocialOfficialTexts   0.592593  0.744186     0.659794   
                 macro                 0.701623  0.747530     0.717885   
                 micro                 0.707498  0.740125     0.723444   
Entity embedding EnvironmentalIssues   0.796259  0.906481     0.847803   
                 Date                  0.952830  0.978208     0.965352   
                 Organisation         

In [22]:
# state_dict = torch.load(constants.model_save_path + "epoch_" + str(constants.epochs-1) + ".model")
# state_dict["rel_classifier.weight"] = state_dict["relation_classifier.weight"]
# state_dict["rel_classifier.bias"] = state_dict["relation_classifier.bias"]
# state_dict["size_embeddings.weight"] = state_dict["width_embedding.weight"]
# del state_dict["relation_classifier.weight"]
# del state_dict["relation_classifier.bias"]
# del state_dict["width_embedding.weight"]
# del state_dict["bert.embeddings.position_ids"]
# torch.save(state_dict, "../../model/spert/pytorch_model.bin")

In [23]:
predict(spert_model, sentences=["However, the Rev. Jesse Jackson, a native of South Carolina, joined critics of FEMA's effort.", 
                                "International Paper spokeswoman Ann Silvernail said that under French law the company was barred from releasing details pending government approval."])

Sentence: However, the Rev. Jesse Jackson, a native of South Carolina, joined critics of FEMA's effort.
Entities: ( 0 )
Relations: ( 0 )
Sentence: International Paper spokeswoman Ann Silvernail said that under French law the company was barred from releasing details pending government approval.
Entities: ( 1 )
Organisation | company
Relations: ( 0 )


In [24]:
# test_sentence = " ".join(["An", "art", "exhibit", "at", "the", "Hakawati", "Theatre", "in", "Arab", "east", "Jerusalem", "was", "a", "series", "of", "portraits", "of", "Palestinians", "killed", "in", "the", "rebellion", "."])

In [25]:
# "entities": [{"type": "Loc", "start": 5, "end": 7}, {"type": "Loc", "start": 10, "end": 11}, {"type": "Other", "start": 17, "end": 18}], "relations": [{"type": "Located_In", "head": 0, "tail": 1}]

In [26]:
# predict(spert_model, sentences=[test_sentence])

In [27]:
# test_sentence = " ".join(["PERUGIA", ",", "Italy", "(", "AP", ")"])

In [28]:
# predict(spert_model, sentences=[test_sentence])