# Setup

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="2,3"

In [2]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import re
import time
import pickle

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from collections import Counter

from sklearn import metrics
from sklearn.metrics import classification_report, f1_score

In [3]:
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f546a49c2b0>

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [40]:
dataset = {'train': [], 'validation': [], 'test': []}
dataset_path = '/data/graphner_embeddings/autoencoder_embeddings_500_2epochs/'

for split in dataset:
    files_list = os.listdir(dataset_path+split)
    for i, filename in tqdm(enumerate(sorted(files_list)), total=len(files_list)):
        dataset[split].append(pickle.load(open(dataset_path+split+'/'+str(i)+'.pickle', 'rb')))

HBox(children=(FloatProgress(value=0.0, max=178610.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=44900.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40760.0), HTML(value='')))




In [7]:
labels = pickle.load(open('labels.pickle', 'rb'))
label2id = {l: i for i, l in enumerate(labels)}
print(label2id)

{'LOC': 0, 'MISC': 1, 'O': 2, 'ORG': 3, 'PER': 4}


# Dataset

In [8]:
x = list(zip(*dataset[split][:10]))
x[0]

(array([-0.01436732,  0.04050591, -0.00089871,  0.00945008,  0.02446665,
        -0.01244789,  0.01709336, -0.02142077,  0.01187329,  0.01046528,
        -0.021883  , -0.00136595,  0.00600611, -0.01329236,  0.00486688,
         0.02181063, -0.00594544, -0.02839023,  0.02226341,  0.0151037 ,
        -0.0096951 , -0.02007345, -0.00057336, -0.01848148,  0.00721815,
         0.00165754,  0.04607863,  0.06198069, -0.02297092,  0.02197368,
         0.01691943,  0.01164624,  0.00712457, -0.00355437,  0.01439966,
        -0.00907129,  0.00530745, -0.00834707, -0.0026083 ,  0.01280503,
         0.00142022,  0.02363534, -0.03480102, -0.01278157,  0.00336421,
         0.00258736,  0.02497259,  0.00027705, -0.05701809, -0.01597444,
         0.00752685,  0.05145542,  0.01890014,  0.03247196,  0.01299692,
        -0.01778115, -0.00883749, -0.01173777, -0.05013476,  0.01755121,
         0.00861797,  0.00312549, -0.02273303,  0.02541528, -0.05910432,
        -0.0291328 ,  0.00334541,  0.04204424, -0.0

In [41]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset, split, label2id=label2id):
        X, Y = zip(*dataset[split])

        self.X = [torch.tensor(x) for x in X]
        self.Y = [torch.tensor(y) for y in Y]
        self.X_len = len(X)
        self.labels = sorted(label2id.keys())
        self.label2id = label2id

    def __len__(self):
        return self.X_len

    def __getitem__(self, index):
        x = self.X[index]
        y = self.Y[index]
        
        return x, y

    def labels(self):
        return self.labels
    
    def Y(self):
        return self.Y

In [12]:
train_set = Dataset(dataset, 'train')
dev_set = Dataset(dataset, 'validation')
test_set = Dataset(dataset, 'test')

In [13]:
batch_size = 256
num_workers = 4

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, num_workers=num_workers, shuffle=True)
dev_loader = torch.utils.data.DataLoader(dev_set, batch_size=batch_size, num_workers=num_workers, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, num_workers=num_workers, shuffle=False)

In [14]:
input_dim = 0
for local_features, local_labels in train_loader:
    input_dim = local_features.shape[1]
    print(local_features.shape)
    print(local_labels.shape)
    break

torch.Size([256, 100])
torch.Size([256])


In [15]:
input_dim

100

In [16]:
training_counter = Counter([y.item() for y in train_set.Y])
print(training_counter)

Counter({2: 144631, 4: 11124, 3: 9984, 0: 8288, 1: 4583})


In [17]:
labels

['LOC', 'MISC', 'O', 'ORG', 'PER']

# The Model

In [30]:
def backprop(batch_X, batch_Y, model, optimizer, loss_fn):
    Y_hat = model(batch_X)
    loss = loss_fn(Y_hat, batch_Y)
    loss.backward()
    optimizer.step()
    
    return loss.item()

class FeedForwardNetwork(nn.Module):
    def __init__(self, input_dim=input_dim, hidden_dim=1024, output_dim=5, dropout_rate=0.2):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fch = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

        # extra layers layers
        self.dropout = nn.Dropout(p=dropout_rate)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fch(x)
        x = self.relu(x)
        x = self.dropout(x)
        logits = self.fc2(x)

        return logits

In [31]:
ffnet = FeedForwardNetwork(dropout_rate=0.2).to(device)

In [32]:
logs = {'loss/train': {}, 'dev': {}}
writer = SummaryWriter(comment='xp5-autoreg-wei2-lr1e3-mom0.9-wd5e4-hd1024-dr0.2-bs64-dim100', log_dir=None,)

In [33]:
label_counter   = Counter([y.item() for y in train_set.Y])
labels_freqs    = [label_counter[label] / sum(label_counter.values()) for label in range(len(labels))]
labels_weights1 = [min(label_counter.values()) / label_counter[label] for label in range(len(labels))]
labels_weights2 = [np.sqrt(min(label_counter.values())) / np.sqrt(label_counter[label]) for label in range(len(labels))]

weights = torch.Tensor(labels_weights2).to(device)
print(weights)

tensor([0.7436, 1.0000, 0.1780, 0.6775, 0.6419], device='cuda:0')


In [34]:
optimizer_params = {'lr': 1e-3, 
                    'momentum': 0.9, 
                    'weight_decay': 5e-4,
                   }

log_interval = int(len(train_loader) / 2)

loss_fn = nn.CrossEntropyLoss(weight=weights)
optimizer = torch.optim.SGD(ffnet.parameters(), **optimizer_params)

In [37]:
%%time
max_epochs = 20

for epoch in range(len(logs['loss/train']), len(logs['loss/train']) + max_epochs):
    
    # Training
    ffnet.train()
    print('Epoch', epoch)
    logs['loss/train'][epoch] = []
    writer.add_scalar("Learning_rate", optimizer_params['lr'], epoch)

    for batch, (batch_X, batch_Y) in enumerate(tqdm(train_loader)):
        # tranfer to GPU
        batch_X, batch_Y = batch_X.float().to(device), batch_Y.to(device)
        optimizer.zero_grad()
        l = backprop(batch_X, batch_Y, ffnet, optimizer, loss_fn)
        logs['loss/train'][epoch].append(l)
        
        if batch % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch * len(batch_X), len(train_loader.dataset),
                100. * batch / len(train_loader), l))
    
    logs['loss/train'][epoch] = np.mean(logs['loss/train'][epoch])
    writer.add_scalar("Loss/train", logs['loss/train'][epoch], epoch)
    print(f'Average loss on epoch {epoch}: {logs["loss/train"][epoch]}')
    
    # Validation
    ffnet.eval()
    with torch.no_grad():
        preds = []
        gt = []
        for batch, (batch_X, batch_Y) in enumerate(tqdm(dev_loader)):
            # Transfer to GPU
            batch_X = batch_X.float().to(device)
            output = nn.Softmax(dim=1)(ffnet(batch_X))
            preds.append(output.cpu())
            gt.append(batch_Y)

        all_out = [np.argmax(l) for batch in preds for l in batch.numpy()]
        all_gt  = [l for batch in gt for l in batch.numpy()]

        print(classification_report(all_out, all_gt, digits=4))

        micro_F1 = metrics.f1_score(all_gt, all_out, average='micro')
        macro_F1 = metrics.f1_score(all_gt, all_out, average='macro')
        weighted_F1 = metrics.f1_score(all_gt, all_out, average='weighted')
        writer.add_scalar("micro_F1/dev", micro_F1, epoch)
        writer.add_scalar("macro_F1/dev", macro_F1, epoch)
        writer.add_scalar("weighted_F1/dev", weighted_F1, epoch)
        logs['dev'][epoch] = (micro_F1, weighted_F1, macro_F1, (all_gt, all_out))

Epoch 20


HBox(children=(FloatProgress(value=0.0, max=698.0), HTML(value='')))


Average loss on epoch 20: 0.7226199455623299


HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.5258    0.4653    0.4937      2366
           1     0.3047    0.5026    0.3794       768
           2     0.9491    0.9773    0.9630     35290
           3     0.2674    0.2451    0.2558      2252
           4     0.7922    0.5883    0.6752      4224

    accuracy                         0.8689     44900
   macro avg     0.5678    0.5557    0.5534     44900
weighted avg     0.8669    0.8689    0.8658     44900

Epoch 21


HBox(children=(FloatProgress(value=0.0, max=698.0), HTML(value='')))


Average loss on epoch 21: 0.7166048536840346


HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.5277    0.4720    0.4983      2341
           1     0.3181    0.5088    0.3915       792
           2     0.9484    0.9783    0.9631     35228
           3     0.2626    0.2501    0.2562      2167
           4     0.8135    0.5837    0.6797      4372

    accuracy                         0.8700     44900
   macro avg     0.5741    0.5586    0.5578     44900
weighted avg     0.8691    0.8700    0.8671     44900

Epoch 22


HBox(children=(FloatProgress(value=0.0, max=698.0), HTML(value='')))


Average loss on epoch 22: 0.7118285242414748


HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.5554    0.4577    0.5018      2541
           1     0.3418    0.4849    0.4009       893
           2     0.9475    0.9794    0.9632     35157
           3     0.2263    0.2385    0.2322      1958
           4     0.8122    0.5856    0.6806      4351

    accuracy                         0.8696     44900
   macro avg     0.5766    0.5492    0.5557     44900
weighted avg     0.8687    0.8696    0.8666     44900

Epoch 23


HBox(children=(FloatProgress(value=0.0, max=698.0), HTML(value='')))


Average loss on epoch 23: 0.707787119343151


HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.5439    0.4613    0.4992      2469
           1     0.3418    0.4983    0.4054       869
           2     0.9464    0.9812    0.9635     35049
           3     0.2578    0.2474    0.2525      2150
           4     0.8138    0.5851    0.6808      4363

    accuracy                         0.8696     44900
   macro avg     0.5807    0.5547    0.5603     44900
weighted avg     0.8667    0.8696    0.8656     44900

Epoch 24


HBox(children=(FloatProgress(value=0.0, max=698.0), HTML(value='')))


Average loss on epoch 24: 0.7030690791869915


HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.5277    0.4753    0.5001      2325
           1     0.3788    0.4598    0.4154      1044
           2     0.9469    0.9799    0.9631     35112
           3     0.2175    0.2426    0.2294      1851
           4     0.8352    0.5736    0.6801      4568

    accuracy                         0.8700     44900
   macro avg     0.5812    0.5462    0.5576     44900
weighted avg     0.8705    0.8700    0.8674     44900

Epoch 25


HBox(children=(FloatProgress(value=0.0, max=698.0), HTML(value='')))


Average loss on epoch 25: 0.7001936905363569


HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.5349    0.4684    0.4994      2391
           1     0.3686    0.4717    0.4138       990
           2     0.9461    0.9817    0.9636     35022
           3     0.2364    0.2440    0.2402      2000
           4     0.8304    0.5793    0.6825      4497

    accuracy                         0.8699     44900
   macro avg     0.5833    0.5490    0.5599     44900
weighted avg     0.8683    0.8699    0.8664     44900

Epoch 26


HBox(children=(FloatProgress(value=0.0, max=698.0), HTML(value='')))


Average loss on epoch 26: 0.6978298189988451


HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.5525    0.4507    0.4965      2567
           1     0.3954    0.4571    0.4240      1096
           2     0.9466    0.9808    0.9634     35072
           3     0.1996    0.2231    0.2107      1847
           4     0.8161    0.5929    0.6868      4318

    accuracy                         0.8692     44900
   macro avg     0.5820    0.5409    0.5563     44900
weighted avg     0.8673    0.8692    0.8660     44900

Epoch 27


HBox(children=(FloatProgress(value=0.0, max=698.0), HTML(value='')))


Average loss on epoch 27: 0.6947774482746862


HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.5401    0.4650    0.4998      2432
           1     0.3954    0.4580    0.4244      1094
           2     0.9455    0.9826    0.9637     34968
           3     0.2171    0.2369    0.2265      1891
           4     0.8339    0.5794    0.6837      4515

    accuracy                         0.8698     44900
   macro avg     0.5864    0.5444    0.5596     44900
weighted avg     0.8683    0.8698    0.8662     44900

Epoch 28


HBox(children=(FloatProgress(value=0.0, max=698.0), HTML(value='')))


Average loss on epoch 28: 0.6934531385287173


HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.5420    0.4619    0.4988      2457
           1     0.4033    0.4494    0.4251      1137
           2     0.9461    0.9817    0.9636     35021
           3     0.2074    0.2281    0.2173      1876
           4     0.8272    0.5886    0.6878      4409

    accuracy                         0.8697     44900
   macro avg     0.5852    0.5420    0.5585     44900
weighted avg     0.8677    0.8697    0.8663     44900

Epoch 29


HBox(children=(FloatProgress(value=0.0, max=698.0), HTML(value='')))


Average loss on epoch 29: 0.6897697806529124


HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.5478    0.4572    0.4984      2509
           1     0.4049    0.4548    0.4284      1128
           2     0.9464    0.9816    0.9636     35035
           3     0.2185    0.2308    0.2245      1954
           4     0.8173    0.5999    0.6919      4274

    accuracy                         0.8700     44900
   macro avg     0.5870    0.5448    0.5614     44900
weighted avg     0.8665    0.8700    0.8662     44900

Epoch 30


HBox(children=(FloatProgress(value=0.0, max=698.0), HTML(value='')))


Average loss on epoch 30: 0.6876914841056212


HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.5325    0.4897    0.5102      2277
           1     0.4009    0.4597    0.4283      1105
           2     0.9458    0.9828    0.9640     34969
           3     0.2597    0.2610    0.2603      2054
           4     0.8371    0.5842    0.6882      4495

    accuracy                         0.8720     44900
   macro avg     0.5952    0.5555    0.5702     44900
weighted avg     0.8692    0.8720    0.8680     44900

Epoch 31


HBox(children=(FloatProgress(value=0.0, max=698.0), HTML(value='')))


Average loss on epoch 31: 0.6858911359293072


HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.5554    0.4483    0.4962      2594
           1     0.4199    0.4574    0.4379      1163
           2     0.9450    0.9837    0.9639     34907
           3     0.2171    0.2331    0.2248      1922
           4     0.8212    0.5971    0.6915      4314

    accuracy                         0.8699     44900
   macro avg     0.5917    0.5439    0.5628     44900
weighted avg     0.8658    0.8699    0.8655     44900

Epoch 32


HBox(children=(FloatProgress(value=0.0, max=698.0), HTML(value='')))


Average loss on epoch 32: 0.6845539778812567


HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.5606    0.4450    0.4962      2638
           1     0.4388    0.4340    0.4364      1281
           2     0.9458    0.9827    0.9639     34974
           3     0.1982    0.2286    0.2123      1789
           4     0.8142    0.6055    0.6945      4218

    accuracy                         0.8700     44900
   macro avg     0.5915    0.5392    0.5607     44900
weighted avg     0.8665    0.8700    0.8661     44900

Epoch 33


HBox(children=(FloatProgress(value=0.0, max=698.0), HTML(value='')))


Average loss on epoch 33: 0.6829179730319703


HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.5669    0.4356    0.4926      2725
           1     0.4420    0.4328    0.4373      1294
           2     0.9453    0.9834    0.9640     34932
           3     0.2049    0.2272    0.2155      1862
           4     0.7985    0.6129    0.6935      4087

    accuracy                         0.8692     44900
   macro avg     0.5915    0.5384    0.5606     44900
weighted avg     0.8638    0.8692    0.8646     44900

Epoch 34


HBox(children=(FloatProgress(value=0.0, max=698.0), HTML(value='')))


Average loss on epoch 34: 0.6810967338119331


HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.5501    0.4604    0.5013      2502
           1     0.4341    0.4432    0.4386      1241
           2     0.9459    0.9828    0.9640     34972
           3     0.2030    0.2517    0.2247      1665
           4     0.8419    0.5843    0.6898      4520

    accuracy                         0.8716     44900
   macro avg     0.5950    0.5445    0.5637     44900
weighted avg     0.8717    0.8716    0.8687     44900

Epoch 35


HBox(children=(FloatProgress(value=0.0, max=698.0), HTML(value='')))


Average loss on epoch 35: 0.6795200198291025


HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.5401    0.4689    0.5020      2412
           1     0.4594    0.4187    0.4381      1390
           2     0.9446    0.9843    0.9641     34871
           3     0.2219    0.2646    0.2414      1731
           4     0.8397    0.5859    0.6902      4496

    accuracy                         0.8715     44900
   macro avg     0.6011    0.5445    0.5671     44900
weighted avg     0.8695    0.8715    0.8677     44900

Epoch 36


HBox(children=(FloatProgress(value=0.0, max=698.0), HTML(value='')))


Average loss on epoch 36: 0.6779435910177777


HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.5511    0.4576    0.5000      2522
           1     0.4428    0.4359    0.4393      1287
           2     0.9455    0.9832    0.9640     34946
           3     0.2311    0.2590    0.2442      1842
           4     0.8244    0.6010    0.6952      4303

    accuracy                         0.8716     44900
   macro avg     0.5990    0.5473    0.5685     44900
weighted avg     0.8680    0.8716    0.8676     44900

Epoch 37


HBox(children=(FloatProgress(value=0.0, max=698.0), HTML(value='')))


Average loss on epoch 37: 0.6770740899452166


HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.5377    0.4794    0.5069      2349
           1     0.4467    0.4367    0.4417      1296
           2     0.9438    0.9851    0.9640     34817
           3     0.2573    0.2614    0.2593      2031
           4     0.8352    0.5945    0.6946      4407

    accuracy                         0.8717     44900
   macro avg     0.6041    0.5514    0.5733     44900
weighted avg     0.8665    0.8717    0.8667     44900

Epoch 38


HBox(children=(FloatProgress(value=0.0, max=698.0), HTML(value='')))


Average loss on epoch 38: 0.6757562373548661


HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.5401    0.4734    0.5046      2389
           1     0.4570    0.4199    0.4376      1379
           2     0.9452    0.9836    0.9640     34918
           3     0.2418    0.2644    0.2526      1887
           4     0.8285    0.6006    0.6964      4327

    accuracy                         0.8720     44900
   macro avg     0.6025    0.5484    0.5711     44900
weighted avg     0.8678    0.8720    0.8677     44900

Epoch 39


HBox(children=(FloatProgress(value=0.0, max=698.0), HTML(value='')))


Average loss on epoch 39: 0.673715257414091


HBox(children=(FloatProgress(value=0.0, max=176.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.5392    0.4689    0.5016      2408
           1     0.4562    0.4225    0.4387      1368
           2     0.9451    0.9838    0.9641     34908
           3     0.2200    0.2706    0.2427      1678
           4     0.8467    0.5853    0.6921      4538

    accuracy                         0.8722     44900
   macro avg     0.6014    0.5462    0.5678     44900
weighted avg     0.8714    0.8722    0.8688     44900

CPU times: user 1min 33s, sys: 26.6 s, total: 2min
Wall time: 1min 58s


In [38]:
ffnet.eval()
with torch.no_grad():
    preds = []
    gt = []
    for batch, (batch_X, batch_Y) in enumerate(tqdm(test_loader)):
        # Transfer to GPU
        batch_X = batch_X.float().to(device)
        output = nn.Softmax(dim=1)(ffnet(batch_X))
        preds.append(output.cpu())
        gt.append(batch_Y)

    all_out = [np.argmax(l) for batch in preds for l in batch.numpy()]
    all_gt  = [l for batch in gt for l in batch.numpy()]

    print(classification_report(all_out, all_gt, digits=4))

HBox(children=(FloatProgress(value=0.0, max=160.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.5979    0.4368    0.5048      2635
           1     0.4488    0.3627    0.4012      1136
           2     0.9341    0.9843    0.9585     30986
           3     0.2331    0.3384    0.2760      1717
           4     0.8431    0.5455    0.6624      4286

    accuracy                         0.8582     40760
   macro avg     0.6114    0.5335    0.5606     40760
weighted avg     0.8597    0.8582    0.8538     40760

