**Import packages**

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from copy import copy
from tqdm import tqdm
import json

import torch
from transformers import RobertaTokenizer, RobertaModel, RobertaConfig
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader


**Read all relation labels**

In [2]:
loc = './dataset/NYT29/relations.txt'
RELATION_LABELS = [f.split('/')[-1] for f in open(loc, 'r').read().splitlines()]
RELATION_LABELS.append('other')

print('Dataset contains {} relations:'.format(len(RELATION_LABELS)))
print(RELATION_LABELS)

Dataset contains 30 relations:
['country', 'capital', 'administrative_divisions', 'neighborhood_of', 'contains', 'nationality', 'place_lived', 'place_of_death', 'company', 'capital', 'place_of_birth', 'children', 'founders', 'place_founded', 'location', 'ethnicity', 'geographic_distribution', 'religion', 'major_shareholders', 'capital', 'capital', 'advisors', 'featured_in_films', 'featured_film_locations', 'county_seat', 'locations', 'place_of_burial', 'interred_here', 'companies_advised', 'other']


**Hyperparameters**

In [12]:
NUM_CLASSES = len(RELATION_LABELS)
MODEL_NAME = 'roberta-base'
MAX_LEN = 128

LR = 2e-5
TRAINING_BATCH_SIZE = 16
VAL_BATCH_SIZE = 1
EPOCHS = 4
DROPOUT = 0.3
NUM_WORKERS = 0

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

**Define dataset class for processing inputs**

In [45]:
def clean_entity(entity_list):
    if not isinstance(entity_list, list):
        entity_list = [entity_list]
    # remove white spaces and newline characters
    for idx, e in enumerate(entity_list):
        e = e.strip()
        entity_list[idx] = e
    return entity_list

class Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset_dict, tokenizer, max_len, add_CLS=False):
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.sentences = dataset_dict['sent']
        self.relation_tuples = dataset_dict['tup']
        if len(self.sentences) != len(self.relation_tuples):
            raise ValueError('The number of sentences and relation tuples are not equal.')
        self.add_CLS = add_CLS
        
        self.separator = '|'
    
    def __len__(self):
        return len(self.sentences)
    
    def __getitem__(self, index):
        # get the sentence and the corresponding relation tuple
        sentence = self.sentences[index]
        relation_tuples = self.relation_tuples[index]
        relation_tuples = relation_tuples.split(self.separator)

        # find all entities mentioned
        all_word_pairs_mentioned = [clean_entity(x.split(';')[:2]) for x in relation_tuples]
        all_entity_pairs_mentioned = [clean_entity(x.split(';')[-1].split('/')[1:3]) for x in relation_tuples]
        all_relation_labels = [clean_entity(x.split(';')[-1].split('/')[-1])[0] for x in relation_tuples]
        all_entities = [item for sublist in all_entity_pairs_mentioned for item in sublist]
        all_entities = list(set(all_entities))
        all_possible_entity_pairs = [[x, y] for x in all_entities for y in all_entities if x != y]

        # look at all possible entity pairs; positive examples: if entity pair is mentioned, negative examples: if entity pair is not mentioned
        # replace the word in the sentence with its NER tag and tokenize the masked sentence
        # save the tokenized sentences for each entity mention and the corresponding relation label
        output = {'masked sentences': [],
                  'input_ids': [],
                  'attention_masks': [],
                  'target_label_ids': [],
                  'entity_pairs': []}
        
        for entity_pair in all_possible_entity_pairs:
            ''' positive examples with given relations '''
            if entity_pair in all_entity_pairs_mentioned:
                sentence_copy = copy(sentence)
                idx = all_entity_pairs_mentioned.index(entity_pair)
                word_pair = all_word_pairs_mentioned[idx]

                word1, word2 = word_pair
                entity1, entity2 = entity_pair
                relation_label = all_relation_labels[idx]
                relation_label = relation_label

                masked_sentence = sentence_copy.replace(word1, entity1)
                masked_sentence = masked_sentence.replace(word2, entity2)
            else:
                ''' negative examples with "other" relation '''
                sentence_copy = copy(sentence)

                entity1, entity2 = entity_pair
                # find corresponding word
                for idx1, e in enumerate(all_entity_pairs_mentioned):
                    try:
                        idx2 = e.index(entity1)
                        word1 = all_word_pairs_mentioned[idx1][idx2]
                    except:
                        pass
                    try:
                        idx3 = e.index(entity2)
                        word2 = all_word_pairs_mentioned[idx1][idx3]
                    except:
                        continue
                relation_label = 'other'

                masked_sentence = sentence_copy.replace(word1, entity1)
                masked_sentence = masked_sentence.replace(word2, entity2)

            # tokenize the masked sentence
            encoding = self.tokenizer(masked_sentence, max_length=self.max_len, padding='max_length', 
                                      return_attention_mask=True, truncation=True, add_special_tokens=self.add_CLS)
            input_ids = torch.tensor(encoding['input_ids'], dtype=torch.long)
            attention_mask = torch.tensor(encoding['attention_mask'], dtype=torch.long)
            target_label_id = torch.tensor(RELATION_LABELS.index(relation_label), dtype=torch.long)

            output['masked sentences'].append(masked_sentence)
            output['input_ids'].append(input_ids)
            output['entity_pairs'].append(entity_pair)
            output['attention_masks'].append(attention_mask)
            output['target_label_ids'].append(target_label_id)

        return output


**Define classifier**

In [46]:
class Classifier(nn.Module):
    def __init__(self, dropout=0.3, num_classes=30):
        super(Classifier, self).__init__()

        self.roberta = RobertaModel.from_pretrained(MODEL_NAME)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.roberta.config.hidden_size, num_classes)

        self.loss_fcn = nn.CrossEntropyLoss()

    def forward(self, ids, masks, targets):
        outputs = self.roberta(ids, attention_mask=masks)
        sequence_output = outputs[0]
        pooled_output = outputs[1]
        context = self.dropout(pooled_output)

        logits = self.classifier(pooled_output)
        _, prediction = torch.argmax(logits, dim=1)

        loss = self.loss_fcn(logits, targets)
        return loss, prediction

**Load the dataset**

In [47]:
dataset_dir = './dataset/NYT29/'

file_types = ['.sent', '.tup']
datasets = ['train', 'test', 'dev']

train = {}
test = {}
dev = {}

for d in datasets:
    for t in file_types:
        with open(os.path.join(dataset_dir, f'{d}{t}'), 'r') as f:
            if t == '.sent':
                exec(f'{d}["sent"] = f.read().splitlines()')
            else:
                exec(f'{d}["tup"] = f.read().splitlines()')



In [48]:
tokenizer = RobertaTokenizer.from_pretrained(MODEL_NAME)

train_dataset = Dataset(train, tokenizer, MAX_LEN)
test_dataset = Dataset(test, tokenizer, MAX_LEN)
dev_dataset = Dataset(dev, tokenizer, MAX_LEN)

In [49]:
def trainer(model, train_dataset_loader, dev_dataset_loader, optimizer, epochs, device):
    

    ''' train '''
    training_loss = []
    training_accuracy = []
    for e in tqdm(range(epochs)):
        correct_predictions = 0
        for data in tqdm(train_dataset_loader):
            ids = data['id'].to(device)
            masks = data['mask'].to(device)
            targets = data['target'].to(device)

            loss, prediction = model(ids, masks, targets)
            training_loss.append(loss.item())
            correct_predictions += torch.sum(prediction == targets).cpu().detach().numpy()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        training_accuracy.append((e, correct_predictions / len(train_dataset)))

        ''' validation '''
        validation_loss = []
        validation_accuracy = []
        model.eval()
        for data in tqdm(dev_dataset_loader):
            ids = data['id'].to(device)
            masks = data['mask'].to(device)
            targets = data['target'].to(device)

            loss, prediction = model(ids, masks, targets)
            validation_loss.append(loss.item())
            correct_predictions += torch.sum(prediction == targets).cpu().detach().numpy()
        validation_accuracy.append((e, correct_predictions / len(dev_dataset)))    

    # save the trained models
    model_checkpoint = dict()
    model_checkpoint['model_state_dict'] = model.state_dict()
    model_checkpoint['optimizer_state_dict'] = optimizer.state_dict()
    model_checkpoint['training_loss'] = training_loss
    model_checkpoint['training_accuracy'] = training_accuracy
    model_checkpoint['validation_loss'] = validation_loss
    model_checkpoint['validation_accuracy'] = validation_accuracy
    torch.save(model_checkpoint, f'./save_data/model_checkpoint.pth')
    return training_loss, training_accuracy, validation_loss, validation_accuracy

In [50]:
params_dataLoader = {'batch_size': TRAINING_BATCH_SIZE,
                     'shuffle': True,
                     'num_workers': NUM_WORKERS}
train_dataset_loader = torch.utils.data.DataLoader(train_dataset, **params_dataLoader)

params_dataLoader_eval = {'batch_size': VAL_BATCH_SIZE,
                          'shuffle': True,
                          'num_workers': NUM_WORKERS}
dev_dataset_loader = torch.utils.data.DataLoader(dev_dataset, **params_dataLoader)


In [51]:
model = Classifier(dropout=DROPOUT, num_classes=NUM_CLASSES).to(device)
optimizer = optim.Adam(model.parameters(), lr=LR)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [52]:
load = all([os.path.exists(f'save_data/{f}') for f in ['model_checkpoint.pth', 'training_loss.json', 'training_accuracy.json']])

if load:
    model_dir = 'save_data/model_checkpoint.pth'
    model_checkpoint = torch.load(model_dir)
    model.load_state_dict(model_checkpoint['model_state_dict'])

    fname = os.path.join(f'./save_data/training_loss.json')
    with open(fname, 'r') as f:
        training_loss = json.load(f)
    fname = os.path.join(f'./save_data/training_accuracy.json')
    with open(fname, 'r') as f:
        training_accuracy = json.load(f)
    fname = os.path.join(f'./save_data/validation_loss.json')
    with open(fname, 'r') as f:
        validation_loss = json.load(f)
    fname = os.path.join(f'./save_data/validation_accuracy.json')
    with open(fname, 'r') as f:
        validation_accuracy = json.load(f)
else:
    if not os.path.exists('save_data/'):
        os.makedirs('save_data')
    training_loss, training_accuracy, validation_loss, validation_accuracy = trainer(model, train_dataset_loader, dev_dataset_loader, optimizer, EPOCHS, device)

    fname = os.path.join(f'./save_data/training_loss.json')
    with open(fname, 'w') as f:
        json.dump(training_loss, f)
    fname = os.path.join(f'./save_data/training_accuracy.json')
    with open(fname, 'w') as f:
        json.dump(training_accuracy, f)
    fname = os.path.join(f'./save_data/validation_loss.json')
    with open(fname, 'w') as f:
        json.dump(validation_loss, f)
    fname = os.path.join(f'./save_data/validation_accuracy.json')
    with open(fname, 'w') as f:
        json.dump(validation_accuracy, f)

_, ax = plt.subplots(2, 2, figsize=(10, 5))
ax[0, 0].plot(training_loss, marker='.')
ax[0, 0].set_title('Training Loss')
ax[0, 1].plot(*zip(*training_accuracy))
ax[0, 1].set_title('Training Accuracy')
ax[1, 0].plot(validation_loss, marker='.')
ax[1, 0].set_title('Validation Loss')
ax[1, 1].plot(*zip(*validation_accuracy))
ax[1, 1].set_title('Validation Accuracy')

  0%|          | 0/3957 [00:00<?, ?it/s]
  0%|          | 0/4 [00:00<?, ?it/s]


RuntimeError: each element in list of batch should be of equal size