In [51]:
import json
import pickle
import csv
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Dict

import torch

from slot_dataset import SeqClsDataset
from slot_model import SeqClassifier
from utils import Vocab
from torch.utils.data import DataLoader
from seqeval.metrics import classification_report
from seqeval.scheme import IOB2

device = torch.device("cuda:3") if torch.cuda.is_available() else torch.device("cpu")

In [2]:
def cal_joint_acc(pred, label):
    pclass = pred.argmax(dim=2)
    correct = 0
    for i in range(len(label)):
        s_label = label[i][label[i]!=-100]
        length = len(s_label)
        if (s_label==pclass[i][:length]).all():
            correct += 1
    jacc = correct / len(label)
    return jacc

In [3]:
def cal_token_acc(pred, label):
    pclass = pred.argmax(dim=2)
    acc = 0
    for i in range(len(label)):
        s_label = label[i][label[i]!=-100]
        length = len(s_label)
        acc += (s_label==pclass[i][:length]).sum().item()/length
    acc /= len(label)
    return acc

In [4]:
eval_file = "./data/slot/eval.json"
cache_dir = "./cache/slot/"
ckpt_dir = "/data/NFS/andy/course/ADL/hw1/slot_weights.pt"
max_len = 35

In [5]:
with open(cache_dir + "vocab.pkl", "rb") as f:
        vocab: Vocab = pickle.load(f)

tag_idx_path = Path(cache_dir + "tag2idx.json")
tag2idx: Dict[str, int] = json.loads(tag_idx_path.read_text())

data = json.loads(Path(eval_file).read_text())
dataset = SeqClsDataset(data, vocab, tag2idx, max_len)

batch_size = 128
val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=dataset.collate_fn)

In [6]:
# max_len = 0
# for i in range(len(dataset.data)):
#     sentence = dataset.data[i]["tokens"]
#     if len(sentence) > max_len:
#         max_len = len(sentence)
# print(max_len)

In [7]:
embeddings = torch.load(cache_dir + "embeddings.pt")

In [8]:
model = SeqClassifier(embeddings=embeddings, hidden_size=256, num_layers=2, dropout=0.2, bidirectional=True, num_class=9)
model.to(device)
model.load_state_dict(torch.load(ckpt_dir))

<All keys matched successfully>

In [53]:
joint_acc = 0
token_acc = 0
all_pred = []
all_label = []

model.eval()
with torch.no_grad():
    for i, batch in enumerate(val_loader):
        data = batch[0].to(device)
        label = batch[1].to(device)
        pred = model(data)
        pclass = pred.argmax(dim=2)
        for i in range(len(label)):
            s_label = label[i][label[i]!=-100]
            length = len(s_label)
            sub_pred = []
            sub_label = []
            for j in range(length):
                sub_pred.append(dataset.idx2label(pclass[i][j].item()))
                sub_label.append(dataset.idx2label(label[i][j].item()))
            all_pred.append(sub_pred)
            all_label.append(sub_label)
        
        joint_acc += cal_joint_acc(pred, label)
        token_acc += cal_token_acc(pred, label)

print(joint_acc/val_loader.__len__())
print(token_acc/val_loader.__len__())
print(classification_report(all_label, all_pred, scheme=IOB2, mode="strict"))

0.8007061298076923
0.9639372878719882
              precision    recall  f1-score   support

        date       0.75      0.74      0.74       206
  first_name       0.93      0.88      0.90       102
   last_name       0.89      0.73      0.80        78
      people       0.77      0.71      0.74       238
        time       0.84      0.90      0.87       218

   micro avg       0.81      0.79      0.80       842
   macro avg       0.84      0.79      0.81       842
weighted avg       0.81      0.79      0.80       842

