In [1]:
import os
import math
import random

In [2]:
import pandas as pd
import torch
import transformers
from transformers import AdamW, BertConfig, BertTokenizer

In [3]:
import model
import evaluator

In [4]:
# modify this part for different dataset
import internal_constants as constants
import internal_input_generator as input_generator

In [5]:
from tqdm import tqdm

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
# # To run reference model
# input_generator.parser.entity_encode = {'O': 0, 'B-Loc': 1, 'I-Loc': 1, 'B-Peop': 3, 'I-Peop': 3, 
#                                          'B-Org': 2, 'I-Org': 2, 'B-Other': 4, 'I-Other': 4}
# input_generator.parser.relation_encode = {'N': 0, 'Kill': 2, 'Located_In': 5, 'OrgBased_In': 3,
#                                            'Live_In': 4, 'Work_For': 1}

In [8]:
entity_label_map = {v: k for k, v in input_generator.parser.entity_encode.items()}
entity_classes = list(entity_label_map.keys())
entity_classes.remove(0)

In [9]:
relation_label_map = {v: k for k, v in input_generator.parser.relation_encode.items()}
relation_classes = list(relation_label_map.keys())
relation_classes.remove(0)

In [10]:
tokenizer = BertTokenizer.from_pretrained(constants.model_path)
input_generator.parser.tokenizer = tokenizer

In [11]:
def get_optimizer_params(model):
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_params = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         'weight_decay': constants.weight_decay},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 
         'weight_decay': 0.0}]
    return optimizer_params

In [12]:
def take_first_tokens(embedding, words):
    """Take the embedding of the first token of each word"""
    reduced_embedding = []
    for i, word in enumerate(words):
        if i == 0 or word != words[i-1]:
            reduced_embedding.append(embedding[i])
    return reduced_embedding

In [13]:
def evaluate(spert_model, group):
    spert_model.eval()
    eval_entity_span_pred = []
    eval_entity_span_true = []
    eval_entity_embedding_pred = []
    eval_entity_embedding_true = []
    eval_relation_span_pred = []
    eval_relation_span_true = []
    eval_generator = input_generator.data_generator(group, device, 
                                                    is_training=False,
                                                    neg_entity_count=constants.neg_entity_count, 
                                                    neg_relation_count=constants.neg_relation_count, 
                                                    max_span_size=constants.max_span_size)
    eval_dataset = list(eval_generator)
    eval_size = len(eval_dataset)
    for inputs, infos in tqdm(eval_dataset, total=eval_size, desc="Evaluation "+group):
        # forward
        outputs = spert_model(**inputs, is_training=False)
        # retrieve results for evaluation
        eval_entity_span_pred.append(outputs["entity"]["span"])
        eval_entity_span_true.append(infos["entity_span"])
        if not constants.is_overlapping:
            eval_entity_embedding_pred += take_first_tokens(outputs["entity"]["embedding"].tolist(), infos["words"])
            eval_entity_embedding_true += take_first_tokens(infos["entity_embedding"].tolist(), infos["words"])
            assert len(eval_entity_embedding_pred) == len(eval_entity_embedding_true)
        eval_relation_span_pred.append([] if outputs["relation"] == None else outputs["relation"]["span"])
        eval_relation_span_true.append(infos["relation_span"])
    # evaluate & save
    results = pd.concat([
        evaluator.evaluate_span(eval_entity_span_true, eval_entity_span_pred, entity_label_map, entity_classes),
        evaluator.evaluate_results(eval_entity_embedding_true, eval_entity_embedding_pred, entity_label_map, entity_classes),
        evaluator.evaluate_loose_relation_span(eval_relation_span_true, eval_relation_span_pred, relation_label_map, relation_classes),
        evaluator.evaluate_span(eval_relation_span_true, eval_relation_span_pred, relation_label_map, relation_classes),
    ], keys=["Entity span", "Entity embedding", "Loose relation", "Strict relation"])
    results.to_csv(constants.model_save_path + "evaluate_" + group + ".csv")
    print(results)

In [14]:
def train():
    """Train the model and evaluate on the dev dataset
    """
    # Training
    train_generator = input_generator.data_generator(constants.train_dataset, device, 
                                                     is_training=True,
                                                     neg_entity_count=constants.neg_entity_count, 
                                                     neg_relation_count=constants.neg_relation_count, 
                                                     max_span_size=constants.max_span_size)
    train_dataset = list(train_generator)
    random.shuffle(train_dataset)
    train_size = len(train_dataset)
    config = BertConfig.from_pretrained(constants.model_path)
    spert_model = model.SpERT.from_pretrained(constants.model_path,
                                              config=config,
                                              # SpERT model parameters
                                              relation_types=constants.relation_types, 
                                              entity_types=constants.entity_types, 
                                              width_embedding_size=constants.width_embedding_size, 
                                              prop_drop=constants.prop_drop, 
                                              freeze_transformer=False, 
                                              max_pairs=constants.max_pairs, 
                                              is_overlapping=constants.is_overlapping, 
                                              relation_filter_threshold=constants.relation_filter_threshold)
    spert_model.to(device)
    optimizer_params = get_optimizer_params(spert_model)
    optimizer = AdamW(optimizer_params, lr=constants.lr, weight_decay=constants.weight_decay, correct_bias=False)
    scheduler = transformers.get_linear_schedule_with_warmup(optimizer, 
                             num_warmup_steps=constants.lr_warmup*train_size*constants.epochs, 
                             num_training_steps=train_size*constants.epochs)
    for epoch in range(constants.epochs):
        losses = []
        entity_losses = []
        relation_losses = []
        train_entity_pred = []
        train_entity_true = []
        train_relation_pred = []
        train_relation_true = []
        spert_model.zero_grad()
        for inputs, infos in tqdm(train_dataset, total=train_size, desc='Train epoch %s' % epoch):
            
#             for i in range(inputs["entity_mask"].shape[0]):
#                 print(inputs["entity_label"][i].item(), 
#                       [word for j, word in enumerate(infos["words"].tolist()) if inputs["entity_mask"][i, j] == 1],
#                       [idx for j, idx in enumerate(infos["entity_embedding"].tolist()) if inputs["entity_mask"][i, j] == 1])
            
            spert_model.train()
            # forward
            outputs = spert_model(**inputs, is_training=True)
            # backward
            loss = outputs["loss"]
            loss.backward()
            torch.nn.utils.clip_grad_norm_(spert_model.parameters(), constants.max_grad_norm)
            optimizer.step()
            scheduler.step()
            spert_model.zero_grad()
            # retrieve results for evaluation
            losses.append(loss.item())
            entity_losses.append(outputs["entity"]["loss"])
            if outputs["relation"] is not None: relation_losses.append(outputs["relation"]["loss"])
                
            train_entity_pred += outputs["entity"]["pred"].tolist()
            train_entity_true += inputs["entity_label"].tolist()
            train_relation_pred += [] if outputs["relation"] == None else outputs["relation"]["pred"].tolist()
            train_relation_true += inputs["relation_label"].tolist()
            assert len(train_entity_pred) == len(train_entity_true)
            assert len(train_relation_pred) == len(train_relation_true)
            
        # evaluate & save checkpoint
        print("epoch:", epoch,"average loss:", sum(losses) / len(losses))
        print("epoch:", epoch,"average entity loss:", sum(entity_losses) / len(entity_losses))
        print("epoch:", epoch,"average relation loss:", sum(relation_losses) / len(relation_losses))
        results = pd.concat([
            evaluator.evaluate_results(train_entity_true, train_entity_pred, entity_label_map, entity_classes),
            evaluator.evaluate_results(train_relation_true, train_relation_pred, relation_label_map, relation_classes)
        ], keys=["Entity", "Relation"])
        results.to_csv(constants.model_save_path + "epoch_" + str(epoch) + ".csv")
        evaluate(spert_model, constants.dev_dataset)
        
    torch.save(spert_model.state_dict(), constants.model_save_path + "epoch_" + str(constants.epochs-1) + ".model")

In [15]:
def predict(spert_model, sentences):
    for sentence in sentences:
        word_list = sentence.split()
        words = []
        token_ids = []
        # transform a sentence to a document for prediction
        for word in word_list:
            token_id = tokenizer(word)["input_ids"][1:-1]
            for tid in token_id:
                words.append(word)
                token_ids.append(tid)
        data_frame = pd.DataFrame()
        data_frame["words"] = words
        data_frame["token_ids"] = token_ids
        data_frame["entity_embedding"] = 0
        data_frame["sentence_embedding"] = 0 # for internal datasets
        doc = {"data_frame": data_frame,
            "entity_position": {}, # Suppose to appear in non-overlapping dataset
            "entities": {}, # Suppose to appear in overlapping dataset
            "relations": {}}
        # predict
        inputs, infos = input_generator.doc_to_input(doc, device, 
                                                     is_training=False, 
                                                     max_span_size=constants.max_span_size)
        outputs = spert_model(**inputs, is_training=False)
        pred_entity_span = outputs["entity"]["span"]
        pred_relation_span = [] if outputs["relation"] is None else outputs["relation"]["span"]
        # print result
        tokens = tokenizer.convert_ids_to_tokens(token_ids)
        print("Sentence:", sentence)
        print("Entities: (", len(pred_entity_span), ")")
        for begin, end, entity_type in pred_entity_span:
            print(entity_label_map[entity_type], "|", " ".join(tokens[begin:end]))
        print("Relations: (", len(pred_relation_span), ")")
        for e1, e2, relation_type in pred_relation_span:
            print(relation_label_map[relation_type], "|", 
                  " ".join(tokens[e1[0]:e1[1]]), "|", 
                  " ".join(tokens[e2[0]:e2[1]]))

In [16]:
train()

Some weights of the model checkpoint at bert-base-cased were not used when initializing SpERT: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing SpERT from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing SpERT from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of SpERT were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['relation_classifier.weight', 'relation_classifier.bi

1 ['waste', 'water'] [1.0, 1.0]
3 ['tupras', 'tupras', 'tupras'] [3.0, 3.0, 3.0]
2 ['2019'] [2.0]
1 ['water'] [1.0]
1 ['wastewater', 'wastewater'] [1.0, 1.0]
0 ['in', '2019', ',', 'a', 'total', 'of', '20.6', '20.6', '20.6', 'million', 'm', '3'] [0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
0 ['million', 'm', '3', 'water', 'is'] [0.0, 0.0, 0.0, 1.0, 0.0]


Train epoch 0:   0%|                                                                | 1/5398 [00:02<3:54:51,  2.61s/it]

0 ['each', 'and', 'every'] [0.0, 0.0, 0.0]
2 ['june', 'june', '2017'] [2.0, 2.0, 2.0]
0 ['initiative', 'to', 'as'] [0.0, 0.0, 0.0]
3 ['we'] [3.0]
7 ['employee'] [7.0]


Train epoch 0:   0%|                                                                | 2/5398 [00:04<3:41:34,  2.46s/it]

0 ['all'] [0.0]
1 ['water', 'stress'] [1.0, 1.0]
4 ['level'] [4.0]
0 ['action', 'plans', 'in', 'consultation'] [0.0, 0.0, 0.0, 0.0]
3 ['group'] [3.0]


Train epoch 0:   0%|                                                                | 3/5398 [00:07<3:51:55,  2.58s/it]

0 [',', 'and', 'internal', 'and', 'external', 'audits', 'audits', ','] [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
4 ['training'] [4.0]
7 ['safety'] [7.0]
0 ['audits', 'audits', ',', 'with', 'corrective', 'corrective', 'actions', 'for', 'all', 'audit'] [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]


Train epoch 0:   0%|                                                                | 4/5398 [00:10<3:52:32,  2.59s/it]

3 ['vinci', 'vinci', 'vinci'] [3.0, 3.0, 3.0]
7 ['social'] [7.0]
0 ['vinci', 'vinci', 'vinci', 'companies', 'provide', 'solutions'] [3.0, 3.0, 3.0, 0.0, 0.0, 0.0]
0 ['natural', 'environments', 'report', 'of', 'the', 'board', 'of', 'social'] [1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 7.0]
1 ['natural', 'environments'] [1.0, 1.0]
3 ['vinci', 'vinci', 'vinci'] [3.0, 3.0, 3.0]
1 ['climate', 'change'] [1.0, 1.0]
2 ['2019'] [2.0]
1 ['environmental'] [1.0]


Train epoch 0:   0%|                                                                | 5/5398 [00:12<3:53:56,  2.60s/it]

0 ['energy', 'purchases', 'in', '2018', 'covered', 'over', '6,000', '6,000', '6,000', 'mcdonalds', 'mcdonalds', 'mcdonalds', 'mcdonalds', 'mcdonalds', 'restaurantsworth', 'restaurantsworth', 'of'] [1.0, 4.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 3.0, 3.0, 3.0, 0.0, 0.0, 0.0]
3 ['mcdonalds', 'mcdonalds', 'mcdonalds', 'mcdonalds', 'mcdonalds'] [3.0, 3.0, 3.0, 3.0, 3.0]
1 ['renewable', 'energy'] [1.0, 1.0]
4 ['purchases'] [4.0]
0 ['restaurantsworth', 'restaurantsworth', 'of', 'electricity', 'across', 'twelve', 'markets'] [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
2 ['2018'] [2.0]


Train epoch 0:   0%|                                                                | 6/5398 [00:15<3:38:08,  2.43s/it]

7 ['human', 'rights'] [7.0, 7.0]
7 ['harassment'] [7.0]
7 ['supplier'] [7.0]
4 ['policy'] [4.0]
7 ['supply', 'chain'] [7.0, 7.0]
7 ['bullying', 'free', 'workplace'] [7.0, 7.0, 7.0]
4 ['code', 'of', 'conduct'] [4.0, 4.0, 4.0]
4 ['standards'] [4.0]
0 ['report', 'potential'] [0.0, 0.0]
7 ['suppliers'] [7.0]
3 ['alcoa', 'alcoa'] [3.0, 3.0]
4 ['respect'] [4.0]
4 ['policy'] [4.0]
4 ['assessment', 'programs'] [4.0, 4.0]
7 ['human', 'rights'] [7.0, 7.0]
0 ['assessment', 'programs', 'for', 'new', 'and'] [4.0, 4.0, 0.0, 0.0, 0.0]
4 ['integrity', 'line'] [4.0, 4.0]
7 ['employees'] [7.0]
7 ['supplier'] [7.0]
4 ['training'] [4.0]
7 ['suppliers'] [7.0]
7 ['human', 'rights'] [7.0, 7.0]
7 ['employee'] [7.0]


Train epoch 0:   0%|                                                                | 7/5398 [00:18<4:04:42,  2.72s/it]

3 ['we'] [3.0]
0 ['2020', ',', 'we', 'aim'] [2.0, 0.0, 3.0, 0.0]
7 ['injuries'] [7.0]
4 ['reduction'] [4.0]
2 ['2020'] [2.0]
0 ['per'] [0.0]


Train epoch 0:   0%|                                                                | 8/5398 [00:21<4:04:29,  2.72s/it]

0 ['or'] [0.0]
0 ['our', 'portfolio', 'spans', '15', 'countries'] [0.0, 0.0, 0.0, 0.0, 0.0]


Train epoch 0:   0%|                                                                | 9/5398 [00:23<3:47:54,  2.54s/it]

7 ['indigenous', 'communities'] [7.0, 7.0]
0 ['dialogue'] [0.0]
0 ['comprehensive', 'approach', 'piloted', 'piloted', 'across'] [0.0, 0.0, 0.0, 0.0, 0.0]


Train epoch 0:   0%|                                                               | 10/5398 [00:25<3:51:05,  2.57s/it]

0 ['strengthen', 'employee'] [0.0, 7.0]
5 ['spain', 'spain'] [5.0, 5.0]
5 ['willowdale', 'willowdale', 'willowdale'] [5.0, 5.0, 5.0]
7 ['employee'] [7.0]
5 ['san', 'san', 'cipran', 'cipran', 'cipran'] [5.0, 5.0, 5.0, 5.0, 5.0]
0 ['australia', 'australia', 'australia', 'and'] [5.0, 5.0, 5.0, 0.0]
7 ['safety'] [7.0]
3 ['we'] [3.0]
5 ['australia', 'australia', 'australia'] [5.0, 5.0, 5.0]
2 ['2019'] [2.0]


Train epoch 0:   0%|▏                                                              | 11/5398 [00:28<3:50:34,  2.57s/it]

0 ['of', 'the', 'air', 'canada', 'canada'] [0.0, 0.0, 3.0, 3.0, 3.0]
0 ['the', 'basis', 'of', 'the', 'air', 'canada', 'canada', 'sms', 'sms', '.'] [0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 3.0, 4.0, 4.0, 0.0]
4 ['sms', 'sms'] [4.0, 4.0]
3 ['air', 'canada', 'canada'] [3.0, 3.0, 3.0]


Train epoch 0:   0%|▏                                                              | 11/5398 [00:29<4:00:48,  2.68s/it]


KeyboardInterrupt: 

In [None]:
config = BertConfig.from_pretrained(constants.model_path)
spert_model = model.SpERT.from_pretrained(constants.model_path,
                                          config=config,
                                          # SpERT model parameters
                                          relation_types=constants.relation_types, 
                                          entity_types=constants.entity_types, 
                                          width_embedding_size=constants.width_embedding_size, 
                                          prop_drop=constants.prop_drop, 
                                          freeze_transformer=True, 
                                          max_pairs=constants.max_pairs, 
                                          is_overlapping=constants.is_overlapping, 
                                          relation_filter_threshold=constants.relation_filter_threshold)
spert_model.to(device)

In [None]:
state_dict = torch.load(constants.model_save_path + "epoch_" + str(constants.epochs-1) + ".model", map_location=device)

In [None]:
# # To run reference model
# state_dict = torch.load("../../model/spert/reference/reference_model.bin")
# state_dict["relation_classifier.weight"] = state_dict["rel_classifier.weight"]
# state_dict["relation_classifier.bias"] = state_dict["rel_classifier.bias"]
# state_dict["width_embedding.weight"] = state_dict["size_embeddings.weight"]

In [None]:
spert_model.load_state_dict(state_dict, strict=False)

In [None]:
evaluate(spert_model, constants.test_dataset)

In [None]:
# state_dict = torch.load(constants.model_save_path + "epoch_" + str(constants.epochs-1) + ".model")
# state_dict["rel_classifier.weight"] = state_dict["relation_classifier.weight"]
# state_dict["rel_classifier.bias"] = state_dict["relation_classifier.bias"]
# state_dict["size_embeddings.weight"] = state_dict["width_embedding.weight"]
# del state_dict["relation_classifier.weight"]
# del state_dict["relation_classifier.bias"]
# del state_dict["width_embedding.weight"]
# del state_dict["bert.embeddings.position_ids"]
# torch.save(state_dict, "../../model/spert/pytorch_model.bin")

In [None]:
predict(spert_model, sentences=["However, the Rev. Jesse Jackson, a native of South Carolina, joined critics of FEMA's effort.", 
                                "International Paper spokeswoman Ann Silvernail said that under French law the company was barred from releasing details pending government approval."])

In [None]:
# test_sentence = " ".join(["An", "art", "exhibit", "at", "the", "Hakawati", "Theatre", "in", "Arab", "east", "Jerusalem", "was", "a", "series", "of", "portraits", "of", "Palestinians", "killed", "in", "the", "rebellion", "."])

In [None]:
# "entities": [{"type": "Loc", "start": 5, "end": 7}, {"type": "Loc", "start": 10, "end": 11}, {"type": "Other", "start": 17, "end": 18}], "relations": [{"type": "Located_In", "head": 0, "tail": 1}]

In [None]:
# predict(spert_model, sentences=[test_sentence])

In [None]:
# test_sentence = " ".join(["PERUGIA", ",", "Italy", "(", "AP", ")"])

In [None]:
# predict(spert_model, sentences=[test_sentence])