In [4]:
import os
from tqdm import tqdm
from collections import Counter

import torch
from torch.utils.data import DataLoader
from seqeval.metrics import classification_report

from utils import set_random_seed, Config, IO2BIO
from dataset import IO2df, MyDataset
from model import BiLSTM_CRF

%load_ext autoreload
%autoreload 2

set_random_seed(seed=0)

In [21]:
config = Config('config.yaml')

tr_titles = IO2df(config.TR_PATH)
va_titles = IO2df(config.VA_PATH)

100%|██████████| 3359329/3359329 [00:03<00:00, 968164.94it/s] 
100%|██████████| 482037/482037 [00:00<00:00, 1087393.18it/s]


In [22]:
tr_titles

Unnamed: 0,id,tokens,tags
0,0,"[Paul, International, airport, .]","[O, O, O, O]"
1,1,"[It, starred, Hicks, 's, wife, ,, Ellaline, Te...","[O, O, person, O, O, O, person, person, O, per..."
2,2,"[``, Time, ``, magazine, said, the, film, was,...","[O, art, O, O, O, O, O, O, O, O, O, O, O, O, O..."
3,3,"[Pakistani, scientists, and, engineers, ', wor...","[O, O, O, O, O, O, O, organization, O, O, O, O..."
4,4,"[In, February, 2008, ,, Church, 's, Chicken, e...","[O, O, O, O, organization, organization, organ..."
...,...,...,...
131762,131762,"[In, response, ,, the, states, who, had, ratif...","[O, O, O, O, O, O, O, O, O, other, O, O, organ..."
131763,131763,"[They, have, long, been, used, as, containers,...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ..."
131764,131764,"[In, 1911, he, came, into, possession, of, the...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O]"
131765,131765,"[The, Lutici, tribes, in, 983, formed, the, Li...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ..."


# Tag -> tagID

In [10]:
TAG2IDX = {
    "O": 1,
    "art": 2,
    "building": 3,
    "event": 4,
    "location": 5,
    "organization": 6,
    "other": 7,
    "person": 8,
    "product": 9,
}

IDX2TAG = {i: t for t, i in TAG2IDX.items()}
IDX2TAG[0] = '[PAD]' # in case if model predicts padding id (0)

In [11]:
tr_titles['tags_ids'] = tr_titles['tags'].transform(lambda x: [TAG2IDX[tag] for tag in x])
va_titles['tags_ids'] = va_titles['tags'].transform(lambda x: [TAG2IDX[tag] for tag in x])

# Token -> tokenID

In [12]:
def calc_token_cntr(filepath):

    token_cntr = Counter()
    num_lines = sum(1 for _ in open(filepath, encoding="utf-8"))

    with open(filepath, "r", encoding="utf-8") as f:
        for line in tqdm(f, total=num_lines):
            line = line.strip().split()
            if line:
                token, fine_tag = line
                token_cntr[token] += 1
    
    return token_cntr

tr_token_cntr = calc_token_cntr(filepath=config.TR_PATH)
va_token_cntr = calc_token_cntr(filepath=config.VA_PATH)
token_cntr = tr_token_cntr + va_token_cntr

MC = 50_000
top_tokens = [token for token, _ in token_cntr.most_common(MC)]
TOKEN2IDX = {token: i + 1 for i, token in enumerate(top_tokens)}
for token in set(token_cntr) - set(top_tokens):
    TOKEN2IDX[token] = len(top_tokens) + 1


100%|██████████| 3359329/3359329 [00:02<00:00, 1397117.61it/s]
100%|██████████| 482037/482037 [00:00<00:00, 1389592.41it/s]


In [13]:
tr_titles['tokens_ids'] = tr_titles['tokens'].transform(lambda x: [TOKEN2IDX[token] for token in x])
va_titles['tokens_ids'] = va_titles['tokens'].transform(lambda x: [TOKEN2IDX[token] for token in x])

# Padding

In [14]:
def padding(ids, max_len = 100):
    if len(ids) >= max_len:
        return ids[:max_len]
    ids.extend([0]*(max_len-len(ids)))
    return ids

tr_titles['tokens_ids'] = tr_titles['tokens_ids'].transform(padding, max_len=100)
tr_titles['tags_ids'] = tr_titles['tags_ids'].transform(padding, max_len=100)

va_titles['tokens_ids'] = va_titles['tokens_ids'].transform(padding, max_len=100)
va_titles['tags_ids'] = va_titles['tags_ids'].transform(padding, max_len=100)

In [15]:
tr_titles

Unnamed: 0,id,tokens,tags,tags_ids,tokens_ids
0,0,"[Paul, International, airport, .]","[O, O, O, O]","[1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...","[555, 169, 737, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,1,"[It, starred, Hicks, 's, wife, ,, Ellaline, Te...","[O, O, person, O, O, O, person, person, O, per...","[1, 1, 8, 1, 1, 1, 8, 8, 1, 8, 8, 1, 0, 0, 0, ...","[34, 1616, 10001, 21, 668, 2, 10001, 10001, 5,..."
2,2,"[``, Time, ``, magazine, said, the, film, was,...","[O, art, O, O, O, O, O, O, O, O, O, O, O, O, O...","[1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[9, 2042, 9, 1073, 383, 1, 75, 10, 9, 8, 10001..."
3,3,"[Pakistani, scientists, and, engineers, ', wor...","[O, O, O, O, O, O, O, organization, O, O, O, O...","[1, 1, 1, 1, 1, 1, 1, 6, 1, 1, 1, 1, 1, 1, 1, ...","[5219, 5672, 5, 6847, 57, 602, 20, 10001, 76, ..."
4,4,"[In, February, 2008, ,, Church, 's, Chicken, e...","[O, O, O, O, organization, organization, organ...","[1, 1, 1, 1, 6, 6, 6, 1, 1, 1, 1, 1, 1, 1, 6, ...","[24, 186, 145, 2, 320, 21, 6098, 920, 1, 531, ..."
...,...,...,...,...,...
131762,131762,"[In, response, ,, the, states, who, had, ratif...","[O, O, O, O, O, O, O, O, O, other, O, O, organ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 7, 1, 1, 6, 6, 6, ...","[24, 1226, 2, 1, 1140, 48, 39, 8961, 1, 10001,..."
131763,131763,"[They, have, long, been, used, as, containers,...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[198, 52, 466, 46, 68, 16, 10001, 15, 8214, 36..."
131764,131764,"[In, 1911, he, came, into, possession, of, the...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, O]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[24, 2527, 28, 386, 58, 5345, 4, 1, 10001, 100..."
131765,131765,"[The, Lutici, tribes, in, 983, formed, the, Li...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[11, 10001, 4090, 6, 10001, 378, 1, 10001, 100..."


# Data loader

In [16]:
tr_dataset = MyDataset(tr_titles)
va_dataset = MyDataset(va_titles)

tr_dataloader = DataLoader(dataset=tr_dataset, batch_size=128, shuffle=True, num_workers=4)
va_dataloader = DataLoader(dataset=va_dataset, batch_size=128, shuffle=True, num_workers=4)

tr_dataset.__getitem__(0)

{'tokens_ids': array([555, 169, 737,   3,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0]),
 'tags_ids': array([1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}

# Model

In [18]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model = BiLSTM_CRF(
    embed_size     = 256,
    hidden_size    = 256, 
    dropout        = 0.5,
    token_voc_size = len(TOKEN2IDX) + 1, 
    tag_voc_size   = len(TAG2IDX) + 1
).to(device)

print(model)

BiLSTM_CRF(
  (token_embedding): Embedding(184112, 256)
  (dropout): Dropout(p=0.2, inplace=False)
  (lstm): LSTM(256, 128, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=256, out_features=10, bias=True)
  (crf): CRF(num_tags=10)
)


In [24]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

for epoch in range(5):

    # TRAIN

    model.train()
    for tr_batch in tqdm(tr_dataloader, total=tr_dataloader.__len__()):
        optimizer.zero_grad()
        
        tr_xs = tr_batch['tokens_ids'].to(device)
        tr_ys = tr_batch['tags_ids'].to(device)
        
        loss = model.loss_fn(tr_xs, tr_ys) + model.regularization_loss_fn(lam=5e-3, alpha=0.5)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=3)
        optimizer.step()
    
    print(f"Epoch: {epoch:02d} NLL:{loss.item()}")
    
    # VALIDATION

    batch_preds = []
    batch_trues = []

    model.eval()
    with torch.no_grad():
        for va_batch in tqdm(va_dataloader, total=va_dataloader.__len__()):
            va_xs = va_batch['tokens_ids'].to(device) # size: [batch=128, seq_len=100]
            va_ys = va_batch['tags_ids'].to(device) # size: [batch=128, seq_len=100]

            # model predictions for the batch
            out = model(va_xs, va_ys) # size: [batch=128, seq_len=100]

            # find max length without PADDING (padding value is 0) for each row in a batch
            title_length = torch.sum(va_ys > 0, dim=1) # size: [batch=128]

            for row_id, true in enumerate(va_ys.tolist()):
                # do not count padding
                true_tags = [IDX2TAG[idx] for idx in true[:title_length[row_id]]]
                # convert to the format expected by seqeval
                true_tags = IO2BIO(true_tags)
                batch_trues.append(true_tags)

            for row_id, pred in enumerate(out):
                # do not count padding
                pred_tags = [IDX2TAG[idx] for idx in pred[:title_length[row_id]]]
                # convert to the format expected by seqeval
                pred_tags = IO2BIO(pred_tags)
                batch_preds.append(pred_tags)

        for i in range(5):
            print('pred:', batch_preds[i])
            print('true:', batch_trues[i])
            print()
        
        report = classification_report(y_true=batch_trues, y_pred=batch_preds, zero_division=0)
        print(report)
    
    weights_folder = 'weights'
    if not os.path.exists(weights_folder):
        os.makedirs(weights_folder)
    
    torch.save(model.state_dict(), f"weights/model_epoch_{epoch:02d}.pt")

In [None]:
# Epoch: 03 NLL:852.8304443359375
# 100%|██████████| 295/295 [00:40<00:00,  7.28it/s]
# pred: ['B-person', 'O', 'O', 'O', 'O', 'O', 'B-person', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
# true: ['B-person', 'O', 'O', 'O', 'O', 'O', 'B-person', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']

# pred: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-location', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-organization', 'I-organization', 'O', 'O', 'O', 'O', 'O', 'O', 'B-product', 'I-product', 'O', 'O', 'O', 'B-product', 'O', 'O']
# true: ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-product', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-product', 'O', 'O', 'O', 'O', 'O', 'B-organization', 'I-organization', 'O', 'O', 'O', 'O', 'O', 'O', 'B-product', 'O', 'O', 'O', 'O', 'B-product', 'O', 'O']

# pred: ['B-person', 'I-person', 'O', 'O', 'B-organization', 'I-organization', 'I-organization', 'O', 'B-art', 'I-art', 'I-art', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
# true: ['B-person', 'I-person', 'O', 'O', 'B-organization', 'I-organization', 'I-organization', 'O', 'B-organization', 'I-organization', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']

# pred: ['O', 'O', 'O', 'O', 'O', 'O', 'B-location', 'O', 'B-location', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
# true: ['O', 'O', 'O', 'O', 'O', 'O', 'B-location', 'O', 'B-location', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']

# pred: ['O', 'O', 'O', 'O', 'O', 'B-location', 'I-location', 'I-location', 'I-location', 'O', 'B-location', 'I-location', 'O', 'O', 'O', 'O', 'B-location', 'O']
# true: ['O', 'O', 'O', 'O', 'O', 'B-location', 'I-location', 'I-location', 'I-location', 'O', 'B-location', 'I-location', 'O', 'O', 'O', 'O', 'B-location', 'O']

#               precision    recall  f1-score   support

#          art       0.69      0.68      0.68      2063
#     building       0.63      0.59      0.61      2484
#        event       0.61      0.54      0.57      2034
#     location       0.77      0.79      0.78     13649
# organization       0.69      0.59      0.63      9585
#        other       0.64      0.51      0.57      4958
#       person       0.83      0.83      0.83     10954
#      product       0.60      0.46      0.52      2955

#    micro avg       0.73      0.68      0.71     48682
#    macro avg       0.68      0.62      0.65     48682
# weighted avg       0.73      0.68      0.70     48682