# Setup

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"

In [2]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import re
import time
import pickle

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from collections import Counter

from sklearn import metrics
from sklearn.metrics import classification_report, f1_score

In [18]:
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f17580572b0>

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
dataset = pickle.load(open('conll_graph.pickle', 'rb'))
print(', '.join([split + f' : {len(dataset[split])}' for split in dataset]))

train : 178610, validation : 44900, test : 40760


In [5]:
dataset['train'][2]

{'word': ['german'],
 'label': 'MISC',
 'gt_label': 'B-MISC',
 'surface': 'German',
 'pos': ['<JJ>'],
 'chunk': ['<NP>'],
 'classes': ['<GEOREGION>', '<NAME>', '<GIVEN NAME>', '<FAMILY NAME>'],
 'extra': ['<CAPITALIZED>'],
 'left_context': ['eu', 'rejects'],
 'right_context': ['call', 'to', 'boycott', 'british', 'lamb']}

In [6]:
vocabulary = pickle.load(open('vocabulary.pickle', 'rb'))
print(', '.join([key + f' : {len(vocabulary[key])}' for key in vocabulary]))

word : 18823, chunk : 11, pos : 25, classes : 14, extra : 3


In [7]:
labels = pickle.load(open('labels.pickle', 'rb'))
label2id = {l: i for i, l in enumerate(labels)}
print(label2id)

{'LOC': 0, 'MISC': 1, 'O': 2, 'ORG': 3, 'PER': 4}


In [8]:
voc2id = {}
for key in vocabulary:
    voc2id[key] = {l: i for i, l in enumerate(vocabulary[key])}
print(voc2id['word']['ismail'], voc2id['extra']['<ACRONYM>'], )

8532 0


# Dataset

In [9]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset, split, voc2id=voc2id, label2id=label2id, context='all'):
        X = []
        Y = []
        
        for doc in tqdm(dataset[split], desc=split.upper()):
            # Create the graph one-hot embedding
            onehot = {}
            for key in voc2id:
                onehot[key] = np.zeros(len(voc2id[key]))
                for v in doc[key]:
                    onehot[key][voc2id[key][v]] = 1.
            
            onehot['left'] = np.zeros(len(voc2id['word']))
            onehot['right'] = np.zeros(len(voc2id['word']))
            
            if context == 'all':
                context = max(len(doc['left_context']), len(doc['right_context']))
            
            for w in doc['left_context'][-context:]:
                onehot['left'][voc2id['word'][w]] = 1.
            for w in doc['right_context'][:context]:
                onehot['right'][voc2id['word'][w]] = 1.
            
            doc_embedding = np.concatenate([onehot['word'], onehot['left'], onehot['right'],
                                            onehot['pos'], onehot['chunk'], onehot['extra'], onehot['classes']])
            X.append(doc_embedding)
            Y.append(label2id[doc['label']])
        
        self.X = X
        self.Y = Y
        self.X_len = len(X)
        self.labels = sorted(label2id.keys())
        self.label2id = label2id
        self.voc2id = voc2id

    def __len__(self):
        return self.X_len

    def __getitem__(self, index):
        x = torch.tensor(self.X[index])
        y = torch.tensor(self.Y[index])
        
        return x, y

    def labels(self):
        return self.labels
    
    def voc(self, key):
        return self.voc2id[key]
    
    def Y(self):
        return self.Y

In [10]:
train_set = Dataset(dataset, 'train', context=3)
dev_set = Dataset(dataset, 'validation', context=3)
test_set = Dataset(dataset, 'test', context=3)

HBox(children=(FloatProgress(value=0.0, description='TRAIN', max=178610.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='VALIDATION', max=44900.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='TEST', max=40760.0, style=ProgressStyle(description_width…




In [35]:
batch_size = 64
num_workers = 4

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, num_workers=num_workers, shuffle=True)
dev_loader = torch.utils.data.DataLoader(dev_set, batch_size=batch_size, num_workers=num_workers, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, num_workers=num_workers, shuffle=False)

In [36]:
input_dim = 0
for local_features, local_labels in train_loader:
    input_dim = local_features.shape[1]
    print(local_features.shape)
    print(local_labels.shape)
    break

torch.Size([64, 56522])
torch.Size([64])


In [37]:
input_dim

56522

In [38]:
training_counter = Counter(train_set.Y)
print(training_counter)

Counter({2: 144631, 4: 11124, 3: 9984, 0: 8288, 1: 4583})


In [39]:
labels

['LOC', 'MISC', 'O', 'ORG', 'PER']

# The Model

In [40]:
def backprop(batch_X, batch_Y, model, optimizer, loss_fn):
    Y_hat = model(batch_X)
    loss = loss_fn(Y_hat, batch_Y)
    loss.backward()
    optimizer.step()
    
    return loss.item()

class FeedForwardNetwork(nn.Module):
    def __init__(self, input_dim=input_dim, hidden_dim=256, output_dim=5, dropout_rate=0.2):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fch = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

        # extra layers layers
        self.dropout = nn.Dropout(p=dropout_rate)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fch(x)
        x = self.relu(x)
        x = self.dropout(x)
        logits = self.fc2(x)

        return logits

In [41]:
ffnet = FeedForwardNetwork(dropout_rate=0.2).to(device)

In [42]:
logs = {'loss/train': {}, 'dev': {}}
writer = SummaryWriter(comment='xp4-onehot-allfeat-win3-wei2-lr1e3-mom0.9-wd5e4-hd256-dr0.2-bs64', log_dir=None,)

In [43]:
label_counter   = Counter(train_set.Y)
labels_freqs    = [label_counter[label] / sum(label_counter.values()) for label in range(len(labels))]
labels_weights1 = [min(label_counter.values()) / label_counter[label] for label in range(len(labels))]
labels_weights2 = [np.sqrt(min(label_counter.values())) / np.sqrt(label_counter[label]) for label in range(len(labels))]

weights = torch.Tensor(labels_weights2).to(device)
print(weights)

tensor([0.7436, 1.0000, 0.1780, 0.6775, 0.6419], device='cuda:0')


In [44]:
optimizer_params = {'lr': 1e-3, 
                    'momentum': 0.9, 
                    'weight_decay': 5e-4,
                   }

log_interval = int(len(train_loader) / 5)

loss_fn = nn.CrossEntropyLoss(weight=weights)
optimizer = torch.optim.SGD(ffnet.parameters(), **optimizer_params)

In [45]:
%%time
max_epochs = 15

for epoch in range(len(logs['loss/train']), len(logs['loss/train']) + max_epochs):
    
    # Training
    ffnet.train()
    print('Epoch', epoch)
    logs['loss/train'][epoch] = []
    writer.add_scalar("Learning_rate", optimizer_params['lr'], epoch)

    for batch, (batch_X, batch_Y) in enumerate(tqdm(train_loader)):
        # tranfer to GPU
        batch_X, batch_Y = batch_X.float().to(device), batch_Y.to(device)
        optimizer.zero_grad()
        l = backprop(batch_X, batch_Y, ffnet, optimizer, loss_fn)
        logs['loss/train'][epoch].append(l)
        
        if batch % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch * len(batch_X), len(train_loader.dataset),
                100. * batch / len(train_loader), l))
    
    logs['loss/train'][epoch] = np.mean(logs['loss/train'][epoch])
    writer.add_scalar("Loss/train", logs['loss/train'][epoch], epoch)
    print(f'Average loss on epoch {epoch}: {logs["loss/train"][epoch]}')
    
    # Validation
    ffnet.eval()
    with torch.no_grad():
        preds = []
        gt = []
        for batch, (batch_X, batch_Y) in enumerate(tqdm(dev_loader)):
            # Transfer to GPU
            batch_X = batch_X.float().to(device)
            output = nn.Softmax(dim=1)(ffnet(batch_X))
            preds.append(output.cpu())
            gt.append(batch_Y)

        all_out = [np.argmax(l) for batch in preds for l in batch.numpy()]
        all_gt  = [l for batch in gt for l in batch.numpy()]

        print(classification_report(all_out, all_gt, digits=4))

        micro_F1 = metrics.f1_score(all_gt, all_out, average='micro')
        macro_F1 = metrics.f1_score(all_gt, all_out, average='macro')
        weighted_F1 = metrics.f1_score(all_gt, all_out, average='weighted')
        writer.add_scalar("micro_F1/dev", micro_F1, epoch)
        writer.add_scalar("macro_F1/dev", macro_F1, epoch)
        writer.add_scalar("weighted_F1/dev", weighted_F1, epoch)
        logs['dev'][epoch] = (micro_F1, weighted_F1, macro_F1, (all_gt, all_out))

Epoch 0


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 0: 1.0386693600534496


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.1289    0.7105    0.2183       380
           1     0.0000    0.0000    0.0000         0
           2     0.9572    0.9695    0.9633     35876
           3     0.2272    0.2223    0.2247      2110
           4     0.8291    0.3981    0.5379      6534

    accuracy                         0.8490     44900
   macro avg     0.4285    0.4601    0.3888     44900
weighted avg     0.8972    0.8490    0.8604     44900

Epoch 1


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 1: 0.7458715805045083


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.6686    0.6173    0.6419      2268
           1     0.3291    0.5681    0.4168       734
           2     0.9582    0.9819    0.9699     35460
           3     0.3285    0.3119    0.3200      2174
           4     0.8709    0.6407    0.7383      4264

    accuracy                         0.8919     44900
   macro avg     0.6311    0.6240    0.6174     44900
weighted avg     0.8945    0.8919    0.8908     44900

Epoch 2


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 2: 0.6062210145356706


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.6375    0.7721    0.6984      1729
           1     0.4680    0.5271    0.4958      1125
           2     0.9531    0.9869    0.9697     35094
           3     0.5208    0.3618    0.4270      2971
           4     0.8734    0.6883    0.7699      3981

    accuracy                         0.8993     44900
   macro avg     0.6906    0.6673    0.6722     44900
weighted avg     0.8932    0.8993    0.8938     44900

Epoch 3


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 3: 0.5250974268692875


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.6896    0.8464    0.7600      1706
           1     0.6006    0.5371    0.5671      1417
           2     0.9605    0.9852    0.9727     35428
           3     0.5693    0.4426    0.4980      2655
           4     0.8680    0.7371    0.7972      3694

    accuracy                         0.9133     44900
   macro avg     0.7376    0.7097    0.7190     44900
weighted avg     0.9081    0.9133    0.9093     44900

Epoch 4


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 4: 0.4582222059411799


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.7574    0.8299    0.7920      1911
           1     0.6582    0.5919    0.6233      1409
           2     0.9593    0.9895    0.9742     35230
           3     0.5891    0.4733    0.5249      2569
           4     0.8951    0.7427    0.8118      3781

    accuracy                         0.9199     44900
   macro avg     0.7718    0.7255    0.7452     44900
weighted avg     0.9147    0.9199    0.9160     44900

Epoch 5


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 5: 0.3990549973660325


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.8075    0.8442    0.8255      2003
           1     0.6875    0.6685    0.6778      1303
           2     0.9657    0.9903    0.9779     35435
           3     0.6294    0.5157    0.5669      2519
           4     0.9069    0.7816    0.8396      3640

    accuracy                         0.9309     44900
   macro avg     0.7994    0.7601    0.7775     44900
weighted avg     0.9269    0.9309    0.9281     44900

Epoch 6


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 6: 0.3488952564732954


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.8133    0.8756    0.8433      1945
           1     0.7609    0.6798    0.7181      1418
           2     0.9675    0.9916    0.9794     35454
           3     0.6797    0.5485    0.6071      2558
           4     0.9107    0.8105    0.8577      3525

    accuracy                         0.9373     44900
   macro avg     0.8264    0.7812    0.8011     44900
weighted avg     0.9334    0.9373    0.9345     44900

Epoch 7


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 7: 0.30335330899200186


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.8438    0.8800    0.8615      2008
           1     0.7380    0.7385    0.7383      1266
           2     0.9772    0.9907    0.9839     35843
           3     0.7137    0.6174    0.6620      2386
           4     0.9158    0.8457    0.8794      3397

    accuracy                         0.9478     44900
   macro avg     0.8377    0.8145    0.8250     44900
weighted avg     0.9458    0.9478    0.9465     44900

Epoch 8


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 8: 0.26481594117134755


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.8634    0.8764    0.8699      2063
           1     0.7395    0.7618    0.7505      1230
           2     0.9784    0.9923    0.9853     35831
           3     0.7558    0.6401    0.6932      2437
           4     0.9197    0.8640    0.8910      3339

    accuracy                         0.9520     44900
   macro avg     0.8514    0.8269    0.8380     44900
weighted avg     0.9501    0.9520    0.9507     44900

Epoch 9


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 9: 0.229757052404784


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.8524    0.9075    0.8791      1967
           1     0.7924    0.7595    0.7756      1322
           2     0.9826    0.9915    0.9870     36014
           3     0.7321    0.7054    0.7185      2142
           4     0.9398    0.8533    0.8944      3455

    accuracy                         0.9567     44900
   macro avg     0.8599    0.8434    0.8509     44900
weighted avg     0.9561    0.9567    0.9561     44900

Epoch 10


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 10: 0.19933880985901695


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.8615    0.9116    0.8858      1979
           1     0.7656    0.7918    0.7785      1225
           2     0.9822    0.9920    0.9871     35979
           3     0.7796    0.6735    0.7227      2389
           4     0.9273    0.8741    0.8999      3328

    accuracy                         0.9573     44900
   macro avg     0.8632    0.8486    0.8548     44900
weighted avg     0.9561    0.9573    0.9564     44900

Epoch 11


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 11: 0.17351850600521418


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.8758    0.9034    0.8894      2030
           1     0.8098    0.7478    0.7776      1372
           2     0.9798    0.9932    0.9865     35845
           3     0.7897    0.6803    0.7309      2396
           4     0.9260    0.8919    0.9087      3257

    accuracy                         0.9576     44900
   macro avg     0.8762    0.8433    0.8586     44900
weighted avg     0.9558    0.9576    0.9564     44900

Epoch 12


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 12: 0.15373003180364006


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.8543    0.9279    0.8896      1928
           1     0.7790    0.8379    0.8074      1178
           2     0.9867    0.9914    0.9890     36164
           3     0.7922    0.7177    0.7531      2278
           4     0.9452    0.8845    0.9139      3352

    accuracy                         0.9628     44900
   macro avg     0.8715    0.8719    0.8706     44900
weighted avg     0.9626    0.9628    0.9624     44900

Epoch 13


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 13: 0.13494139134387353


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.8381    0.9340    0.8835      1879
           1     0.7569    0.8766    0.8124      1094
           2     0.9874    0.9910    0.9892     36205
           3     0.8178    0.6912    0.7492      2442
           4     0.9302    0.8896    0.9095      3280

    accuracy                         0.9621     44900
   macro avg     0.8661    0.8765    0.8687     44900
weighted avg     0.9621    0.9621    0.9616     44900

Epoch 14


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 14: 0.1188175538080951


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.8854    0.9057    0.8954      2047
           1     0.8043    0.8030    0.8036      1269
           2     0.9831    0.9934    0.9883     35961
           3     0.8309    0.6863    0.7517      2499
           4     0.9127    0.9165    0.9146      3124

    accuracy                         0.9616     44900
   macro avg     0.8833    0.8610    0.8707     44900
weighted avg     0.9602    0.9616    0.9605     44900

CPU times: user 1h 25min 13s, sys: 13min 45s, total: 1h 38min 58s
Wall time: 31min 3s


In [46]:
ffnet.eval()
with torch.no_grad():
    preds = []
    gt = []
    for batch, (batch_X, batch_Y) in enumerate(tqdm(test_loader)):
        # Transfer to GPU
        batch_X = batch_X.float().to(device)
        output = nn.Softmax(dim=1)(ffnet(batch_X))
        preds.append(output.cpu())
        gt.append(batch_Y)

    all_out = [np.argmax(l) for batch in preds for l in batch.numpy()]
    all_gt  = [l for batch in gt for l in batch.numpy()]

    print(classification_report(all_out, all_gt, digits=4))

HBox(children=(FloatProgress(value=0.0, max=637.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.8208    0.8333    0.8270      1896
           1     0.7527    0.6544    0.7001      1056
           2     0.9708    0.9930    0.9818     31921
           3     0.7766    0.6136    0.6856      3155
           4     0.8550    0.8679    0.8614      2732

    accuracy                         0.9390     40760
   macro avg     0.8352    0.7924    0.8112     40760
weighted avg     0.9354    0.9390    0.9363     40760

