# Setup

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="2,3"

In [2]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import re
import time
import pickle

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from collections import Counter

from sklearn import metrics
from sklearn.metrics import classification_report, f1_score

In [3]:
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7f2df41462b0>

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [8]:
dataset = {'train': [], 'validation': [], 'test': []}
dataset_path = '/data/graphner_embeddings/ae_emb_npy_1000/'

for split in dataset:
    files_list = os.listdir(dataset_path+split)
    for i, filename in tqdm(enumerate(sorted(files_list)), total=len(files_list)):
        dataset[split].append(pickle.load(open(dataset_path+split+'/'+str(i)+'.pickle', 'rb')))

HBox(children=(FloatProgress(value=0.0, max=178610.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=44900.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=40760.0), HTML(value='')))




In [9]:
labels = pickle.load(open('labels.pickle', 'rb'))
label2id = {l: i for i, l in enumerate(labels)}
print(label2id)

{'LOC': 0, 'MISC': 1, 'O': 2, 'ORG': 3, 'PER': 4}


# Dataset

In [10]:
x = list(zip(*dataset[split][:10]))
x[0]

(array([ 1.03966221e-02, -2.52080709e-03,  8.17069970e-03, -9.78287123e-03,
         2.23466544e-03, -2.27706507e-04,  3.99471726e-03, -1.95034780e-02,
         6.88238582e-03,  7.10824877e-03,  1.66191924e-02, -1.38121378e-03,
        -2.59152539e-02,  4.21613269e-03, -1.26671791e-03,  3.02569382e-03,
        -1.55790932e-02,  5.68529824e-03, -2.07880163e-03,  4.86374460e-03,
         7.03364378e-04,  1.08141275e-02,  1.01359934e-02, -1.02531426e-02,
        -1.12654483e-02,  4.46948456e-03, -3.08720791e-03, -1.22914836e-03,
         1.31610865e-02, -4.27529681e-04,  5.97495772e-02,  3.16650211e-03,
        -1.31280674e-03, -8.82130000e-04, -7.26089254e-03, -5.11579867e-03,
        -1.99166797e-02,  1.01702074e-02,  6.71121385e-03, -9.66870412e-03,
         1.96740683e-03, -4.87065176e-03,  1.39567384e-03,  9.19895899e-03,
        -6.37533143e-03,  4.07804549e-03,  1.57981999e-02,  1.23224426e-02,
        -1.50675662e-02, -1.21070258e-03, -1.36184972e-03,  6.07496686e-03,
         7.7

In [11]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset, split, label2id=label2id):
        X, Y = zip(*dataset[split])

        self.X = [torch.tensor(x) for x in X]
        self.Y = [torch.tensor(y) for y in Y]
        self.X_len = len(X)
        self.labels = sorted(label2id.keys())
        self.label2id = label2id

    def __len__(self):
        return self.X_len

    def __getitem__(self, index):
        x = self.X[index]
        y = self.Y[index]
        x.requires_grad = False
        
        return x, y

    def labels(self):
        return self.labels
    
    def Y(self):
        return self.Y

In [13]:
train_set = Dataset(dataset, 'train')
dev_set = Dataset(dataset, 'validation')
test_set = Dataset(dataset, 'test')

In [14]:
batch_size = 64
num_workers = 4

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, num_workers=num_workers, shuffle=True)
dev_loader = torch.utils.data.DataLoader(dev_set, batch_size=batch_size, num_workers=num_workers, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, num_workers=num_workers, shuffle=False)

In [15]:
input_dim = 0
for local_features, local_labels in train_loader:
    input_dim = local_features.shape[1]
    print(local_features.shape)
    print(local_labels.shape)
    break

torch.Size([64, 1000])
torch.Size([64])


In [16]:
input_dim

1000

In [17]:
training_counter = Counter([y.item() for y in train_set.Y])
print(training_counter)

Counter({2: 144631, 4: 11124, 3: 9984, 0: 8288, 1: 4583})


In [18]:
labels

['LOC', 'MISC', 'O', 'ORG', 'PER']

# The Model

In [43]:
def backprop(batch_X, batch_Y, model, optimizer, loss_fn):
    Y_hat = model(batch_X)
    loss = loss_fn(Y_hat, batch_Y)
    loss.backward()
    optimizer.step()
    
    return loss.item()

class FeedForwardNetwork(nn.Module):
    def __init__(self, input_dim=input_dim, hidden_dim=512, output_dim=5, dropout_rate=0.2):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fch = nn.Linear(hidden_dim, hidden_dim)
        # self.fch2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

        # extra layers layers
        self.dropout = nn.Dropout(p=dropout_rate)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fch(x)
        x = self.relu(x)
        x = self.dropout(x)
        logits = self.fc2(x)

        return logits

In [80]:
logs = {'loss/train': {}, 'dev': {}}
writer = SummaryWriter(comment='xp5-autoreg-wei2-lr1e3-mom0.95-wd5e4-hd512-dr0.2-bs64-f1000', log_dir=None,)

In [81]:
ffnet = FeedForwardNetwork(dropout_rate=0.2).to(device)

In [82]:
label_counter   = Counter([y.item() for y in train_set.Y])
labels_freqs    = [label_counter[label] / sum(label_counter.values()) for label in range(len(labels))]
labels_weights1 = [min(label_counter.values()) / label_counter[label] for label in range(len(labels))]
labels_weights2 = [np.sqrt(min(label_counter.values())) / np.sqrt(label_counter[label]) for label in range(len(labels))]

weights = torch.Tensor(labels_weights2).to(device)
print(weights)

tensor([0.7436, 1.0000, 0.1780, 0.6775, 0.6419], device='cuda:0')


In [83]:
optimizer_params = {'lr': 1e-3, 
                    'momentum': 0.95, 
                    'weight_decay': 5e-4,
                   }

log_interval = int(len(train_loader) / 5)

loss_fn = nn.CrossEntropyLoss(weight=weights)
optimizer = torch.optim.SGD(ffnet.parameters(), **optimizer_params)

In [93]:
%%time
max_epochs = 5

for epoch in range(len(logs['loss/train']), len(logs['loss/train']) + max_epochs):
    
    # Training
    ffnet.train()
    print('Epoch', epoch)
    logs['loss/train'][epoch] = []
    writer.add_scalar("Learning_rate", optimizer_params['lr'], epoch)

    for batch, (batch_X, batch_Y) in enumerate(tqdm(train_loader)):
        # tranfer to GPU
        batch_X, batch_Y = batch_X.float().to(device), batch_Y.to(device)
        optimizer.zero_grad()
        l = backprop(batch_X, batch_Y, ffnet, optimizer, loss_fn)
        logs['loss/train'][epoch].append(l)
        
        if batch % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch * len(batch_X), len(train_loader.dataset),
                100. * batch / len(train_loader), l))
    
    logs['loss/train'][epoch] = np.mean(logs['loss/train'][epoch])
    writer.add_scalar("Loss/train", logs['loss/train'][epoch], epoch)
    print(f'Average loss on epoch {epoch}: {logs["loss/train"][epoch]}')
    
    # Validation
    ffnet.eval()
    with torch.no_grad():
        preds = []
        gt = []
        for batch, (batch_X, batch_Y) in enumerate(tqdm(dev_loader)):
            # Transfer to GPU
            batch_X = batch_X.float().to(device)
            output = nn.Softmax(dim=1)(ffnet(batch_X))
            preds.append(output.cpu())
            gt.append(batch_Y)

        all_out = [np.argmax(l) for batch in preds for l in batch.numpy()]
        all_gt  = [l for batch in gt for l in batch.numpy()]

        print(classification_report(all_out, all_gt, digits=4))

        micro_F1 = metrics.f1_score(all_gt, all_out, average='micro')
        macro_F1 = metrics.f1_score(all_gt, all_out, average='macro')
        weighted_F1 = metrics.f1_score(all_gt, all_out, average='weighted')
        writer.add_scalar("micro_F1/dev", micro_F1, epoch)
        writer.add_scalar("macro_F1/dev", macro_F1, epoch)
        writer.add_scalar("weighted_F1/dev", weighted_F1, epoch)
        logs['dev'][epoch] = (micro_F1, weighted_F1, macro_F1, (all_gt, all_out))

Epoch 42


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 42: 0.48417742139803643


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.7187    0.7525    0.7352      2000
           1     0.5091    0.6022    0.5518      1071
           2     0.9689    0.9854    0.9771     35731
           3     0.5824    0.4449    0.5044      2702
           4     0.8027    0.7415    0.7709      3396

    accuracy                         0.9149     44900
   macro avg     0.7163    0.7053    0.7079     44900
weighted avg     0.9110    0.9149    0.9121     44900

Epoch 43


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 43: 0.484353571203886


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.7034    0.7761    0.7380      1898
           1     0.6030    0.5063    0.5504      1509
           2     0.9672    0.9867    0.9769     35619
           3     0.4985    0.4759    0.4870      2162
           4     0.8393    0.7093    0.7689      3712

    accuracy                         0.9141     44900
   macro avg     0.7223    0.6909    0.7042     44900
weighted avg     0.9107    0.9141    0.9116     44900

Epoch 44


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 44: 0.484504117144239


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.7521    0.7159    0.7336      2200
           1     0.6511    0.4601    0.5392      1793
           2     0.9561    0.9922    0.9738     35016
           3     0.4409    0.5072    0.4717      1794
           4     0.8712    0.6671    0.7556      4097

    accuracy                         0.9084     44900
   macro avg     0.7343    0.6685    0.6948     44900
weighted avg     0.9056    0.9084    0.9047     44900

Epoch 45


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 45: 0.48253625459416455


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.7493    0.7301    0.7396      2149
           1     0.5880    0.5103    0.5464      1460
           2     0.9582    0.9921    0.9748     35098
           3     0.5044    0.4797    0.4917      2170
           4     0.8626    0.6726    0.7559      4023

    accuracy                         0.9105     44900
   macro avg     0.7325    0.6770    0.7017     44900
weighted avg     0.9057    0.9105    0.9067     44900

Epoch 46


HBox(children=(FloatProgress(value=0.0, max=2791.0), HTML(value='')))


Average loss on epoch 46: 0.4827101326563246


HBox(children=(FloatProgress(value=0.0, max=702.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.7593    0.7098    0.7337      2240
           1     0.5391    0.5486    0.5438      1245
           2     0.9583    0.9917    0.9747     35113
           3     0.5959    0.4263    0.4971      2885
           4     0.8004    0.7349    0.7662      3417

    accuracy                         0.9095     44900
   macro avg     0.7306    0.6823    0.7031     44900
weighted avg     0.9014    0.9095    0.9042     44900

CPU times: user 56 s, sys: 10.4 s, total: 1min 6s
Wall time: 1min 8s


In [92]:
ffnet.eval()
with torch.no_grad():
    preds = []
    gt = []
    for batch, (batch_X, batch_Y) in enumerate(tqdm(test_loader)):
        # Transfer to GPU
        batch_X = batch_X.float().to(device)
        output = nn.Softmax(dim=1)(ffnet(batch_X))
        preds.append(output.cpu())
        gt.append(batch_Y)

    all_out = [np.argmax(l) for batch in preds for l in batch.numpy()]
    all_gt  = [l for batch in gt for l in batch.numpy()]

    print(classification_report(all_out, all_gt, digits=4))

HBox(children=(FloatProgress(value=0.0, max=637.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.7081    0.7400    0.7237      1842
           1     0.6590    0.3557    0.4620      1701
           2     0.9564    0.9887    0.9723     31584
           3     0.4813    0.5432    0.5104      2209
           4     0.8384    0.6790    0.7504      3424

    accuracy                         0.9009     40760
   macro avg     0.7287    0.6613    0.6837     40760
weighted avg     0.8971    0.9009    0.8961     40760

