In [1]:
import re
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict
import matplotlib.patches as mpatches

In [2]:
sns.set_theme(context="paper", style="whitegrid")

In [3]:
def parse_trial_line(line):
    pattern = r"INFO:\s+(\w+)\s+\|\s+Epoch\s+(\d+)/\d+\s+\|\s+loss\s+(-?[\d\.]+)\s+\|\s+val RMSE\s+([\d\.]+)"
    match = re.search(pattern, line)
    if match:
        model = match.group(1)
        epoch = int(match.group(2))
        val_rmse = float(match.group(4))
        return model, epoch, val_rmse
    return None

In [4]:
def extract_models(trials_data):
    models_found = set()
    for trial in trials_data.values():
        for (model, _), _ in trial.items():
            models_found.add(model)
    return sorted(models_found)

In [5]:
def get_trials_data(file_path):
    with open(file_path, "r") as f:
        lines = [line.strip() for line in f.readlines()]

    trials_data = {}
    current_trial = {}
    seed = None

    for line in lines:
        if "Trial" in line and "Seed=" in line:
            if seed is not None:
                trials_data[seed] = current_trial
                current_trial = {}

            seed_match = re.search(r"Seed=(\d+)", line)
            if seed_match:
                seed = int(seed_match.group(1))

        elif "val RMSE" in line:
            parsed = parse_trial_line(line)
            if parsed:
                model, epoch, val_rmse = parsed
                current_trial[(model, epoch)] = val_rmse

    if seed is not None and current_trial:
        trials_data[seed] = current_trial

    return trials_data

In [6]:
def aggregate_rmse(trials_data):
    rmse_by_model_epoch = defaultdict(list)

    for trial in trials_data.values():
        for key, val in trial.items():
            rmse_by_model_epoch[key].append(val)

    aggregated = []
    for (model, epoch), rmse_list in rmse_by_model_epoch.items():
        aggregated.append({
            "Model": model,
            "Epoch": epoch,
            "Mean_RMSE": np.mean(rmse_list),
            "Std_RMSE": np.std(rmse_list)
        })

    return pd.DataFrame(aggregated)

In [7]:
def separate_plots(df_summary, save_prefix=None):
    plt.rcParams["font.family"] = "Computer Modern"
    plt.rcParams["text.usetex"] = True
    plt.rcParams["font.size"] = 22

    colors = sns.color_palette("colorblind")
    linestyles = ["-", "--", "-.", ":"]
    models = sorted(df_summary["Model"].unique())

    for i, model in enumerate(models):
        model_df = df_summary[df_summary["Model"] == model].sort_values("Epoch")
        mean_rmse = model_df["Mean_RMSE"].values
        std_rmse = model_df["Std_RMSE"].values

        lower = mean_rmse - std_rmse
        upper = mean_rmse + std_rmse

        color = colors[i % len(colors)]
        linestyle = linestyles[i % len(linestyles)]

        plt.figure(figsize=(14, 12))
        plt.plot(
            model_df["Epoch"],
            mean_rmse,
            label=model,
            color=color,
            linestyle=linestyle,
            linewidth=3
        )
        plt.fill_between(
            model_df["Epoch"],
            lower,
            upper,
            alpha=0.3,
            color=color
        )

        y_min = min(lower) - 0.05 * np.ptp(upper)
        y_max = max(upper) + 0.05 * np.ptp(upper)
        plt.ylim(y_min, y_max)

        if y_min < 0:
            plt.axhline(0, color="gray", linestyle="--", linewidth=1)

        max_epoch = model_df["Epoch"].max()
        plt.xticks(ticks=range(0, max_epoch + 1, 2), fontsize=24)
        plt.yticks(fontsize=24)

        plt.xlabel("Epoch", fontsize=26)
        plt.ylabel("Validation RMSE", fontsize=26)
        plt.grid(axis='y', linestyle="--", linewidth=0.6, alpha=0.6)
        plt.legend(
            fontsize=24,
            title="Model",
            title_fontsize=26,
            loc="upper right",
        )
        plt.tight_layout()

        save_path = f"{save_prefix}_{model}_separate.pdf"

        if save_path:
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            plt.savefig(save_path, dpi=300, bbox_inches="tight")
        else:
            plt.show()

        plt.close()

In [8]:
def combined_plot(df_summary, save_path=None):
    plt.rcParams["font.family"] = "Computer Modern"
    plt.rcParams["text.usetex"] = True
    plt.rcParams["font.size"] = 22

    plt.figure(figsize=(14, 12))

    colors = sns.color_palette("colorblind")
    linestyles = ["-", "--", "-.", ":"]
    models = df_summary["Model"].unique()

    all_means = []
    all_lowers = []
    all_uppers = []

    for i, model in enumerate(models):
        model_df = df_summary[df_summary["Model"] == model].sort_values("Epoch")
        mean_rmse = model_df["Mean_RMSE"].values
        std_rmse = model_df["Std_RMSE"].values

        lower = mean_rmse - std_rmse
        upper = mean_rmse + std_rmse

        all_means.extend(mean_rmse)
        all_lowers.extend(lower)
        all_uppers.extend(upper)

        color = colors[i % len(colors)]
        linestyle = linestyles[i % len(linestyles)]

        plt.plot(
            model_df["Epoch"],
            mean_rmse,
            label=model,
            color=color,
            linestyle=linestyle,
            linewidth=3
        )
        plt.fill_between(
            model_df["Epoch"],
            lower,
            upper,
            alpha=0.3,
            color=color
        )

    y_min = min(all_lowers) - 0.05 * np.ptp(all_uppers)
    y_max = max(all_uppers) + 0.05 * np.ptp(all_uppers)
    plt.ylim(y_min, y_max)

    if y_min < 0:
        plt.axhline(0, color="gray", linestyle="--", linewidth=1)

    max_epoch = df_summary["Epoch"].max()
    plt.xticks(ticks=range(0, max_epoch + 1, 2), fontsize=24)
    plt.yticks(fontsize=24)

    # Create dummy entry for "Model:"
    title_patch = mpatches.Patch(color='none', label='Model:')

    # Get line handles and labels from actual plots
    handles, labels = plt.gca().get_legend_handles_labels()

    # Insert dummy title at the beginning
    handles = [title_patch] + handles
    labels = ['Model'] + labels

    plt.xlabel("Epoch", fontsize=26)
    plt.ylabel("Validation RMSE", fontsize=26)
    plt.grid(axis="y", linestyle="--", linewidth=0.6, alpha=0.6)
    plt.legend(
        handles=handles,
        labels=labels,
        loc="upper center",
        bbox_to_anchor=(0.5, -0.12),  # instead of -0.3
        ncol=4,
        handletextpad=1.0,
        columnspacing=1.5,
        fontsize=24
    )
    plt.tight_layout(rect=[0, 0.06, 1, 1])  # instead of [0, 0.15, 1, 1]

    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
    else:
        plt.show()

    plt.close()

In [9]:
data_dir = "../out"
figures_dir = "../figures"

datasets = ["url", "lcld", "news", "faulty-steel-plates"]

for dataset in datasets:
    for depth in ["deep", "shallow"]:
        for mask in ["", "_mask"]:
            if dataset == "lcld":
                mask = ""
            data = get_trials_data(f"{data_dir}/{dataset}/{depth}{mask}_log.txt")
            df_summary = aggregate_rmse(data)
            combined_savedir = f"{figures_dir}/{dataset}/{depth}{mask}_combined.pdf"
            combined_plot(df_summary, save_path=combined_savedir)
            separate_savedir = f"{figures_dir}/{dataset}/{depth}{mask}"
            separate_plots(df_summary, save_prefix=separate_savedir)