In [None]:
import os
from tqdm import tqdm
from collections import Counter
import numpy as np

import torch
from torch.utils.data import DataLoader
from seqeval.metrics import classification_report

from utils import set_random_seed, Config, IO2BIO
from dataset import IO2df, MyDataset
from model import BiLSTM_CRF

%load_ext autoreload
%autoreload 2

set_random_seed(seed=0)

In [None]:
import torch.utils.tensorboard as tensorboard

In [None]:
config = Config('config.yaml')

tr_titles = IO2df(config.TR_PATH)
va_titles = IO2df(config.VA_PATH)

In [None]:
tr_titles

# Tag -> tagID

In [None]:
TAG2IDX = {
    "O": 1,
    "art": 2,
    "building": 3,
    "event": 4,
    "location": 5,
    "organization": 6,
    "other": 7,
    "person": 8,
    "product": 9,
}

IDX2TAG = {i: t for t, i in TAG2IDX.items()}
IDX2TAG[0] = '[PAD]' # in case if model predicts padding id (0)

In [None]:
tr_titles['tags_ids'] = tr_titles['tags'].transform(lambda x: [TAG2IDX[tag] for tag in x])
va_titles['tags_ids'] = va_titles['tags'].transform(lambda x: [TAG2IDX[tag] for tag in x])

# Token -> tokenID

In [None]:
def calc_token_cntr(filepath):

    token_cntr = Counter()
    num_lines = sum(1 for _ in open(filepath, encoding="utf-8"))

    with open(filepath, "r", encoding="utf-8") as f:
        for line in tqdm(f, total=num_lines):
            line = line.strip().split()
            if line:
                token, fine_tag = line
                token_cntr[token] += 1
    
    return token_cntr

tr_token_cntr = calc_token_cntr(filepath=config.TR_PATH)
va_token_cntr = calc_token_cntr(filepath=config.VA_PATH)
token_cntr = tr_token_cntr + va_token_cntr

MC = 50_000
top_tokens = [token for token, _ in token_cntr.most_common(MC)]
TOKEN2IDX = {token: i + 1 for i, token in enumerate(top_tokens)}
for token in set(token_cntr) - set(top_tokens):
    TOKEN2IDX[token] = len(top_tokens) + 1


In [None]:
tr_titles['tokens_ids'] = tr_titles['tokens'].transform(lambda x: [TOKEN2IDX[token] for token in x])
va_titles['tokens_ids'] = va_titles['tokens'].transform(lambda x: [TOKEN2IDX[token] for token in x])

# Padding

In [None]:
def padding(ids, max_len = 100):
    if len(ids) >= max_len:
        return ids[:max_len]
    ids.extend([0]*(max_len-len(ids)))
    return ids

tr_titles['tokens_ids'] = tr_titles['tokens_ids'].transform(padding, max_len=100)
tr_titles['tags_ids'] = tr_titles['tags_ids'].transform(padding, max_len=100)

va_titles['tokens_ids'] = va_titles['tokens_ids'].transform(padding, max_len=100)
va_titles['tags_ids'] = va_titles['tags_ids'].transform(padding, max_len=100)

In [None]:
tr_titles

# Data loader

In [None]:
tr_dataset = MyDataset(tr_titles)
va_dataset = MyDataset(va_titles)

tr_dataloader = DataLoader(dataset=tr_dataset, batch_size=128, shuffle=True, num_workers=4)
va_dataloader = DataLoader(dataset=va_dataset, batch_size=128, shuffle=True, num_workers=4)

tr_dataset.__getitem__(0)

# Model

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model = BiLSTM_CRF(
    embed_size     = 100,
    hidden_size    = 256, 
    dropout        = 0.5,
    token_voc_size = len(TOKEN2IDX) + 1, 
    tag_voc_size   = len(TAG2IDX) + 1
).to(device)

In [None]:
print(model)

# Train

In [None]:
weights_folder = 'weights'
if not os.path.exists(weights_folder):
    os.makedirs(weights_folder)
    
runs_folder = '.runs'
if not os.path.exists(runs_folder):
    os.makedirs(runs_folder)

In [None]:
# Make optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Make tensorboard writer
writer = tensorboard.SummaryWriter(log_dir='./runs')

global_tr_losses = []
global_va_losses = []

for epoch in range(5):

    # TRAINING PHASE
    
    tr_losses = []
    
    model.train()
    for tr_batch in tqdm(tr_dataloader, total=tr_dataloader.__len__()):
        optimizer.zero_grad()
        
        tr_xs = tr_batch['tokens_ids'].to(device)
        tr_ys = tr_batch['tags_ids'].to(device)
        
        # Calculate loss
        tr_emission_scores = model(tr_xs).to(device) # size: [batch=128, seq_len=100, 10]
        tr_loss = model.loss_fn(emission_scores=tr_emission_scores, tags=tr_ys, mask=(tr_ys > 0).bool())
        tr_losses.append(tr_loss.item())
        
        # Calculate total loss
        total_loss = tr_loss + model.regularization_loss_fn(lam=1e-3, alpha=0.5)
        
        # Backward pass: compute gradient of the loss w.r.t. all learnable parameters
        total_loss.backward()
        
        # Clip computed gradients
        grad_norm = torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=1e2)
        
        # Optimize: update the weights using Adam
        optimizer.step()
        
    # END TRAINING PHASE AND UPDATE LOG

    with torch.no_grad():
        print(f"Epoch: {epoch:02d} NLL:{tr_loss.item()}")
        writer.add_scalar('tr/'+'loss', np.mean(tr_losses), global_step=epoch)
        writer.add_scalar('tr/'+'total_grad_norm', grad_norm, global_step=epoch)
        for name, param in model.named_parameters():
            writer.add_histogram('tr/'+name, param.data, global_step=epoch)
        print("tr loss", np.mean(tr_losses))
        
    # VALIDATION PHASE
    
    va_losses = []
    
    batch_preds = []
    batch_trues = []

    model.eval()
    with torch.no_grad():
        for va_batch in tqdm(va_dataloader, total=va_dataloader.__len__()):
            va_xs = va_batch['tokens_ids'].to(device) # size: [batch=128, seq_len=100]
            va_ys = va_batch['tags_ids'].to(device) # size: [batch=128, seq_len=100]

            # Forward pass: compute predicted output by passing input to the model
            va_emission_scores = model(va_xs).to(device) # size: [batch=128, seq_len=100]
            va_preds = model.decode(va_emission_scores)
            va_loss = model.loss_fn(emission_scores=va_emission_scores, tags=va_ys, mask=(va_ys > 0).bool())
            va_losses.append(va_loss.item())
            
            # find max length without PADDING (padding value is 0) for each row in a batch
            title_length = torch.sum(va_ys > 0, dim=1) # size: [batch=128]

            for row_id, true in enumerate(va_ys.tolist()):
                # do not count padding
                true_tags = [IDX2TAG[idx] for idx in true[:title_length[row_id]]]
                # convert to the format expected by seqeval
                true_tags = IO2BIO(true_tags)
                batch_trues.append(true_tags)

            for row_id, pred in enumerate(va_preds):
                # do not count padding
                pred_tags = [IDX2TAG[idx] for idx in pred[:title_length[row_id]]]
                # convert to the format expected by seqeval
                pred_tags = IO2BIO(pred_tags)
                batch_preds.append(pred_tags)
            

        for i in range(5):
            print('pred:', batch_preds[i])
            print('true:', batch_trues[i])
            print()

        print("va loss", np.mean(va_losses))
        writer.add_scalar('va/'+'loss', np.mean(va_losses), global_step=epoch)

        report = classification_report(y_true=batch_trues, y_pred=batch_preds, zero_division=0)
        print(report)


    global_tr_losses.append(np.mean(tr_losses))
    global_va_losses.append(np.mean(va_losses))
    torch.save(model.state_dict(), f"weights/model_epoch_{epoch:02d}.pt")

writer.close()

In [None]:
# Epoch: 04 NLL:171.06204223632812
# tr loss 299.8672805786133
# 100%|██████████| 148/148 [00:42<00:00,  3.49it/s]
# pred: ['B-other', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-other', 'O', 'O', 'O']
# true: ['O', 'O', 'O', 'O', 'B-other', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-other', 'O', 'O', 'O']

# pred: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-location', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-location', 'I-location', 'O']
# true: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-location', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-location', 'I-location', 'O']

# pred: ['O', 'O', 'B-person', 'O', 'O', 'B-person', 'O', 'O', 'O', 'O', 'O', 'B-person', 'I-person', 'O', 'O', 'O', 'O', 'O']
# true: ['O', 'O', 'B-art', 'I-art', 'O', 'B-person', 'O', 'O', 'O', 'O', 'O', 'B-person', 'I-person', 'O', 'O', 'O', 'O', 'O']

# pred: ['O', 'O', 'B-person', 'I-person', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-art', 'I-art', 'I-art', 'O', 'O', 'O', 'O', 'O', 'O', 'B-person', 'I-person', 'O', 'O', 'B-art', 'I-art', 'O', 'O', 'O', 'B-art', 'I-art', 'O', 'O', 'O', 'O', 'B-person', 'I-person', 'I-person', 'O', 'O', 'O', 'B-person', 'I-person', 'O', 'O', 'B-person', 'O', 'B-person', 'I-person', 'I-person', 'O', 'B-person', 'O', 'O', 'O', 'O', 'O', 'O', 'B-person', 'I-person', 'I-person', 'O', 'O', 'B-art', 'I-art', 'I-art', 'I-art', 'I-art', 'I-art', 'I-art', 'I-art', 'I-art', 'I-art', 'O', 'B-art', 'I-art', 'I-art', 'I-art', 'O', 'O', 'O', 'O', 'O', 'O']
# true: ['O', 'O', 'B-person', 'I-person', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-art', 'I-art', 'I-art', 'O', 'O', 'O', 'O', 'O', 'O', 'B-person', 'I-person', 'O', 'O', 'B-art', 'I-art', 'O', 'O', 'O', 'B-art', 'I-art', 'O', 'O', 'O', 'O', 'O', 'B-person', 'I-person', 'O', 'O', 'O', 'B-person', 'I-person', 'O', 'O', 'B-art', 'O', 'B-art', 'I-art', 'I-art', 'O', 'B-art', 'O', 'O', 'O', 'O', 'O', 'O', 'B-person', 'I-person', 'I-person', 'O', 'O', 'B-art', 'I-art', 'I-art', 'I-art', 'I-art', 'O', 'B-art', 'I-art', 'I-art', 'I-art', 'O', 'B-art', 'I-art', 'I-art', 'I-art', 'O', 'O', 'O', 'O', 'O', 'O']

# pred: ['O', 'B-other', 'I-other', 'O', 'B-other', 'I-other', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-other', 'O', 'O', 'O', 'O']
# true: ['O', 'B-other', 'O', 'O', 'B-other', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-other', 'O', 'O', 'O', 'O']

# va loss 487.23016816216546
#               precision    recall  f1-score   support

#          art       0.69      0.67      0.68      2063
#     building       0.64      0.59      0.61      2484
#        event       0.62      0.57      0.59      2034
#     location       0.78      0.79      0.78     13649
# organization       0.64      0.63      0.64      9585
#        other       0.63      0.54      0.58      4958
#       person       0.80      0.85      0.82     10954
#      product       0.65      0.44      0.52      2955

#    micro avg       0.72      0.70      0.71     48682
#    macro avg       0.68      0.63      0.65     48682
# weighted avg       0.72      0.70      0.71     48682