In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay
import os
import glob
import json
import numpy as np

In [None]:
def extract_metrics(sudir):
    metrics_data = []
    # Search for files matching the pattern *_FULL_metrics.json 
    file_pattern = os.path.join(sudir, '*_Full_metrics.json')
    files = glob.glob(file_pattern)

    if not files:
        print(f"No previous metrics files found in {sudir}")
        return {}
    
    file_path = files[0]

    try:
        with open(file_path, 'r') as file:
            metrics = json.load(file)
        return metrics
    except FileNotFoundError:
        print(f"Error: File not found at {file_path}")
        return {}
    except json.JSONDecodeError:
        print(f"Error: Invalid JSON in file {file_path}")
        return {}
    except Exception as e:
        print(f"An unexpected error occurred: {str(e)}")
        return {}

In [None]:
def single_plot(data):
    fig, axs = plt.subplots(2, 2, figsize=(12, 8))

    axs[0, 0].plot(data['lossi'], label='Train Loss')
    axs[0, 0].plot(data['devlossi'], label='Validation Loss')
    axs[0, 0].set_title('Loss vs Epochs')
    axs[0, 0].legend()
    axs[0, 0].set_xlabel('Epochs')
    axs[0, 0].set_ylabel('Loss')
    axs[0, 0].axhline(data["best_loss_dev"], color='g', linestyle='--', label=f'Best Dev Loss: {data["best_loss_dev"]:.3f}')
    axs[0, 0].legend()

    axs[0, 1].plot(data['f1i'], label='Train F1')
    axs[0, 1].plot(data['devlf1i'], label='Validation F1')
    axs[0, 1].set_title('F1 Score vs Epochs')
    axs[0, 1].legend()
    axs[0, 1].set_xlabel('Epochs')
    axs[0, 1].set_ylabel('F1')
    axs[0, 1].axhline(data["best_f1_dev"], color='g', linestyle='--', label=f'Best Dev F1: {data["best_f1_dev"]:.3f}')
    axs[0, 1].legend()

    confusion_matrix = np.array(data['confusion_matrix'])
    ConfusionMatrixDisplay(confusion_matrix).plot(ax=axs[1, 0])
    axs[1, 1].axis('off')  
    axs[1, 0].set_title('Confusion Matrix')
    plt.tight_layout()
    plt.show()

    # return fig

In [None]:
def main(eval_dir_path):
    metrics = extract_metrics(eval_dir_path)
    single_plot(data=metrics) 
    #plot = single_plot(data=metrics)

    # plt.tight_layout()
    # plt.savefig(os.path.join(eval_dir_path, "_fullEval.png"))
    # plt.close()

In [None]:
eval_dir_path = 'path/to/your/eval/directory'
main(eval_dir_path)