In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3"

In [2]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import re
import time
import pickle

import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from collections import Counter

from sklearn import metrics
from sklearn.metrics import classification_report, f1_score

In [3]:
! nvidia-smi

Wed Nov 18 15:22:56 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.23.04    Driver Version: 455.23.04    CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           On   | 00000000:05:00.0 Off |                    0 |
| N/A   23C    P8    26W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla K80           On   | 00000000:06:00.0 Off |                    0 |
| N/A   29C    P8    29W / 149W |      0MiB / 11441MiB |      0%      Default |
|       

In [4]:
torch.manual_seed(42)
np.random.seed(42)

# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print('Device:', device)
torch.backends.cudnn.benchmark = True

Device: cuda


In [5]:
from datasets import load_dataset

conll_dataset = load_dataset("conll2003")

Reusing dataset conll2003 (/opt/tmp/huggingface/datasets/conll2003/conll2003/1.0.0/26b70ce2b0f32cb35a27151dbfa2dbe88c82bcdaf8f29433bcdc612a9b314e83)


In [6]:
cn = pd.read_csv('conceptnet_en.csv')
cn_isa = pd.read_csv('data/conceptnet_isa.csv')
cn_keys = set(cn.subject.values)

In [7]:
word2labels = pickle.load(open('edges/word2labels.pickle', 'rb'))

In [8]:
word2labels['jacob']

['given_name']

In [9]:
len(conll_dataset['train']), len(conll_dataset['validation']), len(conll_dataset['test'])

(14041, 3250, 3453)

In [11]:
nodes_embeddings = {'hope_gsvd': None,
                    'lap_eigmap_svd': None,
                    'lle_svd': None,
                    'node2vec_rw': None}

for embedding_name in nodes_embeddings:
    nodes_embeddings[embedding_name] = pickle.load(open('edges/'+embedding_name+'_all_embeddings.pickle', 'rb'))

In [12]:
class Dataset(data.Dataset):
    def __init__(self, dataset, dataset_split, nodes_embeddings, window_size = 2):
        'Initialization'
        RAW, X, Y = [], [], []
        for doc in tqdm(dataset[dataset_split], desc=f'Loading split {dataset_split}'):
            text = [w.lower() for w in doc['words']]
            for i, (token, pos, label) in enumerate(zip(doc['words'], doc['pos'], doc['ner'])):
                if token == pos:
                    continue # this is punctuation

                

                if token.endswith('='):
                    token = token[:-1]

                while token and token[0] in "!$%&'*+,-.:;<=>?@`":
                    token = token[1:]

                token = re.sub(r'\d+', '<NUM>', token)
                token = token.replace('`', "'")
                
                if token.lower() not in nodes_embeddings: # new words appearing only in the eval and test
                    token = '<span>'
                
                if not token:
                    continue

                extra = ['<'+pos.lower()+'>' if pos.lower() in nodes_embeddings else '<span>']
                if token.lower() in word2labels:
                    extra.extend(['<'+l.lower()+'>' for l in word2labels[token.lower()]])
                if token.lower() not in cn_keys:
                    extra.append('<not_in_dict>')
                if token == token.upper():
                    extra.append('<all_caps>')
                if token.count('.') > 0 and (token.count('.') + 1) == len(token.split('.')): # C.J or C.J.
                    extra.append('<accronym>')
                if token[0] == token[0].upper() and token[1:] == token[1:].lower(): 
                    extra.append('<capitalized>')
                
                
                left_context  = text[max(i-window_size, 0):i] + ([] if i >= window_size else ['<span>'])
                left_context  = [w if w in nodes_embeddings else '<span>' for w in left_context]
                
                right_context = text[i+1:i+1+window_size] + ([] if i + window_size < len(text) else ['<span>'])
                right_context = [w if w in nodes_embeddings else '<span>' for w in right_context]
                
                graph_rep = np.concatenate([nodes_embeddings[token.lower()],
                                            np.mean([nodes_embeddings[w] for w in left_context], axis=0),
                                            np.mean([nodes_embeddings[w] for w in right_context], axis=0),
                                            np.mean([nodes_embeddings[w] for w in extra], axis=0),
                                           ])
                X.append(graph_rep)
                Y.append(label.split('-')[-1])
                RAW.append((token, left_context, right_context, extra))
                
        
        self.X = np.array(X)
        self.labels = sorted(set(Y))
        self.y2index = {l: i for i, l in enumerate(self.labels)}
        self.Y = np.array([self.y2index[y] for y in Y])
        self.RAW = RAW
    
    def __len__(self):
        'Denotes the total number of samples'
        return len(self.X)
    
    def get_raw_item(self, index):
        'Denotes the total number of samples'
        return self.RAW[index]
    
    def get_labels(self):
        return self.labels
    
    def get_Y(self):
        return self.Y

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        x = self.X[index] #.to('cuda') # [:voc_size]
        y = self.Y[index]
        
        return x, y

In [13]:
batch_size  = 64
num_workers = 4
embeddings_to_use = 'lle_svd'

train_set = Dataset(conll_dataset, 'train', nodes_embeddings[embeddings_to_use])
train_loader = DataLoader(train_set, batch_size=batch_size, num_workers=num_workers, shuffle=True) # , sampler=sampler) #

dev_set = Dataset(conll_dataset, 'validation', nodes_embeddings[embeddings_to_use])
dev_loader = DataLoader(dev_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

test_set = Dataset(conll_dataset, 'test', nodes_embeddings[embeddings_to_use])
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)

HBox(children=(FloatProgress(value=0.0, description='Loading split train', max=14041.0, style=ProgressStyle(de…




HBox(children=(FloatProgress(value=0.0, description='Loading split validation', max=3250.0, style=ProgressStyl…




HBox(children=(FloatProgress(value=0.0, description='Loading split test', max=3453.0, style=ProgressStyle(desc…




In [14]:
labels = train_set.get_labels()
label_counter   = Counter([labels[y] for y in train_set.get_Y()])
labels_freqs    = [label_counter[label] / sum(label_counter.values()) for label in labels]
labels_weights  = [min(label_counter.values()) / label_counter[label] for label in labels]
labels_weights2 = [np.sqrt(min(label_counter.values())) / np.sqrt(label_counter[label]) for label in labels]

# sampling_probs = [labels_weights2[labels_to_id[l]] for l in Y_train]
# sampler = torch.utils.data.sampler.WeightedRandomSampler(sampling_probs, len(Y_train), replacement=True)

In [15]:
labels_weights2

[0.7438616241144019,
 1.0,
 0.177263548574672,
 0.6777425292081322,
 0.6420761452814192]

In [16]:
labels_freqs

[0.04606261358647022,
 0.025487831311239433,
 0.8111366149981382,
 0.05548855381845061,
 0.06182438628570158]

In [17]:
labels

['LOC', 'MISC', 'O', 'ORG', 'PER']

In [18]:
t = time.time()
print(len(train_loader))
for batch_X, batch_Y in train_loader:
    print(batch_X.shape)
    print(batch_Y.shape)
    print(sum(batch_X[0]))
    print('Class distribution in this batch:', Counter(batch_Y.numpy()))
    break
print(f'time: {time.time() - t:.3}s')

2812
torch.Size([64, 1200])
torch.Size([64])
tensor(-0.2693, dtype=torch.float64)
Class distribution in this batch: Counter({2: 49, 3: 6, 4: 6, 0: 2, 1: 1})
time: 0.509s


In [19]:
train_set.get_raw_item(2)

('German',
 ['eu', 'rejects'],
 ['call', 'to'],
 ['<span>',
  '<human>',
  '<person>',
  '<person_with_nationality>',
  '<capitalized>'])

# Training

In [20]:
writer = SummaryWriter(log_dir=None,filename_suffix='secondattemptwithn2vembeddings')

In [21]:
def backprop(batch_X, batch_Y, model, optimizer, loss_fn):
    Y_hat = model(batch_X)
    loss = loss_fn(Y_hat, batch_Y)
    loss.backward()
    optimizer.step()
    
    return loss.item()

In [22]:
input_dim = 1200
class FeedForwardNetwork(nn.Module):
    def __init__(self, input_dim=input_dim, hidden_dim=512, output_dim=5, dropout_rate=0.2):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fch = nn.Linear(hidden_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

        # extra layers layers
        self.dropout = nn.Dropout(p=dropout_rate)
        # self.batchnorm1 = nn.BatchNorm1d(hidden_dim)
        # self.batchnorm2 = nn.BatchNorm1d(hidden_dim)

        
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fch(x)
        x = self.relu(x)
        x = self.dropout(x)
        #x = self.fch2(x)
        # x = self.dropout(x)
        # x = self.relu(x)
        logits = self.fc2(x)

        return logits

In [23]:
logs = {'loss/train': {}, 'dev': {}}

In [24]:
ffnet = FeedForwardNetwork().to('cuda')

In [25]:
log_interval = int(len(train_loader) / 4)
weights = torch.Tensor(labels_weights2).to('cuda')
weights

tensor([0.7439, 1.0000, 0.1773, 0.6777, 0.6421], device='cuda:0')

In [26]:
optimizer_params = {'lr': 1e-4, 
                    'momentum': 0.9, 
                    'weight_decay': 5e-4,
                   }

loss_fn = nn.CrossEntropyLoss(weight=weights)
optimizer = torch.optim.SGD(ffnet.parameters(), **optimizer_params)

In [27]:
%%time
max_epochs = 20

for epoch in range(len(logs['loss/train']), len(logs['loss/train']) + max_epochs):
    
    # Training
    ffnet.train()
    print('Epoch', epoch)
    logs['loss/train'][epoch] = []
    writer.add_scalar("Learning_rate", optimizer_params['lr'], epoch)

    for batch, (batch_X, batch_Y) in enumerate(tqdm(train_loader)):
        # tranfer to GPU
        batch_X, batch_Y = batch_X.float().to(device), batch_Y.to(device)
        optimizer.zero_grad()
        l = backprop(batch_X, batch_Y, ffnet, optimizer, loss_fn)
        logs['loss/train'][epoch].append(l)
        
        if batch % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch * len(batch_X), len(train_loader.dataset),
                100. * batch / len(train_loader), l))
    
    logs['loss/train'][epoch] = np.mean(logs['loss/train'][epoch])
    writer.add_scalar("Loss/train", logs['loss/train'][epoch], epoch)
    print(f'Average loss on epoch {epoch}: {logs["loss/train"][epoch]}')
    
    # Validation
    ffnet.eval()
    with torch.no_grad():
        preds = []
        gt = []
        for batch, (batch_X, batch_Y) in enumerate(tqdm(dev_loader)):
            # Transfer to GPU
            batch_X = batch_X.float().to(device)
            output = nn.Softmax(dim=1)(ffnet(batch_X))
            preds.append(output.cpu())
            gt.append(batch_Y)
    
        all_out = [np.argmax(l) for batch in preds for l in batch.numpy()]
        all_gt  = [l for batch in gt for l in batch.numpy()]
        
        print(classification_report(all_out, all_gt, digits=4))
        micro_F1 = metrics.f1_score(all_gt, all_out, average='micro')
        macro_F1 = metrics.f1_score(all_gt, all_out, average='macro')
        weighted_F1 = metrics.f1_score(all_gt, all_out, average='weighted')
        writer.add_scalar("micro_F1/dev", micro_F1, epoch)
        writer.add_scalar("macro_F1/dev", macro_F1, epoch)
        writer.add_scalar("weighted_F1/dev", weighted_F1, epoch)
        logs['dev'][epoch] = (micro_F1, weighted_F1, macro_F1, (all_gt, all_out))

Epoch 0


HBox(children=(FloatProgress(value=0.0, max=2812.0), HTML(value='')))


Average loss on epoch 0: 1.4497832432919173


HBox(children=(FloatProgress(value=0.0, max=707.0), HTML(value='')))




  _warn_prf(average, modifier, msg_start, len(result))


              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     0.0000    0.0000    0.0000         0
           2     1.0000    0.8107    0.8955     45224
           3     0.0000    0.0000    0.0000         0
           4     0.0000    0.0000    0.0000         0

    accuracy                         0.8107     45224
   macro avg     0.2000    0.1621    0.1791     45224
weighted avg     1.0000    0.8107    0.8955     45224

Epoch 1


HBox(children=(FloatProgress(value=0.0, max=2812.0), HTML(value='')))


Average loss on epoch 1: 1.3581586841329572


HBox(children=(FloatProgress(value=0.0, max=707.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     0.0000    0.0000    0.0000         0
           2     1.0000    0.8107    0.8955     45224
           3     0.0000    0.0000    0.0000         0
           4     0.0000    0.0000    0.0000         0

    accuracy                         0.8107     45224
   macro avg     0.2000    0.1621    0.1791     45224
weighted avg     1.0000    0.8107    0.8955     45224

Epoch 2


HBox(children=(FloatProgress(value=0.0, max=2812.0), HTML(value='')))


Average loss on epoch 2: 1.3519120421552047


HBox(children=(FloatProgress(value=0.0, max=707.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     0.0000    0.0000    0.0000         0
           2     1.0000    0.8107    0.8955     45224
           3     0.0000    0.0000    0.0000         0
           4     0.0000    0.0000    0.0000         0

    accuracy                         0.8107     45224
   macro avg     0.2000    0.1621    0.1791     45224
weighted avg     1.0000    0.8107    0.8955     45224

Epoch 3


HBox(children=(FloatProgress(value=0.0, max=2812.0), HTML(value='')))


Average loss on epoch 3: 1.3508673601732146


HBox(children=(FloatProgress(value=0.0, max=707.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     0.0000    0.0000    0.0000         0
           2     1.0000    0.8107    0.8955     45224
           3     0.0000    0.0000    0.0000         0
           4     0.0000    0.0000    0.0000         0

    accuracy                         0.8107     45224
   macro avg     0.2000    0.1621    0.1791     45224
weighted avg     1.0000    0.8107    0.8955     45224

Epoch 4


HBox(children=(FloatProgress(value=0.0, max=2812.0), HTML(value='')))


Average loss on epoch 4: 1.3517795719980006


HBox(children=(FloatProgress(value=0.0, max=707.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     0.0000    0.0000    0.0000         0
           2     1.0000    0.8107    0.8955     45224
           3     0.0000    0.0000    0.0000         0
           4     0.0000    0.0000    0.0000         0

    accuracy                         0.8107     45224
   macro avg     0.2000    0.1621    0.1791     45224
weighted avg     1.0000    0.8107    0.8955     45224

Epoch 5


HBox(children=(FloatProgress(value=0.0, max=2812.0), HTML(value='')))


Average loss on epoch 5: 1.3511112738681552


HBox(children=(FloatProgress(value=0.0, max=707.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     0.0000    0.0000    0.0000         0
           2     1.0000    0.8107    0.8955     45224
           3     0.0000    0.0000    0.0000         0
           4     0.0000    0.0000    0.0000         0

    accuracy                         0.8107     45224
   macro avg     0.2000    0.1621    0.1791     45224
weighted avg     1.0000    0.8107    0.8955     45224

Epoch 6


HBox(children=(FloatProgress(value=0.0, max=2812.0), HTML(value='')))


Average loss on epoch 6: 1.3503657317644842


HBox(children=(FloatProgress(value=0.0, max=707.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     0.0000    0.0000    0.0000         0
           2     1.0000    0.8107    0.8955     45224
           3     0.0000    0.0000    0.0000         0
           4     0.0000    0.0000    0.0000         0

    accuracy                         0.8107     45224
   macro avg     0.2000    0.1621    0.1791     45224
weighted avg     1.0000    0.8107    0.8955     45224

Epoch 7


HBox(children=(FloatProgress(value=0.0, max=2812.0), HTML(value='')))


Average loss on epoch 7: 1.3503016230264735


HBox(children=(FloatProgress(value=0.0, max=707.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     0.0000    0.0000    0.0000         0
           2     1.0000    0.8107    0.8955     45224
           3     0.0000    0.0000    0.0000         0
           4     0.0000    0.0000    0.0000         0

    accuracy                         0.8107     45224
   macro avg     0.2000    0.1621    0.1791     45224
weighted avg     1.0000    0.8107    0.8955     45224

Epoch 8


HBox(children=(FloatProgress(value=0.0, max=2812.0), HTML(value='')))


Average loss on epoch 8: 1.3507668593445339


HBox(children=(FloatProgress(value=0.0, max=707.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     0.0000    0.0000    0.0000         0
           2     1.0000    0.8107    0.8955     45224
           3     0.0000    0.0000    0.0000         0
           4     0.0000    0.0000    0.0000         0

    accuracy                         0.8107     45224
   macro avg     0.2000    0.1621    0.1791     45224
weighted avg     1.0000    0.8107    0.8955     45224

Epoch 9


HBox(children=(FloatProgress(value=0.0, max=2812.0), HTML(value='')))


Average loss on epoch 9: 1.3505865182001595


HBox(children=(FloatProgress(value=0.0, max=707.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     0.0000    0.0000    0.0000         0
           2     1.0000    0.8107    0.8955     45224
           3     0.0000    0.0000    0.0000         0
           4     0.0000    0.0000    0.0000         0

    accuracy                         0.8107     45224
   macro avg     0.2000    0.1621    0.1791     45224
weighted avg     1.0000    0.8107    0.8955     45224

Epoch 10


HBox(children=(FloatProgress(value=0.0, max=2812.0), HTML(value='')))


Average loss on epoch 10: 1.350439735933354


HBox(children=(FloatProgress(value=0.0, max=707.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     0.0000    0.0000    0.0000         0
           2     1.0000    0.8107    0.8955     45224
           3     0.0000    0.0000    0.0000         0
           4     0.0000    0.0000    0.0000         0

    accuracy                         0.8107     45224
   macro avg     0.2000    0.1621    0.1791     45224
weighted avg     1.0000    0.8107    0.8955     45224

Epoch 11


HBox(children=(FloatProgress(value=0.0, max=2812.0), HTML(value='')))


Average loss on epoch 11: 1.3503955516375994


HBox(children=(FloatProgress(value=0.0, max=707.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     0.0000    0.0000    0.0000         0
           2     1.0000    0.8107    0.8955     45224
           3     0.0000    0.0000    0.0000         0
           4     0.0000    0.0000    0.0000         0

    accuracy                         0.8107     45224
   macro avg     0.2000    0.1621    0.1791     45224
weighted avg     1.0000    0.8107    0.8955     45224

Epoch 12


HBox(children=(FloatProgress(value=0.0, max=2812.0), HTML(value='')))


Average loss on epoch 12: 1.3504040157515158


HBox(children=(FloatProgress(value=0.0, max=707.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     0.0000    0.0000    0.0000         0
           2     1.0000    0.8107    0.8955     45224
           3     0.0000    0.0000    0.0000         0
           4     0.0000    0.0000    0.0000         0

    accuracy                         0.8107     45224
   macro avg     0.2000    0.1621    0.1791     45224
weighted avg     1.0000    0.8107    0.8955     45224

Epoch 13


HBox(children=(FloatProgress(value=0.0, max=2812.0), HTML(value='')))


Average loss on epoch 13: 1.3503890950047275


HBox(children=(FloatProgress(value=0.0, max=707.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     0.0000    0.0000    0.0000         0
           2     1.0000    0.8107    0.8955     45224
           3     0.0000    0.0000    0.0000         0
           4     0.0000    0.0000    0.0000         0

    accuracy                         0.8107     45224
   macro avg     0.2000    0.1621    0.1791     45224
weighted avg     1.0000    0.8107    0.8955     45224

Epoch 14


HBox(children=(FloatProgress(value=0.0, max=2812.0), HTML(value='')))


Average loss on epoch 14: 1.3506880231720964


HBox(children=(FloatProgress(value=0.0, max=707.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     0.0000    0.0000    0.0000         0
           2     1.0000    0.8107    0.8955     45224
           3     0.0000    0.0000    0.0000         0
           4     0.0000    0.0000    0.0000         0

    accuracy                         0.8107     45224
   macro avg     0.2000    0.1621    0.1791     45224
weighted avg     1.0000    0.8107    0.8955     45224

Epoch 15


HBox(children=(FloatProgress(value=0.0, max=2812.0), HTML(value='')))


Average loss on epoch 15: 1.3503088406571961


HBox(children=(FloatProgress(value=0.0, max=707.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     0.0000    0.0000    0.0000         0
           2     1.0000    0.8107    0.8955     45224
           3     0.0000    0.0000    0.0000         0
           4     0.0000    0.0000    0.0000         0

    accuracy                         0.8107     45224
   macro avg     0.2000    0.1621    0.1791     45224
weighted avg     1.0000    0.8107    0.8955     45224

Epoch 16


HBox(children=(FloatProgress(value=0.0, max=2812.0), HTML(value='')))


Average loss on epoch 16: 1.349661894873908


HBox(children=(FloatProgress(value=0.0, max=707.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     0.0000    0.0000    0.0000         0
           2     1.0000    0.8107    0.8955     45224
           3     0.0000    0.0000    0.0000         0
           4     0.0000    0.0000    0.0000         0

    accuracy                         0.8107     45224
   macro avg     0.2000    0.1621    0.1791     45224
weighted avg     1.0000    0.8107    0.8955     45224

Epoch 17


HBox(children=(FloatProgress(value=0.0, max=2812.0), HTML(value='')))


Average loss on epoch 17: 1.3499271386166556


HBox(children=(FloatProgress(value=0.0, max=707.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     0.0000    0.0000    0.0000         0
           2     1.0000    0.8107    0.8955     45224
           3     0.0000    0.0000    0.0000         0
           4     0.0000    0.0000    0.0000         0

    accuracy                         0.8107     45224
   macro avg     0.2000    0.1621    0.1791     45224
weighted avg     1.0000    0.8107    0.8955     45224

Epoch 18


HBox(children=(FloatProgress(value=0.0, max=2812.0), HTML(value='')))


Average loss on epoch 18: 1.3492229531491295


HBox(children=(FloatProgress(value=0.0, max=707.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     0.0000    0.0000    0.0000         0
           2     1.0000    0.8107    0.8955     45224
           3     0.0000    0.0000    0.0000         0
           4     0.0000    0.0000    0.0000         0

    accuracy                         0.8107     45224
   macro avg     0.2000    0.1621    0.1791     45224
weighted avg     1.0000    0.8107    0.8955     45224

Epoch 19


HBox(children=(FloatProgress(value=0.0, max=2812.0), HTML(value='')))


Average loss on epoch 19: 1.349678888705004


HBox(children=(FloatProgress(value=0.0, max=707.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     0.0000    0.0000    0.0000         0
           2     1.0000    0.8107    0.8955     45224
           3     0.0000    0.0000    0.0000         0
           4     0.0000    0.0000    0.0000         0

    accuracy                         0.8107     45224
   macro avg     0.2000    0.1621    0.1791     45224
weighted avg     1.0000    0.8107    0.8955     45224

CPU times: user 49min 9s, sys: 53.7 s, total: 50min 2s
Wall time: 4min 50s


In [28]:
ffnet.eval()
with torch.no_grad():
    preds = []
    gt = []
    for batch, (batch_X, batch_Y) in enumerate(tqdm(test_loader)):
        # Transfer to GPU
        batch_X = batch_X.float().to(device)
        output = nn.Softmax(dim=1)(ffnet(batch_X))
        preds.append(output.cpu())
        gt.append(batch_Y)

    all_out = [np.argmax(l) for batch in preds for l in batch.numpy()]
    all_gt  = [l for batch in gt for l in batch.numpy()]

    print(classification_report(all_out, all_gt, digits=4))

HBox(children=(FloatProgress(value=0.0, max=643.0), HTML(value='')))


              precision    recall  f1-score   support

           0     0.0000    0.0000    0.0000         0
           1     0.0000    0.0000    0.0000         0
           2     1.0000    0.8027    0.8905     41090
           3     0.0000    0.0000    0.0000         0
           4     0.0000    0.0000    0.0000         0

    accuracy                         0.8027     41090
   macro avg     0.2000    0.1605    0.1781     41090
weighted avg     1.0000    0.8027    0.8905     41090

