In [None]:
import os
import math
import numpy as np
import pandas as pd

In [None]:
import sys
sys.path.append("..\\parser")
import conll04_parser
import model

In [None]:
import torch
from torch.nn import functional as F

In [None]:
from sklearn.metrics import precision_recall_fscore_support

In [None]:
# Constants
NUM_CLASSES = 8 # Number of relation classes
NUM_EPOCH = 3
VALIDATION_SIZE = 100 # Number of observations evalutated in validation step

In [None]:
def data_generator(group):
    data = conll04_parser.extract_data(group)
    for doc in data:
        # If this sentence has at least two entities for a possible relation
        if len(doc["entity_position"]) >= 2:
            new_entity_position = {}
            for entity in doc["entity_position"]:
                new_entity_position[entity] = (
                    doc["entity_position"][entity][0] + 1, # +1: space for CLS token
                    doc["entity_position"][entity][1] + 1  # +1: space for CLS token
                )
            # Add CLS and SEP to the sentence
            input_ids = [conll04_parser.CLS_TOKEN] + doc["data_frame"]["token_ids"].tolist() + [conll04_parser.SEP_TOKEN]
            e1_mask, e2_mask, labels = model.generate_entity_mask(len(input_ids), new_entity_position, doc["relations"])
            assert e1_mask.shape[0] == e2_mask.shape[0] == labels.shape[0]
            assert len(input_ids) == e1_mask.shape[1] == e2_mask.shape[1]
            yield {
                "input_ids": torch.tensor([input_ids]).long(), 
                "attention_mask": torch.ones((1, len(input_ids)), dtype=torch.long),
                "token_type_ids": torch.zeros((1, len(input_ids)), dtype=torch.long),
                "e1_mask": e1_mask,
                "e2_mask": e2_mask,
                "labels": labels
            }
            del e1_mask
            del e2_mask
            del labels

In [None]:
# # Test data_generator()
# generator = data_generator("train")
# # Test on the first document ("1024")
# test_inputs = next(generator)
# assert test_inputs["input_ids"][0, 0] == conll04_parser.CLS_TOKEN
# assert test_inputs["input_ids"][0, 1] == 2200
# assert test_inputs["input_ids"][0, -2] == 1012
# assert test_inputs["input_ids"][0, -1] == conll04_parser.SEP_TOKEN
# assert torch.equal(test_inputs["e1_mask"][0, 22:24], torch.tensor([1, 1]))
# assert torch.equal(test_inputs["e1_mask"][2, 25:28], torch.tensor([1, 1, 1]))
# assert torch.equal(test_inputs["e1_mask"][4, 29:31], torch.tensor([1, 1]))
# assert torch.equal(test_inputs["labels"], torch.tensor([0, 2, 0, 2, 0, 0]))