In [1]:
import numpy as np
import torch
import torch.nn as nn
import pickle
import pandas as pd

from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split 

from utils.dataset import MatricesDataset
from utils.models import CNN, DeepCNN
from utils.train_val_test import test

In [16]:
config = {
    'dataset':{
        'class_label': 'original',
        'capacity': 'full',
        'test_size': 0.4,
        'test_batch': 2048,
    },
    'model':{
        'device': 'cuda'
    }
}

model_cnn = CNN().to(config['model']['device'])
model_deepcnn = DeepCNN().to(config['model']['device'])
#model_mlp = MLP(model_config).to(config['model']['device'])

model_list = [model_cnn, model_deepcnn]
targets_list = ['h11', 'h21', 'h31', 'h22']
metrics_name_list = ['loss','mse','accuracy','balanced_accuracy','f1','precision','recall']

In [3]:
with open('data/padded_matrices', 'rb') as f:
    matrices = pickle.load(f)

df = pd.read_csv('data/cicy4folds_extended.csv')
df = df.replace('Null', np.nan)
nan_indeces = df.index[df.isna().any(axis=1)].tolist()

df_clear = df.drop(nan_indeces)
matrices_clear = np.delete(matrices, nan_indeces, axis=0)

assert len(df_clear)==len(matrices_clear)

In [4]:
_, df_test, _, matrices_test = train_test_split(df_clear, matrices_clear, test_size=config['dataset']['test_size'], shuffle=True)

In [6]:
criterion = nn.MSELoss()

results_table = torch.zeros((len(model_list), len(targets_list), len(metrics_name_list)))

for i, model in enumerate(model_list):
    for j, target_name in enumerate(targets_list):
        model_name = model.__class__.__name__
        print(f'Model {model_name} on {target_name} hodge number:' )
        test_ds = MatricesDataset(df_test, matrices_test, target_name)
        test_dataloader = DataLoader(test_ds, batch_size=config['dataset']['test_batch'], num_workers=4)

        model.load_state_dict(torch.load(f"models/{model_name}_for_{target_name}.pth"))

        test_metrics = test(model, criterion, test_dataloader, config)

        results_table[i,j,:] = torch.tensor(list(test_metrics.values()))

Model CNN on h11 hodge number:


100%|██████████| 177/177 [00:22<00:00,  7.98it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Model CNN on h21 hodge number:


100%|██████████| 177/177 [00:20<00:00,  8.52it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Model CNN on h31 hodge number:


100%|██████████| 177/177 [00:20<00:00,  8.54it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Model CNN on h22 hodge number:


100%|██████████| 177/177 [00:20<00:00,  8.45it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Model DeepCNN on h11 hodge number:


100%|██████████| 177/177 [00:25<00:00,  7.07it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Model DeepCNN on h21 hodge number:


100%|██████████| 177/177 [00:25<00:00,  6.93it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Model DeepCNN on h31 hodge number:


100%|██████████| 177/177 [00:25<00:00,  6.94it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Model DeepCNN on h22 hodge number:


100%|██████████| 177/177 [00:25<00:00,  6.88it/s]
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [43]:
metric_name = 'mse' #'loss','mse','accuracy','balanced_accuracy','f1','precision','recall'

pd.DataFrame(results_table[:,:,metrics_name_list.index(metric_name)], columns=targets_list, index=[model.__class__.__name__ for model in model_list]).round(2)

Unnamed: 0,h11,h21,h31,h22
CNN,0.22,1.11,12.69,401.269989
DeepCNN,0.04,0.33,3.2,72.260002
