In [1]:
from collections import Counter
from types import SimpleNamespace

import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, classification_report
import torch
import torchtext
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import torch.optim as optim

from kg.ner.model import LSTM

In [2]:
# load data
train_df = pd.read_csv('/Users/tmorrill002/Documents/datasets/conll/transformed/train.csv')
val_df = pd.read_csv('/Users/tmorrill002/Documents/datasets/conll/transformed/validation.csv')
test_df = pd.read_csv('/Users/tmorrill002/Documents/datasets/conll/transformed/test.csv')

In [3]:
# create vocabulary and label dictionaries
vocab = torchtext.vocab.Vocab(Counter(train_df['Token'].value_counts().to_dict()))
label_dict = {}
i = 0
for k in train_df['NER_Tag_Normalized'].unique():
    label_dict[k] = i
    i += 1

In [4]:
class CoNLL2003Dataset(torch.utils.data.Dataset):
    def __init__(self, df, vocab, label_dict, transform=None):
        self.df = df
        self.vocab = vocab
        self.label_dict = label_dict
        self.transform = transform
        self.sentences, self.labels = self._prepare_data()
    
    def _prepare_data(self):
        temp_df = self.df.groupby(['Article_ID', 'Sentence_ID'], as_index=False).agg(Sentence=('Token', list), Labels=('NER_Tag_Normalized', list))
        sentences = temp_df['Sentence'].values.tolist()
        labels = temp_df['Labels'].values.tolist()
        return sentences, labels
    
    def __len__(self):
        return len(self.sentences)
    
    def __getitem__(self, idx):
        if self.transform:
            raise NotImplementedError
        
        indices = []
        for token in self.sentences[idx]:
            indices.append(self.vocab[token])
        labels = []
        for label in self.labels[idx]:
            labels.append(self.label_dict[label])
        
        return torch.tensor(indices), torch.tensor(labels)

In [5]:
train_dataset = CoNLL2003Dataset(train_df, vocab, label_dict)
val_dataset = CoNLL2003Dataset(val_df, vocab, label_dict)
test_dataset = CoNLL2003Dataset(test_df, vocab, label_dict)

In [6]:
assert vocab[train_df.iloc[0]['Token']] == train_dataset[0][0][0]

In [7]:
assert label_dict[train_df.iloc[0]['NER_Tag_Normalized']] == train_dataset[0][1][0]

In [8]:
small_batch = []
small_batch.append(torch.tensor(train_dataset[0][0]))
small_batch.append(torch.tensor(train_dataset[1][0]))
small_batch_lens = [len(x) for x in small_batch]

small_labels_batch = []
small_labels_batch.append(torch.tensor(train_dataset[0][1]))
small_labels_batch.append(torch.tensor(train_dataset[1][1]))

  small_batch.append(torch.tensor(train_dataset[0][0]))
  small_batch.append(torch.tensor(train_dataset[1][0]))
  small_labels_batch.append(torch.tensor(train_dataset[0][1]))
  small_labels_batch.append(torch.tensor(train_dataset[1][1]))


In [9]:
small_batch_padded = pad_sequence(small_batch, batch_first=True, padding_value=vocab['<pad>'])
small_labels_batch_padded = pad_sequence(small_labels_batch, batch_first=True, padding_value=-1)

In [10]:
small_batch_padded

tensor([[  964, 22406,   236,   771,     7,  4586,   210,  7683,     2],
        [  737,  2088,     1,     1,     1,     1,     1,     1,     1]])

In [11]:
small_labels_batch_padded

tensor([[ 0,  1,  2,  1,  1,  1,  2,  1,  1],
        [ 3,  3, -1, -1, -1, -1, -1, -1, -1]])

In [12]:
small_batch_lens

[9, 2]

In [13]:
packed = pack_padded_sequence(small_batch_padded, small_batch_lens, batch_first=True)

In [14]:
packed

PackedSequence(data=tensor([  964,   737, 22406,  2088,   236,   771,     7,  4586,   210,  7683,
            2]), batch_sizes=tensor([2, 2, 1, 1, 1, 1, 1, 1, 1]), sorted_indices=None, unsorted_indices=None)

In [15]:
batch, sequence_lengths = pad_packed_sequence(packed, batch_first=True, padding_value=vocab['<pad>'])

In [16]:
config = {
    'vocab_size': len(vocab),
    'embedding_dim': 128,
    'hidden_size': 128,
    'num_classes': len(label_dict),
    'batch_size': 16
}
config = SimpleNamespace(**config)

In [17]:
model = LSTM(config)

In [18]:
output = model((small_batch_padded, small_batch_lens))

In [19]:
def loss_fn(outputs, labels):
    labels = labels.reshape(-1)
    mask = (labels >= 0).float()
    labels = labels % outputs.shape[1]
    num_tokens = mask.sum()
    return -torch.sum(outputs[range(outputs.shape[0]), labels] * mask) / num_tokens

In [20]:
loss_fn(output, small_labels_batch_padded)

tensor(1.5694, grad_fn=<DivBackward0>)

In [21]:
def collate_fn(batch):
    sentence_indices, sentence_labels = zip(*batch)
    sentence_lens = [len(x) for x in sentence_indices]
    
    # vocab['<pad>'] = 1
    sentences_padded = pad_sequence(sentence_indices, batch_first=True, padding_value=1)
    labels_padded = pad_sequence(sentence_labels, batch_first=True, padding_value=-1)
    
    return (sentences_padded, sentence_lens), labels_padded

In [35]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=16, collate_fn=collate_fn, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=16, collate_fn=collate_fn)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=16, collate_fn=collate_fn)

In [23]:
sample_batch = next(iter(train_dataloader))

In [24]:
optimizer = optim.Adam(model.parameters())

In [25]:
def get_predictions(output, lengths, concatenate=True):
    # extract predictions
    max_len = max(lengths)
    preds = output.argmax(dim=1)
    i = 0
    preds_list = []
    for length in lengths:
        start = i*max_len
        stop = start + length
        preds_list.append(preds[start:stop])
        i += 1
    if concatenate:
        return torch.cat(preds_list)
    return preds_list

def recover_labels(padded_labels, lengths):
    # extract labels
    max_len = max(lengths)
    labels_vector = padded_labels.reshape(-1)
    i = 0
    labels_list = []
    for length in lengths:
        start = i*max_len
        stop = start + length
        labels_list.append(labels_vector[start:stop])
        i += 1
    return torch.cat(labels_list)

In [26]:
def accuracy(output, sentences, labels, dataset='Train'):
    batch_preds = get_predictions(output, sentences[1])
    batch_labels = recover_labels(labels, sentences[1])
    raw_acc = accuracy_score(batch_preds, batch_labels)
    acc = round(raw_acc * 100, 2)
    # print(f'{dataset} accuracy score: {acc}%')
    return raw_acc

In [27]:
PATIENCE=3
running_patience = PATIENCE
best_val_loss = np.inf
record_loss = []
for i in range(5):
    model.train()
    print(f'Epoch number: {i+1}')
    j = 0
    for sentences, labels in train_dataloader:
        output = model(sentences)
        loss = loss_fn(output, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if j % 100 == 0:
            record_loss.append(loss)
            print(f'Sample train batch - loss value: {round(loss.item(), 4)} \t accuracy score: {round(metrics(output, sentences, labels)*100, 2)}%')
        j += 1
    
    # monitor validation loss
    model.eval()
    val_loss_scores = []
    val_acc_scores = []
    val_preds = []
    val_labels = []
    for sentences, labels in val_dataloader:
        output = model(sentences)
        loss = loss_fn(output, labels)
        val_loss_scores.append(loss.item())
        val_acc_scores.append(metrics(output, sentences, labels))
        val_preds += get_predictions(output, sentences[1]).tolist()
        val_labels += recover_labels(labels, sentences[1]).tolist()
    val_loss = round(np.mean(val_loss_scores), 4)
    val_acc = round(np.mean(val_acc_scores)*100, 2)
    print(f'Average validation loss: {val_loss}')
    print(f'Validation accuracy score: {val_acc}%')
    
    # stopping criterion
    if val_loss < best_val_loss:
        running_patience = PATIENCE
        best_val_loss = val_loss
    else:
        running_patience -= 1
        if running_patience == 0:
            print(f'Model has not improved for {PATIENCE} epochs. Stopping training.')
            break
    
    print()

Epoch number: 1
Sample train batch - loss value: 1.5594 	 accuracy score: 37.89%
Sample train batch - loss value: 0.5848 	 accuracy score: 82.83%
Sample train batch - loss value: 0.656 	 accuracy score: 77.13%
Sample train batch - loss value: 0.3748 	 accuracy score: 88.73%
Sample train batch - loss value: 0.4269 	 accuracy score: 85.38%
Sample train batch - loss value: 0.3668 	 accuracy score: 88.5%
Sample train batch - loss value: 0.2897 	 accuracy score: 90.36%
Sample train batch - loss value: 0.3677 	 accuracy score: 86.23%
Sample train batch - loss value: 0.2 	 accuracy score: 93.51%
Average validation loss: 0.3073
Validation accuracy score: 90.1%

Epoch number: 2
Sample train batch - loss value: 0.2267 	 accuracy score: 92.89%
Sample train batch - loss value: 0.3443 	 accuracy score: 91.11%
Sample train batch - loss value: 0.165 	 accuracy score: 95.11%
Sample train batch - loss value: 0.1788 	 accuracy score: 94.77%
Sample train batch - loss value: 0.1946 	 accuracy score: 93.68

In [34]:
label_set = list(set(val_preds).union(set(val_labels)))

print(classification_report(val_labels, val_preds, labels=label_set, target_names=list(label_dict.keys())))

              precision    recall  f1-score   support

         ORG       0.83      0.64      0.72      2092
           O       0.97      0.99      0.98     42759
        MISC       0.88      0.70      0.78      1268
         PER       0.78      0.85      0.82      3149
         LOC       0.85      0.82      0.84      2094

    accuracy                           0.95     51362
   macro avg       0.86      0.80      0.83     51362
weighted avg       0.95      0.95      0.95     51362



In [36]:
# monitor validation loss
model.eval()
val_loss_scores = []
val_acc_scores = []
val_preds = []
val_labels = []
for sentences, labels in test_dataloader:
    output = model(sentences)
    loss = loss_fn(output, labels)
    val_loss_scores.append(loss.item())
    val_acc_scores.append(metrics(output, sentences, labels))
    val_preds += get_predictions(output, sentences[1]).tolist()
    val_labels += recover_labels(labels, sentences[1]).tolist()
val_loss = round(np.mean(val_loss_scores), 4)
val_acc = round(np.mean(val_acc_scores)*100, 2)
print(f'Average validation loss: {val_loss}')
print(f'Validation accuracy score: {val_acc}%')

Average validation loss: 0.2872
Validation accuracy score: 91.39%


In [37]:
label_set = list(set(val_preds).union(set(val_labels)))

print(classification_report(val_labels, val_preds, labels=label_set, target_names=list(label_dict.keys())))

              precision    recall  f1-score   support

         ORG       0.84      0.56      0.67      2496
           O       0.96      0.98      0.97     38323
        MISC       0.75      0.61      0.68       918
         PER       0.67      0.77      0.72      2773
         LOC       0.79      0.78      0.78      1925

    accuracy                           0.93     46435
   macro avg       0.80      0.74      0.76     46435
weighted avg       0.93      0.93      0.92     46435

