In [1]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import os
import pandas as pd
import pickle
import re
import seaborn as sns

plt.style.use('./stylefiles/plot.mplstyle')
mpl.rcParams['xtick.labelsize'] = 10

In [2]:
PATH = './results/base/'

In [3]:
# List the files
results_folders = os.listdir(PATH)

# Load the data
list_results = []
for folder in results_folders:
    if folder[0] == '.':
        continue
    folder_split = re.split("_", folder)
    N = int(folder_split[0][1:])
    M = folder_split[1][1:]
    
    files = os.listdir(PATH + folder)
    list_MRSE = len(files) * [None]
    for idx, file in enumerate(files):
        with open(PATH + folder + "/" + file, "rb") as f:
            data = pickle.load(f)
        errors = data['errors']
        MRSE = {
            'MRSEcov': errors['MRSE_cov'],
            'MRSEpsplines': errors['MRSE_psplines'],
            'MRSEgram': errors['MRSE_gram']
        }
        MRSE = pd.DataFrame.from_records([MRSE])
        MRSE.insert(0, "N", N)
        MRSE.insert(0, "M", M)
        list_MRSE[idx] = MRSE
    list_results.append(pd.concat(list_MRSE))
results = pd.concat(list_results)

In [4]:
SORT_VALUES = ['11-11-21', '26-26-51', '101-51-201']

results_pp = pd.melt(
    results,
    id_vars=['M', 'N'],
    value_vars=['MRSEcov', 'MRSEpsplines', 'MRSEgram']
)
results_pp.M = results_pp.M.astype("category")
results_pp.M = results_pp.M.cat.set_categories(SORT_VALUES)
results_pp = results_pp.sort_values(by=['N', 'M'])

In [7]:
gg = sns.catplot(
    data=results_pp,
    x="value", y="variable",
    col="N", row="M",
    kind="box",
    flierprops=dict(marker="+", markerfacecolor="gray", markersize=1),
    fill=False,
    color="#111111",
    height=2,
    aspect=2
)
gg.set_titles(template="$N = {col_name} \:|\: M = {row_name}$", size=12)
gg.set(xlim=(1e-3, 1))
gg.set_xlabels("MRSE (log scale)", fontsize=10)
gg.set_ylabels("")
gg.set_yticklabels(["(Tensor) PCA", "2D/1D B-Splines", "Gram"], size=10)
for ax in gg.axes.flat:
    #ax.axvline(x=1, color='r', lw=1, ls='--')
    ax.set_xscale("log")
gg.fig.tight_layout()

plt.savefig(
    f'MRSE.eps',
    format='eps',
    bbox_inches='tight'
)
plt.close()