In [1]:
import json
import pickle
import csv
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Dict

import torch

from slot_dataset import SeqClsDataset
from slot_model import SeqClassifier
from utils import Vocab
from torch.utils.data import DataLoader

device = torch.device("cuda:2") if torch.cuda.is_available() else torch.device("cpu")

In [2]:
test_file = "./data/slot/test.json"
cache_dir = "./cache/slot/"
ckpt_dir = "/data/NFS/andy/course/ADL/hw1/slot_weights2.pt"
pred_file = "/data/NFS/andy/course/ADL/hw1/pred.slot2.csv"
max_len = 35

In [3]:
with open(cache_dir + "vocab.pkl", "rb") as f:
        vocab: Vocab = pickle.load(f)

tag_idx_path = Path(cache_dir + "tag2idx.json")
tag2idx: Dict[str, int] = json.loads(tag_idx_path.read_text())

data = json.loads(Path(test_file).read_text())
dataset = SeqClsDataset(data, vocab, tag2idx, max_len)

batch_size = 128
test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=dataset.collate_fn_test)

In [4]:
# max_len = 0
# for i in range(len(dataset.data)):
#     sentence = dataset.data[i]["tokens"]
#     if len(sentence) > max_len:
#         max_len = len(sentence)
# print(max_len)

In [5]:
embeddings = torch.load(cache_dir + "embeddings.pt")

In [6]:
model = SeqClassifier(embeddings=embeddings, hidden_size=256, num_layers=2, dropout=0.2, bidirectional=True, num_class=9)
model.to(device)
model.load_state_dict(torch.load(ckpt_dir))

<All keys matched successfully>

In [7]:
all_pred = []
model.eval()
with torch.no_grad():
    for i, batch in enumerate(test_loader):
        data = batch.to(device)
        pred = model(data)
        pclass = pred.argmax(dim=2)
        for p in pclass:
            all_pred.append(p.cpu().data.numpy())

In [8]:
f = open(pred_file, "w")
writer = csv.writer(f)
writer.writerow(["id", "tags"])
for i in range(len(all_pred)):
    length = len(dataset.data[i]["tokens"])
    tags = ""
    for j in range(length):
        tag = dataset.idx2label(all_pred[i][j])
        tags += tag
        if j != length-1:
            tags += " "
    writer.writerow(["test-%d" %i, tags])
f.close()