In [1]:
import os
import math
import random

In [2]:
import pandas as pd
pd.set_option('expand_frame_repr', False)
import matplotlib.pyplot as plt
import torch
import transformers
from transformers import AdamW, BertConfig, BertTokenizer

In [3]:
import model
import evaluator

In [4]:
# modify this part for different dataset
import internal_constants as constants
import internal_input_generator as input_generator

In [5]:
from tqdm import tqdm

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
# # To run reference model
# input_generator.parser.entity_encode = {'O': 0, 'B-Loc': 1, 'I-Loc': 1, 'B-Peop': 3, 'I-Peop': 3, 
#                                          'B-Org': 2, 'I-Org': 2, 'B-Other': 4, 'I-Other': 4}
# input_generator.parser.relation_encode = {'N': 0, 'Kill': 2, 'Located_In': 5, 'OrgBased_In': 3,
#                                            'Live_In': 4, 'Work_For': 1}

In [8]:
entity_label_map = {v: k for k, v in input_generator.parser.entity_encode.items()}
entity_classes = list(entity_label_map.keys())
entity_classes.remove(0)

In [9]:
relation_label_map = {v: k for k, v in input_generator.parser.relation_encode.items()}
relation_classes = list(relation_label_map.keys())
relation_classes.remove(0)

In [10]:
# for internal dataset
relation_possibility = {
    (3, 5): [0, 0, 0, 1, 0, 0, 0, 0], # Organization + Location -> IsRelatedTo
    (3, 6): [0, 0, 0, 0, 1, 0, 0, 0], # Organization + CoalActivity -> HasActivity
    (3, 8): [0, 0, 0, 0, 0, 1, 0, 0], # Organization + SocialOfficialText -> Recognizes
    (3, 4): [0, 1, 0, 0, 0, 0, 0, 0], # Organization + CommitmentLevel -> Makes
    (4, 1): [0, 0, 1, 0, 0, 0, 0, 0], # CommitmentLevel + EnvironmentalIssues -> Of
    (4, 7): [0, 0, 1, 0, 0, 0, 0, 0]  # CommitmentLevel + SocialIssues -> Of
}

relation_classes.remove(6) # remove In
relation_classes.remove(7) # remove IsInvolvedIn

In [11]:
# # for conll04
# relation_possibility = {
#     (2, 2): [0, 1, 0, 0, 0, 0], # people kill people
#     (1, 1): [0, 0, 1, 0, 0, 0], # location located in location
#     (3, 1): [0, 0, 0, 1, 0, 0], # organization based in location
#     (2, 1): [0, 0, 0, 0, 1, 0], # people live in location
#     (2, 3): [0, 0, 0, 0, 0, 1]  # people work for organization
# }

In [12]:
tokenizer = BertTokenizer.from_pretrained(constants.model_path)
input_generator.parser.tokenizer = tokenizer

In [13]:
def get_optimizer_params(model):
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_params = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay': constants.weight_decay},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 
         'weight_decay': 0.0}]
    return optimizer_params

In [14]:
def take_first_tokens(embedding, words):
    """Take the embedding of the first token of each word"""
    reduced_embedding = []
    for i, word in enumerate(words):
        if i == 0 or word != words[i-1]:
            reduced_embedding.append(embedding[i])
    return reduced_embedding

In [15]:
def evaluate(spert_model, group, suffix="", compute_loss=False):
    """
    If 'compute_loss' is True, this function performed similarly to train, without backwarding losses,
    useful for printing losses on the test dataset.
    """
    spert_model.eval()
    eval_generator = input_generator.data_generator(group, device, 
                                                    is_training=compute_loss,
                                                    neg_entity_count=constants.neg_entity_count, 
                                                    neg_relation_count=constants.neg_relation_count, 
                                                    max_span_size=constants.max_span_size)
    eval_dataset = list(eval_generator)
    eval_size = len(eval_dataset)
    
    if compute_loss:
        entity_losses = []
        relation_losses = []
        losses = []
    else:
        eval_entity_span_pred = []
        eval_entity_span_true = []
        eval_entity_embedding_pred = []
        eval_entity_embedding_true = []
        eval_relation_span_pred = []
        eval_relation_span_true = []
        
    for inputs, infos in tqdm(eval_dataset, total=eval_size, desc="Evaluation "+group):
        # forward
        outputs = spert_model(**inputs, is_training=compute_loss)
        
        if compute_loss:
            losses.append(outputs["loss"].item())
            entity_losses.append(outputs["entity"]["loss"])
            if outputs["relation"] is not None: relation_losses.append(outputs["relation"]["loss"])
        else:
            # retrieve results for evaluation
            eval_entity_span_pred.append(outputs["entity"]["span"])
            eval_entity_span_true.append(infos["entity_span"])

            if not constants.is_overlapping:
                eval_entity_embedding_pred += take_first_tokens(outputs["entity"]["embedding"].tolist(), infos["words"])
                eval_entity_embedding_true += take_first_tokens(infos["entity_embedding"].tolist(), infos["words"])
                assert len(eval_entity_embedding_pred) == len(eval_entity_embedding_true)

            eval_relation_span_pred.append([] if outputs["relation"] == None else outputs["relation"]["span"])
            eval_relation_span_true.append(infos["relation_span"])
    
    if compute_loss:
        print("average evaluation loss:", sum(losses) / len(losses))
        print("average evaluation entity loss:", sum(entity_losses) / len(entity_losses))
        print("average evaluation relation loss:", sum(relation_losses) / len(relation_losses))
        return  sum(losses) / len(losses), sum(entity_losses) / len(entity_losses), sum(relation_losses) / len(relation_losses)
    else:
        # evaluate & save
        results = pd.concat([
            evaluator.evaluate_span(eval_entity_span_true, eval_entity_span_pred, entity_label_map, entity_classes),
            evaluator.evaluate_results(eval_entity_embedding_true, eval_entity_embedding_pred, entity_label_map, entity_classes),
            evaluator.evaluate_loose_relation_span(eval_relation_span_true, eval_relation_span_pred, relation_label_map, relation_classes),
            evaluator.evaluate_span(eval_relation_span_true, eval_relation_span_pred, relation_label_map, relation_classes),
        ], keys=["Entity span", "Entity embedding", "Loose relation", "Strict relation"])
        results.to_csv(constants.model_save_path + "evaluate_" + group + "_" + suffix + ".csv")
        print("Evaluation result on " + group + " dataset")
        print(results)

In [16]:
def train():
    """Train the model and evaluate on the dev dataset
    """
    min_loss = math.inf
    wait_time = 0
    train_losses = []
    train_entity_losses = []
    train_relation_losses = []
    eval_losses = []
    eval_entity_losses = []
    eval_relation_losses = []
    
    # Training
    train_generator = input_generator.data_generator(constants.train_dataset, device, 
                                                     is_training=True,
                                                     neg_entity_count=constants.neg_entity_count, 
                                                     neg_relation_count=constants.neg_relation_count, 
                                                     max_span_size=constants.max_span_size)
    train_dataset = list(train_generator)
    random.shuffle(train_dataset)
    train_size = len(train_dataset)
    config = BertConfig.from_pretrained(constants.model_path)
    spert_model = model.SpERT.from_pretrained(constants.model_path,
                                              config=config,
                                              # SpERT model parameters
                                              relation_types=constants.relation_types, 
                                              entity_types=constants.entity_types, 
                                              width_embedding_size=constants.width_embedding_size, 
                                              prop_drop=constants.prop_drop, 
                                              freeze_transformer=False, 
                                              max_pairs=constants.max_pairs, 
                                              is_overlapping=constants.is_overlapping, 
                                              relation_filter_threshold=constants.relation_filter_threshold,
                                              relation_possibility=relation_possibility)
    spert_model.to(device)
    optimizer_params = get_optimizer_params(spert_model)
    optimizer = AdamW(optimizer_params, lr=constants.lr, weight_decay=constants.weight_decay, correct_bias=False)
    scheduler = transformers.get_linear_schedule_with_warmup(optimizer, 
                             num_warmup_steps=constants.lr_warmup*train_size*constants.epochs, 
                             num_training_steps=train_size*constants.epochs)
    
    for epoch in range(constants.epochs):
        losses = []
        entity_losses = []
        relation_losses = []
        train_entity_pred = []
        train_entity_true = []
        train_relation_pred = []
        train_relation_true = []
        spert_model.zero_grad()
        for inputs, infos in tqdm(train_dataset, total=train_size, desc='Train epoch %s' % epoch):
            spert_model.train()
            
            # forward
            outputs = spert_model(**inputs, is_training=True)
            
            # backward
            loss = outputs["loss"]
            loss.backward()
            torch.nn.utils.clip_grad_norm_(spert_model.parameters(), constants.max_grad_norm)
            optimizer.step()
            scheduler.step()
            spert_model.zero_grad()
            
            # retrieve results for evaluation
            losses.append(loss.item())
            entity_losses.append(outputs["entity"]["loss"])
            if outputs["relation"] is not None: relation_losses.append(outputs["relation"]["loss"])

            train_entity_pred += outputs["entity"]["pred"].tolist()
            train_entity_true += inputs["entity_label"].tolist()
            train_relation_pred += [] if outputs["relation"] == None else outputs["relation"]["pred"].tolist()
            train_relation_true += inputs["relation_label"].tolist()
            assert len(train_entity_pred) == len(train_entity_true)
            assert len(train_relation_pred) == len(train_relation_true)
            
        # evaluate & save checkpoint
        print("epoch:", epoch,"average training loss:", sum(losses) / len(losses))
        print("epoch:", epoch,"average training entity loss:", sum(entity_losses) / len(entity_losses))
        print("epoch:", epoch,"average training relation loss:", sum(relation_losses) / len(relation_losses))
        results = pd.concat([
            evaluator.evaluate_results(train_entity_true, train_entity_pred, entity_label_map, entity_classes),
            evaluator.evaluate_results(train_relation_true, train_relation_pred, relation_label_map, relation_classes)
        ], keys=["Entity", "Relation"])
        results.to_csv(constants.model_save_path + "epoch_" + str(epoch) + ".csv")
        
        # evalutation on dev dataset
        eval_loss, eval_entity_loss, eval_relation_loss = evaluate(spert_model, constants.dev_dataset, compute_loss=True)
        evaluate(spert_model, constants.dev_dataset, suffix=str(epoch))
        
        # record losses
        train_losses.append(sum(losses) / len(losses))
        train_entity_losses.append(sum(entity_losses) / len(entity_losses))
        train_relation_losses.append(sum(relation_losses) / len(relation_losses))
        eval_losses.append(eval_loss)
        eval_entity_losses.append(eval_entity_loss)
        eval_relation_losses.append(eval_relation_loss)
        
        torch.save(spert_model.state_dict(), constants.model_save_path + "epoch_" + str(epoch) + ".model")
        
        # early stopping
        if eval_loss < min_loss:
            min_loss = eval_loss
            wait_time = 0
        else:
            wait_time += 1
        if wait_time > constants.patience:
            print("Early stopping at epoch", epoch)
            break

    plt.plot(list(range(len(train_losses))), train_losses, 'r')
    plt.plot(list(range(len(eval_losses))), eval_losses, 'b')
    plt.show()
    plt.plot(list(range(len(train_entity_losses))), train_entity_losses, 'r')
    plt.plot(list(range(len(eval_entity_losses))), eval_entity_losses, 'b')
    plt.show()
    plt.plot(list(range(len(train_relation_losses))), train_relation_losses, 'r')
    plt.plot(list(range(len(eval_relation_losses))), eval_relation_losses, 'b')
    plt.show()

In [17]:
def predict(spert_model, sentences):
    for sentence in sentences:
        word_list = sentence.split()
        words = []
        token_ids = []
        # transform a sentence to a document for prediction
        for word in word_list:
            token_id = tokenizer(word)["input_ids"][1:-1]
            for tid in token_id:
                words.append(word)
                token_ids.append(tid)
        data_frame = pd.DataFrame()
        data_frame["words"] = words
        data_frame["token_ids"] = token_ids
        data_frame["entity_embedding"] = 0
        data_frame["sentence_embedding"] = 0 # for internal datasets
        doc = {"data_frame": data_frame,
            "entity_position": {}, # Suppose to appear in non-overlapping dataset
            "entities": {}, # Suppose to appear in overlapping dataset
            "relations": {}}
        # predict
        inputs, infos = input_generator.doc_to_input(doc, device, 
                                                     is_training=False, 
                                                     max_span_size=constants.max_span_size)
        outputs = spert_model(**inputs, is_training=False)
        pred_entity_span = outputs["entity"]["span"]
        pred_relation_span = [] if outputs["relation"] is None else outputs["relation"]["span"]
        # print result
        tokens = tokenizer.convert_ids_to_tokens(token_ids)
        print("Sentence:", sentence)
        print("Entities: (", len(pred_entity_span), ")")
        for begin, end, entity_type in pred_entity_span:
            print(entity_label_map[entity_type], "|", " ".join(tokens[begin:end]))
        print("Relations: (", len(pred_relation_span), ")")
        for e1, e2, relation_type in pred_relation_span:
            print(relation_label_map[relation_type], "|", 
                  " ".join(tokens[e1[0]:e1[1]]), "|", 
                  " ".join(tokens[e2[0]:e2[1]]))

In [None]:
train()

Some weights of the model checkpoint at bert-base-cased were not used when initializing SpERT: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing SpERT from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SpERT from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SpERT were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['relation_classifier.weight', 'relation_classifier.bi

epoch: 0 average training loss: 0.4843374995312501
epoch: 0 average training entity loss: 0.3458811406296279
epoch: 0 average training relation loss: 0.20337072730089734


Evaluation Test: 100%|█████████████████████████████████████████████████████████████| 1738/1738 [03:46<00:00,  7.68it/s]


average evaluation loss: 0.07585260456281107
average evaluation entity loss: 0.05058491181381801
average evaluation relation loss: 0.03869185011304652


Evaluation Test:  94%|█████████████████████████████████████████████████████████▏   | 1628/1738 [06:22<00:34,  3.19it/s]

In [None]:
config = BertConfig.from_pretrained(constants.model_path)
spert_model = model.SpERT.from_pretrained(constants.model_path,
                                          config=config,
                                          # SpERT model parameters
                                          relation_types=constants.relation_types, 
                                          entity_types=constants.entity_types, 
                                          width_embedding_size=constants.width_embedding_size, 
                                          prop_drop=constants.prop_drop, 
                                          freeze_transformer=True, 
                                          max_pairs=constants.max_pairs, 
                                          is_overlapping=constants.is_overlapping, 
                                          relation_filter_threshold=constants.relation_filter_threshold,
                                          relation_possibility=relation_possibility)
spert_model.to(device)

In [None]:
state_dict = torch.load(constants.model_save_path + "epoch_" + str(29) + ".model", map_location=device)

In [None]:
# # To run reference model
# state_dict = torch.load("../../model/spert/reference/reference_model.bin")
# state_dict["relation_classifier.weight"] = state_dict["rel_classifier.weight"]
# state_dict["relation_classifier.bias"] = state_dict["rel_classifier.bias"]
# state_dict["width_embedding.weight"] = state_dict["size_embeddings.weight"]

In [None]:
spert_model.load_state_dict(state_dict, strict=False)

In [None]:
evaluate(spert_model, constants.test_dataset)

In [None]:
# predict(spert_model, sentences=["The companys main initiative is to develop the Program for the Restitution of Livelihoods in Mozambique and Malawi to assist 15,500 families affected by involuntary displacement due to the Nacala Corridor installation .", 
#                                 "The measures taken by BPCL intended to support rights to exercise freedom of association and bargaining are as follows : BPCL does not discriminate between its permanent and contract employees and they are treated alike by the Company in terms of various aspects especially for protecting their rights and skilling ."])