This script creates the labelwise f-scores for each task, as shown in the Appendix.

In [1]:
from matplotlib import pyplot as plt
import pandas as pd
from datetime import date
import seaborn as sns
import os
from eval_helper_functions import get_task_path_dict, get_pred_label_files, get_pred_label_files_MT


In [2]:
# input dir
# input dir
result_collection = os.path.join("..", "model_output")
print(os.path.exists(result_collection))

model_types = ["LOGR", "CNN", "HISAN", "BERT","MTCNN", "MTHISAN", "MTBERT"]

model_color = {"LOGR": "Greys", 
               "CNN" :"Oranges", "MTCNN_2": "YlOrBr", "MTCNN_3": "YlOrRd",
               "HISAN": "Greens", "MTHISAN_2": "YlGn", "MTHISAN_3":"BuGn",
               "BERT" : "Purples", "MTBERT_2": "RdPu", "MTBERT_3": "PuRd"}

tasks = ["mor", "sit2", "sit3", "his", "beh", "sit"]
mt_tasks = ["morsit", "behhissit"]
# note: KB-BERT was abbreviated to BERT in the filenames

True


In [3]:
output_dir = os.path.join("..", "plots", f"{date.today()}_f1_labelwise_plots_all_folds")
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

In [4]:
filename ="labelwise_f1"

for model in model_types:
    for cw_flag in [True, False]:
        if "MT" not in model:
            task_dict = get_task_path_dict(result_collection, model_type=model, tasks=tasks)
            pred_true_label_files = get_pred_label_files(task_path_dict=task_dict,
                                                    pred_true_labels_filename=f"{filename}.csv",
                                                    cw_flag=cw_flag)
        else:
            task_dict = get_task_path_dict(result_collection, model_type=model, tasks=mt_tasks)
            pred_true_label_files = get_pred_label_files_MT(task_path_dict=task_dict,
                                                            pred_true_labels_filename=f"{filename}_.csv",
                                                            cw_flag=cw_flag)
        
        print(f"Creating plots for model {model}, (cws {cw_flag}), tasks {list(pred_true_label_files.keys())}...")
        for task in tasks:
            
            if task not in pred_true_label_files.keys():
                continue

            # create dict with labelwise macro f1-scores (stored in dfs) for all folds
            dfs_x5 = dict()
            for fold in range(1, 6):
                conf_df = pd.read_csv(pred_true_label_files[task][str(fold)], index_col=0)
                dfs_x5[fold] = conf_df
            combined_list = []
            for key, df in dfs_x5.items():
                # add column to save fold -> needed in grouped bar plot
                df['Fold'] = key
                combined_list.append(df)

            # merge all folds into one df (Fold col allows tracing back the fold)
            combined_df = pd.concat(combined_list, ignore_index=True)


            ## create merged plots: 

            # set suffix for clean title of plots
            if cw_flag:
                cw_suffix = "(+cw)"
            else:
                cw_suffix = "(-cw)"

            if "MT" in model:
                if task in ["sit2", "mor"]:
                    model_name = f"{model}" + r"$_2$"
                else:
                    model_name = f"{model}" + r"$_3$"
            else:
                model_name = model


            
            if "sit" in task:
                plt.figure(figsize=(10, 3)) # slightly smaller to fit on one page in report
            else:
                plt.figure(figsize=(10, 4))

            # access correct color pallette for each model (specified in beginning of script)
            if "MT" in model:
                if task in ["sit3", "beh", "his"]:
                    palette_col = model + "_3"
                elif task in ["sit2", "mor"]:
                    palette_col = model + "_2"
                else:
                    raise ValueError
            else:
                palette_col = model
            palette = sns.color_palette(model_color[f"{palette_col}"], n_colors=len(dfs_x5))

            # barplot for one model, cw setting, task with 5 bars (=one for each fold)
            sns.barplot(x='labels', y='labelwise_f1_scores', hue='Fold', data=combined_df, legend=False, palette=palette)

            plt.xlabel('Classes', fontsize=14)
            plt.ylabel('Labelwise F-Scores', fontsize=14)
            plt.title(f"{model_name} {cw_suffix}", fontsize=16)
            plt.xticks(fontsize=16, rotation=45)
            plt.tight_layout()

            # save in subdirs of output dir
            if not os.path.exists(os.path.join(output_dir, model)):
                os.mkdir(os.path.join(output_dir, model))

            plt.savefig(os.path.join(output_dir, model, f"{cw_flag}_{model}_{task}.png"))
            plt.close()

Creating plots for model LOGR, (cws True), tasks ['mor', 'his', 'beh', 'sit']...
Creating plots for model LOGR, (cws False), tasks ['mor', 'his', 'beh', 'sit']...
Creating plots for model CNN, (cws True), tasks ['mor', 'his', 'beh', 'sit']...
Creating plots for model CNN, (cws False), tasks ['mor', 'his', 'beh', 'sit']...
Creating plots for model HISAN, (cws True), tasks ['mor', 'his', 'beh', 'sit']...
Creating plots for model HISAN, (cws False), tasks ['mor', 'his', 'beh', 'sit']...
Creating plots for model BERT, (cws True), tasks ['mor', 'his', 'beh', 'sit']...
Creating plots for model BERT, (cws False), tasks ['mor', 'his', 'beh', 'sit']...
Creating plots for model MTCNN, (cws True), tasks ['sit2', 'mor', 'sit3', 'his', 'beh']...
Creating plots for model MTCNN, (cws False), tasks ['sit2', 'mor', 'sit3', 'his', 'beh']...
Creating plots for model MTHISAN, (cws True), tasks ['sit2', 'mor', 'sit3', 'his', 'beh']...
Creating plots for model MTHISAN, (cws False), tasks ['sit2', 'mor', 'si