In [5]:
import os
import pickle
from typing import Optional

import pandas as pd
import matplotlib.pyplot as plt
from pprint import pprint
from transformers.modelcard import parse_log_history

In [9]:
##### INPUTS #####
dataset_name = "race_pp"  # "race_pp_4000"
seed = None
model_name = "majority"  # "random"  # "DistilBERT"
encoding = None  # "question_all"

## Test metrics

In [11]:
def get_metrics(dataset_name: str, model_name: str, encoding: Optional[str], seed: Optional[int]) -> pd.DataFrame:
    """Read metrics from output directory and return as a dataframe."""
    if seed is None:
        output_dir = os.path.join('output', dataset_name)
    else:
        output_dir = os.path.join('output', dataset_name, 'seed_' + str(seed))

    if encoding is None:
        metrics = pd.read_csv(os.path.join(output_dir, 'eval_metrics_' + model_name + '.csv'))
    else:
        metrics = pd.read_csv(os.path.join(output_dir, 'eval_metrics_' + model_name + '_' + encoding + '.csv'))
    metrics = metrics.transpose().reset_index().rename(columns={'index': 'metric', 0: 'value'})
    return metrics

metrics = get_metrics(dataset_name, model_name, encoding, seed)
metrics

Unnamed: 0,metric,value
0,test_mean_absolute_error,0.380007
1,train_mean_absolute_error,0.379077
2,test_root_mean_squared_error,0.616447
3,train_root_mean_squared_error,0.615692
4,test_r2_score,-0.045821
5,train_r2_score,-0.044054
6,test_spearman_rho,
7,train_spearman_rho,
8,test_pearson_rho,
9,train_pearson_rho,


## Training logs

In [None]:
def get_train_logs(dataset_name: str, model_name: str, encoding: str, seed: int) -> tuple:
    """Read training logs from output directory and return as a dictionary."""
    output_dir = os.path.join('output', dataset_name, 'seed_' + str(seed))

    with open(os.path.join(output_dir, model_name + '_' + encoding, "train_logs.pickle"), 'rb') as handle:
        logs = pickle.load(handle)
    train_log, lines, eval_results = parse_log_history(logs)  # NOTE: func from transformers.modelcard
    return train_log, lines, eval_results

train_log, lines, eval_results = get_train_logs(dataset_name, model_name, encoding, seed)

print("="*3, "train_log", "="*3)
pprint(train_log)
print("="*3, "lines", "="*3)
pprint(lines)
print("="*3, "eval_results", "="*3)
pprint(eval_results)

In [None]:
def remove_nesting(logs: list[dict[str, float]]) -> list[dict[str, float]]:
    """Remove unnecessary nesting from logs."""
    new_logs = []
    for log_epoch in logs:
        tmp_dict = {}
        for key, value in log_epoch.items():
            if isinstance(value, dict):
                nested_keys = list(value.keys())
                if len(nested_keys) == 1 and key.lower() == nested_keys[0]:
                    tmp_dict[key] = value[nested_keys[0]]
            else:
                tmp_dict[key] = value
        new_logs.append(tmp_dict)
    return new_logs


all_logs = remove_nesting(lines)
pprint(all_logs)

In [None]:
def plot_history(all_logs: list[dict[str, float]], metric: str) -> None:
    """Plot metric and loss in function of number of epochs.
    Parameters
    ----------
    all_logs : list[dict[str, float]]
        List of dictionaries containing the training logs for each epoch.
    metric : str
        Metric to plot (in addition to loss).
    """
    
    epochs_arr = [log_epoch['Epoch'] for log_epoch in all_logs]
    train_loss_arr = [log_epoch['Training Loss'] for log_epoch in all_logs]
    val_loss_arr = [log_epoch['Validation Loss'] for log_epoch in all_logs]
    metric_arr = [log_epoch[metric] for log_epoch in all_logs]

    plt.figure(figsize=(12,5))

    # Accuracy
    plt.subplot(1,2,1)
    plt.plot(epochs_arr, metric_arr)
    plt.ylim(0, 1)
    plt.title(metric)
    plt.ylabel(metric)
    plt.xlabel('Epochs')
    plt.legend(['valid'], loc='lower right')

    # Loss
    plt.subplot(1,2,2)
    plt.plot(epochs_arr, train_loss_arr)
    plt.plot(epochs_arr, val_loss_arr)
    plt.title('Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epochs')
    plt.legend(['train', 'valid'], loc='upper right')

    plt.show()

metric = "R Squared"
# metric = "Pearsonr"
plot_history(all_logs, metric)