In [1]:
from collections import Counter
from types import SimpleNamespace

import pandas as pd
import torch
import torchtext
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pad_packed_sequence
import torch.optim as optim

from kg.ner.model import LSTM

In [2]:
train_df = pd.read_csv('/Users/tmorrill002/Documents/datasets/conll/transformed/train.csv')

In [3]:
vocab = torchtext.vocab.Vocab(Counter(train_df['Token'].value_counts().to_dict()))
label_dict = {}
i = 0
for k in train_df['NER_Tag_Normalized'].unique():
    label_dict[k] = i
    i += 1

In [4]:
class CoNLL2003Dataset(torch.utils.data.Dataset):
    def __init__(self, df, vocab, label_dict, transform=None):
        self.df = df
        self.vocab = vocab
        self.label_dict = label_dict
        self.transform = transform
        self.sentences, self.labels = self._prepare_data()
    
    def _prepare_data(self):
        temp_df = self.df.groupby(['Article_ID', 'Sentence_ID'], as_index=False).agg(Sentence=('Token', list), Labels=('NER_Tag_Normalized', list))
        sentences = temp_df['Sentence'].values.tolist()
        labels = temp_df['Labels'].values.tolist()
        return sentences, labels
    
    def __len__(self):
        return len(self.sentences)
    
    def __getitem__(self, idx):
        if self.transform:
            raise NotImplementedError
        
        indices = []
        for token in self.sentences[idx]:
            indices.append(self.vocab[token])
        labels = []
        for label in self.labels[idx]:
            labels.append(self.label_dict[label])
        
        return torch.tensor(indices), torch.tensor(labels)

In [5]:
train_dataset = CoNLL2003Dataset(train_df, vocab, label_dict)

In [6]:
assert vocab[train_df.iloc[0]['Token']] == train_dataset[0][0][0]

In [7]:
assert label_dict[train_df.iloc[0]['NER_Tag_Normalized']] == train_dataset[0][1][0]

In [8]:
train_dataset[0]

(tensor([  964, 22406,   236,   771,     7,  4586,   210,  7683,     2]),
 tensor([0, 1, 2, 1, 1, 1, 2, 1, 1]))

In [9]:
small_batch = []
small_batch.append(torch.tensor(train_dataset[0][0]))
small_batch.append(torch.tensor(train_dataset[1][0]))
small_batch_lens = [len(x) for x in small_batch]

small_labels_batch = []
small_labels_batch.append(torch.tensor(train_dataset[0][1]))
small_labels_batch.append(torch.tensor(train_dataset[1][1]))

  small_batch.append(torch.tensor(train_dataset[0][0]))
  small_batch.append(torch.tensor(train_dataset[1][0]))
  small_labels_batch.append(torch.tensor(train_dataset[0][1]))
  small_labels_batch.append(torch.tensor(train_dataset[1][1]))


In [10]:
small_batch_padded = pad_sequence(small_batch, batch_first=True, padding_value=vocab['<pad>'])
small_labels_batch_padded = pad_sequence(small_labels_batch, batch_first=True, padding_value=-1)

In [11]:
small_batch_padded

tensor([[  964, 22406,   236,   771,     7,  4586,   210,  7683,     2],
        [  737,  2088,     1,     1,     1,     1,     1,     1,     1]])

In [12]:
small_labels_batch_padded

tensor([[ 0,  1,  2,  1,  1,  1,  2,  1,  1],
        [ 3,  3, -1, -1, -1, -1, -1, -1, -1]])

In [13]:
small_batch_lens

[9, 2]

In [14]:
packed = pack_padded_sequence(small_batch_padded, small_batch_lens, batch_first=True)

In [15]:
packed

PackedSequence(data=tensor([  964,   737, 22406,  2088,   236,   771,     7,  4586,   210,  7683,
            2]), batch_sizes=tensor([2, 2, 1, 1, 1, 1, 1, 1, 1]), sorted_indices=None, unsorted_indices=None)

In [16]:
batch, sequence_lengths = pad_packed_sequence(packed, batch_first=True, padding_value=vocab['<pad>'])

In [17]:
batch

tensor([[  964, 22406,   236,   771,     7,  4586,   210,  7683,     2],
        [  737,  2088,     1,     1,     1,     1,     1,     1,     1]])

In [18]:
config = {
    'vocab_size': len(vocab),
    'embedding_dim': 128,
    'hidden_size': 128,
    'num_classes': len(label_dict),
    'batch_size': 16
}
config = SimpleNamespace(**config)

In [19]:
model = LSTM(config)

In [20]:
output = model((small_batch_padded, small_batch_lens))

In [21]:
def loss_fn(outputs, labels):
    labels = labels.reshape(-1)
    mask = (labels >= 0).float()
    labels = labels % outputs.shape[1]
    num_tokens = mask.sum()
    return -torch.sum(outputs[range(outputs.shape[0]), labels] * mask) / num_tokens

In [22]:
loss_fn(output, small_labels_batch_padded)

tensor(1.6088, grad_fn=<DivBackward0>)

In [23]:
def collate_fn(batch):
    sentence_indices, sentence_labels = zip(*batch)
    sentence_lens = [len(x) for x in sentence_indices]
    
    # vocab['<pad>'] = 1
    sentences_padded = pad_sequence(sentence_indices, batch_first=True, padding_value=1)
    labels_padded = pad_sequence(sentence_labels, batch_first=True, padding_value=-1)
    
    return (sentences_padded, sentence_lens), labels_padded

In [24]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=16, collate_fn=collate_fn)

In [25]:
sample_batch = next(iter(train_dataloader))

In [26]:
optimizer = optim.Adam(model.parameters())

In [None]:
model.train()
record_loss = []
for i in range(5):
    print(f'Epoch number: {i+1}')
    j = 0
    for sentences, labels in train_dataloader:
        output = model(sentences)
        loss = loss_fn(output, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if j % 100 == 0:
            record_loss.append(loss)
            print(loss)
        j += 1

Epoch number: 1
tensor(1.5981, grad_fn=<DivBackward0>)
tensor(0.4196, grad_fn=<DivBackward0>)
tensor(0.4185, grad_fn=<DivBackward0>)
tensor(0.7137, grad_fn=<DivBackward0>)
tensor(0.3584, grad_fn=<DivBackward0>)
tensor(0.5985, grad_fn=<DivBackward0>)
tensor(0.5766, grad_fn=<DivBackward0>)
tensor(0.4007, grad_fn=<DivBackward0>)
tensor(0.4434, grad_fn=<DivBackward0>)
Epoch number: 2
tensor(0.2937, grad_fn=<DivBackward0>)
tensor(0.2776, grad_fn=<DivBackward0>)
tensor(0.2724, grad_fn=<DivBackward0>)
tensor(0.2294, grad_fn=<DivBackward0>)
tensor(0.1081, grad_fn=<DivBackward0>)
tensor(0.4313, grad_fn=<DivBackward0>)
tensor(0.4322, grad_fn=<DivBackward0>)
tensor(0.2291, grad_fn=<DivBackward0>)
tensor(0.2432, grad_fn=<DivBackward0>)
Epoch number: 3
tensor(0.1745, grad_fn=<DivBackward0>)
tensor(0.2005, grad_fn=<DivBackward0>)
tensor(0.1723, grad_fn=<DivBackward0>)
tensor(0.0786, grad_fn=<DivBackward0>)
tensor(0.0304, grad_fn=<DivBackward0>)
tensor(0.2829, grad_fn=<DivBackward0>)
tensor(0.2776, g