In [100]:
import pandas as pd
import numpy as np
import os
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.colors
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", [[1, 1, 1],[31/255,119/255,180/255]])
from sklearn.metrics import r2_score, mean_absolute_percentage_error, mean_absolute_error

In [101]:
def plot_error_per_cat(y_test, y_pred, show_strip=True, show_relative_error=False, 
                       log_scaled_y=False, **kwargs):
    """
    Plots error by using Boxplots. Predicitons where automatically binned on True values.
    
    Relative Error (%) | $\epsilon = \frac{\hat{y}-y}{y}*100$
    
    params:
    ----------
    y_true: np.ndarray
        True values    
        
    y_pred: np.ndarray
        predicted values
        
    show_strip: Bool
        show stripplot if true else not
    
    show_relative_error: Bool
        Shows error on a relative scale which makes it more comparable 
        to bigger sized grains
    
    log_scaled_y: Bool
        Log Scales the y-axis to have a better scaling for lower sizes.
        
    returns:
    ----------
    plt.plot
    
    """
    fig = plt.subplots(figsize=(10, 5))
    sns.set_palette(kwargs.get('cmap') if kwargs.get('cmap') else 'Greys', 15)
    sns.set_style('whitegrid')
    if show_relative_error:
        p = sns.boxplot(y=((y_pred-y_test)/y_test)*100, x=y_test, color='#808080',
                        showfliers=kwargs.get('showfliers') if kwargs.get('showfliers') else False, fliersize=2)
    else:
        if log_scaled_y:
            p = sns.boxplot(y=np.log(y_pred), x=np.round(np.log(y_test), 2), color='#808080', 
                            showfliers=kwargs.get('showfliers') if kwargs.get('showfliers') else False, 
                            fliersize=2)
        else:
            p = sns.boxplot(y=y_pred, x=y_test, color='#808080',
                            showfliers=kwargs.get('showfliers') if kwargs.get('showfliers') else False, 
                            fliersize=2)  
    if show_strip:
        if show_relative_error:
            p2 = sns.stripplot(y=((y_pred-y_test)/y_test)*100, x=y_test, alpha=.5, color='grey', size=3)
        else:
            if log_scaled_y:
                p2 = sns.stripplot(y=np.log(y_pred), x=np.round(np.log(y_test), 2), alpha=.5, color='grey', size=3)
            else:
                p2 = sns.stripplot(y=y_pred, x=y_test, alpha=.5, color='grey', size=3)

    p.set_xlabel(r'Target $y$')
    if show_relative_error:
        p.set_ylabel('Relative Error (%)')
    else:
        if log_scaled_y:
            p.set_ylabel('log(Predicted)')
        else:
            p.set_ylabel(r'Prediction $\hat{y}$')
    plt.suptitle(kwargs.get('title') if kwargs.get('title') else 'True vs. Predicted per Category', fontsize=14)
    if show_relative_error:
        p.set_title(r'', fontsize=9)
    if not show_relative_error:
        plt.ylim(0)
    return p

def plot_resid_plot(df):
    cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", [[128/255,128/255,128/255],[31/255,119/255,180/255]], N=10)
    p = sns.jointplot(x=df['size_mm'], y=df['predictions'], kind='resid',
                      color=cmap(7), xlim=(-5, 180), ratio=6,
                      joint_kws=dict(color=cmap(0), lowess=True,
                                     line_kws=dict(linestyle='dashed'),
                                     scatter_kws=dict(marker='1'))
                     )
    p.set_axis_labels(r'$y$', r'$\epsilon$', fontsize=10)
    
    return p

def make_plots(simple_path, complex_path, system):
    df_complex = pd.read_csv(complex_path, sep= ' ')
    df_simple = pd.read_csv(simple_path, sep=' ')
    
    # Simple Figures
    p = plot_error_per_cat(y_test=df_simple['size_mm'], y_pred=df_simple['predictions'], show_strip=False, showfliers=True,
                            title=' ' )
    plt.savefig(f'./plots/{system}_simple_boxplot.png')
    plt.close()
    p = plot_error_per_cat(y_test=df_simple['size_mm'], y_pred=df_simple['predictions'], show_strip=False, showfliers=True,
                           show_relative_error=True, title=' ')
    plt.savefig(f'./plots/{system}_simple_boxplot_relative.png')
    plt.close()
    p = plot_resid_plot(df=df_simple)
    plt.savefig(f'./plots/{system}_simple_residplot.png')
    plt.close()
    
    # Complex Figures
    p = plot_error_per_cat(y_test=df_complex['size_mm'], y_pred=df_complex['predictions'], show_strip=False, showfliers=True,
                            title=' ' )
    plt.savefig(f'./plots/{system}_complex_boxplot.png')
    plt.close()
    p = plot_error_per_cat(y_test=df_complex['size_mm'], y_pred=df_complex['predictions'], show_strip=False, showfliers=True,
                           show_relative_error=True, title=' ')
    plt.savefig(f'./plots/{system}_complex_boxplot_relative.png')
    plt.close()
    p = plot_resid_plot(df=df_complex)
    plt.savefig(f'./plots/{system}_complex_residplot.png')
    plt.close()


In [102]:
create_plots = False

# MPA

In [103]:
if create_plots:
    make_plots(simple_path='./predictions/mpa_simple.csv', complex_path='./predictions/mpa_comnplex.csv', system='mpa')

# SPG

In [104]:
if create_plots:
    make_plots(simple_path='./predictions/spg_simple.csv', complex_path='./predictions/spg_complex.csv', system='spg')

# SPS

In [105]:
if create_plots:
    make_plots(simple_path='./predictions/sps_simple.csv', complex_path='./predictions/sps_complex.csv', system='sps')

## Metrics

In [112]:
def calculate_metrics(prediction_folder: str):
    prediction_files = os.listdir(prediction_folder)
    prediction_files = [file for file in prediction_files if '.csv' in file]
    results = {}
    for i, file in enumerate(prediction_files):
        measuring_system, tmp = file.split('_')
        model_complexity = tmp.split('.')[0]
        
        df = pd.read_csv(prediction_folder + file, sep=' ')
        y_test, y_pred = df['size_mm'], df['predictions']
        results[i] = dict(
            system=measuring_system,
            complexity=model_complexity,
            r2=round(r2_score(y_true=y_test, y_pred=y_pred), 3),
            mae=mean_absolute_error(y_true=y_test, y_pred=y_pred),
            mape=mean_absolute_percentage_error(y_true=y_test, y_pred=y_pred),
        )
    
    return pd.DataFrame(results).T

In [113]:
calculate_metrics(prediction_folder='./predictions/')

Unnamed: 0,system,complexity,r2,mae,mape
0,mpa,complex,0.824,9.700862,0.265074
1,mpa,simple,0.755,11.986124,0.297918
2,spg,complex,0.804,13.982657,0.226537
3,sps,complex,0.841,11.34907,0.272735
4,sps,simple,0.794,12.468267,0.256883
5,spg,simple,0.796,13.632525,0.216108
