In [1]:
import json
import torch
import pickle
import pandas as pd

from tqdm import tqdm
from utils import Vocab
from typing import Dict
from pathlib import Path
from model import SeqTagger
from torch.utils.data import DataLoader
from dataset import SeqTaggingClsDataset

In [8]:
# global variables
cache_dir = Path("./cache/slot/")
test_file = Path("./data/slot/eval.json")
ckpt_dir = Path("./ckpt/slot/")

max_len = 128
batch_size = 512
hidden_size = 512
num_layers = 2
dropout = 0.1
bidirectional = True
device = "cuda"
num_epoch = 50
lr = 1e-3

# loading vocab
with open(cache_dir / "vocab.pkl", "rb") as f:
    vocab: Vocab = pickle.load(f)
        
# load embeddings
embeddings = torch.load(cache_dir / "embeddings.pt")

In [6]:
tag_idx_path = cache_dir / "tag2idx.json"
tag2idx: Dict[str, int] = json.loads(tag_idx_path.read_text())

data = json.loads(test_file.read_text())

data_eval = [
    {
        "token": token,
        "id": item["id"]
    }
    for item in data
    for token in item["tokens"]
]

dataset = SeqTaggingClsDataset(data_eval, vocab, tag2idx, max_len)

# TODO: crecate DataLoader for test dataset
eval_data_loader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=dataset.collate_fn,
)

In [14]:
# loading model
test_model = SeqTagger(
    embeddings,
    hidden_size,
    num_layers,
    dropout,
    bidirectional,
    len(tag2idx),
)
test_model.load_state_dict(torch.load(ckpt_dir / "slot_model.pt"))
test_model.to(device)
test_model.eval()

result_index_list = list()
result_ids_list = list()
for batch in eval_data_loader:
    ids = batch["ids"].to(device)
    pred = test_model(ids).to(device)
    result_ids_list.extend(pred.argmax(dim=1).tolist())
    result_index_list.extend(batch["index"])

result_dict = dict()
for index, id in zip(result_index_list, result_ids_list):
    if index not in result_dict:
        result_dict[index] = []
    result_dict[index].append(dataset.idx2label(id))

In [15]:
result_dict

{'eval-0': ['O', 'O', 'O', 'O', 'O'],
 'eval-1': ['B-time', 'O'],
 'eval-2': ['O', 'O', 'B-people'],
 'eval-3': ['O', 'O', 'O', 'O', 'O', 'O'],
 'eval-4': ['O', 'O', 'B-date', 'O', 'I-date'],
 'eval-5': ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'],
 'eval-6': ['O', 'O', 'O', 'O', 'O', 'O', 'O'],
 'eval-7': ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'],
 'eval-8': ['O', 'O', 'O', 'O', 'B-people', 'O'],
 'eval-9': ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'],
 'eval-10': ['O', 'O', 'B-date'],
 'eval-11': ['O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'B-people',
  'O'],
 'eval-12': ['O', 'O', 'O', 'B-people', 'O'],
 'eval-13': ['O', 'O', 'O', 'O', 'O'],
 'eval-14': ['O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O'],
 'eval-15': ['O', 'B-date', 'O', 'O'],
 'eval-16': ['O', 'O', 'B-date', 'O'],
 'eval-17': ['O', 'O', 'B-people', 'O', 'O', 'O'],
 'eval-18': ['O', 