# Import Library

In [1]:
import json
import torch
import pickle
import pandas as pd

from typing import Dict
from pathlib import Path
from IPython.display import display

from tqdm import tqdm
from utils import Vocab
from model import SeqTagger
from torch.utils.data import DataLoader
from dataset import SeqTaggingClsDataset

from seqeval.scheme import IOB2
from seqeval.metrics import f1_score
from seqeval.metrics import accuracy_score
from seqeval.metrics import classification_report

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
with open("cache/slot/vocab.pkl", "rb") as f:
    vocab: Vocab = pickle.load(f)
embeddings = torch.load("cache/slot/embeddings.pt")

In [3]:
tag2idx: Dict[str, int] = json.loads(Path("cache/slot/tag2idx.json").read_text())
tag2idx

{'I-date': 0,
 'I-people': 1,
 'B-people': 2,
 'B-first_name': 3,
 'B-time': 4,
 'B-date': 5,
 'B-last_name': 6,
 'O': 7,
 'I-time': 8}

In [4]:
max_len = 40
batch_size = 1
eval_file_path = Path("data/slot/eval.json")
eval_data = json.loads(eval_file_path.read_text())
eval_dataset = SeqTaggingClsDataset(eval_data, vocab, tag2idx, max_len)

eval_data_loader = DataLoader(
    eval_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=eval_dataset.collate_fn,
)

In [5]:
hidden_size = 512
num_layers = 4
dropout = 0.1
bidirectional = True
device = "cuda"

# load model
model = SeqTagger(
    embeddings,
    hidden_size,
    num_layers,
    dropout,
    bidirectional,
    len(tag2idx)
)

model.load_state_dict(torch.load("ckpt/slot/slot_model.pt"))
model.to(device)
model.eval()

SeqTagger(
  (embed): Embedding(4117, 300)
  (rnn): GRU(300, 512, num_layers=4, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=1024, out_features=9, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear): Linear(in_features=1024, out_features=9, bias=True)
)

In [6]:
result_dict = {
    "id": list(),
    "tags": list()
}
count = 0
for batch in eval_data_loader:
    ids = batch["ids"].to(device)
    pred = model(ids).to(device)
    sentence_len = len(eval_data[count].get("tokens"))
    pred_res = pred.argmax(dim=1).tolist()[:sentence_len]
    pred_res = [eval_dataset.idx2label(res) for res in pred_res]
    pred_res_str = " ".join(pred_res)

    result_dict["id"].append(eval_data[count]["id"])
    result_dict["tags"].append(pred_res)

    count += 1

  return F.log_softmax(affine_out)


In [7]:
y_pred = result_dict.get("tags")

y_true = [
    item.get("tags")
    for item in eval_data
]

display(y_pred[:5])
print("-" * 59)
display(y_true[:5])

[['O', 'O', 'O', 'O', 'O'],
 ['B-time', 'O'],
 ['O', 'O', 'B-people'],
 ['O', 'O', 'O', 'O', 'O', 'O'],
 ['O', 'O', 'B-date', 'I-date', 'I-date']]

-----------------------------------------------------------


[['O', 'O', 'O', 'O', 'O'],
 ['B-time', 'O'],
 ['O', 'O', 'B-people'],
 ['O', 'O', 'O', 'O', 'O', 'O'],
 ['O', 'O', 'B-date', 'I-date', 'I-date']]

In [8]:
seqeval_test_report = classification_report(y_true, y_pred, mode='strict', scheme=IOB2)

In [9]:
from pprint import pprint
pprint(seqeval_test_report.split("\n"))

['              precision    recall  f1-score   support',
 '',
 '        date       0.76      0.67      0.71       206',
 '  first_name       0.84      0.87      0.86       102',
 '   last_name       0.68      0.74      0.71        78',
 '      people       0.61      0.60      0.60       238',
 '        time       0.72      0.84      0.78       218',
 '',
 '   micro avg       0.71      0.73      0.72       842',
 '   macro avg       0.72      0.75      0.73       842',
 'weighted avg       0.71      0.73      0.72       842',
 '']


In [10]:
y_pred[0] == y_true[0]

True

In [11]:
joint_correct_count = 0
for pred_list, true_list in zip(y_pred, y_true):
    if pred_list == true_list:
        joint_correct_count += 1

print(f"Joint Accuracy = {joint_correct_count} / {len(y_true)}")

Joint Accuracy = 728 / 1000


In [12]:
token_correct_count = 0
all_token_count = 0
for pred_list, true_list in zip(y_pred, y_true):
    all_token_count += len(pred_list)
    for pred_value, true_value in zip(pred_list, true_list):
        if pred_value == true_value:
            token_correct_count += 1
            
print(f"Token Accuracy = {token_correct_count} / {all_token_count}")

Token Accuracy = 7511 / 7891


In [13]:
751100 / 7891

95.18438727664427

In [14]:
accuracy_score(y_true, y_pred)

0.9518438727664428