In [1]:
import sys
sys.path.append("..")
from pathlib import Path
import json

import torch
import torch.nn.functional as F
from pytorch_lightning import seed_everything
from tqdm import tqdm
from seqeval.metrics import classification_report
from seqeval.scheme import IOB2

from src.slot.data_manager import SlotDataManager
from src.slot.models import SlotTagger

In [2]:
data_manager = SlotDataManager(
    cache_dir=Path("../cache/slot"),
    max_len=128,
    batch_size=32,
    num_workers=8,
    data_dir=Path("../data/slot"),
    test_file=Path("../data/slot/test.json")
)

2022-03-18 03:04:09 | INFO | Vocab loaded from /home/jacky/110-2_ADL/homeworks/hw01/cache/slot/vocab.pkl
2022-03-18 03:04:09 | INFO | Tag-2-Index loaded from /home/jacky/110-2_ADL/homeworks/hw01/cache/slot/tag2idx.json
2022-03-18 03:04:09 | INFO | Embeddings loaded from /home/jacky/110-2_ADL/homeworks/hw01/cache/slot/embeddings.pt


In [3]:
model = SlotTagger.load_from_checkpoint(Path("../upload/ckpt/slot/slot-best.ckpt"))
model



SlotTagger(
  (embedding): Embedding(3002, 300, padding_idx=0)
  (rnn): GRU(300, 128, dropout=0.5, bidirectional=True)
  (fc): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=256, out_features=256, bias=True)
    (2): ELU(alpha=1.0)
    (3): Dropout(p=0.5, inplace=False)
    (4): Linear(in_features=256, out_features=9, bias=True)
  )
  (loss): CrossEntropyLoss()
)

In [4]:
seed_everything(1123)
valid_dataloader = data_manager.get_valid_dataloader()
length_list = []
output_list = []
y_list = []
flatten_output_list = []
flatten_y_list = []
for x, length, y in tqdm(valid_dataloader):
    output = model(x, length)
    length_list.append(length)
    output_list.append(output)
    y_list.append(y)
    flatten_output_list.append(torch.cat([
        sen_output[:sen_len, :]
        for sen_output, sen_len, in zip(output, length)
    ]))
    flatten_y_list.append(torch.cat([
        sen_tags[:sen_len]
        for sen_tags, sen_len, in zip(y, length)
    ]))
length = torch.cat(length_list)
max_len = length.max()
output = torch.cat([
    F.pad(o, (0, 0, 0, max_len - o.shape[1]))
    for o in output_list
])
y = torch.cat(y_list)
flatten_output = torch.cat(flatten_output_list)
flatten_y = torch.cat(flatten_y_list)
output.shape, y.shape, length.shape, flatten_output.shape, flatten_y.shape

Global seed set to 1123
100%|██████████| 32/32 [00:00<00:00, 38.44it/s]


(torch.Size([1000, 33, 9]),
 torch.Size([1000, 33]),
 torch.Size([1000]),
 torch.Size([7891, 9]),
 torch.Size([7891]))

In [5]:
token_acc = model.token_acc(pred=flatten_output, target=flatten_y)
print(f"Token Accuracy: {token_acc}")

Token Accuracy: 0.9618552923202515


In [6]:
join_acc = model.join_acc(pred=output, length=length, target=y)
print(f"Join Accuracy: {join_acc}")

Join Accuracy: 0.7730000019073486


In [7]:
clipped_y = [
    [data_manager.idx2tag[idx] for idx in sen_tags[:sen_len].tolist()]
    for sen_tags, sen_len in zip(y, length)
]
clipped_pred = [
    [data_manager.idx2tag[idx] for idx in sen_val[:sen_len].argmax(dim=1).tolist()]
    for sen_val, sen_len in  zip(output, length)
]

In [8]:
print(classification_report(y_true=clipped_y, y_pred=clipped_pred, scheme=IOB2, mode="strict"))

              precision    recall  f1-score   support

        date       0.72      0.74      0.73       206
  first_name       0.94      0.91      0.93       102
   last_name       0.89      0.81      0.85        78
      people       0.69      0.68      0.68       238
        time       0.81      0.79      0.80       218

   micro avg       0.78      0.76      0.77       842
   macro avg       0.81      0.78      0.80       842
weighted avg       0.78      0.76      0.77       842



In [9]:
from collections import Counter
Counter(tag for tags in clipped_y for tag in tags)

Counter({'O': 6458,
         'B-time': 218,
         'B-people': 238,
         'B-date': 206,
         'I-date': 290,
         'B-first_name': 102,
         'B-last_name': 78,
         'I-people': 231,
         'I-time': 70})

In [10]:
Counter(tag for tags in clipped_pred for tag in tags)

Counter({'O': 6475,
         'B-time': 213,
         'B-people': 234,
         'B-date': 210,
         'I-date': 276,
         'I-people': 234,
         'B-first_name': 99,
         'B-last_name': 71,
         'I-time': 79})

In [11]:
from typing import List, Tuple
def get_chunk(tags: List[str]) -> List[Tuple[str, int, int]]:
    result = []
    current = None
    for i, tag in enumerate(tags):
        if tag[:2] != "I-" and current is not None:
            result.append((current[0], current[1], i-1))
            current = None
        if tag[:2] == "B-":
            current = (tag[2:], i)
    if current is not None:
        result.append((current[0], current[1], i))
    return result

In [12]:
chunked_y = [get_chunk(tags) for tags in clipped_y]
chunked_pred = [get_chunk(tags) for tags in clipped_pred]
len(chunked_y), len(chunked_pred)

(1000, 1000)

In [13]:
tags = set(chunk[0] for chunks in chunked_y for chunk in chunks) | set(chunk[0] for chunks in chunked_pred for chunk in chunks)
result = {
    tag: {
        "tp": 0,
        "fp": 0,
        "fn": 0,
        "support": 0,
    }
    for tag in tags
}

In [14]:
for y_chunks, pred_chunks in zip(chunked_y, chunked_pred):
    for chunk in y_chunks:
        tag = chunk[0]
        result[tag]["support"] += 1
        if chunk in pred_chunks:
            result[tag]["tp"] += 1
        else:
            result[tag]["fn"] += 1
    for chunk in pred_chunks:
        tag = chunk[0]
        if chunk not in y_chunks:
            result[tag]["fp"] += 1

In [15]:
result = {
    tag: {
        "precision": value["tp"] / (value["tp"] + value["fp"]),
        "recall": value["tp"] / (value["tp"] + value["fn"]),
        **value
    }
    for tag, value in result.items()
}
result = {
    tag: {
        "f1": 2 * value["precision"] * value["recall"] / (value["precision"] + value["recall"]),
        **value
    }
    for tag, value in result.items()
}

In [16]:
# print(json.dumps(result, indent=2))

In [20]:
print(f'Support: {sum(value["support"] for value in result.values())}')

Support: 842


In [29]:
metrics = ("precision", "recall", "f1")
tp = sum(value["tp"] for value in result.values())
fp = sum(value["fp"] for value in result.values())
fn = sum(value["fn"] for value in result.values())
micro = {
    "precision": tp / (tp + fp),
    "recall": tp / (tp + fn)
}
micro["f1"] = 2 * micro["precision"] * micro["recall"] / (micro["precision"] + micro["recall"])
for m in metrics:
    print(f"{m} (micro): {micro[m]}")

precision (micro): 0.7750906892382105
recall (micro): 0.7612826603325415
f1 (micro): 0.768124625524266


In [24]:
metrics = ("precision", "recall", "f1")
for m in metrics:
    print(f"{m} (macro): {sum(value[m] for value in result.values()) / len(result)}")

precision (macro): 0.8092146663977651
recall (macro): 0.7845565010335862
f1 (macro): 0.7964254380657672


In [28]:
metrics = ("precision", "recall", "f1")
for m in metrics:
    print(f"{m} (weighted): {sum(value[m] * value['support'] for value in result.values()) / sum(value['support'] for value in result.values())}")

precision (weighted): 0.7766317182495965
recall (weighted): 0.7612826603325415
f1 (weighted): 0.7687007353824342
