In [None]:
import os
import json
import torch
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Functions

In [None]:
def load_grid_summary(base_path):
    return pd.read_csv(os.path.join(base_path, "grid_summary.csv"))

def load_combination_summary(config_path):
    with open(os.path.join(config_path, "combination_summary.json")) as f:
        return json.load(f)

def load_log_csv(config_path):
    return pd.read_csv(os.path.join(config_path, "log.csv"))

def load_model_fold(config_path, fold_number):
    model_path = os.path.join(config_path, f"model_fold{fold_number}.pt")
    return torch.load(model_path, map_location="cpu")

In [None]:
def plot_loss_accuracy(log_df, config_id):
    plt.figure(figsize=(10, 5))
    sns.lineplot(data=log_df, x="epoch", y="train_loss", hue="fold", legend='brief', palette="Blues", linewidth=1.5)
    sns.lineplot(data=log_df, x="epoch", y="val_loss", hue="fold", legend=False, palette="Oranges", linewidth=1.5)
    plt.title(f"Train/Val Loss - {config_id}")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.show()

    plt.figure(figsize=(8, 4))
    sns.lineplot(data=log_df, x="epoch", y="val_accuracy", hue="fold", palette="Greens", linewidth=1.5)
    plt.title(f"Validation Accuracy - {config_id}")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.grid(True)
    plt.show()


In [None]:
def plot_fold_accuracies(combination_summary, config_id):
    acc = combination_summary["fold_accuracies"]
    sns.boxplot(data=acc)
    plt.title(f"Fold Accuracy Distribution - {config_id}")
    plt.ylabel("Accuracy")
    plt.grid(True)
    plt.show()


# Results

Load and show the best combinations

In [None]:
# Load grid summary
base_path = "tuning_results"
summary_df = load_grid_summary(base_path)
summary_df_sorted = summary_df.sort_values("mean_val_accuracy", ascending=False).reset_index(drop=True)

# Show the top n combinations
top_n = 5
top_configs = summary_df_sorted.head(top_n)
top_configs

Some plots

In [None]:
selected_config = top_configs.iloc[0]
config_path = selected_config['path']
log_df = load_log_csv(config_path)

plot_loss_accuracy(log_df, selected_config['config_id'])

In [None]:
comb_summary = load_combination_summary(config_path)
plot_fold_accuracies(comb_summary, selected_config['config_id'])

Load a checkpoint

In [None]:
checkpoint = load_model_fold(config_path, fold_number=1)

model_state = checkpoint['state_dict']
metrics = checkpoint['metrics']
params = checkpoint['params']
