In [1]:
import json
import pickle
import csv
from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import Dict

import torch
from torch.utils.data import DataLoader

from intent_dataset import SeqClsDataset
from intent_model import SeqClassifier
from utils import Vocab

device = torch.device("cuda:3") if torch.cuda.is_available() else torch.device("cpu")

In [2]:
test_file = "./data/intent/test.json"
cache_dir = "./cache/intent/"
ckpt_dir = "/data/NFS/andy/course/ADL/hw1/intent_weights2.pt"
pred_file = "/data/NFS/andy/course/ADL/hw1/pred.intent3.csv"
max_len = 28

In [3]:
with open(cache_dir + "vocab.pkl", "rb") as f:
        vocab: Vocab = pickle.load(f)

intent_idx_path = Path(cache_dir + "intent2idx.json")
intent2idx: Dict[str, int] = json.loads(intent_idx_path.read_text())

data = json.loads(Path(test_file).read_text())
dataset = SeqClsDataset(data, vocab, intent2idx, max_len)

batch_size = 256
test_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=dataset.collate_fn_test)

In [4]:
embeddings = torch.load(cache_dir + "embeddings.pt")

In [5]:
model = SeqClassifier(embeddings=embeddings, hidden_size=256, num_layers=2, dropout=0.3, bidirectional=True, num_class=150)
model.to(device)
model.load_state_dict(torch.load(ckpt_dir))

<All keys matched successfully>

In [6]:
all_pred = []
model.eval()
with torch.no_grad():
    for i, batch in enumerate(test_loader):
        data = batch.to(device)
        pred = model(data)
        pclass = pred.argmax(dim=1)
        for c in pclass:
            all_pred.append(c.item())

In [7]:
f = open(pred_file, "w")
writer = csv.writer(f)
writer.writerow(["id", "intent"])
for i in range(len(all_pred)):
    intent = dataset.idx2label(all_pred[i])
    writer.writerow(["test-%d" %i, intent])
f.close()