In [1]:
import collections
import math
import numpy as np
import pandas as pd
import os
import random
import torch
import torch.nn as nn
import zipfile
from matplotlib import pylab
from six.moves import range
from six.moves.urllib.request import urlretrieve
from torch.nn.utils.rnn import pad_sequence
import torch
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
from torchtext.data.utils import get_tokenizer, ngrams_iterator
from torchtext.datasets import DATASETS
from torchtext.utils import download_from_url
from torchtext.vocab import build_vocab_from_iterator
import torch.nn as nn
from torchtext.data.utils import get_tokenizer
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torchtext.vocab import FastText, CharNGram
from itertools import chain

In [2]:
# Download the data

DATA_URL = 'https://github.com/ZihanWangKi/CrossWeigh/raw/master/data/'
DATA_DIR = 'data'

def download_file(url, filename, data_dir, expected_bytes):
    os.makedirs(data_dir, exist_ok=True)
    
    file_path = os.path.join(data_dir, filename)
    if not os.path.exists(file_path):
        file_path, _ = urlretrieve(url + filename, file_path)
    else:
        file_path = os.path.join(data_dir, filename)
    
    stat_info = os.stat(file_path)
    if stat_info.st_size == expected_bytes:
        print(f'Found and verified {file_path}')
    else:
        print(stat_info.st_size)
        raise Exception(f'Failed to verify {file_path}. Can you retrieve it with a browser?')
        
    return file_path

train_file = download_file(DATA_URL, 'conllpp_train.txt', DATA_DIR, 3283420)
dev_file = download_file(DATA_URL, 'conllpp_dev.txt', DATA_DIR, 827443)  
test_file = download_file(DATA_URL, 'conllpp_test.txt', DATA_DIR, 748737)

Found and verified data/conllpp_train.txt
Found and verified data/conllpp_dev.txt
Found and verified data/conllpp_test.txt


In [3]:
!head data/conllpp_train.txt

-DOCSTART- -X- -X- O

EU NNP B-NP B-ORG
rejects VBZ B-VP O
German JJ B-NP B-MISC
call NN I-NP O
to TO B-VP O
boycott VB I-VP O
British JJ B-NP B-MISC
lamb NN I-NP O


In [4]:
def load_data(file_path):
    print("Loading data...")
    sentences, labels = [], []
    
    with open(file_path, 'r', encoding='latin-1') as file:
        sentence_start = True
        sentence_tokens = []
        sentence_labels = []
        
        for line in file:
            if len(line.strip()) == 0 or line.split(' ')[0] == '-DOCSTART-':
                sentence_start = False
            else:
                sentence_start = True
                token, _, _, label = line.split(' ')
                sentence_tokens.append(token)
                sentence_labels.append(label.strip())
            
            if not sentence_start and len(sentence_tokens) > 0:
                sentences.append(' '.join(sentence_tokens))
                labels.append(sentence_labels)
                sentence_tokens, sentence_labels = [], []
    
    print('\tDone')
    return sentences, labels

In [5]:
train_sentences, train_labels = load_data(train_file)
dev_sentences, dev_labels = load_data(dev_file)
test_sentences, test_labels = load_data(test_file)
# Print some stats
print(f"Train size: {len(train_labels)}")
print(f"Dev size: {len(dev_labels)}")
print(f"Test size: {len(test_labels)}")

# Print some data
print('\nSample data\n')
for sent, labels in zip(dev_sentences[:5], dev_labels[:5]):
    print(f"Sentence: {sent}")
    print(f"Labels: {labels}")
    assert(len(sent.split(' ')) == len(labels))
    print('\n')

Loading data...
	Done
Loading data...
	Done
Loading data...
	Done
Train size: 14041
Dev size: 3250
Test size: 3452

Sample data

Sentence: CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY .
Labels: ['O', 'O', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


Sentence: LONDON 1996-08-30
Labels: ['B-LOC', 'O']


Sentence: West Indian all-rounder Phil Simmons took four for 38 on Friday as Leicestershire beat Somerset by an innings and 39 runs in two days to take over at the head of the county championship .
Labels: ['B-MISC', 'I-MISC', 'O', 'B-PER', 'I-PER', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ORG', 'O', 'B-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


Sentence: Their stay on top , though , may be short-lived as title rivals Essex , Derbyshire and Surrey all closed in on victory while Kent made up for lost time in their rain-affected match against Nottinghamshire .
Labels: ['O', 'O', 'O', 'O', 'O', 'O', 'O', '

In [6]:
class SentenceTokenizer:
    def __call__(self, sentence):
        return sentence.lower().split(' ')
    
class WordTokenizer:
    def __call__(self, word):
        return [c for c in word.lower()]

In [8]:
sentence_tokenizer = SentenceTokenizer()
word_tokenizer = WordTokenizer()

In [9]:
sentences = train_sentences + test_sentences + dev_sentences
all_labels = train_labels + test_labels + dev_labels

In [10]:
def yield_word_tokens(data):
    for sentence in data:
        yield sentence_tokenizer(sentence)
        
def yield_char_tokens(data):
    for word_tokens in yield_word_tokens(data):
        for word_token in word_tokens:
            yield word_tokenizer(word_token)

In [11]:
word_vocab = build_vocab_from_iterator(yield_word_tokens(sentences), specials=('<pad>', '<unk>'))
char_vocab = build_vocab_from_iterator(yield_char_tokens(sentences), specials=('<pad>', '<unk>'))

In [12]:
# Get the word to idx and idx to char dictionaries
word_to_idx = word_vocab.get_stoi()
idx_to_word = word_vocab.get_itos()
# Get the char to idx and idx to char dictionaries
char_to_idx = char_vocab.get_stoi()
idx_to_char = char_vocab.get_itos()

In [13]:
def get_label_mappings(labels):
    unique_labels = pd.Series(chain(*labels)).unique()
    label_to_idx = dict(zip(unique_labels, np.arange(unique_labels.shape[0])))
    idx_to_label = {i: label for label, i in label_to_idx.items()}
    
    label_weights = {}
    label_counts = pd.Series(chain(*labels)).value_counts()
    
    for label, count in label_counts.items():
        label_weights[label_to_idx[label]] = label_counts.min() / count
    
    return label_to_idx, idx_to_label, label_weights

In [14]:
label_to_idx, idx_to_label, label_weights = get_label_mappings(train_labels)

In [15]:
for label, idx in label_to_idx.items():
    assert(label == idx_to_label[idx])
    assert(idx in label_weights)

In [16]:
# Get the weights per class as a tensor
class_weights = torch.zeros(len(label_weights))
for i, weight in label_weights.items():
    class_weights[i] = weight

In [18]:
labels = pd.Series(chain(*train_labels))

In [19]:
# Check for class balance

print("Training label counts:")  
print(pd.Series(chain(*train_labels)).value_counts())

print("\nValidation label counts:")
print(pd.Series(chain(*dev_labels)).value_counts())

print("\nTest label counts:")  
print(pd.Series(chain(*test_labels)).value_counts())

Training label counts:
O         169578
B-LOC       7140
B-PER       6600
B-ORG       6321
I-PER       4528
I-ORG       3704
B-MISC      3438
I-LOC       1157
I-MISC      1155
Name: count, dtype: int64

Validation label counts:
O         42759
B-PER      1842
B-LOC      1837
B-ORG      1341
I-PER      1307
B-MISC      922
I-ORG       751
I-MISC      346
I-LOC       257
Name: count, dtype: int64

Test label counts:
O         38143
B-ORG      1714
B-LOC      1645
B-PER      1617
I-PER      1161
I-ORG       881
B-MISC      722
I-LOC       259
I-MISC      252
Name: count, dtype: int64


In [20]:
# Series length
pd.Series(train_sentences).str.split().str.len().describe(percentiles=[0.05, 0.95])

count    14041.000000
mean        14.501887
std         11.602756
min          1.000000
5%           2.000000
50%         10.000000
95%         37.000000
max        113.000000
dtype: float64

### Parameters

In [21]:
# Size of token embeddings
EMBEDDING_DIM = 300
# Number of hidden units in the GRU layer
HIDDEN_DIM = 64
# Number of hidden units in the GRU layer
CHAR_DIM = 32
# Number of output nodes in the last layer
NUM_CLASSES = len(idx_to_label)

BATCH_SIZE = 128  
EPOCHS = 25
FAST_TEXT = FastText("simple")
LEARNING_RATE = 1.0
MAX_WORD_LENGTH = 12

DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'

In [22]:
def collate_batch(batch):
    label_list, sentence_list, sentence_lengths = [], [], []
    word_list = []

    for sentence, words, labels in batch:
        sentence_list.append(torch.tensor(sentence, dtype=torch.int64))
        sentence_lengths.append(len(sentence))
        label_list.append(torch.tensor(labels, dtype=torch.int64))
        word_list.append(torch.tensor(words, dtype=torch.int64))
            
    return (
        nn.utils.rnn.pad_sequence(sentence_list, batch_first=True).to(DEVICE),
        nn.utils.rnn.pad_sequence(label_list, batch_first=True, padding_value=-1).to(DEVICE),    
        torch.tensor(sentence_lengths).to(DEVICE),
        nn.utils.rnn.pad_sequence(word_list, batch_first=True).to(DEVICE)
    )

In [23]:
def get_dataloader(sentences, labels):
    data = []

    for sentence, labels in zip(sentences, labels):
        word_tokens = sentence_tokenizer(sentence)
        int_sentence = word_vocab(word_tokens)
        int_words = []
        for word_token in word_tokens:
            int_words.append(char_vocab(word_tokenizer(word_token[:MAX_WORD_LENGTH]) + max(0, MAX_WORD_LENGTH - len(word_token)) * ['<pad>']))
                    
        labels = [label_to_idx[label] for label in labels]
        assert(len(int_sentence) == len(labels))
        data.append([int_sentence, int_words, labels])
        
    return DataLoader(data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)    


In [24]:
train_dataloader = get_dataloader(train_sentences, train_labels)
dev_dataloader = get_dataloader(dev_sentences, dev_labels)
test_dataloader = get_dataloader(test_sentences, test_labels)

In [25]:
# Define the model

class NERModel(nn.Module):
    def __init__(self, num_classes, embedding_dim, hidden_dim, initialize=True, fine_tune_embeddings=True, use_conv_embeddings=True):
        super(NERModel, self).__init__()
        self.vocab_size = len(word_vocab)
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        self.char_dim = 32
        self.kernel_size = 5
        self.max_word_length = MAX_WORD_LENGTH
        self.use_conv_embeddings = use_conv_embeddings
        
        if self.use_conv_embeddings:
            self.conv = nn.Conv1d(self.char_dim, self.char_dim, self.kernel_size)
            self.max_pool = nn.MaxPool1d(self.max_word_length - self.kernel_size + 1)
            
        self.embedding = nn.Embedding(len(word_vocab), embedding_dim if not initialize else 300, padding_idx=0)
        
        self.char_embedding = nn.Embedding(len(char_vocab), self.char_dim, padding_idx=0)
        
        if initialize:
            self.embedding.weight.requires_grad = False
            for i in range(len(word_vocab)):
                token = word_vocab.lookup_token(i)
                self.embedding.weight[i, :] = FAST_TEXT.get_vecs_by_tokens(token, lower_case_backup=True)
            self.embedding.weight.requires_grad = True
        else:
            self.init_weights()
                
        if not fine_tune_embeddings:
            self.embedding.weight.requires_grad = False
        
        self.rnn = nn.GRU(self.embedding_dim + self.char_dim, self.hidden_dim, batch_first=True, bidirectional=True)

        self.fc = nn.Linear(2 * self.hidden_dim, num_classes)

        self.dropout = nn.Dropout(0.3)
        
    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        if self.use_conv_embeddings:
            self.char_embedding.weight.data.uniform_(-initrange, initrange)

    def forward(self, sentences, lengths, words):
        embedded_sentences = self.embedding(sentences.int()) 
        
        if self.use_conv_embeddings:                        
            embedded_words = self.char_embedding(words.int())
                                                
            N, L_sentence, L_word, D_char = embedded_words.shape
            
            embedded_words = embedded_words.view(N * L_sentence, L_word, -1)
            embedded_words = torch.swapaxes(embedded_words, 2, 1)
            embedded_words = self.conv(embedded_words)
            embedded_words = self.max_pool(embedded_words).squeeze()
            embedded_words = embedded_words.view(N, L_sentence, -1)
                      
            embedded_sentences = torch.cat([embedded_sentences, embedded_words], axis=-1)
            
        embedded_sentences = nn.utils.rnn.pack_padded_sequence(embedded_sentences, lengths.cpu().numpy(), enforce_sorted=False, batch_first=True)
        
        logits, _ = self.rnn(embedded_sentences)
         
        logits, _ = nn.utils.rnn.pad_packed_sequence(logits, batch_first=True)

        logits = self.fc(logits)
        
        return logits

In [26]:
criterion = torch.nn.CrossEntropyLoss(weight=class_weights, ignore_index=-1).to(DEVICE)

model = NERModel(NUM_CLASSES, EMBEDDING_DIM, HIDDEN_DIM, initialize=True, fine_tune_embeddings=True, use_conv_embeddings=True).to(DEVICE)

optimizer = torch.optim.SGD(model.parameters(), lr=LEARNING_RATE)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)


In [27]:
def train_epoch(dataloader, model, optimizer, criterion, epoch):
    model.train()
    total_acc, total_count = 0, 0
    total_loss, total_batches = 0.0, 0.0
    log_interval = 50

    for idx, (sentences, labels, lengths, words) in enumerate(dataloader):
        optimizer.zero_grad()
                        
        logits = model(sentences, lengths, words)
                           
        N, L, _ = logits.shape
        logits = logits.view(N * L, -1)
        labels = labels.view(N * L)
        loss = criterion(input=logits, target=labels)
        
        total_loss += loss.item()
        total_batches += 1
        
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        
        optimizer.step()
        model.eval()

        masks = (labels != -1)
        total_acc += (logits.argmax(-1) == labels)[masks].sum().item()
        total_count += masks.sum()

        model.train()
        if idx % log_interval == 0 and idx > 0:
            print(f"| epoch {epoch:3d} | {idx:5d}/{len(dataloader):5d} batches | accuracy {total_acc / total_count:8.3f} | loss {total_loss / total_batches:8.3f}")
            total_acc, total_count = 0, 0
            total_loss, total_batches = 0.0, 0.0

In [28]:
def evaluate(dataloader, model):
    model.eval()
    total_acc, total_count = 0, 0
    total_loss, total_batches = 0.0, 0.0

    with torch.no_grad():
        for idx, (sentences, labels, lengths, words) in enumerate(dataloader):
            logits = model(sentences, lengths, words)
            N, L, _ = logits.shape
            logits = logits.view(N * L, -1)
            labels = labels.view(N * L)
        
            total_loss += criterion(input=logits, target=labels)
            total_batches += 1
        
            masks = (labels != -1)
            total_acc += (logits.argmax(-1) == labels)[masks].sum().item()
            total_count += masks.sum()
        
    return total_acc / total_count, total_loss / total_batches

In [29]:
import time

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train_epoch(train_dataloader, model, optimizer, criterion, epoch)
    accuracy, loss = evaluate(dev_dataloader, model)
    scheduler.step()
    print("-" * 59)
    print(f"| end of epoch {epoch:3d} | time: {time.time() - epoch_start_time:5.2f}s | valid accuracy {accuracy:8.3f} | valid loss {loss:8.3f}")
    print("-" * 59)

print("Checking the results on test set...")
test_accuracy, test_loss = evaluate(test_dataloader, model)
print(f"test accuracy {test_accuracy:8.3f} | test loss {test_loss:8.3f}")

| epoch   1 |    50/  110 batches | accuracy    0.476 | loss    1.987
| epoch   1 |   100/  110 batches | accuracy    0.766 | loss    1.481
-----------------------------------------------------------
| end of epoch   1 | time: 16.12s | valid accuracy    0.776 | valid loss    1.223
-----------------------------------------------------------
| epoch   2 |    50/  110 batches | accuracy    0.768 | loss    1.201
| epoch   2 |   100/  110 batches | accuracy    0.771 | loss    1.201
-----------------------------------------------------------
| end of epoch   2 | time: 15.68s | valid accuracy    0.779 | valid loss    1.159
-----------------------------------------------------------
| epoch   3 |    50/  110 batches | accuracy    0.770 | loss    1.163
| epoch   3 |   100/  110 batches | accuracy    0.771 | loss    1.168
-----------------------------------------------------------
| end of epoch   3 | time: 16.92s | valid accuracy    0.780 | valid loss    1.159
----------------------------------