This script creates confusion matrices from the raw model predictions saved in training. For a simple overview across folds, one plot with all five matrices for each fold is saved.

Three different options are given: absolute counts, normalized by all examples and normalized by the class support size (number of true examples in the dataset); the final variant is also shown in the Appendix.

In [1]:
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from datetime import date
import os

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

from eval_helper_functions import get_task_path_dict, get_pred_label_files, get_pred_label_files_MT

In [2]:

result_collection = os.path.join("..", "model_output")
model_types = ["LOGR", "CNN", "HISAN", "BERT", "MTCNN", "MTHISAN", "MTBERT"]

tasks = ["mor", "sit2", "sit3", "his", "beh", "sit"]
mt_tasks = ["morsit", "behhissit"]
# note: KB-BERT was abbreviated to BERT in the filenames

In [3]:
for normalize_suffix in [None, "all", "true"]:

    output_dir = os.path.join("..", "plots", f"{date.today()}_confusion_matrix_norm_{str(normalize_suffix)}")
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    filename = "pred_true_labels"
    for model in model_types:
        for cw_flag in [True, False]:
            if "MT" not in model:
                task_dict = get_task_path_dict(result_collection, model_type=model, tasks=tasks)
                pred_true_label_files = get_pred_label_files(task_path_dict=task_dict,
                                                        pred_true_labels_filename=f"{filename}.csv",
                                                        cw_flag=cw_flag)
            else:
                task_dict = get_task_path_dict(result_collection, model_type=model, tasks=mt_tasks)
                pred_true_label_files = get_pred_label_files_MT(task_path_dict=task_dict,
                                                                pred_true_labels_filename=f"{filename}.csv",
                                                                cw_flag=cw_flag)

            print(f"Creating plots for model {model}, (cws {cw_flag}), tasks {list(pred_true_label_files.keys())}")
                
            for task in tasks:
                
                if task not in pred_true_label_files.keys():
                    continue

                # create dict with confusion matrix dfs for all folds
                conf_matrices_5x = dict()
                for fold in range(1, 6):
                    pred_output = pd.read_csv(pred_true_label_files[task][str(fold)], index_col=0)

                    conf_matrix = confusion_matrix(y_true=pred_output["labels_true_alph"], 
                                   y_pred=pred_output["labels_pred_alph"],
                                   normalize=normalize_suffix)
                    
                    if normalize_suffix:
                        conf_matrix = conf_matrix * 100
                        conf_matrix = conf_matrix.round(2)
                    conf_matrices_5x[fold] = conf_matrix

                fig, axes = plt.subplots(1, 5, figsize=(28, 7), 
                                         gridspec_kw={'width_ratios': [1, 1, 1, 1, 1], 'wspace': 0.3})
                vmin = min(cm.min() for cm in conf_matrices_5x.values())
                vmax = max(cm.max() for cm in conf_matrices_5x.values())

                for i, (fold, cm) in enumerate(conf_matrices_5x.items()):
                     
                    display = ConfusionMatrixDisplay(confusion_matrix=cm,
                                                    display_labels=sorted(list(pred_output["labels_true_alph"].unique())))
                
                    image = display.plot(ax=axes[i], include_values=True, cmap="magma", 
                                 xticks_rotation="vertical",
                                 colorbar=False)
                    
                    # Title for each plot
                    if cw_flag:
                        cw_suffix = "(+cw)"
                    else:
                        cw_suffix = "(-cw)"

                    if "MT" in model:
                        if task in ["sit2", "mor"]:
                            model_name = f"{model}" + r"$_2$"
                        else:
                            model_name = f"{model}" + r"$_3$"
                    else:
                        model_name = model
                    axes[i].set_title(f"{model_name} {cw_suffix} - Fold {fold}")
                    
                    # adjust size of values in plot based on tasks (=number of classes)
                    if task == "mor" or task == "his":
                        fontsize_values = 5
                    else:
                        fontsize_values = 10
                    for text in image.text_.ravel():
                        text.set_fontsize(fontsize_values)


                cbar = fig.colorbar(image.im_, ax=axes, location='right', pad=0.05)
                cbar.set_label('Counts')


                if not os.path.exists(os.path.join(output_dir,model)):
                    os.mkdir(os.path.join(output_dir,model))
                plt.savefig(os.path.join(output_dir,model, f"{cw_flag}_{model}_{task}.png"))
                plt.close()
                

Creating plots for model LOGR, (cws True), tasks ['mor', 'his', 'beh', 'sit']
Creating plots for model LOGR, (cws False), tasks ['mor', 'his', 'beh', 'sit']
Creating plots for model CNN, (cws True), tasks ['mor', 'his', 'beh', 'sit']
Creating plots for model CNN, (cws False), tasks ['mor', 'his', 'beh', 'sit']
Creating plots for model HISAN, (cws True), tasks ['mor', 'his', 'beh', 'sit']
Creating plots for model HISAN, (cws False), tasks ['mor', 'his', 'beh', 'sit']
Creating plots for model BERT, (cws True), tasks ['mor', 'his', 'beh', 'sit']
Creating plots for model BERT, (cws False), tasks ['mor', 'his', 'beh', 'sit']
Creating plots for model MTCNN, (cws True), tasks ['sit2', 'mor', 'sit3', 'his', 'beh']
Creating plots for model MTCNN, (cws False), tasks ['sit2', 'mor', 'sit3', 'his', 'beh']
Creating plots for model MTHISAN, (cws True), tasks ['sit2', 'mor', 'sit3', 'his', 'beh']
Creating plots for model MTHISAN, (cws False), tasks ['sit2', 'mor', 'sit3', 'his', 'beh']
Creating plots