# Comparison of accuracy results across datasets and models

In [None]:
import glob
import pickle
import pandas as pd
from plotnine import *

List test files

In [None]:
test_files = glob.glob('output/*/test_results.pickle')
print("\n".join(test_files))

Initiate empty dict to store results

In [None]:
results = {
    'dataset': [],
    'weights': [],
    'model': [],
    'accuracy': [],
    'balanced_accuracy': [],
    'living_precision': [],
    'living_recall': [],
}

Loop over test files and extract accuracy values for each dataset and model

In [None]:
for test_file in test_files:
    condi = test_file.split('/')[1]
    model = condi.split('_')[0].upper()
    weights = condi.split('_')[1].upper()
    dataset = condi.split('_')[2]
    
    with open(test_file,'rb') as file:
        test_results = pickle.load(file)        
        accuracy = test_results.get('accuracy')
        balanced_accuracy = test_results.get('balanced_accuracy')
        living_precision = test_results.get('living_precision')
        living_recall = test_results.get('living_recall')
    
    results['dataset'].append(dataset)
    results['weights'].append(weights)
    results['model'].append(model)
    results['accuracy'].append(accuracy)
    results['balanced_accuracy'].append(balanced_accuracy)
    results['living_precision'].append(living_precision)
    results['living_recall'].append(living_recall)

Convert to dataframe and make columns categorical

In [None]:
df_results = pd.DataFrame(results)
df_results = df_results.assign(model = lambda df: df.model + '_' + df.weights)
df_results = df_results.drop('weights', axis=1)

## Accuracy

In [None]:
df_results.pivot(index='model', columns='dataset', values='accuracy')

Plot results

In [None]:
(ggplot(df_results, aes(x='dataset', y='accuracy', fill='model')) + 
    geom_col(stat='identity', position='dodge') +
    labs(fill='Model', title = 'Accuracy value per dataset per model',
        x='Dataset', y='Accuracy')+
    theme_classic() +
    scale_fill_manual(values={
        'CNN_NW': 'lightgray', 
        'CNN_W': 'darkgray', 
        'RF_NW': 'dimgray',
        'RF_W': 'black',
    }) +
    ylim(0, 1) +
    theme(axis_text_x=element_text(size=10), axis_text_y=element_text(size=10)))

## Balanced accuracy

In [None]:
df_results.pivot(index='model', columns='dataset', values='balanced_accuracy')

In [None]:
(ggplot(df_results, aes(x='dataset', y='balanced_accuracy', fill='model')) + 
    geom_col(stat='identity', position='dodge') +
    labs(fill='Model', title = 'Balanced accuracy value per dataset per model',
        x='Dataset', y='Balanced accuracy')+
    theme_classic() +
    scale_fill_manual(values={
        'CNN_NW': 'lightgray', 
        'CNN_W': 'darkgray', 
        'RF_NW': 'dimgray',
        'RF_W': 'black',
    }) +
    ylim(0, 1) +
    theme(axis_text_x=element_text(size=10), axis_text_y=element_text(size=10)))

## Living precision

In [None]:
df_results.pivot(index='model', columns='dataset', values='living_precision')

In [None]:
(ggplot(df_results, aes(x='dataset', y='living_precision', fill='model')) + 
    geom_col(stat='identity', position='dodge') +
    labs(fill='Model', title = 'Living precision value per dataset per model',
        x='Dataset', y='Living precision')+
    theme_classic() +
    scale_fill_manual(values={
        'CNN_NW': 'lightgray', 
        'CNN_W': 'darkgray', 
        'RF_NW': 'dimgray',
        'RF_W': 'black',
    }) +
    ylim(0, 1) +
    theme(axis_text_x=element_text(size=10), axis_text_y=element_text(size=10)))

## Living recall

In [None]:
df_results.pivot(index='model', columns='dataset', values='living_recall')

In [None]:
(ggplot(df_results, aes(x='dataset', y='living_recall', fill='model')) + 
    geom_col(stat='identity', position='dodge') +
    labs(fill='Model', title = 'Living recall value per dataset per model',
        x='Dataset', y='Living recall')+
    theme_classic() +
    scale_fill_manual(values={
        'CNN_NW': 'lightgray', 
        'CNN_W': 'darkgray', 
        'RF_NW': 'dimgray',
        'RF_W': 'black',
    }) +
    ylim(0, 1) +
    theme(axis_text_x=element_text(size=10), axis_text_y=element_text(size=10)))