In [4]:
from models.cnn_lstm import CNN_LSTM
from models.cnn import CNN
from models.lstm import LSTM
from models.mlp import MLP
from utils import load_model, DATA_DIR, BATCH_SIZE, NUM_WORKER
from loader.fi_loader import FIDataset

import torch
from torch.utils.data import DataLoader

import os
from itertools import product

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

# Generate Testing Statistics

In [5]:
testing_results = {
    'model_type': [],
    'train_data_size': [],
    'prediction_horizon': [],
    'accuracy': [],
    'precision': [],
    'recall': [],
    'f1': []
}

In [6]:
for model_name, model_type in zip(['CNN_LSTM', 'CNN', 'LSTM', 'MLP'], [CNN_LSTM, CNN, LSTM, MLP]):
    model_path = os.path.join('.', 'trained_models', model_name)

    for cf, k in product([1, 3, 5, 8], [0, 2, 4]):

        trained_model_path = os.path.join(model_path, f'CNN_LSTM_Zscore_CF{cf}_pred_{k}.pth')
        print(trained_model_path)

        trained_model = load_model(model_type, trained_model_path)

        test_data = FIDataset(DATA_DIR, 'Zscore', cf, k=k, train=False)
        test_loader = DataLoader(dataset=test_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKER)

        trained_model.eval()
        trained_model.to(trained_model.device)

        all_predictions = []
        all_targets = []

        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(trained_model.device, dtype=torch.float32), targets.to(trained_model.device, dtype=torch.int64)
                outputs = trained_model(inputs)
                _, predictions = torch.max(outputs, 1)

                all_predictions.extend(predictions.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

        accuracy = accuracy_score(all_targets, all_predictions)
        precision = precision_score(all_targets, all_predictions, average='weighted', zero_division=0)
        recall = recall_score(all_targets, all_predictions, average='weighted', zero_division=0)
        f1 = f1_score(all_targets, all_predictions, average='weighted', zero_division=0)

        testing_results['model_type'].append(model_name)
        testing_results['train_data_size'].append(cf)
        testing_results['prediction_horizon'].append(k)
        testing_results['accuracy'].append(accuracy)
        testing_results['precision'].append(precision)
        testing_results['recall'].append(recall)
        testing_results['f1'].append(f1)
        print(testing_results)

./trained_models/CNN_LSTM/CNN_LSTM_Zscore_CF1_pred_0.pth
{'model_type': [<class 'models.cnn_lstm.CNN_LSTM'>], 'train_data_size': [1], 'prediction_horizon': [0], 'accuracy': [0.5661914460285132], 'precision': [0.5661914460285132], 'recall': [0.5661914460285132], 'f1': [0.468106838291607]}
./trained_models/CNN_LSTM/CNN_LSTM_Zscore_CF1_pred_2.pth
{'model_type': [<class 'models.cnn_lstm.CNN_LSTM'>, <class 'models.cnn_lstm.CNN_LSTM'>], 'train_data_size': [1, 1], 'prediction_horizon': [0, 2], 'accuracy': [0.5661914460285132, 0.3798892892579247], 'precision': [0.5661914460285132, 0.3798892892579247], 'recall': [0.5661914460285132, 0.3798892892579247], 'f1': [0.468106838291607, 0.3795995871420431]}
./trained_models/CNN_LSTM/CNN_LSTM_Zscore_CF1_pred_4.pth


Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x1076dc9a0>
Traceback (most recent call last):
  File "/Users/alexzhang/Desktop/Academic/4B/CS 480/.venv/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1618, in __del__
    self._shutdown_workers()
  File "/Users/alexzhang/Desktop/Academic/4B/CS 480/.venv/lib/python3.13/site-packages/torch/utils/data/dataloader.py", line 1582, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/multiprocessing/popen_fork.py", line 41, in wait
    if not wait([self.sentinel], timeout):
  File "/Library/Frameworks/Python.framework/Versions/3.13/lib/python3.13/multiprocessing/connection.py", line 1148, in wait
    ready = selector.select(timeout)
  File "/Library/F