In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import numpy as np
from sklearn.metrics import accuracy_score
import torch
import torch.nn as nn
from utils.logger import PrettyLogger
from utils.io_func import save_to_csv, save_to_pkl, save_to_pth
from config import SEED
from utils.helper import LSTMHelper

In [None]:
logger = PrettyLogger()
helper = LSTMHelper()

In [None]:
TIME_STEP_LIST = range(1, 23)
TRAIN_YEARS = [str(year) for year in [2015, 2016, 2017]]
TEST_YEARS = [str(year) for year in [2018]]
DATA_DIR = "../preprocessing/out/Site_1/"
X_PATH_TEMPLATE = os.path.join(DATA_DIR, "x-corn_soybean-{year}.npy")
Y_PATH_TEMPLATE = os.path.join(DATA_DIR, "y-corn_soybean-{year}.npy")
RESULT_DIR_TEMPLATE = (
    "./out/in_season/AtLSTM-corn_soybean/time_step_{time_step}/"
)
DEVICE = torch.device("cuda:0")

# Preparation

In [None]:
def get_paths(path_template, years):
    return [path_template.format(year=year) for year in years]

In [None]:
def save_all(
    result_dir,
    y_train_soft_pred, y_test_soft_pred, y_train_hard_pred, y_test_hard_pred,
    attn_train, attn_test,
    loss_train_list, loss_test_list, acc_train_list, acc_test_list,
    acc_train, acc_test,
    scaler, net
):
    save_to_csv(
        y_train_soft_pred, os.path.join(result_dir, "y_train_soft_pred.csv")
    )
    save_to_csv(
        y_test_soft_pred, os.path.join(result_dir, "y_test_soft_pred.csv")
    )
    save_to_csv(
        y_train_hard_pred, os.path.join(result_dir, "y_train_hard_pred.csv")
    )
    save_to_csv(
        y_test_hard_pred, os.path.join(result_dir, "y_test_hard_pred.csv")
    )
    save_to_csv(
        attn_train, os.path.join(result_dir, "attn_train.csv")
    )
    save_to_csv(
        attn_test, os.path.join(result_dir, "attn_test.csv")
    )
    save_to_csv(
        np.array([
            loss_train_list, loss_test_list, acc_train_list, acc_test_list
        ]).T,
        os.path.join(result_dir, "training_record.csv"),
        header=["training loss", "test loss", "training acc", "test acc"]
    )
    save_to_csv(
        np.array([[acc_train, acc_test]]),
        os.path.join(result_dir, "perf_abstract.csv"),
        header=["acc_train", "acc_test"]
    )
    save_to_pkl(scaler, os.path.join(result_dir, "scaler.pkl"))
    save_to_pth(net, os.path.join(result_dir, "atlstm.pth"))

# Training

In [None]:
for time_step in TIME_STEP_LIST:
    logger.info("Time step {} starts".format(time_step))
    
    # input
    x_train = helper.input_x(get_paths(X_PATH_TEMPLATE, TRAIN_YEARS))[
        :, :time_step, :
    ]
    y_train = helper.input_y(get_paths(Y_PATH_TEMPLATE, TRAIN_YEARS))
    x_test = helper.input_x(get_paths(X_PATH_TEMPLATE, TEST_YEARS))[
        :, :time_step, :
    ]
    y_test = helper.input_y(get_paths(Y_PATH_TEMPLATE, TEST_YEARS))
    
    # normalization
    scaler, x_train, x_test = helper.normalize_without_scaler(x_train, x_test)
    
    # training preparation
    train_dataloader = helper.make_data_loader(x_train, y_train, shuffle=True)
    test_dataloader = helper.make_data_loader(x_test, y_test, shuffle=False)

    net = helper.build_model()
    helper.init_parameters(net)
    net = nn.DataParallel(net, device_ids=[0, 1, 2, 3])
    net.to(DEVICE)
    
    # training
    loss_train_list, acc_train_list, attn_train_list = [], [], []
    loss_test_list, acc_test_list, attn_test_list = [], [], []
    helper.train_model(
        net, train_dataloader, test_dataloader, DEVICE, logger,
        loss_train_list, acc_train_list, attn_train_list,
        loss_test_list, acc_test_list, attn_test_list,
    )
    
    # prediction
    y_train_soft_pred, y_train_hard_pred, attn_train = helper.predict(
        net, helper.make_data_loader(x_train, y_train, shuffle=False), DEVICE
    )
    y_test_soft_pred, y_test_hard_pred, attn_test = helper.predict(
        net, test_dataloader, DEVICE
    )
    acc_train = accuracy_score(y_train, y_train_hard_pred)
    acc_test = accuracy_score(y_test, y_test_hard_pred)
    logger.info("train acc:", acc_train, "test acc:", acc_test)
    
    save_all(
        RESULT_DIR_TEMPLATE.format(time_step=time_step),
        y_train_soft_pred, y_test_soft_pred,
        y_train_hard_pred, y_test_hard_pred,
        attn_train, attn_test,
        loss_train_list, loss_test_list, acc_train_list, acc_test_list,
        acc_train, acc_test,
        scaler, net
    )
    
    logger.info("Time step {} ends".format(time_step))