In [1]:
import os
import json
from tqdm import tqdm
from collections import Counter
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.utils.tensorboard as tensorboard
from seqeval.metrics import classification_report

from utils import set_random_seed, Config, load_ner_config
from dataset import io2df, io2bio, padding, NERDataset
from model import BiLSTM_CRF

%load_ext autoreload
%autoreload 2

set_random_seed(seed=0)

In [2]:
config = Config('config.yaml')
config.__dict__

{'TR_PATH': 'data/supervised/train.txt',
 'VA_PATH': 'data/supervised/dev.txt',
 'TE_PATH': 'data/supervised/test.txt',
 'SEQ_LEN': 64,
 'BATCH_SIZE': 128,
 'LR': 0.001,
 'REG_LAMBDA': 0.0001,
 'MAX_GRAD_NORM': 100,
 'NUM_EPOCHS': 10,
 'EMBED_SIZE': 128,
 'HIDDEN_SIZE': 128,
 'DROPOUT': 0.5}

In [3]:
tr_titles = io2df(config.TR_PATH)
va_titles = io2df(config.VA_PATH)

100%|██████████| 3359329/3359329 [00:04<00:00, 724331.02it/s]
100%|██████████| 482037/482037 [00:00<00:00, 486813.09it/s]


In [4]:
tr_titles

Unnamed: 0,id,tokens,tags_fine_grained,tags_coarse_grained
0,0,"[Paul, International, airport, .]","[O, O, O, O]","[O, O, O, O]"
1,1,"[It, starred, Hicks, 's, wife, ,, Ellaline, Te...","[O, O, person-artist/author, O, O, O, person-a...","[O, O, person, O, O, O, person, person, O, per..."
2,2,"[``, Time, ``, magazine, said, the, film, was,...","[O, art-writtenart, O, O, O, O, O, O, O, O, O,...","[O, art, O, O, O, O, O, O, O, O, O, O, O, O, O..."
3,3,"[Pakistani, scientists, and, engineers, ', wor...","[O, O, O, O, O, O, O, organization-other, O, O...","[O, O, O, O, O, O, O, organization, O, O, O, O..."
4,4,"[In, February, 2008, ,, Church, 's, Chicken, e...","[O, O, O, O, organization-company, organizatio...","[O, O, O, O, organization, organization, organ..."
...,...,...,...,...
131762,131762,"[In, response, ,, the, states, who, had, ratif...","[O, O, O, O, O, O, O, O, O, other-law, O, O, o...","[O, O, O, O, O, O, O, O, O, other, O, O, organ..."
131763,131763,"[They, have, long, been, used, as, containers,...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ..."
131764,131764,"[In, 1911, he, came, into, possession, of, the...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O]","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O]"
131765,131765,"[The, Lutici, tribes, in, 983, formed, the, Li...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ..."


# Tag -> tagID

In [5]:
TAG2IDX = load_ner_config('ner_tags/ner_fine_grained.json')
IDX2TAG = {i: t for t, i in TAG2IDX.items()}

In [6]:
TAG2IDX

{'PAD': 0,
 'O': 1,
 'art-broadcastprogram': 2,
 'art-film': 3,
 'art-music': 4,
 'art-other': 5,
 'art-painting': 6,
 'art-writtenart': 7,
 'building-airport': 8,
 'building-hospital': 9,
 'building-hotel': 10,
 'building-library': 11,
 'building-other': 12,
 'building-restaurant': 13,
 'building-sportsfacility': 14,
 'building-theater': 15,
 'event-attack/battle/war/militaryconflict': 16,
 'event-disaster': 17,
 'event-election': 18,
 'event-other': 19,
 'event-protest': 20,
 'event-sportsevent': 21,
 'location-GPE': 22,
 'location-bodiesofwater': 23,
 'location-island': 24,
 'location-mountain': 25,
 'location-other': 26,
 'location-park': 27,
 'location-road/railway/highway/transit': 28,
 'organization-company': 29,
 'organization-education': 30,
 'organization-government/governmentagency': 31,
 'organization-media/newspaper': 32,
 'organization-other': 33,
 'organization-politicalparty': 34,
 'organization-religion': 35,
 'organization-showorganization': 36,
 'organization-sportsl

In [7]:
tr_titles['tags_ids'] = tr_titles['tags_fine_grained'].transform(lambda x: [TAG2IDX[tag] for tag in x])
va_titles['tags_ids'] = va_titles['tags_fine_grained'].transform(lambda x: [TAG2IDX[tag] for tag in x])

# Token -> tokenID

In [8]:
def calc_token_cntr(filepath):

    token_cntr = Counter()
    num_lines = sum(1 for _ in open(filepath, encoding="utf-8"))

    with open(filepath, "r", encoding="utf-8") as f:
        for line in tqdm(f, total=num_lines):
            line = line.strip().split()
            if line:
                token, fine_tag = line
                token_cntr[token] += 1
    
    return token_cntr

token_cntr = calc_token_cntr(filepath=config.TR_PATH)

top_tokens = [token for token, cnt in token_cntr.most_common() if cnt >= 5]
TOKEN2IDX = {token: i + 2 for i, token in enumerate(top_tokens)}
TOKEN2IDX['PAD'] = 0
TOKEN2IDX['UKN'] = 1

with open('tokenizers/token2idx.json', 'w') as f:
    json.dump(TOKEN2IDX, f, indent=4)

len(TOKEN2IDX)

100%|██████████| 3359329/3359329 [00:03<00:00, 902990.85it/s] 


34705

In [9]:
tr_titles['tokens_ids'] = tr_titles['tokens'].transform(lambda x: [TOKEN2IDX[token] if token in TOKEN2IDX else TOKEN2IDX['UKN'] for token in x])
va_titles['tokens_ids'] = va_titles['tokens'].transform(lambda x: [TOKEN2IDX[token] if token in TOKEN2IDX else TOKEN2IDX['UKN'] for token in x])

# Padding

In [10]:
tr_titles['tokens_ids'] = tr_titles['tokens_ids'].transform(padding, max_len=config.SEQ_LEN)
tr_titles['tags_ids'] = tr_titles['tags_ids'].transform(padding, max_len=config.SEQ_LEN)

va_titles['tokens_ids'] = va_titles['tokens_ids'].transform(padding, max_len=config.SEQ_LEN)
va_titles['tags_ids'] = va_titles['tags_ids'].transform(padding, max_len=config.SEQ_LEN)

In [11]:
tr_titles

Unnamed: 0,id,tokens,tags_fine_grained,tags_coarse_grained,tags_ids,tokens_ids
0,0,"[Paul, International, airport, .]","[O, O, O, O]","[O, O, O, O]","[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[586, 170, 711, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,1,"[It, starred, Hicks, 's, wife, ,, Ellaline, Te...","[O, O, person-artist/author, O, O, O, person-a...","[O, O, person, O, O, O, person, person, O, per...","[1, 1, 52, 1, 1, 1, 51, 51, 1, 51, 51, 1, 0, 0...","[35, 1601, 15202, 22, 659, 3, 1, 1, 6, 5586, 8..."
2,2,"[``, Time, ``, magazine, said, the, film, was,...","[O, art-writtenart, O, O, O, O, O, O, O, O, O,...","[O, art, O, O, O, O, O, O, O, O, O, O, O, O, O...","[1, 7, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[10, 2065, 10, 1045, 381, 2, 76, 11, 10, 9, 1,..."
3,3,"[Pakistani, scientists, and, engineers, ', wor...","[O, O, O, O, O, O, O, organization-other, O, O...","[O, O, O, O, O, O, O, organization, O, O, O, O...","[1, 1, 1, 1, 1, 1, 1, 33, 1, 1, 1, 1, 1, 1, 1,...","[5323, 5587, 6, 6537, 59, 603, 21, 1, 78, 1076..."
4,4,"[In, February, 2008, ,, Church, 's, Chicken, e...","[O, O, O, O, organization-company, organizatio...","[O, O, O, O, organization, organization, organ...","[1, 1, 1, 1, 29, 29, 29, 1, 1, 1, 1, 1, 1, 1, ...","[25, 187, 139, 3, 340, 22, 6282, 922, 2, 541, ..."
...,...,...,...,...,...,...
131762,131762,"[In, response, ,, the, states, who, had, ratif...","[O, O, O, O, O, O, O, O, O, other-law, O, O, o...","[O, O, O, O, O, O, O, O, O, other, O, O, organ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 48, 1, 1, 33, 33, ...","[25, 1235, 3, 2, 1110, 49, 40, 8312, 2, 14192,..."
131763,131763,"[They, have, long, been, used, as, containers,...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[198, 53, 470, 47, 68, 17, 19121, 16, 8367, 38..."
131764,131764,"[In, 1911, he, came, into, possession, of, the...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O]","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[25, 2509, 29, 383, 57, 5299, 5, 2, 1, 28249, ..."
131765,131765,"[The, Lutici, tribes, in, 983, formed, the, Li...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[12, 1, 3938, 7, 1, 391, 2, 1, 17656, 3, 5239,..."


# Data loader

In [12]:
tr_dataset = NERDataset(tr_titles)
va_dataset = NERDataset(va_titles)

tr_dataloader = DataLoader(dataset=tr_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=4)
va_dataloader = DataLoader(dataset=va_dataset, batch_size=config.BATCH_SIZE, shuffle=True, num_workers=4)

In [13]:
tr_dataset.__getitem__(0)

{'tokens_ids': tensor([586, 170, 711,   4,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0]),
 'tags_ids': tensor([1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}

# Model

In [14]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model = BiLSTM_CRF(
    embed_size     = config.EMBED_SIZE,
    hidden_size    = config.HIDDEN_SIZE, 
    dropout        = config.DROPOUT,
    token_voc_size = len(TOKEN2IDX), 
    tag_voc_size   = len(TAG2IDX),
).to(device)
print(model)

token_embedding.weight         initialized w with Xavier            parameters #: 4442240
lstm.weight_ih_l0              initialized w with Xavier            parameters #: 32768
lstm.weight_hh_l0              initialized w with Xavier            parameters #: 16384
lstm.weight_ih_l0_reverse      initialized w with Xavier            parameters #: 32768
lstm.weight_hh_l0_reverse      initialized w with Xavier            parameters #: 16384
fc.weight                      initialized w with Xavier            parameters #: 8704
fc.bias                        initialized b with zero              parameters #: 68
crf.start_transitions          initialized b with zero              parameters #: 68
crf.end_transitions            initialized b with zero              parameters #: 68
crf.transitions                initialized w with Xavier            parameters #: 4624
BiLSTM_CRF(
  (token_embedding): Embedding(34705, 128)
  (lstm): LSTM(128, 64, bias=False, batch_first=True, bidirectional=True)


# Train

In [15]:
weights_folder = 'weights'
if not os.path.exists(weights_folder):
    os.makedirs(weights_folder)
    
runs_folder = '.runs'
if not os.path.exists(runs_folder):
    os.makedirs(runs_folder)

In [38]:
# Make optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=config.LR)

# Make tensorboard writer
writer = tensorboard.SummaryWriter(log_dir='./runs')

for epoch in range(config.NUM_EPOCHS):

    # TRAINING PHASE
    
    tr_losses = []
    
    model.train()
    for batch_num, tr_batch in tqdm(enumerate(tr_dataloader), total=tr_dataloader.__len__()):
        optimizer.zero_grad()
        
        tr_xs = tr_batch['tokens_ids'].to(device)
        tr_ys = tr_batch['tags_ids'].to(device)
        
        # Calculate loss
        tr_emission_scores = model(x=tr_xs).to(device) # size: [batch=128, seq_len=100, 10]
        tr_loss = model.loss_fn(emission_scores=tr_emission_scores, tags=tr_ys, mask=(tr_ys > 0).bool())
        tr_losses.append(tr_loss.item())
        
        # Calculate total loss
        regularization_loss = model.regularization_loss_fn(lam=config.REG_LAMBDA)
        total_loss = tr_loss + regularization_loss
        if batch_num % 100 == 0:
            print(f'crf loss: {tr_loss:.2f} | regularization_loss {regularization_loss:.2f} | total_loss {total_loss:.2f}')
        
        # Backward pass: compute gradient of the loss w.r.t. all learnable parameters
        total_loss.backward()
        
        # Clip computed gradients
#         if batch_num % 100 == 0:
#             print('before')
#             print(torch.max(torch.cat([p.grad.view(-1) for p in model.parameters()])))
#             print(torch.min(torch.cat([p.grad.view(-1) for p in model.parameters()])))
#             print(torch.norm(torch.cat([p.grad.view(-1) for p in model.parameters()])))
        grad_norm = torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=config.MAX_GRAD_NORM)
#         if batch_num % 100 == 0:
#             print('after')
#             print(torch.max(torch.cat([p.grad.view(-1) for p in model.parameters()])))
#             print(torch.min(torch.cat([p.grad.view(-1) for p in model.parameters()])))
#             print(torch.norm(torch.cat([p.grad.view(-1) for p in model.parameters()])))
        
        # Optimize: update the weights using Adam optimizer
        optimizer.step()
        
    # END TRAINING PHASE AND UPDATE LOG

    with torch.no_grad():

        print(f"Epoch: {epoch:02d} | current NLL:{tr_loss.item():.2f} | avg NLL over batch:{np.mean(tr_losses):.2f}")
        writer.add_scalar('tr/'+'loss', np.mean(tr_losses), global_step=epoch)

        for name, param in model.named_parameters():
            writer.add_histogram('tr/' + name + '_weight', param.data, global_step=epoch)
            writer.add_histogram('tr/' + name + '_grad', param.grad, global_step=epoch)
        
    # VALIDATION PHASE
    
    va_losses = []
    
    batch_preds = []
    batch_trues = []

    model.eval()
    with torch.no_grad():
        for va_batch in tqdm(va_dataloader, total=va_dataloader.__len__()):
            va_xs = va_batch['tokens_ids'].to(device) # size: [batch=128, seq_len=100]
            va_ys = va_batch['tags_ids'].to(device) # size: [batch=128, seq_len=100]

            # Forward pass: compute predicted output by passing input to the model
            va_emission_scores = model(x=va_xs).to(device) # size: [batch=128, seq_len=100]
            va_preds = torch.tensor(model.decode(va_emission_scores)).to(device)
            va_loss = model.loss_fn(emission_scores=va_emission_scores, tags=va_ys, mask=(va_ys > 0).bool())
            va_losses.append(va_loss.item())
            
            mask = (va_ys > 0).bool()

            for row_id, true in enumerate(va_ys):
                # do not count padding
                true_tags = true[mask[row_id]]
                # idx2tag
                true_tags = [IDX2TAG[idx] for idx in true_tags.tolist()]
                # convert to the format expected by seqeval
                true_tags = io2bio(true_tags)
                batch_trues.append(true_tags)

            for row_id, pred in enumerate(va_preds):
                # do not count padding
                pred_tags = pred[mask[row_id]]
                # idx2tag
                pred_tags = [IDX2TAG[idx] for idx in pred_tags.tolist()]
                # convert to the format expected by seqeval
                pred_tags = io2bio(pred_tags)
                batch_preds.append(pred_tags)
        
        for i in range(5):
            print('pred:', batch_preds[i])
            print('true:', batch_trues[i])
            print()

        print("va loss", np.mean(va_losses))
        writer.add_scalar(f'va/'+'loss', np.mean(va_losses), global_step=epoch)

        report = classification_report(y_true=batch_trues, y_pred=batch_preds, zero_division=0)
        print(report)

    torch.save(model.state_dict(), f"weights/model_epoch_{epoch:02d}.pt")

writer.close()

  0%|          | 0/1030 [00:00<?, ?it/s]

crf loss: 13973.81 | regularization_loss 1.89 | total_loss 13975.70


  1%|▏         | 13/1030 [00:07<09:20,  1.81it/s]


In [20]:
# version 45
# !tensorboard --logdir=artefacts/fine_grained/runs