In [1]:
import os
import math
import numpy as np
import pandas as pd

In [2]:
import sys
sys.path.append("../parser")
import conll04_parser
import model
from transformers import BertTokenizer

In [3]:
conll04_parser.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [4]:
import torch
from torch.nn import functional as F

In [5]:
from sklearn.metrics import precision_recall_fscore_support

In [6]:
# Constants
NUM_CLASSES = 6 # Number of relation classes
NUM_EPOCH = 50

In [7]:
device = torch.device("cuda:0")

In [8]:
def data_generator(group):
    data = conll04_parser.extract_data(group)
    for doc in data:
        # If this sentence has at least two entities for a possible relation
        if len(doc["entity_position"]) >= 2:
            new_entity_position = {}
            for entity in doc["entity_position"]:
                new_entity_position[entity] = (
                    doc["entity_position"][entity][0] + 1, # +1: space for CLS token
                    doc["entity_position"][entity][1] + 1  # +1: space for CLS token
                )
            # Add CLS and SEP to the sentence
            input_ids = [conll04_parser.CLS_TOKEN] + doc["data_frame"]["token_ids"].tolist() + [conll04_parser.SEP_TOKEN]
            e1_mask, e2_mask, labels = model.generate_entity_mask(len(input_ids), new_entity_position, doc["relations"])
            assert e1_mask.shape[0] == e2_mask.shape[0] == labels.shape[0]
            assert len(input_ids) == e1_mask.shape[1] == e2_mask.shape[1]
            yield {
                "input_ids": torch.tensor([input_ids]).long().to(device), 
                "attention_mask": torch.ones((1, len(input_ids)), dtype=torch.long).to(device),
                "token_type_ids": torch.zeros((1, len(input_ids)), dtype=torch.long).to(device),
                "e1_mask": e1_mask.to(device),
                "e2_mask": e2_mask.to(device),
                "labels": labels.to(device)
            }
            del e1_mask
            del e2_mask
            del labels

In [9]:
# # Test data_generator()
# generator = data_generator("train")
# # Test on the first document ("1024")
# test_inputs = next(generator)
# assert test_inputs["input_ids"][0, 0] == conll04_parser.CLS_TOKEN
# assert test_inputs["input_ids"][0, 1] == 2200
# assert test_inputs["input_ids"][0, -2] == 1012
# assert test_inputs["input_ids"][0, -1] == conll04_parser.SEP_TOKEN
# assert torch.equal(test_inputs["e1_mask"][0, 22:24], torch.tensor([1, 1]))
# assert torch.equal(test_inputs["e1_mask"][2, 25:28], torch.tensor([1, 1, 1]))
# assert torch.equal(test_inputs["e1_mask"][4, 29:31], torch.tensor([1, 1]))
# assert torch.equal(test_inputs["labels"], torch.tensor([0, 2, 0, 2, 0, 0]))

In [10]:
mre_model = model.BertForMre(NUM_CLASSES)
mre_model.load_state_dict(torch.load("../../model/re/conll04_50.model"))
mre_model.to(device)

BertForMre(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
 

In [11]:
# Freeze all layers except for the last classifier layer on top
for param in mre_model.parameters():
    param.requires_grad = False
mre_model.classifier.weight.requires_grad = True
mre_model.classifier.bias.requires_grad = True

In [12]:
for param in mre_model.parameters():
    print("size:", param.shape)
    print(param.requires_grad)

size: torch.Size([30522, 768])
False
size: torch.Size([512, 768])
False
size: torch.Size([2, 768])
False
size: torch.Size([768])
False
size: torch.Size([768])
False
size: torch.Size([768, 768])
False
size: torch.Size([768])
False
size: torch.Size([768, 768])
False
size: torch.Size([768])
False
size: torch.Size([768, 768])
False
size: torch.Size([768])
False
size: torch.Size([768, 768])
False
size: torch.Size([768])
False
size: torch.Size([768])
False
size: torch.Size([768])
False
size: torch.Size([3072, 768])
False
size: torch.Size([3072])
False
size: torch.Size([768, 3072])
False
size: torch.Size([768])
False
size: torch.Size([768])
False
size: torch.Size([768])
False
size: torch.Size([768, 768])
False
size: torch.Size([768])
False
size: torch.Size([768, 768])
False
size: torch.Size([768])
False
size: torch.Size([768, 768])
False
size: torch.Size([768])
False
size: torch.Size([768, 768])
False
size: torch.Size([768])
False
size: torch.Size([768])
False
size: torch.Size([768])
False
si

In [13]:
from transformers import AdamW
optimizer = AdamW(mre_model.parameters(), lr=1e-5)

In [14]:
def validate_model():
    val_generator = data_generator("dev")
    true_labels = []
    predicted_labels = []
    for inputs in val_generator:
        # forward
        outputs = mre_model(**inputs)
        true_labels += inputs["labels"].tolist()
        pred_labels = F.softmax(outputs.logits, dim=-1).argmax(dim=1)
        predicted_labels += pred_labels.tolist()
        assert len(predicted_labels) == len(true_labels)
        del inputs
        
    print("[validation]")
    result = pd.DataFrame(columns=["precision", "recall", "fbeta_score", "support"])
    result.loc["macro"] = list(precision_recall_fscore_support(true_labels, predicted_labels, average="macro"))
    result.loc["micro"] = list(precision_recall_fscore_support(true_labels, predicted_labels, average="micro"))
    print(result)

In [15]:
def train_model():
    for epoch in range(NUM_EPOCH):  # loop over the dataset multiple times
        true_labels = []
        predicted_labels = []

        for i, inputs in enumerate(data_generator("train"), 0):
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = mre_model(**inputs)
            loss = outputs.loss
            loss.backward()
            optimizer.step()

            # print statistics
            true_labels += inputs["labels"].tolist()
            pred_labels = F.softmax(outputs.logits, dim=-1).argmax(dim=1)
            predicted_labels += pred_labels.tolist()
            assert len(predicted_labels) == len(true_labels)
            if i % 800 == 799:    # print every 800 mini-batches
                print("[%d, %5d]" % (epoch + 1, i + 1))
                result = pd.DataFrame(columns=["precision", "recall", "fbeta_score", "support"])
                result.loc["macro"] = list(precision_recall_fscore_support(true_labels, predicted_labels, average="macro"))
                result.loc["micro"] = list(precision_recall_fscore_support(true_labels, predicted_labels, average="micro"))
                print(result)
                true_labels = []
                predicted_labels = []

            del inputs
            
        validate_model()

    print('Finished Training')

In [16]:
train_model()

[1,   800]
       precision    recall  fbeta_score  support
macro   0.754528  0.606628     0.661935      NaN
micro   0.932754  0.932754     0.932754      NaN
[validation]
       precision    recall  fbeta_score  support
macro   0.746110  0.598716     0.653817      NaN
micro   0.905797  0.905797     0.905797      NaN
[2,   800]
       precision    recall  fbeta_score  support
macro   0.750365  0.612338     0.667248      NaN
micro   0.932933  0.932933     0.932933      NaN
[validation]
       precision    recall  fbeta_score  support
macro   0.718114  0.565275     0.618478      NaN
micro   0.899931  0.899931     0.899931      NaN
[3,   800]
       precision    recall  fbeta_score  support
macro   0.756164  0.625380     0.675470      NaN
micro   0.934724  0.934724     0.934724      NaN
[validation]
       precision    recall  fbeta_score  support
macro   0.729296  0.574840     0.632862      NaN
micro   0.902692  0.902692     0.902692      NaN
[4,   800]
       precision    recall  fbeta_s

In [17]:
def test_model():
    test_generator = data_generator("test")
    true_labels = []
    predicted_labels = []
    for inputs in test_generator:
        # forward
        outputs = mre_model(**inputs)
        true_labels += inputs["labels"].tolist()
        pred_labels = F.softmax(outputs.logits, dim=-1).argmax(dim=1)
        predicted_labels += pred_labels.tolist()
        assert len(predicted_labels) == len(true_labels)
        del inputs
    
    label_map = {v: k for k, v in conll04_parser.relation_encode.items()}
    classes = list(label_map.keys())
    precision, recall, fbeta_score, support = precision_recall_fscore_support(true_labels, predicted_labels, average=None, labels=classes)
    result = pd.DataFrame(index=[label_map[c] for c in classes])
    result["precision"] = precision
    result["recall"] = recall
    result["fbeta_score"] = fbeta_score
    result["support"] = support
    result.loc["macro"] = list(precision_recall_fscore_support(true_labels, predicted_labels, average="macro"))
    result.loc["micro"] = list(precision_recall_fscore_support(true_labels, predicted_labels, average="micro"))
    
    print(result)
    return result

In [18]:
result = test_model()

             precision    recall  fbeta_score  support
N             0.949753  0.961765     0.955721   3400.0
Kill          0.770833  0.787234     0.778947     47.0
Located_In    0.695652  0.510638     0.588957     94.0
OrgBased_In   0.644737  0.466667     0.541436    105.0
Live_In       0.514019  0.550000     0.531401    100.0
Work_For      0.696203  0.723684     0.709677     76.0
macro         0.711866  0.666665     0.684357      NaN
micro         0.919414  0.919414     0.919414      NaN


In [19]:
result.to_csv("conll04_100_result.csv")

In [20]:
torch.save(mre_model.state_dict(), "../../model/re/conll04_100.model")