In [1]:
import numpy as np
import matplotlib.pyplot as plt
from utils_eval import color_helper
import os
import glob
import json

In [2]:
def extract_metrics(sudir):
    metrics_data = []
    # Search for files matching the pattern *_FULL_metrics.json 
    file_pattern = os.path.join(sudir, '*_Full_metrics.json')
    files = glob.glob(file_pattern)

    if not files:
        print(f"No previous metrics files found in {sudir}")
        return {}
    
    file_path = files[0]

    try:
        with open(file_path, 'r') as file:
            metrics = json.load(file)
        return metrics
    except FileNotFoundError:
        print(f"Error: File not found at {file_path}")
        return {}
    except json.JSONDecodeError:
        print(f"Error: Invalid JSON in file {file_path}")
        return {}
    except Exception as e:
        print(f"An unexpected error occurred: {str(e)}")
        return {}

In [3]:
def add_to_plot(axs, data, color, name):
    #WE NEEED TO ADD NAME
    axs[0, 0].plot(data['lossi'], label=f'Train Loss {name}', color=color)
    axs[0, 0].plot(data['devlossi'], label=f'Validation Loss {name}', color=color, linestyle='--')

    axs[0, 1].plot(data['f1i'], label=f'Train F1 {name}', color=color)
    axs[0, 1].plot(data['devlf1i'], label=f'Validation F1 {name}', color=color, linestyle='--')
    return axs

In [4]:
def main(eval_dir_path):
    fig, axs = plt.subplots(1, 2, figsize=(12, 6))

    axs[0, 0].set_title('Loss vs Epochs')
    axs[0, 0].legend()
    axs[0, 0].set_xlabel('Epochs')
    axs[0, 0].set_ylabel('Loss')
    axs[0, 0].legend()

    axs[0, 1].set_title('F1 Score vs Epochs')
    axs[0, 1].legend()
    axs[0, 1].set_xlabel('Epochs')
    axs[0, 1].set_ylabel('F1')
    axs[0, 1].legend()

    subdirs = [d for d in os.listdir(eval_dir_path) if os.path.isdir(os.path.join(eval_dir_path, d))]
    num_dir = len(subdir)
    colors = color_helper(num_dir)

    for i, subdir in enumerate(subdirs):
        metrics = extract_metrics(subdir)
        name = os.path.basename(subdir)
        axs = add_to_plot(axs=axs, data=metrics, color=colors[i], name=name)


    plt.tight_layout()
    plt.savefig(os.path.join(eval_dir_path,f"_multiEval.png"))
    plt.close()

In [5]:
eval_dir_path = 'path/to/your/eval/directory'
main(eval_dir_path)

AttributeError: module 'matplotlib' has no attribute 'plot'