# Setup

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="2,3"

In [2]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import re
import time
import pickle

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from collections import Counter

from sklearn import metrics
from sklearn.metrics import classification_report, f1_score

In [3]:
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f1df41542b0>

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
dataset = pickle.load(open('conll_graph_all.pickle', 'rb'))
print(', '.join([split + f' : {len(dataset[split])}' for split in dataset]))

train : 178610, validation : 44900, test : 40760


In [6]:
dataset['train'][2]

{'word': ['german'],
 'label': 'MISC',
 'gt_label': 'B-MISC',
 'surface': 'German',
 'pos': ['<JJ>'],
 'chunk': ['<NP>'],
 'classes': ['<GEOREGION>', '<NAME>', '<GIVEN NAME>', '<FAMILY NAME>'],
 'extra': ['<CAPITALIZED>'],
 'left_context': ['eu', 'rejects'],
 'right_context': ['call', 'to', 'boycott', 'british', 'lamb']}

In [7]:
vocabulary = pickle.load(open('vocabulary_all.pickle', 'rb'))
print(', '.join([key + f' : {len(vocabulary[key])}' for key in vocabulary]))

word : 18993, chunk : 11, pos : 25, classes : 14, extra : 3


In [8]:
labels = pickle.load(open('labels.pickle', 'rb'))
label2id = {l: i for i, l in enumerate(labels)}
print(label2id)

{'LOC': 0, 'MISC': 1, 'O': 2, 'ORG': 3, 'PER': 4}


In [9]:
voc2id = {}
for key in vocabulary:
    voc2id[key] = {l: i for i, l in enumerate(vocabulary[key])}
print(voc2id['word']['ismail'], voc2id['extra']['<ACRONYM>'], )

8594 0


# Dataset

In [10]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset, split, voc2id=voc2id, label2id=label2id, context='all'):
        X = []
        Y = []
        
        for doc in tqdm(dataset[split], desc=split.upper()):
            # Create the graph one-hot embedding
            onehot = {}
            for key in voc2id:
                onehot[key] = np.zeros(len(voc2id[key]))
                for v in doc[key]:
                    onehot[key][voc2id[key][v]] = 1.
            
            onehot['left'] = np.zeros(len(voc2id['word']))
            onehot['right'] = np.zeros(len(voc2id['word']))
            
            if context == 'all':
                context = max(len(doc['left_context']), len(doc['right_context']))
            
            for w in doc['left_context'][-context:]:
                onehot['left'][voc2id['word'][w]] = 1.
            for w in doc['right_context'][:context]:
                onehot['right'][voc2id['word'][w]] = 1.
            
            doc_embedding = np.concatenate([onehot['word'], onehot['left'], onehot['right'],
                                            onehot['pos'], onehot['chunk'], onehot['classes'], onehot['extra']])
            X.append(torch.tensor(doc_embedding))
            Y.append(torch.tensor(label2id[doc['label']]))
        
        self.X = X
        self.Y = Y
        self.X_len = len(X)
        self.labels = sorted(label2id.keys())
        self.label2id = label2id
        self.voc2id = voc2id

    def __len__(self):
        return self.X_len

    def __getitem__(self, index):
        x = self.X[index]
        y = self.Y[index]
        
        return x, y

    def labels(self):
        return self.labels
    
    def voc(self, key):
        return self.voc2id[key]
    
    def Y(self):
        return self.Y

In [11]:
train_set = Dataset(dataset, 'train', context=3)
dev_set = Dataset(dataset, 'validation', context=3)
test_set = Dataset(dataset, 'test', context=3)

HBox(children=(FloatProgress(value=0.0, description='TRAIN', max=178610.0, style=ProgressStyle(description_wid…




HBox(children=(FloatProgress(value=0.0, description='VALIDATION', max=44900.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='TEST', max=40760.0, style=ProgressStyle(description_width…




In [12]:
batch_size = 64
num_workers = 4

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, num_workers=num_workers, shuffle=True)
dev_loader = torch.utils.data.DataLoader(dev_set, batch_size=batch_size, num_workers=num_workers, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, num_workers=num_workers, shuffle=False)

In [13]:
input_dim = 0
for local_features, local_labels in train_loader:
    input_dim = local_features.shape[1]
    print(local_features.shape)
    print(local_labels.shape)
    break

torch.Size([64, 57032])
torch.Size([64])


In [14]:
input_dim

57032

In [15]:
training_counter = Counter([y.item() for y in train_set.Y])
print(training_counter)

Counter({2: 144631, 4: 11124, 3: 9984, 0: 8288, 1: 4583})


In [16]:
labels

['LOC', 'MISC', 'O', 'ORG', 'PER']

# The Model

In [19]:
def backprop(batch_X, batch_Y, model, optimizer, loss_fn):
    Y_hat = model(batch_X)
    loss = loss_fn(Y_hat, batch_Y)
    loss.backward()
    optimizer.step()
    
    return loss.item()

class FeedForwardNetwork(nn.Module):
    def __init__(self, input_dim=input_dim, hidden_dim=1024, output_dim=5, dropout_rate=0.2):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fch = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

        # extra layers layers
        self.dropout = nn.Dropout(p=dropout_rate)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fch(x)
        x = self.relu(x)
        x = self.dropout(x)
        logits = self.fc2(x)

        return logits

In [20]:
ffnet = FeedForwardNetwork().to(device)

In [21]:
logs = {'loss/train': {}, 'dev': {}}
writer = SummaryWriter(comment='xp4-onehot-allfeat-win3-wei2-lr1e3-mom0.9-wd5e4-hd1024-dr0.2-bs64', log_dir=None,)

In [22]:
label_counter   = Counter([y.item() for y in train_set.Y])
labels_freqs    = [label_counter[label] / sum(label_counter.values()) for label in range(len(labels))]
labels_weights1 = [min(label_counter.values()) / label_counter[label] for label in range(len(labels))]
labels_weights2 = [np.sqrt(min(label_counter.values())) / np.sqrt(label_counter[label]) for label in range(len(labels))]

weights = torch.Tensor(labels_weights2).to(device)
print(weights)

tensor([0.7436, 1.0000, 0.1780, 0.6775, 0.6419], device='cuda:0')


In [23]:
optimizer_params = {'lr': 1e-3, 
                    'momentum': 0.9, 
                    'weight_decay': 5e-4,
                   }

log_interval = int(len(train_loader) / 5)

loss_fn = nn.CrossEntropyLoss(weight=weights)
optimizer = torch.optim.SGD(ffnet.parameters(), **optimizer_params)

In [28]:
%%time
max_epochs = 10

for epoch in range(len(logs['loss/train']), len(logs['loss/train']) + max_epochs):
    
    # Training
    ffnet.train()
    print('Epoch', epoch)
    logs['loss/train'][epoch] = []
    writer.add_scalar("Learning_rate", optimizer_params['lr'], epoch)

    for batch, (batch_X, batch_Y) in enumerate(tqdm(train_loader)):
        # tranfer to GPU
        batch_X, batch_Y = batch_X.float().to(device), batch_Y.to(device)
        optimizer.zero_grad()
        l = backprop(batch_X, batch_Y, ffnet, optimizer, loss_fn)
        logs['loss/train'][epoch].append(l)
        
        if batch % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch * len(batch_X), len(train_loader.dataset),
                100. * batch / len(train_loader), l))
    
    logs['loss/train'][epoch] = np.mean(logs['loss/train'][epoch])
    writer.add_scalar("Loss/train", logs['loss/train'][epoch], epoch)
    print(f'Average loss on epoch {epoch}: {logs["loss/train"][epoch]}')
    
    # Validation
    ffnet.eval()
    with torch.no_grad():
        preds = []
        gt = []
        for batch, (batch_X, batch_Y) in enumerate(tqdm(dev_loader)):
            # Transfer to GPU
            batch_X = batch_X.float().to(device)
            output = nn.Softmax(dim=1)(ffnet(batch_X))
            preds.append(output.cpu())
            gt.append(batch_Y)

        all_out = [np.argmax(l) for batch in preds for l in batch.numpy()]
        all_gt  = [l for batch in gt for l in batch.numpy()]

        print(classification_report(all_out, all_gt, digits=4))

        micro_F1 = metrics.f1_score(all_gt, all_out, average='micro')
        macro_F1 = metrics.f1_score(all_gt, all_out, average='macro')
        weighted_F1 = metrics.f1_score(all_gt, all_out, average='weighted')
        writer.add_scalar("micro_F1/dev", micro_F1, epoch)
        writer.add_scalar("macro_F1/dev", macro_F1, epoch)
        writer.add_scalar("weighted_F1/dev", weighted_F1, epoch)
        logs['dev'][epoch] = (micro_F1, weighted_F1, macro_F1, (all_gt, all_out))

Epoch 20


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 20: 0.058727922472790475


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.8897    0.9101    0.8998      2047
           1     0.8161    0.8319    0.8239      1243
           2     0.9916    0.9906    0.9911     36374
           3     0.8198    0.7783    0.7985      2174
           4     0.9142    0.9366    0.9253      3062

    accuracy                         0.9686     44900
   macro avg     0.8863    0.8895    0.8877     44900
weighted avg     0.9685    0.9686    0.9685     44900

Epoch 21


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 21: 0.05417027509734915


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.8601    0.9307    0.8940      1935
           1     0.8051    0.8430    0.8236      1210
           2     0.9919    0.9904    0.9912     36391
           3     0.8275    0.7642    0.7946      2235
           4     0.9257    0.9281    0.9269      3129

    accuracy                         0.9683     44900
   macro avg     0.8821    0.8913    0.8861     44900
weighted avg     0.9684    0.9683    0.9682     44900

Epoch 22


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 22: 0.05075821866018051


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.8663    0.9303    0.8971      1950
           1     0.7877    0.8816    0.8320      1132
           2     0.9919    0.9910    0.9914     36372
           3     0.8217    0.7709    0.7955      2200
           4     0.9413    0.9097    0.9253      3246

    accuracy                         0.9689     44900
   macro avg     0.8818    0.8967    0.8883     44900
weighted avg     0.9693    0.9689    0.9689     44900

Epoch 23


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 23: 0.047202978642715464


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.8926    0.9064    0.8994      2062
           1     0.8074    0.8448    0.8257      1211
           2     0.9910    0.9917    0.9914     36314
           3     0.8261    0.7725    0.7984      2207
           4     0.9232    0.9324    0.9278      3106

    accuracy                         0.9689     44900
   macro avg     0.8880    0.8896    0.8885     44900
weighted avg     0.9688    0.9689    0.9688     44900

Epoch 24


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 24: 0.04478775661299222


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.8949    0.9146    0.9047      2049
           1     0.8129    0.8408    0.8266      1225
           2     0.9923    0.9905    0.9914     36406
           3     0.8023    0.8194    0.8108      2021
           4     0.9366    0.9184    0.9274      3199

    accuracy                         0.9701     44900
   macro avg     0.8878    0.8967    0.8922     44900
weighted avg     0.9705    0.9701    0.9703     44900

Epoch 25


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 25: 0.04211730366710884


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.8792    0.9251    0.9016      1990
           1     0.7758    0.8896    0.8288      1105
           2     0.9936    0.9892    0.9914     36502
           3     0.8232    0.7790    0.8005      2181
           4     0.9241    0.9286    0.9263      3122

    accuracy                         0.9694     44900
   macro avg     0.8792    0.9023    0.8897     44900
weighted avg     0.9701    0.9694    0.9696     44900

Epoch 26


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 26: 0.039916071320948335


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.8930    0.9126    0.9027      2049
           1     0.7830    0.8756    0.8267      1133
           2     0.9927    0.9902    0.9915     36432
           3     0.8251    0.7844    0.8043      2171
           4     0.9251    0.9316    0.9283      3115

    accuracy                         0.9697     44900
   macro avg     0.8838    0.8989    0.8907     44900
weighted avg     0.9701    0.9697    0.9698     44900

Epoch 27


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 27: 0.037937623747820996


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.9050    0.9028    0.9039      2099
           1     0.8540    0.7829    0.8169      1382
           2     0.9922    0.9905    0.9913     36398
           3     0.7432    0.8781    0.8050      1747
           4     0.9496    0.9099    0.9293      3274

    accuracy                         0.9698     44900
   macro avg     0.8888    0.8928    0.8893     44900
weighted avg     0.9710    0.9698    0.9701     44900

Epoch 28


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 28: 0.036590933250493576


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.8911    0.9076    0.8993      2056
           1     0.8224    0.8309    0.8267      1254
           2     0.9922    0.9912    0.9917     36375
           3     0.7999    0.8149    0.8073      2026
           4     0.9343    0.9191    0.9267      3189

    accuracy                         0.9698     44900
   macro avg     0.8880    0.8928    0.8903     44900
weighted avg     0.9701    0.9698    0.9699     44900

Epoch 29


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 29: 0.03511683914994949


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.9088    0.8938    0.9013      2129
           1     0.8548    0.7697    0.8100      1407
           2     0.9906    0.9917    0.9912     36298
           3     0.7374    0.8727    0.7994      1744
           4     0.9512    0.8983    0.9240      3322

    accuracy                         0.9686     44900
   macro avg     0.8886    0.8852    0.8852     44900
weighted avg     0.9697    0.9686    0.9688     44900

CPU times: user 1h 9min 30s, sys: 9min 39s, total: 1h 19min 10s
Wall time: 30min 50s


In [29]:
ffnet.eval()
with torch.no_grad():
    preds = []
    gt = []
    for batch, (batch_X, batch_Y) in enumerate(tqdm(test_loader)):
        # Transfer to GPU
        batch_X = batch_X.float().to(device)
        output = nn.Softmax(dim=1)(ffnet(batch_X))
        preds.append(output.cpu())
        gt.append(batch_Y)

    all_out = [np.argmax(l) for batch in preds for l in batch.numpy()]
    all_gt  = [l for batch in gt for l in batch.numpy()]

    print(classification_report(all_out, all_gt, digits=4))

HBox(children=(FloatProgress(value=0.0, max=637.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.8686    0.8085    0.8375      2068
           1     0.7854    0.5988    0.6795      1204
           2     0.9799    0.9895    0.9847     32337
           3     0.6635    0.7975    0.7243      2074
           4     0.9268    0.8352    0.8786      3077

    accuracy                         0.9473     40760
   macro avg     0.8448    0.8059    0.8209     40760
weighted avg     0.9484    0.9473    0.9469     40760

