In [46]:
import os
import gc
import math
import random
from collections import defaultdict

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from torch.optim import Adam
from tqdm import tqdm_notebook as tqdm

from sklearn.metrics import classification_report, f1_score

import warnings
warnings.filterwarnings('ignore')


In [47]:
data_path = '/kaggle/input/conll003-englishversion/'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [48]:
def load_sentences(filepath):
    final = []
    sentences = []
    with open(filepath, 'r') as f:
        for line in f.readlines():
            if line == '-DOCSTART- -X- -X- O\n' or line == '\n':
                if len(sentences) > 0:
                    final.append(sentences)
                    sentences = []
            else:
                l = line.split(' ')
                sentences.append((l[0], l[3].strip('\n')))
    return final

In [49]:
train_sents = load_sentences(data_path + 'train.txt')
test_sents = load_sentences(data_path + 'test.txt')
val_sents = load_sentences(data_path + 'valid.txt')

train_sents[:3]

[[('EU', 'B-ORG'),
  ('rejects', 'O'),
  ('German', 'B-MISC'),
  ('call', 'O'),
  ('to', 'O'),
  ('boycott', 'O'),
  ('British', 'B-MISC'),
  ('lamb', 'O'),
  ('.', 'O')],
 [('Peter', 'B-PER'), ('Blackburn', 'I-PER')],
 [('BRUSSELS', 'B-LOC'), ('1996-08-22', 'O')]]

In [50]:
ner_labels = ['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC']
id2label = {str(i): label for i, label in enumerate(ner_labels)}
label2id = {value: int(key) for key, value in id2label.items()}


In [51]:
def get_df(samples):
    df,label = [], []
    for lines in samples:
        cur_line, cur_label = list(zip(*lines))
        df.append(list(cur_line))
        label.append([label2id[i] for i in cur_label])
    return {'text':df, 'label':label}
    
    
train_df = get_df(train_sents)
test_df = get_df(test_sents)
val_df = get_df(val_sents)


In [52]:
word_dict = defaultdict(int)

for line in train_df['text']:
    for word in line:
        word_dict[word] += 1

lower_freq_word = []
for k,v in word_dict.items():
    if v < 2:
        lower_freq_word.append(k)

for word in lower_freq_word:
    del word_dict[word]
    
word_dict['<UNK>'] = -1
word_dict['<PAD>'] = -2

word2id = {}

for idx, word in enumerate(word_dict.keys()):
    word2id[word] = idx

In [53]:
def prepare_sequence(seq, to_ix):
    idxs = []
    for w in seq:
        if w in to_ix.keys():
            idxs.append(to_ix[w])
        else:
            idxs.append(to_ix['<UNK>'])
    return idxs


In [54]:
class CoNLLDataset(Dataset):
    def __init__(self, df):
        self.texts = df['text']
        self.labels = df['label']

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        inputs = prepare_sequence(self.texts[item], word2id)
        label = self.labels[item]
        return {
            'input_ids': inputs,
            'labels': label
        }


In [55]:
class Collate:
    def __init__(self, train):
        self.train = train

    def __call__(self, batch):
        output = dict()
        output["input_ids"] = [sample["input_ids"] for sample in batch]
        if self.train:
            output["labels"] = [sample["labels"] for sample in batch]

        # calculate max token length of this batch
        batch_max = max([len(ids) for ids in output["input_ids"]])

        # add padding

        output["input_ids"] = [s + (batch_max - len(s)) * [word2id['<PAD>']] for s in output["input_ids"]]
        if self.train:
            output['labels'] = [s + (batch_max - len(s)) * [-100] for s in output["labels"]]

        # convert to tensors
        output["input_ids"] = torch.tensor(output["input_ids"], dtype=torch.long)
        if self.train:
            output["labels"] = torch.tensor(output["labels"], dtype=torch.long)

        return output
    
collate_fn = Collate(True)


In [56]:
class BiLSTMTagger(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, vocab_size, output_size, embeddings=None):
        super(BiLSTMTagger, self).__init__()
        
        # 1. Embedding Layer
        if embeddings is None:
            self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        else:
            self.embeddings = nn.Embedding.from_pretrained(embeddings)
        
        # 2. LSTM Layer
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, bidirectional=True, num_layers=3, batch_first=True)

        # 3. Dense Layer
        self.fc = nn.Linear(2*hidden_dim, output_size)
        
    def forward(self, batch_text):

        embeddings = self.embeddings(batch_text)
        
        lstm_output, _ = self.lstm(embeddings) 

        logits = self.fc(lstm_output)
        return logits


In [57]:
def remove_predictions_for_masked_items(predicted_labels, correct_labels): 

    predicted_labels_without_mask = []
    correct_labels_without_mask = []
        
    for p, c in zip(predicted_labels, correct_labels):
        if c > 0:
            predicted_labels_without_mask.append(p)
            correct_labels_without_mask.append(c)
            
    return predicted_labels_without_mask, correct_labels_without_mask


In [58]:
def train(model, train_loader, val_loader, batch_size, max_epochs, num_batches, patience, output_path):
    criterion = nn.CrossEntropyLoss(ignore_index=-100)  # we mask the <pad> labels
    optimizer = Adam(model.parameters())

    train_f_score_history = []
    dev_f_score_history = []
    no_improvement = 0
    for epoch in range(max_epochs):

        total_loss = 0
        predictions, correct = [], []
        model.train()
        for batch in tqdm(train_loader, total=num_batches, desc=f"Epoch {epoch}"):
            
            cur_batch_size, text_length = batch['input_ids'].shape
            
            pred = model(batch['input_ids'].to(device)).view(cur_batch_size*text_length, NUM_CLASSES)
            gold = batch['labels'].to(device).view(cur_batch_size*text_length)
            
            loss = criterion(pred, gold)
            
            total_loss += loss.item()
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            _, pred_indices = torch.max(pred, 1)
            
            predicted_labels = list(pred_indices.cpu().numpy())
            correct_labels = list(batch['labels'].view(cur_batch_size*text_length).numpy())
            
            predicted_labels, correct_labels = remove_predictions_for_masked_items(predicted_labels, 
                                                                                   correct_labels)
            
            predictions += predicted_labels
            correct += correct_labels

        train_score = f1_score(correct, predictions, average="macro")
        train_f_score_history.append(train_score)
            
        print("Total training loss:", total_loss)
        print("Training Macro F1:", train_score)
        
        total_loss = 0
        predictions, correct = [], []
        
        model.eval()
        with torch.no_grad():
            for batch in val_loader:

                cur_batch_size, text_length = batch['input_ids'].shape
                
                pred = model(batch['input_ids'].to(device)).view(cur_batch_size*text_length, NUM_CLASSES)
                gold = batch['labels'].to(device).view(cur_batch_size*text_length)
                
                loss = criterion(pred, gold)
                total_loss += loss.item()

                _, pred_indices = torch.max(pred, 1)
                predicted_labels = list(pred_indices.cpu().numpy())
                correct_labels = list(batch['labels'].view(cur_batch_size*text_length).numpy())

                predicted_labels, correct_labels = remove_predictions_for_masked_items(predicted_labels, 
                                                                                       correct_labels)

                predictions += predicted_labels
                correct += correct_labels

        dev_score = f1_score(correct, predictions, average="macro")
            
        print("Total validation loss:", total_loss)
        print("Validation Macro F1:", dev_score)
        
        dev_f = dev_score
        if len(dev_f_score_history) > patience and dev_f < max(dev_f_score_history):
            no_improvement += 1

        elif len(dev_f_score_history) == 0 or dev_f > max(dev_f_score_history):
            print("Saving model.")
            torch.save(model, output_path)
            no_improvement = 0
            
        if no_improvement > patience:
            print("Validation F-score does not improve anymore. Stop training.")
            dev_f_score_history.append(dev_f)
            break
            
        dev_f_score_history.append(dev_f)
        
    return train_f_score_history, dev_f_score_history


In [59]:
def test(model, test_iter, batch_size, labels, target_names): 
    total_loss = 0
    predictions, correct = [], []
    
    model.eval()
    with torch.no_grad():    
    
        for batch in test_iter:

            cur_batch_size, text_length = batch['input_ids'].shape

            pred = model(batch['input_ids'].to(device)).view(cur_batch_size*text_length, NUM_CLASSES)
            gold = batch['labels'].to(device).view(cur_batch_size*text_length)

            _, pred_indices = torch.max(pred, 1)
            predicted_labels = list(pred_indices.cpu().numpy())
            correct_labels = list(batch['labels'].view(cur_batch_size*text_length).numpy())

            predicted_labels, correct_labels = remove_predictions_for_masked_items(predicted_labels, 
                                                                                   correct_labels)

            predictions += predicted_labels
            correct += correct_labels
    
    print(classification_report(correct, predictions, labels=labels, target_names=target_names))


In [60]:
EMBEDDING_DIM = 100
HIDDEN_DIM = 64
NUM_CLASSES = len(id2label)
MAX_EPOCHS = 50
PATIENCE = 3
BATCH_SIZE = 32
VOCAB_SIZE = len(word2id)
OUTPUT_PATH = "/tmp/bilstmtagger"
num_batches = math.ceil(len(train_df) / BATCH_SIZE)


In [61]:
train_dataset = CoNLLDataset(train_df)
val_dataset = CoNLLDataset(val_df)
test_dataset = CoNLLDataset(test_df)

train_loader = DataLoader(train_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=False,
                              collate_fn=collate_fn,
                              num_workers=4,
                              pin_memory=True,
                              drop_last=False)

val_loader = DataLoader(val_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=False,
                              collate_fn=collate_fn,
                              num_workers=4,
                              pin_memory=True,
                              drop_last=False)

test_loader = DataLoader(test_dataset,
                              batch_size=BATCH_SIZE,
                              shuffle=False,
                              collate_fn=collate_fn,
                              num_workers=4,
                              pin_memory=True,
                              drop_last=False)


In [62]:
tagger = BiLSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, VOCAB_SIZE+2, NUM_CLASSES) 
tagger


BiLSTMTagger(
  (embeddings): Embedding(11986, 100)
  (lstm): LSTM(100, 64, num_layers=3, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=128, out_features=9, bias=True)
)

In [63]:
train_f, dev_f = train(tagger.to(device), train_loader, val_loader, BATCH_SIZE, MAX_EPOCHS, 
                       num_batches, PATIENCE, OUTPUT_PATH)


Epoch 0:   0%|          | 0/1 [00:00<?, ?it/s]

Total training loss: 344.91125263273716
Training Macro F1: 0.015759030414735937
Total validation loss: 56.12493021786213
Validation Macro F1: 0.12278908939995542
Saving model.


Epoch 1:   0%|          | 0/1 [00:00<?, ?it/s]

Total training loss: 169.40841363370419
Training Macro F1: 0.34511809616674805
Total validation loss: 32.19121523946524
Validation Macro F1: 0.3909965080443668
Saving model.


Epoch 2:   0%|          | 0/1 [00:00<?, ?it/s]

Total training loss: 94.49006951414049
Training Macro F1: 0.5633081707952685
Total validation loss: 24.53801471926272
Validation Macro F1: 0.5754467136828897
Saving model.


Epoch 3:   0%|          | 0/1 [00:00<?, ?it/s]

Total training loss: 57.38383488636464
Training Macro F1: 0.6982772593818211
Total validation loss: 19.97235928615555
Validation Macro F1: 0.6620841531433297
Saving model.


Epoch 4:   0%|          | 0/1 [00:00<?, ?it/s]

Total training loss: 36.80419745715335
Training Macro F1: 0.7702704142394489
Total validation loss: 18.357784854946658
Validation Macro F1: 0.6912155592611796
Saving model.


Epoch 5:   0%|          | 0/1 [00:00<?, ?it/s]

Total training loss: 24.71564652340021
Training Macro F1: 0.8143480469290479
Total validation loss: 18.344960731221363
Validation Macro F1: 0.7007456356658951
Saving model.


Epoch 6:   0%|          | 0/1 [00:00<?, ?it/s]

Total training loss: 18.02895562676713
Training Macro F1: 0.8380445039382818
Total validation loss: 18.702265097934287
Validation Macro F1: 0.7053333844190042
Saving model.


Epoch 7:   0%|          | 0/1 [00:00<?, ?it/s]

Total training loss: 12.79710330144735
Training Macro F1: 0.8544786054049311
Total validation loss: 19.24992256623227
Validation Macro F1: 0.7079567246533347
Saving model.


Epoch 8:   0%|          | 0/1 [00:00<?, ?it/s]

Total training loss: 9.996878715639468
Training Macro F1: 0.8638350364609675
Total validation loss: 20.82966603297973
Validation Macro F1: 0.7128380770052578
Saving model.


Epoch 9:   0%|          | 0/1 [00:00<?, ?it/s]

Total training loss: 8.351324024581118
Training Macro F1: 0.867895969663346
Total validation loss: 20.563435456497245
Validation Macro F1: 0.7179940495111077
Saving model.


Epoch 10:   0%|          | 0/1 [00:00<?, ?it/s]

Total training loss: 7.492537721525878
Training Macro F1: 0.8697717958699429
Total validation loss: 21.29669574102445
Validation Macro F1: 0.7193381615401253
Saving model.


Epoch 11:   0%|          | 0/1 [00:00<?, ?it/s]

Total training loss: 6.228401094733272
Training Macro F1: 0.8728223420514779
Total validation loss: 21.214801813548547
Validation Macro F1: 0.7293125129767131
Saving model.


Epoch 12:   0%|          | 0/1 [00:00<?, ?it/s]

Total training loss: 4.595958497564425
Training Macro F1: 0.8775759580764044
Total validation loss: 21.35615302968654
Validation Macro F1: 0.7284069577117012


Epoch 13:   0%|          | 0/1 [00:00<?, ?it/s]

Total training loss: 3.822086682092049
Training Macro F1: 0.8797246720832006
Total validation loss: 22.77894842135356
Validation Macro F1: 0.7229216199709314


Epoch 14:   0%|          | 0/1 [00:00<?, ?it/s]

Total training loss: 3.4123341184549645
Training Macro F1: 0.8814040693502462
Total validation loss: 23.280070661312493
Validation Macro F1: 0.7312843020522896
Saving model.


Epoch 15:   0%|          | 0/1 [00:00<?, ?it/s]

Total training loss: 3.04851106451315
Training Macro F1: 0.8820989537208375
Total validation loss: 23.38572436825416
Validation Macro F1: 0.7312267384319441


Epoch 16:   0%|          | 0/1 [00:00<?, ?it/s]

Total training loss: 2.4467578478415817
Training Macro F1: 0.8831553081997341
Total validation loss: 25.135275926100803
Validation Macro F1: 0.7289843343072056


Epoch 17:   0%|          | 0/1 [00:00<?, ?it/s]

Total training loss: 1.93056082939529
Training Macro F1: 0.8849791436329109
Total validation loss: 27.75272669313017
Validation Macro F1: 0.7249804365335615


Epoch 18:   0%|          | 0/1 [00:00<?, ?it/s]

Total training loss: 2.9225483823811373
Training Macro F1: 0.8814133408316237
Total validation loss: 26.422789808002562
Validation Macro F1: 0.7277859207371847
Validation F-score does not improve anymore. Stop training.
