In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.table import Table
from utils_eval import color_helper
import os

from utils_eval import extract_metrics

In [3]:
def add_to_plot(axs, data, color, name):
    axs[0,0].plot(data['train_loss'], label=f'Train Loss {name}', color=color)
    axs[0,0].plot(data['dev_loss'], label=f'Validation Loss {name}', color=color, linestyle='--')

    axs[0,1].plot(data['train_f1'], label=f'Train F1 {name}', color=color)
    axs[0,1].plot(data['dev_f1'], label=f'Validation F1 {name}', color=color, linestyle='--')
    return axs

In [None]:
def add_table(axs, all_best_metric, f1, top_n=3):
    if f1:
        # Sort by F1 score
        top_models = sorted(all_best_metric.items(), key=lambda x: x[1]['best_f1_dev'], reverse=True)[:top_n]
        metric_key = 'best_f1_dev'
        column_label = 'Best Dev F1 ----'
        ax = axs[1, 1]
        title = f'Top {top_n} Models by F1 Score'
    else:
        # Sort by loss
        top_models = sorted(all_best_metric.items(), key=lambda x: x[1]['best_loss_dev'])[:top_n]
        metric_key = 'best_loss_dev'
        column_label = 'Best Dev Loss ----'
        ax = axs[1, 0]
        title = f'Top {top_n} Models by Loss'
    
    data = []
    rows = [i+1 for i in range(top_n)]
    cell_colors = []
    text_colors = []  # New list to store text colors
    
    for n, (model_name, metrics) in enumerate(top_models):
        value = metrics[metric_key]
        data.append([f"{model_name}: {value:.3f}"])
        cell_color = metrics['color']
        cell_colors.append([cell_color])

        #make white text is background is black
        text_color = '#FFFFFF' if cell_color == '#000000' else '#000000'
        text_colors.append([text_color])
    
    table = ax.table(cellText=data, 
                    rowLabels=rows, 
                    colLabels=[column_label],
                    loc='upper center', 
                    cellColours=cell_colors)
    
    # set text color for each cell
    for (i, j), cell in table._cells.items():
        if i > 0:  #dont do the first (number) cell of each row
            cell.get_text().set_color(text_colors[i-1][0])
    
    table.auto_set_font_size(False)
    table.set_fontsize(25)
    table.scale(1, 3)
    ax.axis('off')
    ax.set_title(title, fontsize=30)
    return table

In [4]:
def main(eval_dir_path):
    fig, axs = plt.subplots(2, 2, figsize=(30, 20))

    axs[0, 0].set_title('Loss vs Epochs')
    axs[0, 0].set_xlabel('Epochs')
    axs[0, 0].set_ylabel('Loss')

    axs[0, 1].set_title('F1 Score vs Epochs')
    axs[0, 1].set_xlabel('Epochs')
    axs[0, 1].set_ylabel('F1')

    subdirs = [d for d in os.listdir(eval_dir_path) if os.path.isdir(os.path.join(eval_dir_path, d))]
    num_dir = len(subdirs)
    colors = color_helper(num_dir)

    all_best_metric = {} #dic with  each models best metrics and their color

    for i, subdir in enumerate(subdirs):
        metrics = extract_metrics(basedir=eval_dir_path, subdir=subdir)
        name = os.path.basename(subdir)
        axs = add_to_plot(axs=axs, data=metrics, color=colors[i], name=name)
        all_best_metric.update({name: {'best_f1_dev':metrics['best_f1_dev'], 'best_loss_dev':metrics['best_loss_dev'], 'color':colors[i]}})

    fig.legend(loc='center right', bbox_to_anchor=(1.12, 0.7),fontsize='x-large')
    
    
    add_table(axs=axs, all_best_metric=all_best_metric, f1=True)
    add_table(axs=axs, all_best_metric=all_best_metric, f1=False)



    # axs[0].legend(bbox_to_anchor=(0.5,-.2))
    
    plt.tight_layout()
    # plt.subplots_adjust(right=0.85)
    plt.savefig(os.path.join(eval_dir_path,f"_multiEval.png"), bbox_inches='tight', pad_inches=0.5)
    plt.close()

In [None]:
eval_dir_path = 'path/to/your/eval/directory'
main(eval_dir_path)