In [1]:
import sys
sys.path.append("..")
from pathlib import Path

import torch
from seqeval.metrics import classification_report
from seqeval.scheme import IOB2

from src.slot.data_manager import SlotDataManager
from src.slot.models import SlotTagger

In [2]:
data_manager = SlotDataManager(
    cache_dir=Path("../cache/slot"),
    max_len=128,
    batch_size=32,
    num_workers=8,
    data_dir=Path("../data/slot"),
    test_file=Path("../data/slot/test.json")
)

2022-03-05 06:20:24 | INFO | Vocab loaded from /home/jacky/110-2_ADL/homeworks/hw01/cache/slot/vocab.pkl
2022-03-05 06:20:24 | INFO | Tag-2-Index loaded from /home/jacky/110-2_ADL/homeworks/hw01/cache/slot/tag2idx.json
2022-03-05 06:20:24 | INFO | Embeddings loaded from /home/jacky/110-2_ADL/homeworks/hw01/cache/slot/embeddings.pt


In [3]:
valid_dataloader = data_manager.get_valid_dataloader()
x, length, y = next(iter(valid_dataloader))
x.shape, length.shape, y.shape

(torch.Size([32, 33]), torch.Size([32]), torch.Size([32, 128]))

In [4]:
model = SlotTagger.load_from_checkpoint(Path("../ckpt/slot/20220303_1645/slot-epoch=12-val_acc=0.97-val_loss=0.12.ckpt"))
model

SlotTagger(
  (embedding): Embedding(4117, 300, padding_idx=0)
  (rnn): GRU(300, 512, num_layers=2, dropout=0.2, bidirectional=True)
  (fc): Sequential(
    (0): Dropout(p=0.2, inplace=False)
    (1): Linear(in_features=1024, out_features=1024, bias=True)
    (2): PReLU(num_parameters=1)
    (3): Dropout(p=0.2, inplace=False)
    (4): Linear(in_features=1024, out_features=9, bias=True)
  )
  (loss): CrossEntropyLoss()
)

In [5]:
output = model(x, length)
flatten_output = torch.cat([
    sen_output[:sen_len, :]
    for sen_output, sen_len, in zip(output, length)
])
flatten_y = torch.cat([
    sen_tags[:sen_len]
    for sen_tags, sen_len, in zip(y, length)
])

In [6]:
token_acc = (flatten_output.argmax(dim=1) == flatten_y).float().mean()
print(f"Token Accuracy: {token_acc}")

Token Accuracy: 0.9765258431434631


In [7]:
join_acc = torch.tensor([
    torch.all(sen_val[:sen_len].argmax(dim=1) == sen_tags[:sen_len])
    for sen_val, sen_tags, sen_len in zip(output, y, length)
]).float().mean()
print(f"Join Accuracy: {join_acc}")

Join Accuracy: 0.875


In [8]:
clipped_y = [
    [data_manager.idx2tag[idx] for idx in sen_tags[:sen_len].tolist()]
    for sen_tags, sen_len in zip(y, length)
]
clipped_pred = [
    [data_manager.idx2tag[idx] for idx in sen_val[:sen_len].argmax(dim=1).tolist()]
    for sen_val, sen_len in  zip(output, length)
]

In [9]:
print(classification_report(y_true=clipped_y, y_pred=clipped_pred, scheme=IOB2, mode="strict"))

              precision    recall  f1-score   support

        date       0.71      0.71      0.71         7
  first_name       1.00      1.00      1.00         1
   last_name       0.00      0.00      0.00         1
      people       1.00      0.89      0.94         9
        time       0.83      1.00      0.91         5

   micro avg       0.86      0.83      0.84        23
   macro avg       0.71      0.72      0.71        23
weighted avg       0.83      0.83      0.83        23



  _warn_prf(average, modifier, msg_start, len(result))


In [12]:
from collections import Counter
Counter(tag for tags in clipped_y for tag in tags)

Counter({'O': 180,
         'B-time': 5,
         'B-people': 9,
         'B-date': 7,
         'I-date': 9,
         'B-first_name': 1,
         'B-last_name': 1,
         'I-people': 1})

In [13]:
Counter(tag for tags in clipped_pred for tag in tags)

Counter({'O': 181,
         'B-time': 6,
         'B-people': 8,
         'B-date': 7,
         'I-date': 9,
         'B-first_name': 1,
         'I-people': 1})