In [20]:
from models.cnn_lstm import CNN_LSTM
from models.cnn import CNN
from models.lstm import LSTM
from models.mlp import MLP
from utils import load_model, DATA_DIR, BATCH_SIZE, NUM_WORKER
from loader.fi_loader import FIDataset

import torch
from torch.utils.data import DataLoader

import os
from itertools import product
from tqdm import tqdm

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Generate Testing Statistics

In [21]:
testing_results = {
    'model_type': [],
    'train_data_size': [],
    'prediction_horizon': [],
    'accuracy': [],
    'precision': [],
    'recall': [],
    'f1': []
}

In [22]:
testing_params = product(zip(['CNN_LSTM', 'CNN', 'LSTM', 'MLP'], [CNN_LSTM, CNN, LSTM, MLP]), [1, 3, 5, 8], [0, 2, 4])

for (model_name, model_type), cf, k in tqdm(testing_params):
    model_path = os.path.join('.', 'trained_models', model_name)
    trained_model_path = os.path.join(model_path, f'{model_name}_Zscore_CF{cf}_pred_{k}.pth')
    print(trained_model_path)

    trained_model = load_model(model_type, trained_model_path)

    test_data = FIDataset(DATA_DIR, 'Zscore', cf, k=k, train=False)
    test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)

    trained_model.eval()
    trained_model.to(trained_model.device)

    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(trained_model.device, dtype=torch.float32), targets.to(trained_model.device, dtype=torch.int64)
            outputs = trained_model(inputs)
            _, predictions = torch.max(outputs, 1)

            all_predictions.extend(predictions.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())

    accuracy = accuracy_score(all_targets, all_predictions)
    precision = precision_score(all_targets, all_predictions, average='weighted', zero_division=0)
    recall = recall_score(all_targets, all_predictions, average='weighted', zero_division=0)
    f1 = f1_score(all_targets, all_predictions, average='weighted', zero_division=0)

    testing_results['model_type'].append(model_name)
    testing_results['train_data_size'].append(cf)
    testing_results['prediction_horizon'].append(k)
    testing_results['accuracy'].append(accuracy)
    testing_results['precision'].append(precision)
    testing_results['recall'].append(recall)
    testing_results['f1'].append(f1)

0it [00:00, ?it/s]

.\trained_models\CNN_LSTM\CNN_LSTM_Zscore_CF1_pred_0.pth


1it [00:10, 10.76s/it]

.\trained_models\CNN_LSTM\CNN_LSTM_Zscore_CF1_pred_2.pth


2it [00:19,  9.51s/it]

.\trained_models\CNN_LSTM\CNN_LSTM_Zscore_CF1_pred_4.pth


3it [00:27,  9.03s/it]

.\trained_models\CNN_LSTM\CNN_LSTM_Zscore_CF3_pred_0.pth


4it [00:35,  8.66s/it]

.\trained_models\CNN_LSTM\CNN_LSTM_Zscore_CF3_pred_2.pth


5it [00:43,  8.44s/it]

.\trained_models\CNN_LSTM\CNN_LSTM_Zscore_CF3_pred_4.pth


6it [00:52,  8.35s/it]

.\trained_models\CNN_LSTM\CNN_LSTM_Zscore_CF5_pred_0.pth


7it [01:00,  8.29s/it]

.\trained_models\CNN_LSTM\CNN_LSTM_Zscore_CF5_pred_2.pth


8it [01:09,  8.42s/it]

.\trained_models\CNN_LSTM\CNN_LSTM_Zscore_CF5_pred_4.pth


9it [01:17,  8.54s/it]

.\trained_models\CNN_LSTM\CNN_LSTM_Zscore_CF8_pred_0.pth


10it [01:28,  9.14s/it]

.\trained_models\CNN_LSTM\CNN_LSTM_Zscore_CF8_pred_2.pth


11it [01:38,  9.41s/it]

.\trained_models\CNN_LSTM\CNN_LSTM_Zscore_CF8_pred_4.pth


12it [01:48,  9.60s/it]

.\trained_models\CNN\CNN_Zscore_CF1_pred_0.pth


13it [01:55,  8.82s/it]

.\trained_models\CNN\CNN_Zscore_CF1_pred_2.pth


14it [02:02,  8.16s/it]

.\trained_models\CNN\CNN_Zscore_CF1_pred_4.pth


15it [02:08,  7.70s/it]

.\trained_models\CNN\CNN_Zscore_CF3_pred_0.pth


16it [02:15,  7.38s/it]

.\trained_models\CNN\CNN_Zscore_CF3_pred_2.pth


17it [02:21,  7.13s/it]

.\trained_models\CNN\CNN_Zscore_CF3_pred_4.pth


18it [02:28,  6.92s/it]

.\trained_models\CNN\CNN_Zscore_CF5_pred_0.pth


19it [02:34,  6.84s/it]

.\trained_models\CNN\CNN_Zscore_CF5_pred_2.pth


20it [02:41,  6.78s/it]

.\trained_models\CNN\CNN_Zscore_CF5_pred_4.pth


21it [02:48,  6.74s/it]

.\trained_models\CNN\CNN_Zscore_CF8_pred_0.pth


22it [02:56,  7.07s/it]

.\trained_models\CNN\CNN_Zscore_CF8_pred_2.pth


23it [03:03,  7.31s/it]

.\trained_models\CNN\CNN_Zscore_CF8_pred_4.pth


24it [03:11,  7.49s/it]

.\trained_models\LSTM\LSTM_Zscore_CF1_pred_0.pth


25it [03:18,  7.38s/it]

.\trained_models\LSTM\LSTM_Zscore_CF1_pred_2.pth


26it [03:25,  7.22s/it]

.\trained_models\LSTM\LSTM_Zscore_CF1_pred_4.pth


27it [03:32,  7.08s/it]

.\trained_models\LSTM\LSTM_Zscore_CF3_pred_0.pth


28it [03:39,  6.97s/it]

.\trained_models\LSTM\LSTM_Zscore_CF3_pred_2.pth


29it [03:46,  6.89s/it]

.\trained_models\LSTM\LSTM_Zscore_CF3_pred_4.pth


30it [03:52,  6.84s/it]

.\trained_models\LSTM\LSTM_Zscore_CF5_pred_0.pth


31it [04:00,  7.17s/it]

.\trained_models\LSTM\LSTM_Zscore_CF5_pred_2.pth


32it [04:08,  7.29s/it]

.\trained_models\LSTM\LSTM_Zscore_CF5_pred_4.pth


33it [04:15,  7.38s/it]

.\trained_models\LSTM\LSTM_Zscore_CF8_pred_0.pth


34it [04:24,  7.65s/it]

.\trained_models\LSTM\LSTM_Zscore_CF8_pred_2.pth


35it [04:32,  7.89s/it]

.\trained_models\LSTM\LSTM_Zscore_CF8_pred_4.pth


36it [04:40,  8.03s/it]

.\trained_models\MLP\MLP_Zscore_CF1_pred_0.pth


37it [04:47,  7.73s/it]

.\trained_models\MLP\MLP_Zscore_CF1_pred_2.pth


38it [04:55,  7.69s/it]

.\trained_models\MLP\MLP_Zscore_CF1_pred_4.pth


39it [05:02,  7.45s/it]

.\trained_models\MLP\MLP_Zscore_CF3_pred_0.pth


40it [05:09,  7.23s/it]

.\trained_models\MLP\MLP_Zscore_CF3_pred_2.pth


41it [05:15,  7.05s/it]

.\trained_models\MLP\MLP_Zscore_CF3_pred_4.pth


42it [05:22,  6.92s/it]

.\trained_models\MLP\MLP_Zscore_CF5_pred_0.pth


43it [05:29,  6.87s/it]

.\trained_models\MLP\MLP_Zscore_CF5_pred_2.pth


44it [05:35,  6.87s/it]

.\trained_models\MLP\MLP_Zscore_CF5_pred_4.pth


45it [05:42,  6.87s/it]

.\trained_models\MLP\MLP_Zscore_CF8_pred_0.pth


46it [05:50,  7.21s/it]

.\trained_models\MLP\MLP_Zscore_CF8_pred_2.pth


47it [05:59,  7.55s/it]

.\trained_models\MLP\MLP_Zscore_CF8_pred_4.pth


48it [06:07,  7.65s/it]


# Save Testing Stats

In [23]:
import pandas as pd
import pickle

testing_results = pd.DataFrame(testing_results)
with open('testing_results.pkl', 'wb') as f:
    pickle.dump(testing_results, f)

In [24]:
testing_results

Unnamed: 0,model_type,train_data_size,prediction_horizon,accuracy,precision,recall,f1
0,CNN_LSTM,1,0,0.566191,0.458958,0.566191,0.468102
1,CNN_LSTM,1,2,0.379915,0.380346,0.379915,0.379621
2,CNN_LSTM,1,4,0.523683,0.516821,0.523683,0.509413
3,CNN_LSTM,3,0,0.59682,0.495624,0.59682,0.473716
4,CNN_LSTM,3,2,0.605758,0.610327,0.605758,0.585969
5,CNN_LSTM,3,4,0.66125,0.661412,0.66125,0.661034
6,CNN_LSTM,5,0,0.612629,0.523285,0.612629,0.50456
7,CNN_LSTM,5,2,0.685146,0.690854,0.685146,0.675011
8,CNN_LSTM,5,4,0.726705,0.726812,0.726705,0.726628
9,CNN_LSTM,8,0,0.791139,0.789718,0.791139,0.757995
