# Finetuning NERGrit
NERGrit is a Named Entity Recognition dataset with 3 possible entity tags (`PERSON`, `PLACE`, `ORGANIZATION`) in IOB chunking format

In [1]:
import os, sys
sys.path.append('../')
os.chdir('../')

import random
import numpy as np
import pandas as pd
import torch
from torch import optim
from tqdm import tqdm

from transformers import BertConfig, BertTokenizer
from nltk.tokenize import word_tokenize

from modules.word_classification import BertForWordClassification
from utils.forward_fn import forward_word_classification
from utils.metrics import ner_metrics_fn
from utils.data_utils import NerGritDataset, NerDataLoader

In [2]:
###
# common functions
###
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
def count_param(module, trainable=False):
    if trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())
    
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def metrics_to_string(metric_dict):
    string_list = []
    for key, value in metric_dict.items():
        string_list.append('{}:{:.2f}'.format(key, value))
    return ' '.join(string_list)

In [3]:
# Set random seed
set_seed(26092020)

# Load IndoBERT Model

In [4]:
# Load Tokenizer and Config
tokenizer = BertTokenizer.from_pretrained('indobenchmark/indobert-base-p1')
config = BertConfig.from_pretrained('indobenchmark/indobert-base-p1')
config.num_labels = NerGritDataset.NUM_LABELS

# Instantiate model
model = BertForWordClassification.from_pretrained('indobenchmark/indobert-base-p1', config=config)

Some weights of BertForWordClassification were not initialized from the model checkpoint at indobenchmark/indobert-base-p1 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
model

BertForWordClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(50000, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise

In [6]:
count_param(model)

124446727

# Prepare Named Entity Recognition Dataset (NERGrit)

In [7]:
train_dataset_path = './dataset/nergrit_ner-grit/train_preprocess.txt'
valid_dataset_path = './dataset/nergrit_ner-grit/valid_preprocess.txt'
test_dataset_path = './dataset/nergrit_ner-grit/test_preprocess_masked_label.txt'

In [8]:
train_dataset = NerGritDataset(train_dataset_path, tokenizer, lowercase=True)
valid_dataset = NerGritDataset(valid_dataset_path, tokenizer, lowercase=True)
test_dataset = NerGritDataset(test_dataset_path, tokenizer, lowercase=True)

train_loader = NerDataLoader(dataset=train_dataset, max_seq_len=512, batch_size=16, num_workers=16, shuffle=True)  
valid_loader = NerDataLoader(dataset=valid_dataset, max_seq_len=512, batch_size=16, num_workers=16, shuffle=False)  
test_loader = NerDataLoader(dataset=test_dataset, max_seq_len=512, batch_size=16, num_workers=16, shuffle=False)

In [9]:
w2i, i2w = NerGritDataset.LABEL2INDEX, NerGritDataset.INDEX2LABEL
print(w2i)
print(i2w)

{'I-PERSON': 0, 'B-ORGANISATION': 1, 'I-ORGANISATION': 2, 'B-PLACE': 3, 'I-PLACE': 4, 'O': 5, 'B-PERSON': 6}
{0: 'I-PERSON', 1: 'B-ORGANISATION', 2: 'I-ORGANISATION', 3: 'B-PLACE', 4: 'I-PLACE', 5: 'O', 6: 'B-PERSON'}


# Test model on sample sentences

In [10]:
def word_subword_tokenize(sentence, tokenizer):
    # Add CLS token
    subwords = [tokenizer.cls_token_id]
    subword_to_word_indices = [-1] # For CLS

    # Add subwords
    for word_idx, word in enumerate(sentence):
        subword_list = tokenizer.encode(word, add_special_tokens=False)
        subword_to_word_indices += [word_idx for i in range(len(subword_list))]
        subwords += subword_list

    # Add last SEP token
    subwords += [tokenizer.sep_token_id]
    subword_to_word_indices += [-1]

    return subwords, subword_to_word_indices

In [11]:
text = word_tokenize('Bung Tomo adalah pahlawan nasional Republik Indonesia')
subwords, subword_to_word_indices = word_subword_tokenize(text, tokenizer)

subwords = torch.LongTensor(subwords).view(1, -1).to(model.device)
subword_to_word_indices = torch.LongTensor(subword_to_word_indices).view(1, -1).to(model.device)
logits = model(subwords, subword_to_word_indices)[0]

preds = torch.topk(logits, k=1, dim=-1)[1].squeeze().numpy()
labels = [i2w[preds[i]] for i in range(len(preds))]

pd.DataFrame({'words': text, 'label': labels})

Unnamed: 0,words,label
0,Bung,O
1,Tomo,B-PERSON
2,adalah,B-PLACE
3,pahlawan,B-PERSON
4,nasional,O
5,Republik,O
6,Indonesia,O


In [12]:
text = word_tokenize('Budi pergi ke mall kelapa gading membeli kue bantal')
subwords, subword_to_word_indices = word_subword_tokenize(text, tokenizer)

subwords = torch.LongTensor(subwords).view(1, -1).to(model.device)
subword_to_word_indices = torch.LongTensor(subword_to_word_indices).view(1, -1).to(model.device)
logits = model(subwords, subword_to_word_indices)[0]

preds = torch.topk(logits, k=1, dim=-1)[1].squeeze().cpu().numpy()
labels = [i2w[preds[i]] for i in range(len(preds))]

pd.DataFrame({'words': text, 'label': labels})

Unnamed: 0,words,label
0,Budi,B-ORGANISATION
1,pergi,I-PLACE
2,ke,B-PERSON
3,mall,B-PERSON
4,kelapa,I-PERSON
5,gading,I-PERSON
6,membeli,B-ORGANISATION
7,kue,B-PERSON
8,bantal,I-PERSON


In [13]:
text = word_tokenize('Saya sudah sampai di depan menara bca')
subwords, subword_to_word_indices = word_subword_tokenize(text, tokenizer)

subwords = torch.LongTensor(subwords).view(1, -1).to(model.device)
subword_to_word_indices = torch.LongTensor(subword_to_word_indices).view(1, -1).to(model.device)
logits = model(subwords, subword_to_word_indices)[0]

preds = torch.topk(logits, k=1, dim=-1)[1].squeeze().cpu().numpy()
labels = [i2w[preds[i]] for i in range(len(preds))]

pd.DataFrame({'words': text, 'label': labels})

Unnamed: 0,words,label
0,Saya,B-PLACE
1,sudah,B-ORGANISATION
2,sampai,I-PERSON
3,di,I-PLACE
4,depan,I-PERSON
5,menara,I-ORGANISATION
6,bca,B-PLACE


# Fine Tuning & Evaluation

In [14]:
optimizer = optim.Adam(model.parameters(), lr=5e-6)
model = model.cuda()

In [15]:
# Train
n_epochs = 8
for epoch in range(n_epochs):
    model.train()
    torch.set_grad_enabled(True)
 
    total_train_loss = 0
    list_hyp, list_label = [], []

    train_pbar = tqdm(train_loader, leave=True, total=len(train_loader))
    for i, batch_data in enumerate(train_pbar):
        # Forward model
        loss, batch_hyp, batch_label = forward_word_classification(model, batch_data[:-1], i2w=i2w, device='cuda')

        # Update model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tr_loss = loss.item()
        total_train_loss = total_train_loss + tr_loss

        # Calculate metrics
        list_hyp += batch_hyp
        list_label += batch_label

        train_pbar.set_description("(Epoch {}) TRAIN LOSS:{:.4f} LR:{:.8f}".format((epoch+1),
            total_train_loss/(i+1), get_lr(optimizer)))

    # Calculate train metric
    metrics = ner_metrics_fn(list_hyp, list_label)
    print("(Epoch {}) TRAIN LOSS:{:.4f} {} LR:{:.8f}".format((epoch+1),
        total_train_loss/(i+1), metrics_to_string(metrics), get_lr(optimizer)))

    # Evaluate on validation
    model.eval()
    torch.set_grad_enabled(False)
    
    total_loss, total_correct, total_labels = 0, 0, 0
    list_hyp, list_label = [], []

    pbar = tqdm(valid_loader, leave=True, total=len(valid_loader))
    for i, batch_data in enumerate(pbar):
        batch_seq = batch_data[-1]        
        loss, batch_hyp, batch_label = forward_word_classification(model, batch_data[:-1], i2w=i2w, device='cuda')
        
        # Calculate total loss
        valid_loss = loss.item()
        total_loss = total_loss + valid_loss

        # Calculate evaluation metrics
        list_hyp += batch_hyp
        list_label += batch_label
        metrics = ner_metrics_fn(list_hyp, list_label)

        pbar.set_description("VALID LOSS:{:.4f} {}".format(total_loss/(i+1), metrics_to_string(metrics)))
        
    metrics = ner_metrics_fn(list_hyp, list_label)
    print("(Epoch {}) VALID LOSS:{:.4f} {}".format((epoch+1),
        total_loss/(i+1), metrics_to_string(metrics)))

(Epoch 1) TRAIN LOSS:1.4168 LR:0.00000500: 100%|██████████| 105/105 [00:19<00:00,  5.31it/s]
  0%|          | 0/14 [00:00<?, ?it/s]

(Epoch 1) TRAIN LOSS:1.4168 ACC:0.80 F1:0.03 REC:0.01 PRE:0.02 LR:0.00000500


VALID LOSS:1.2833 ACC:0.84 F1:0.19 REC:0.08 PRE:0.12: 100%|██████████| 14/14 [00:01<00:00, 10.93it/s]
  0%|          | 0/105 [00:00<?, ?it/s]

(Epoch 1) VALID LOSS:1.2833 ACC:0.84 F1:0.19 REC:0.08 PRE:0.12


(Epoch 2) TRAIN LOSS:1.1296 LR:0.00000500: 100%|██████████| 105/105 [00:20<00:00,  5.16it/s]
  0%|          | 0/14 [00:00<?, ?it/s]

(Epoch 2) TRAIN LOSS:1.1296 ACC:0.86 F1:0.37 REC:0.21 PRE:0.27 LR:0.00000500


VALID LOSS:1.2114 ACC:0.90 F1:0.53 REC:0.50 PRE:0.51: 100%|██████████| 14/14 [00:01<00:00, 10.48it/s]
  0%|          | 0/105 [00:00<?, ?it/s]

(Epoch 2) VALID LOSS:1.2114 ACC:0.90 F1:0.53 REC:0.50 PRE:0.51


(Epoch 3) TRAIN LOSS:1.0481 LR:0.00000500: 100%|██████████| 105/105 [00:21<00:00,  4.83it/s]
  0%|          | 0/14 [00:00<?, ?it/s]

(Epoch 3) TRAIN LOSS:1.0481 ACC:0.91 F1:0.57 REC:0.51 PRE:0.54 LR:0.00000500


VALID LOSS:1.1528 ACC:0.92 F1:0.55 REC:0.60 PRE:0.58: 100%|██████████| 14/14 [00:01<00:00, 10.77it/s]
  0%|          | 0/105 [00:00<?, ?it/s]

(Epoch 3) VALID LOSS:1.1528 ACC:0.92 F1:0.55 REC:0.60 PRE:0.58


(Epoch 4) TRAIN LOSS:0.9969 LR:0.00000500: 100%|██████████| 105/105 [00:20<00:00,  5.04it/s]
  0%|          | 0/14 [00:00<?, ?it/s]

(Epoch 4) TRAIN LOSS:0.9969 ACC:0.93 F1:0.64 REC:0.62 PRE:0.63 LR:0.00000500


VALID LOSS:1.0662 ACC:0.93 F1:0.60 REC:0.71 PRE:0.65: 100%|██████████| 14/14 [00:01<00:00, 10.85it/s]
  0%|          | 0/105 [00:00<?, ?it/s]

(Epoch 4) VALID LOSS:1.0662 ACC:0.93 F1:0.60 REC:0.71 PRE:0.65


(Epoch 5) TRAIN LOSS:0.9409 LR:0.00000500: 100%|██████████| 105/105 [00:20<00:00,  5.12it/s]
  0%|          | 0/14 [00:00<?, ?it/s]

(Epoch 5) TRAIN LOSS:0.9409 ACC:0.94 F1:0.68 REC:0.68 PRE:0.68 LR:0.00000500


VALID LOSS:1.0823 ACC:0.94 F1:0.64 REC:0.72 PRE:0.67: 100%|██████████| 14/14 [00:01<00:00, 10.83it/s]
  0%|          | 0/105 [00:00<?, ?it/s]

(Epoch 5) VALID LOSS:1.0823 ACC:0.94 F1:0.64 REC:0.72 PRE:0.67


(Epoch 6) TRAIN LOSS:0.8976 LR:0.00000500: 100%|██████████| 105/105 [00:20<00:00,  5.08it/s]
  0%|          | 0/14 [00:00<?, ?it/s]

(Epoch 6) TRAIN LOSS:0.8976 ACC:0.95 F1:0.73 REC:0.74 PRE:0.73 LR:0.00000500


VALID LOSS:1.0892 ACC:0.94 F1:0.71 REC:0.71 PRE:0.71: 100%|██████████| 14/14 [00:01<00:00, 10.94it/s]
  0%|          | 0/105 [00:00<?, ?it/s]

(Epoch 6) VALID LOSS:1.0892 ACC:0.94 F1:0.71 REC:0.71 PRE:0.71


(Epoch 7) TRAIN LOSS:0.8610 LR:0.00000500: 100%|██████████| 105/105 [00:20<00:00,  5.03it/s]
  0%|          | 0/14 [00:00<?, ?it/s]

(Epoch 7) TRAIN LOSS:0.8610 ACC:0.96 F1:0.76 REC:0.78 PRE:0.77 LR:0.00000500


VALID LOSS:0.8833 ACC:0.94 F1:0.70 REC:0.77 PRE:0.73: 100%|██████████| 14/14 [00:01<00:00, 10.84it/s]
  0%|          | 0/105 [00:00<?, ?it/s]

(Epoch 7) VALID LOSS:0.8833 ACC:0.94 F1:0.70 REC:0.77 PRE:0.73


(Epoch 8) TRAIN LOSS:0.8347 LR:0.00000500: 100%|██████████| 105/105 [00:20<00:00,  5.12it/s]
  0%|          | 0/14 [00:00<?, ?it/s]

(Epoch 8) TRAIN LOSS:0.8347 ACC:0.96 F1:0.78 REC:0.80 PRE:0.79 LR:0.00000500


VALID LOSS:0.8616 ACC:0.94 F1:0.72 REC:0.76 PRE:0.74: 100%|██████████| 14/14 [00:01<00:00, 10.73it/s]

(Epoch 8) VALID LOSS:0.8616 ACC:0.94 F1:0.72 REC:0.76 PRE:0.74





In [16]:
# Evaluate on test
model.eval()
torch.set_grad_enabled(False)

total_loss, total_correct, total_labels = 0, 0, 0
list_hyp, list_label = [], []

pbar = tqdm(test_loader, leave=True, total=len(test_loader))
for i, batch_data in enumerate(pbar):
    _, batch_hyp, _ = forward_word_classification(model, batch_data[:-1], i2w=i2w, device='cuda')
    list_hyp += batch_hyp

# Save prediction
df = pd.DataFrame({'label':list_hyp}).reset_index()
df.to_csv('pred.txt', index=False)

print(df)

100%|██████████| 14/14 [00:01<00:00, 11.35it/s]


     index                                              label
0        0  [B-PERSON, I-PERSON, O, O, O, O, B-ORGANISATIO...
1        1  [O, O, O, O, O, O, O, B-PERSON, O, O, O, O, O,...
2        2  [O, O, O, O, O, O, O, O, B-ORGANISATION, I-ORG...
3        3  [O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ...
4        4  [O, O, O, O, O, O, B-PERSON, I-PERSON, O, O, O...
..     ...                                                ...
204    204  [O, O, O, O, O, O, B-PLACE, O, O, O, O, B-PLAC...
205    205      [O, O, O, O, O, O, O, O, O, O, O, O, O, O, O]
206    206  [O, O, O, O, B-PLACE, I-PLACE, O, O, O, B-PLAC...
207    207  [O, O, O, O, O, O, O, B-PERSON, O, O, O, B-PLA...
208    208  [O, O, O, O, O, O, O, O, B-PLACE, I-PLACE, O, ...

[209 rows x 2 columns]


# Test fine-tuned model with sample sentences

In [17]:
text = word_tokenize('Bung Tomo adalah pahlawan nasional Republik Indonesia')
subwords, subword_to_word_indices = word_subword_tokenize(text, tokenizer)

subwords = torch.LongTensor(subwords).view(1, -1).to(model.device)
subword_to_word_indices = torch.LongTensor(subword_to_word_indices).view(1, -1).to(model.device)
logits = model(subwords, subword_to_word_indices)[0]

preds = torch.topk(logits, k=1, dim=-1)[1].squeeze().cpu().numpy()
labels = [i2w[preds[i]] for i in range(len(preds))]

pd.DataFrame({'words': text, 'label': labels})

Unnamed: 0,words,label
0,Bung,B-PERSON
1,Tomo,B-PERSON
2,adalah,O
3,pahlawan,O
4,nasional,O
5,Republik,O
6,Indonesia,B-PLACE


In [18]:
text = word_tokenize('Budi pergi ke mall kelapa gading membeli kue bantal')
subwords, subword_to_word_indices = word_subword_tokenize(text, tokenizer)

subwords = torch.LongTensor(subwords).view(1, -1).to(model.device)
subword_to_word_indices = torch.LongTensor(subword_to_word_indices).view(1, -1).to(model.device)
logits = model(subwords, subword_to_word_indices)[0]

preds = torch.topk(logits, k=1, dim=-1)[1].squeeze().cpu().numpy()
labels = [i2w[preds[i]] for i in range(len(preds))]

pd.DataFrame({'words': text, 'label': labels})

Unnamed: 0,words,label
0,Budi,B-PERSON
1,pergi,O
2,ke,O
3,mall,B-PLACE
4,kelapa,I-PLACE
5,gading,I-PLACE
6,membeli,O
7,kue,O
8,bantal,O


In [19]:
text = word_tokenize('Saya sudah sampai di depan menara bca')
subwords, subword_to_word_indices = word_subword_tokenize(text, tokenizer)

subwords = torch.LongTensor(subwords).view(1, -1).to(model.device)
subword_to_word_indices = torch.LongTensor(subword_to_word_indices).view(1, -1).to(model.device)
logits = model(subwords, subword_to_word_indices)[0]

preds = torch.topk(logits, k=1, dim=-1)[1].squeeze().cpu().numpy()
labels = [i2w[preds[i]] for i in range(len(preds))]

pd.DataFrame({'words': text, 'label': labels})

Unnamed: 0,words,label
0,Saya,O
1,sudah,O
2,sampai,O
3,di,O
4,depan,O
5,menara,B-PLACE
6,bca,I-PLACE
