In [1]:
import pickle
import torch
import os
import sys
import matplotlib.pyplot as plt

# return to main dir
current_directory = os.getcwd()
main_dir = os.path.dirname(current_directory+'/dataloader/')
sys.path.append(main_dir)

from model.baseline.gru import GRU
from model.baseline.bi_gru import biGRU
from model.baseline.lstm import LSTM
from model.baseline.bi_lstm import biLSTM
from model.baseline.rnn import RNN
from model.baseline.bi_rnn import biRNN
from model.transformer import Transformer
from model.dendinet import DendiNet

from utils.loss import rmse, mse, rmse_tsf
from utils.F_train_validate import fit, validate
from utils.F_train_validate_dendinet import fit as fit_DendiNet
from utils.F_train_validate_dendinet import validate as validate_DendiNet

from configs.LSTM_GRU_configs import Configs as Configs_LSTM_GRU
from configs.DendiNet_configs import Configs as Configs_DendiNet


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'avaiable device: {device}')

avaiable device: cuda


### upload dataloaders

In [2]:
with open('dataloader/save/LSTM_GRU/US06_test_-10_loader.pkl', 'rb') as file:
    US06_test_loader = pickle.load(file)
with open('dataloader/save/LSTM_GRU/HWFTa_test_-10_loader.pkl', 'rb') as file:
    HWFTa_test_loader = pickle.load(file)

print(f'Test dataloaders for GRU, LSTM, RNN are loaded!')

Test dataloaders for GRU, LSTM, RNN are loaded!


In [3]:
with open('dataloader/save/DendiNet/US06_test_-10_loader.pkl', 'rb') as file:
    US06_test_loader_dendinet = pickle.load(file)
with open('dataloader/save/DendiNet/HWFTa_test_-10_loader.pkl', 'rb') as file:
    HWFTa_test_loader_dendinet = pickle.load(file)
print(f'Test dataloaders for DendiNet are loaded!')

Test dataloaders for DendiNet are loaded!


### upload models

In [4]:
# load models
configs_LSTM_GRU = Configs_LSTM_GRU()

# RNN
GRU_model_save_dir = os.path.dirname(main_dir) + '/model_save/GRU.pth'
GRU_model = GRU(configs_LSTM_GRU).to(device)
GRU_model.load_state_dict(torch.load(GRU_model_save_dir))

# biRNN
biGRU_model_save_dir = os.path.dirname(main_dir) + '/model_save/biGRU.pth'
biGRU_model = biGRU(configs_LSTM_GRU).to(device)
biGRU_model.load_state_dict(torch.load(biGRU_model_save_dir))

# LSTM
LSTM_model_save_dir = os.path.dirname(main_dir) + '/model_save/LSTM.pth'
LSTM_model = LSTM(configs_LSTM_GRU).to(device)
LSTM_model.load_state_dict(torch.load(LSTM_model_save_dir))

# biLSTM
biLSTM_model_save_dir = os.path.dirname(main_dir) + '/model_save/biLSTM.pth'
biLSTM_model = biLSTM(configs_LSTM_GRU).to(device)
biLSTM_model.load_state_dict(torch.load(biLSTM_model_save_dir))

# RNN
RNN_model_save_dir = os.path.dirname(main_dir) + '/model_save/RNN.pth'
RNN_model = RNN(configs_LSTM_GRU).to(device)
RNN_model.load_state_dict(torch.load(RNN_model_save_dir))

# biRNN
biRNN_model_save_dir = os.path.dirname(main_dir) + '/model_save/biRNN.pth'
biRNN_model = biRNN(configs_LSTM_GRU).to(device)
biRNN_model.load_state_dict(torch.load(biRNN_model_save_dir))

<All keys matched successfully>

In [5]:
# transformer
configs_transformer = Configs_DendiNet()
transformer_model_save_dir = os.path.dirname(main_dir) + '/model_save/transformer.pth'
transformer_model = Transformer(configs_transformer).to(device)
transformer_model.load_state_dict(torch.load(transformer_model_save_dir))

<All keys matched successfully>

In [6]:
# DendiNet
configs_DendiNet = Configs_DendiNet()
DendiNet_model_save_dir = os.path.dirname(main_dir) + '/model_save/DendiNet.pth'
DendiNet_model = DendiNet(configs_DendiNet).to(device)
DendiNet_model.load_state_dict(torch.load(DendiNet_model_save_dir))



<All keys matched successfully>

## valiate on test data

### US06

In [7]:
# validation for GRU, LSTM, RNN
_, US06_target_log, GRU_US06_pred_log = validate(GRU_model, US06_test_loader, rmse, device=device)
_, _, biGRU_US06_pred_log = validate(biGRU_model, US06_test_loader, rmse, device=device)
_, _, LSTM_US06_pred_log = validate(LSTM_model, US06_test_loader, rmse, device=device)
_, _, biLSTM_US06_pred_log = validate(biLSTM_model, US06_test_loader, rmse, device=device)
_, _, RNN_US06_pred_log = validate(RNN_model, US06_test_loader, rmse, device=device)
_, _, biRNN_US06_pred_log = validate(biRNN_model, US06_test_loader, rmse, device=device)

In [8]:
# validation for transformer
_, trans_US06_target_log, trans_US06_pred_log = validate_DendiNet(transformer_model, US06_test_loader_dendinet, mse, device=device)

In [9]:
# validation for DendiNet
_, DendiNet_US06_target_log, DendiNet_US06_pred_log = validate_DendiNet(DendiNet_model, US06_test_loader_dendinet, mse, device=device)

In [10]:
# save it to .pt file
US06_test_log = {'target': US06_target_log,
                 'target4trans':trans_US06_target_log.reshape(-1),
                 'target4dendinet': DendiNet_US06_target_log.reshape(-1),
                 'GRU': GRU_US06_pred_log,
                 'biGRU': biGRU_US06_pred_log,
                 'LSTM': LSTM_US06_pred_log,
                 'biLSTM': biLSTM_US06_pred_log,
                 'RNN': RNN_US06_pred_log,
                 'biRNN': biRNN_US06_pred_log,
                 'transformer':trans_US06_pred_log.reshape(-1),
                 'DendiNet': DendiNet_US06_pred_log.reshape(-1)
                }

US06_filename = 'US06_-10.pt'
torch.save(US06_test_log, 'visualization/test_log_save/'+US06_filename)

### HWFTa

In [11]:
# validation for GRU, LSTM, RNN
_, HWFTa_target_log, GRU_HWFTa_pred_log = validate(GRU_model, HWFTa_test_loader, rmse, device=device)
_, _, biGRU_HWFTa_pred_log = validate(biGRU_model, HWFTa_test_loader, rmse, device=device)
_, _, LSTM_HWFTa_pred_log = validate(LSTM_model, HWFTa_test_loader, rmse, device=device)
_, _, biLSTM_HWFTa_pred_log = validate(biLSTM_model, HWFTa_test_loader, rmse, device=device)
_, _, RNN_HWFTa_pred_log = validate(RNN_model, HWFTa_test_loader, rmse, device=device)
_, _, biRNN_HWFTa_pred_log = validate(biRNN_model, HWFTa_test_loader, rmse, device=device)

In [12]:
# validation for transformer
_, trans_HWFTa_target_log, trans_HWFTa_pred_log = validate_DendiNet(transformer_model, HWFTa_test_loader_dendinet, mse, device=device)

In [13]:
# validation for DendiNet
_, DendiNet_HWFTa_target_log, DendiNet_HWFTa_pred_log = validate_DendiNet(DendiNet_model, HWFTa_test_loader_dendinet, mse, device=device)

In [14]:
# save it to .pt file
HWFTa_test_log = {'target': HWFTa_target_log,
                 'target4trans':trans_HWFTa_target_log.reshape(-1),
                 'target4dendinet': DendiNet_HWFTa_target_log.reshape(-1),
                 'GRU': GRU_HWFTa_pred_log,
                 'biGRU': biGRU_HWFTa_pred_log,
                 'LSTM': LSTM_HWFTa_pred_log,
                 'biLSTM': biLSTM_HWFTa_pred_log,
                 'RNN': RNN_HWFTa_pred_log,
                 'biRNN': biRNN_HWFTa_pred_log,
                 'transformer':trans_HWFTa_pred_log.reshape(-1),
                 'DendiNet': DendiNet_HWFTa_pred_log.reshape(-1)
                }

HWFTa_filename = 'HWFTa_-10.pt'
torch.save(HWFTa_test_log, 'visualization/test_log_save/'+HWFTa_filename)