In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
import torch
import torch.nn as nn
from utils.logger import PrettyLogger
from utils.io_func import save_to_csv, load_from_pkl, load_from_pth
from utils.helper import DCMHelper

In [None]:
logger = PrettyLogger()
helper = DCMHelper()

In [None]:
BASE_SITES = ["Site_A"]
TEST_SITE = "Site_B"
TEST_YEARS = [str(year) for year in [2018]]
DATA_DIR = "../preprocessing/out/{}".format(TEST_SITE)
X_PATH_TEMPLATE = os.path.join(DATA_DIR, "x-{year}.npy")
Y_PATH_TEMPLATE = os.path.join(DATA_DIR, "y-{year}.npy")
SCALER_PATH = "./out/training/{}/scaler.pkl".format("_".join(BASE_SITES))
MODEL_PATH = "./out/training/{}/atbilstm.pth".format("_".join(BASE_SITES))
RESULT_DIR = "./out/spatial_tran/{}/{}".format("_".join(BASE_SITES), TEST_SITE)
DEVICE = torch.device("cuda:0")

# Input

In [None]:
def get_paths(path_template, years):
    return [path_template.format(year=year) for year in years]


x_test = helper.input_x(get_paths(X_PATH_TEMPLATE, TEST_YEARS))
y_test = helper.input_y(get_paths(Y_PATH_TEMPLATE, TEST_YEARS))

# Normalization

In [None]:
scaler = load_from_pkl(SCALER_PATH)
x_test = helper.normalize_with_scaler(scaler, x_test)

# Prediction

In [None]:
test_dataloader = helper.make_data_loader(x_test, y_test, shuffle=False)

net = helper.build_model()
net.load_state_dict(load_from_pth(MODEL_PATH))
net = nn.DataParallel(net, device_ids=[0, 1, 2, 3])
net.to(DEVICE)

y_test_soft_pred, y_test_hard_pred, attn_test = helper.predict(
    net, test_dataloader, DEVICE
)
acc_test = accuracy_score(y_test, y_test_hard_pred)
logger.info(TEST_SITE, "test acc:", acc_test)

# Saving all

In [None]:
save_to_csv(
    y_test_soft_pred, os.path.join(RESULT_DIR, "y_test_soft_pred.csv")
)
save_to_csv(
    y_test_hard_pred, os.path.join(RESULT_DIR, "y_test_hard_pred.csv")
)
save_to_csv(
    np.array([[acc_test]]),
    os.path.join(RESULT_DIR, "perf_abstract.csv"),
    header=["acc_test"]
)
save_to_csv(
    helper.test_time_list,
    os.path.join(RESULT_DIR, "test_time.csv"),
    header=["test_start_time", "test_end_time", "duration"]
)