In [15]:
import numpy as np
import pandas as pd
import os
from typing import List, Optional
import ast

from results_processing import get_groups, get_dataset, flatten_results

# 1. Data post-processing

In [2]:
lcn_data = pd.read_csv('LcnTotalResults.csv')
mlp_data = pd.read_csv('MlpPartialResults.csv')
multi_data = pd.read_csv('MultiModalResults.csv')

data_groups = get_groups()

In [3]:
regression_cat_groups = data_groups['opml_reg_numcat_group']
regression_num_groups = data_groups['opml_reg_purnum_group']
classification_cat_groups = data_groups['opml_class_numcat_group']
classification_num_groups = data_groups['opml_class_purnum_group']


regression_groups = regression_cat_groups + regression_num_groups
classification_groups = classification_cat_groups + classification_cat_groups

In [4]:
lcn_reg = get_dataset(lcn_data, regression_groups)
mlp_reg = get_dataset(mlp_data, regression_groups)
multi_reg = get_dataset(multi_data, regression_groups)

lcn_cls = get_dataset(lcn_data, classification_groups)
mlp_cls = get_dataset(mlp_data, classification_groups)
multi_cls = get_dataset(multi_data, classification_groups)

reg_datasets = pd.concat([lcn_reg, mlp_reg, multi_reg], ignore_index=True)
cls_datasets = pd.concat([lcn_cls, mlp_cls, multi_cls], ignore_index=True)

In [21]:
reg_results = flatten_results(reg_datasets, ['metrics'])
cls_results = flatten_results(cls_datasets, ['metrics'])

In [22]:
reg_results.columns

Index(['dataset', 'hyperparameters', 'model', 'train_loss', 'epoch',
       'val_loss', 'test_loss', 'epoch_time', 'train_metrics.RMSE',
       'train_metrics.r2_score', 'train_metrics.se_quant', 'val_metrics.RMSE',
       'val_metrics.r2_score', 'val_metrics.se_quant', 'test_metrics.RMSE',
       'test_metrics.r2_score', 'test_metrics.se_quant',
       'validate_metrics.r2_score', 'validate_metrics.RMSE',
       'validate_metrics.se_quant'],
      dtype='object')

In [23]:
cls_results.columns

Index(['dataset', 'hyperparameters', 'model', 'train_loss', 'val_loss',
       'test_loss', 'epoch', 'epoch_time', 'train_metrics.accuracy_score',
       'train_metrics.roc_auc_score', 'train_metrics.confusion_matrix',
       'val_metrics.accuracy_score', 'val_metrics.roc_auc_score',
       'val_metrics.confusion_matrix', 'test_metrics.accuracy_score',
       'test_metrics.roc_auc_score', 'test_metrics.confusion_matrix',
       'validate_metrics.accuracy_score', 'validate_metrics.confusion_matrix',
       'validate_metrics.roc_auc_score'],
      dtype='object')