### Imports

In [None]:
# loading
import os
from glob import glob

# data
import math
import numpy as np
import pandas as pd

# visual
import matplotlib as mpl
import matplotlib.pyplot as plt


# metrics
from sklearn.metrics import classification_report, confusion_matrix

print('All packages imported!')

### Helpers

In [None]:
mpl.rcParams['figure.dpi'] = 500
mpl.rcParams['font.size'] = 9

# Latex document Text width
latex_width = 469.75502

def set_size(width=latex_width, height=latex_width, fraction=1, subplots=(1, 1)):
    """Set figure dimensions to avoid scaling in LaTeX.
    
    Credit to Jack Walton for the function.
    Source: https://jwalton.info/Embed-Publication-Matplotlib-Latex/
    """

    fig_width_pt = width * fraction
    fig_height_pt = height * fraction
    
    inches_per_pt = 1 / 72.27
    
    fig_width_in = fig_width_pt * inches_per_pt
    fig_height_in = fig_height_pt * inches_per_pt * (subplots[0] / subplots[1])

    return (fig_width_in, fig_height_in)

In [None]:
def load_task_df(task_number):
    if task_number not in range(1, 5):
        raise ValueError("Invalid task number. Please provide a task number between 1 and 4.")

    task_fps = glob(f'ckpt/test/task_{task_number}_*/result.csv')
    dfs = []

    for fp in task_fps:
        subject = fp.split('/')[-2].split('_')[-1]
        df = pd.read_csv(fp)
        df['subject'] = subject
        dfs.append(df)

    return pd.concat(dfs, ignore_index=True)


def loo_accuracy(df):
    acc_df = df.groupby(['subject'])[['pred', 'true']].apply(lambda x: (x['pred'] == x['true']).mean()).reset_index()
    acc_df.columns = ['subject', 'accuracy']
    return acc_df


def filter_outliers(df):
    acc_df = loo_accuracy(df)
    Q1 = acc_df['accuracy'].quantile(0.25)
    Q3 = acc_df['accuracy'].quantile(0.75)
    IQR = Q3 - Q1
    lower_bound = Q1 - 1.5 * IQR

    filtered_subjects = acc_df[(acc_df['accuracy'] >= lower_bound)]['subject']
    removed_subjects = acc_df[~acc_df['subject'].isin(filtered_subjects)]['subject']

    if len(removed_subjects) > 0:
        print(f"Removed subjects: {', '.join(removed_subjects)}")
    else:
        print("No outliers removed.")
    print()

    return df[df['subject'].isin(filtered_subjects)]

In [None]:
def plot_acc_hist(df, lower_bound=0.50, save_fig=False, filename=""):
    if save_fig and len(filename) == 0:
        raise ValueError("Please provide a filename to save the figure.")
    
    acc_df = loo_accuracy(df)

    fig, ax = plt.subplots(figsize=(set_size(height=0.5*latex_width)))

    bins = np.arange(lower_bound, 1.05, 0.025)
    proportions, bins, patches = ax.hist(acc_df['accuracy'], bins=bins, weights=np.ones(len(acc_df)) / len(acc_df), edgecolor='black')
    ax.set_xlim([lower_bound, 1])
    '''
    if max(proportions) % 2:
        ax.set_ylim([0, max(proportions) + 0.02])
    else:
        ax.set_ylim([0, max(proportions) + 0.01])
    '''
    ax.set_ylim([0, 0.3])

    ax.set_xlabel('Accuracy')
    ax.set_ylabel('Proportion')


    if save_fig:
        plt.savefig(filename, bbox_inches='tight')
    plt.show()


def plot_classification_report(df, save_fig=False, filename=""):
    if save_fig and len(filename) == 0:
        raise ValueError("Please provide a filename to save the figure.")

    cm = confusion_matrix(df['true'], df['pred'])
    cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    fig, ax = plt.subplots(figsize=(set_size()))

    cax = plt.pcolormesh(cm, cmap='viridis', vmin=0, vmax=1, edgecolors='k', linewidth=0.5)
    ax = plt.gca()
    ax.set_aspect('equal')
    ax.invert_yaxis()

    cbar = plt.colorbar(fraction=0.046, pad=0.04, shrink=0.5)
    cbar.set_label('Proportion', rotation=270)
    cbar.set_ticks(np.arange(0, 1.1, 1))
    

    # Add labels and title
    plt.xlabel('Predicted labels')
    plt.ylabel('True labels')

    # Add accuracy values to each entry
    for i in range(len(cm)):
        for j in range(len(cm)):
            if cm[i, j] >= 0.01:
                plt.text(j + 0.5, i + 0.5, round(cm[i, j], 2),
                         ha='center', va='center', color='black' if cm[i, j] > 0.5 else 'white', fontsize=7)

    # Add xticks for every class
    plt.xticks(np.arange(0.5, len(cm), 1), range(len(cm)))
    plt.yticks(np.arange(0.5, len(cm), 1), range(len(cm)))

    if save_fig:
        plt.savefig(filename, bbox_inches='tight')

    # Show the plot
    plt.show()

### Results

In [None]:
save_fig = False

for i in range(1, 5):
    print("##########")
    print(f"# TASK {i} #")
    print("##########")

    task_df = load_task_df(i)
    acc_df = loo_accuracy(task_df)
    task_clean_df = filter_outliers(task_df)

    print(acc_df.describe())
    print(classification_report(task_df['true'], task_df['pred']))

    lower_bound = math.floor(acc_df.accuracy.min() * 10.0) / 10.0
    plot_acc_hist(task_df, lower_bound=lower_bound, save_fig=save_fig, filename=f'task_{i}_acc_hist.pdf')
    plot_classification_report(task_df, save_fig=save_fig, filename=f'task_{i}_class_report.pdf')