In [None]:
import os
import pickle
import re
import shutil
from collections import defaultdict
import matplotlib.pyplot as plt
import numpy as np
from tabulate import tabulate

%matplotlib qt

In [None]:
def summarize_metric_sweep_lr(base_dir, algorithms, datasets, lr_values, metric="test_accuracy"):
    """
    Computes mean and std of a specified metric (e.g., test_accuracy or test_loss) over seeds,
    for each learning rate in `lr_values`. Displays results by algorithm (rows) and dataset (columns),
    grouped by learning rate with the LR value shown only once per block.

    Returns:
        summary_stats: dict[lr][algorithm][dataset] = (mean, std)
    """

    # folder_pattern = re.compile(r"(\d{4}-\d{2}-\d{2})_(\d{2}-\d{2})_(\d+)_([\d]+)")
    folder_pattern = re.compile(r"(\d+)_([\d]+)_EpochSeed")

    # Store results: {lr -> algorithm -> dataset -> values}
    summary_stats = defaultdict(lambda: defaultdict(dict))
    all_rows = []

    for lr_value in lr_values:
        results = defaultdict(lambda: defaultdict(list))

        for entry in os.listdir(base_dir):
            if not folder_pattern.match(entry):
                continue
            run_dir = os.path.join(base_dir, entry)

            for dataset in datasets:
                dataset_path = os.path.join(run_dir, dataset)
                lr_folder = os.path.join(dataset_path, f"lr{lr_value}")
                if not os.path.isdir(lr_folder):
                    continue

                for algorithm in algorithms:
                    result_file = os.path.join(lr_folder, f"{algorithm}.pkl")
                    if os.path.isfile(result_file):
                        with open(result_file, 'rb') as f:
                            data = pickle.load(f)
                            val = data.get(metric)
                            if val is not None:
                                results[algorithm][dataset].append(val)

        # Compute mean/std for this learning rate
        for algorithm in algorithms:
            row = [lr_value if algorithm == algorithms[0] else "", algorithm]
            for dataset in datasets:
                vals = results[algorithm][dataset]
                if vals:
                    mean_val = np.mean(vals)
                    std_val = np.std(vals)
                    summary_stats[lr_value][algorithm][dataset] = (mean_val, std_val)
                    row.append(f"{mean_val:.4f} ± {std_val:.4f}")
                else:
                    summary_stats[lr_value][algorithm][dataset] = None
                    row.append("-")
            all_rows.append(row)

    # Build headers
    headers = ["LR", "Algorithm"] + datasets
    table_str = tabulate(all_rows, headers=headers, tablefmt="fancy_grid")
    print(table_str)
    return summary_stats

In [None]:
def get_algorithm_colors(algorithms):
    """
    Returns a dictionary mapping each algorithm to a distinct color from the 'tab10' colormap.
    """

    # num_colors = len(algorithms)
    # colormap = plt.get_cmap('tab10', num_colors)
    # color_map = {algorithm: colormap(i) for i, algorithm in enumerate(algorithms)}
    color_map = {
        "SGD": "blue",
        "SGD_CLARA": "dodgerblue",
        "SGD_CLARA_us": "cyan",

        "Adam": "red",
        "Adam_CLARA": "coral",  # "deeppink",
        "Adam_CLARA_us": "magenta",

        "D-Adaptation": "green"
    }
    return color_map

In [None]:
def plot_algorithm_performance(summary_stats, algorithms, datasets, lr_values, metric_name="Test Accuracy"):
    """
    Plot bar charts with error bars showing mean and std for each algorithm at different learning rates,
    one plot per dataset, with improved spacing, consistent y-axis, a horizontal line for best performance,
    and transparency on non-best bars. Each algorithm is assigned a consistent color.
    """

    num_lrs = len(lr_values)
    num_algos = len(algorithms)
    bar_width = 0.15
    group_spacing = 0.4  # space between groups
    total_group_width = num_algos * bar_width + group_spacing
    x = np.arange(num_lrs) * total_group_width

    color_map = get_algorithm_colors(algorithms)

    for dataset in datasets:
        plt.figure(figsize=(12, 6))

        all_means = []  # store all means to compute the best one

        # First, gather all means to find the best one
        for algorithm in algorithms:
            for lr in lr_values:
                stats = summary_stats.get(lr, {}).get(algorithm, {}).get(dataset)
                if stats is not None:
                    mean, _ = stats
                    all_means.append((mean, algorithm, lr))

        if not all_means:
            continue  # skip empty plots

        # Find the best performing (mean) algorithm-lr pair
        best_mean, best_algo, best_lr = max(all_means, key=lambda x: x[0])

        for i, algorithm in enumerate(algorithms):
            means = []
            stds = []
            alphas = []

            for lr in lr_values:
                stats = summary_stats.get(lr, {}).get(algorithm, {}).get(dataset)
                if stats is not None:
                    mean, std = stats
                else:
                    mean, std = 0, 0
                means.append(mean)
                stds.append(std)

                # Make best bar opaque, others semi-transparent
                if algorithm == best_algo and lr == best_lr:
                    alphas.append(1.0)
                else:
                    alphas.append(0.5)

            x_pos = x + i * bar_width
            bars = plt.bar(x_pos, means, width=bar_width, label=algorithm, yerr=stds,
                           capsize=5, color=color_map[algorithm])

            # Adjust alpha bar by bar
            for bar, a in zip(bars, alphas):
                bar.set_alpha(a)

        # Add horizontal line at best performance
        plt.axhline(y=best_mean, color='gray', linestyle='--', linewidth=1)
        plt.text(x[0] - bar_width * 3.9, best_mean, f"{best_mean:.2f}", va='center', ha='right', color='gray')

        # Plot styling
        plt.xticks(x + (num_algos / 2 - 0.5) * bar_width, lr_values)
        plt.xlabel("Learning Rate")
        plt.ylabel(metric_name)
        plt.title(f"{metric_name} on {dataset}")
        plt.ylim(0, 100)
        plt.legend()
        plt.grid(True, axis='y', linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.show()

In [None]:
summary = summarize_metric_sweep_lr(
    base_dir="Experiments/selected_results/",
    algorithms=["SGD", "SGD_CLARA", "SGD_CLARA_us", "Adam", "Adam_CLARA", "Adam_CLARA_us", "D-Adaptation"],
    datasets=["breast_cancer", "iris", "wine", "digits", "mnist", "fmnist", "cifar10", "cifar100"],
    lr_values=["1e-06", "1e-05", "1e-04", "1e-03", "1e-02", "1e-01", "1.00"],  # must match folder name exactly
    metric="test_accuracy"  # test_accuracy or test_loss
)

In [None]:
plot_algorithm_performance(
    summary_stats=summary,  # from earlier function
    algorithms=["SGD", "SGD_CLARA", "SGD_CLARA_us", "Adam", "Adam_CLARA", "Adam_CLARA_us", "D-Adaptation"],
    datasets=["breast_cancer", "iris", "wine", "digits", "mnist", "fmnist", "cifar10", "cifar100"],
    lr_values=["1e-06", "1e-05", "1e-04", "1e-03", "1e-02", "1e-01", "1.00"],
    metric_name="Test Accuracy"
)

In [None]:
def plot_algorithm_performance_lines(summary_stats, algorithms, datasets, lr_values, metric_name="Test Accuracy", save_fig=False):
    """
    Plot line charts (instead of bar plots) with error bars for each algorithm at different learning rates,
    one plot per dataset. Each algorithm is represented by a different color and marker.
    """

    def get_markers():
        return ['o', 's', '^', 'D', 'v', 'P', '>', 'X', 'h', '+']

    color_map = get_algorithm_colors(algorithms)
    marker_list = get_markers()

    for dataset in datasets:
        plt.figure(figsize=(10, 6))

        # Keeping track of the best test performance across algorithms and learning rate values
        global_best = {"val": -np.inf, "lr": None, "algo": None}

        for i, algorithm in enumerate(algorithms):
            means = []
            stds = []

            for lr in lr_values:
                val = summary_stats.get(lr, {}).get(algorithm, {}).get(dataset)
                if val:
                    mean, std = val
                else:
                    mean, std = np.nan, np.nan
                means.append(mean)
                stds.append(std)

            # Convert learning rates to float for plotting
            lr_floats = [float(lr) for lr in lr_values]
            plt.errorbar(
                lr_floats,
                means,
                yerr=stds,
                label=algorithm,
                marker=marker_list[i % len(marker_list)],
                markersize=8,
                color=color_map[algorithm],
                capsize=4,
                linestyle='-'
            )

            # Track global best point
            max_val = np.nanmax(means)
            if max_val > global_best["val"]:
                best_idx = np.nanargmax(means)
                global_best = {
                    "val": max_val,
                    "lr": float(lr_values[best_idx]),
                    "algo": algorithm
                }

        # Highlight best overall performance
        plt.plot(
            global_best["lr"], global_best["val"],
            marker='*', color='black', markersize=15,
            label="Best overall", zorder=5
        )
        plt.text(
            global_best["lr"], global_best["val"] + 1.5,
            f"{global_best['val']:.2f}",
            ha='center', va='bottom',
            fontsize=9, color='black', fontweight='bold'
        )

        plt.xscale("log")
        plt.xlabel("Initial Learning Rate")
        plt.ylabel(metric_name)
        # plt.title(f"{metric_name} on {dataset}")
        plt.title(f"{dataset}")
        plt.ylim(0, 100)
        plt.grid(True, which='both', linestyle='--', alpha=0.7)
        plt.legend()
        plt.tight_layout()

        # Save as PDF in a subdirectory
        if save_fig:
            save_dir = "plots"
            os.makedirs(save_dir, exist_ok=True)

            filename = f"{dataset}_{metric_name.replace(' ', '_')}.pdf"
            filepath = os.path.join(save_dir, filename)
            plt.savefig(filepath, format='pdf', bbox_inches='tight')

        plt.show()

In [None]:
plot_algorithm_performance_lines(
    summary_stats=summary,
    algorithms=["SGD", "SGD_CLARA", "SGD_CLARA_us", "Adam", "Adam_CLARA", "Adam_CLARA_us", "D-Adaptation"],
    # algorithms=["SGD_CLARA", "SGD_CLARA_us", "Adam_CLARA", "Adam_CLARA_us"],
    datasets=["breast_cancer", "iris", "wine", "digits", "mnist", "fmnist", "cifar10", "cifar100"],
    lr_values=["1e-06", "1e-05", "1e-04", "1e-03", "1e-02", "1e-01", "1.00"],
    metric_name="Test Accuracy",
    save_fig=False
)

In [None]:
def plot_avg_performance_vs_damping(base_dir, algorithm, datasets, lr_values, metric="test_accuracy"):
    """
    For a given algorithm, plot average test performance across all datasets (with std deviation as error bars)
    as a function of damping value. Each line corresponds to a different learning rate.
    """
    # Pattern: date_time_epoch_seed_damping
    # folder_pattern = re.compile(r"(\d{4}-\d{2}-\d{2})_(\d{2}-\d{2})_(\d+)_([\d]+)_([0-9eE\.-]+)")
    folder_pattern = re.compile(r"(\d+)_([\d]+)_([0-9eE\.-]+)_EpochSeedDamping")

    # results[damping][lr] = list of metric values (avg over datasets and seeds)
    results = defaultdict(lambda: defaultdict(list))

    for entry in os.listdir(base_dir):
        match = folder_pattern.match(entry)
        if not match:
            continue
        _, _, damping = match.groups()
        damping = float(damping)
        run_dir = os.path.join(base_dir, entry)

        for dataset in datasets:
            dataset_path = os.path.join(run_dir, dataset)
            if not os.path.isdir(dataset_path):
                continue

            for lr in lr_values:
                lr_path = os.path.join(dataset_path, f"lr{lr}")
                result_file = os.path.join(lr_path, f"{algorithm}.pkl")
                if os.path.isfile(result_file):
                    with open(result_file, "rb") as f:
                        data = pickle.load(f)
                        value = data.get(metric)
                        if value is not None:
                            results[damping][lr].append(value)

    # Sort damping values
    damping_values = sorted(results.keys())
    plt.figure(figsize=(10, 6))

    for lr in lr_values:
        means = []
        stds = []
        for damping in damping_values:
            vals = results[damping].get(lr, [])
            if vals:
                means.append(np.mean(vals))
                stds.append(np.std(vals))
            else:
                means.append(np.nan)
                stds.append(np.nan)

        plt.errorbar(damping_values, means, yerr=stds, marker='o', label=f"lr={lr}", capsize=4)

    plt.xscale("log")
    plt.xlabel("Damping")
    plt.ylabel(metric.replace("_", " ").capitalize())
    plt.title(f"Performance over all datasets for {algorithm}")
    plt.ylim(0, 100)
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.tight_layout()

    # Save plot
    save_dir = "plots_vs_damping"
    os.makedirs(save_dir, exist_ok=True)
    filename = f"{algorithm}_{metric}.pdf"
    filepath = os.path.join(save_dir, filename)
    plt.savefig(filepath, format="pdf", bbox_inches="tight")

    plt.show()

In [None]:
plot_avg_performance_vs_damping(
    base_dir="Experiments/selected_results/damping_experiments/",
    algorithm="SGD_CLARA_us",
    datasets=["breast_cancer", "iris", "wine", "digits", "mnist", "fmnist", "cifar10", "cifar100"],
    lr_values=["1e-06", "1e-05", "1e-04", "1e-03", "1e-02", "1e-01", "1.00"],
    metric="test_accuracy"
)

In [None]:
def plot_performance_vs_damping_per_dataset(base_dir, algorithm, datasets, lr_values, metric="test_accuracy"):
    """
    For a given algorithm, plot test performance as a function of damping (with error bars) for each dataset
    and each learning rate. Also plots best damping as a function of learning rate.
    """
    folder_pattern = re.compile(r"(\d+)_([\d]+)_([0-9eE\.-]+)_EpochSeedDamping")

    for dataset in datasets:
        results = defaultdict(lambda: defaultdict(list))  # results[damping][lr] = list of metric values

        for entry in os.listdir(base_dir):
            match = folder_pattern.match(entry)
            if not match:
                continue
            _, _, damping = match.groups()
            damping = float(damping)
            run_dir = os.path.join(base_dir, entry)

            dataset_path = os.path.join(run_dir, dataset)
            if not os.path.isdir(dataset_path):
                continue

            for lr in lr_values:
                lr_path = os.path.join(dataset_path, f"lr{lr}")
                result_file = os.path.join(lr_path, f"{algorithm}.pkl")
                if os.path.isfile(result_file):
                    with open(result_file, "rb") as f:
                        data = pickle.load(f)
                        value = data.get(metric)
                        if value is not None:
                            results[damping][lr].append(value)

        damping_values = sorted(results.keys())
        best_dampings = []

        # Plot: performance vs. damping (1 plot per dataset)
        plt.figure(figsize=(10, 6))
        for lr in lr_values:
            means, stds = [], []
            for damping in damping_values:
                vals = results[damping].get(lr, [])
                if vals:
                    means.append(np.mean(vals))
                    stds.append(np.std(vals))
                else:
                    means.append(np.nan)
                    stds.append(np.nan)

            # Get best damping for this lr
            valid = [(d, m) for d, m in zip(damping_values, means) if not np.isnan(m)]
            if valid:
                best_damping, best_value = max(valid, key=lambda x: x[1])
                best_dampings.append((float(lr), best_damping))

                # Copy best model to target folder
                for s in range(5):
                    source_path = f"Experiments/selected_results/damping_experiments/100_{s}_{best_damping:.0e}_EpochSeedDamping/{dataset}/lr{lr}/{algorithm}.pkl"
                    dest_path = f"Experiments/selected_results/100_{s}_EpochSeed/{dataset}/lr{lr}/"
                    os.makedirs(dest_path, exist_ok=True)
                    if os.path.exists(source_path):
                        shutil.copy(source_path, dest_path)

            plt.errorbar(
                damping_values, means, yerr=stds,
                marker='o', label=f"lr={lr}", capsize=5
            )

        plt.xscale("log")
        plt.xlabel("Damping")
        plt.ylabel(metric.replace("_", " ").title())
        plt.title(f"{algorithm} on {dataset}")
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.ylim(0, 100)
        plt.tight_layout()

        # Save performance vs damping plot
        save_dir = "plots_vs_damping"
        os.makedirs(save_dir, exist_ok=True)
        filename = f"{dataset}_{algorithm}_{metric}.pdf"
        filepath = os.path.join(save_dir, filename)
        plt.savefig(filepath, format="pdf", bbox_inches="tight")
        plt.close()

        # Plot: best damping vs. learning rate
        if best_dampings:
            lr_floats, best_damp_vals = zip(*sorted(best_dampings))

            plt.figure(figsize=(7, 5))
            plt.plot(lr_floats, best_damp_vals, marker='o', linestyle='-')
            plt.xscale("log")
            plt.yscale("log")
            plt.yscale("log")
            plt.ylim(1e-5, 1e-1)
            plt.xlabel("Initial Learning Rate")
            plt.ylabel("Best Damping")
            plt.title(f"Best Damping vs Learning Rate\n{algorithm} on {dataset}")
            plt.grid(True, linestyle="--", alpha=0.7)
            plt.tight_layout()

            # Save best damping vs learning rate plot
            best_plot_dir = "plots_best_damping"
            os.makedirs(best_plot_dir, exist_ok=True)
            best_filename = f"{dataset}_{algorithm}_{metric}_best_damping_vs_lr.pdf"
            best_filepath = os.path.join(best_plot_dir, best_filename)
            plt.savefig(best_filepath, format="pdf", bbox_inches="tight")
            plt.close()

In [None]:
plot_performance_vs_damping_per_dataset(
    base_dir="Experiments/selected_results/damping_experiments/",
    algorithm="SGD_CLARA_us",
    datasets=["breast_cancer", "iris", "wine", "digits", "mnist", "fmnist", "cifar10", "cifar100"],
    lr_values=["1e-06", "1e-05", "1e-04", "1e-03", "1e-02", "1e-01", "1.00"],
    metric="test_accuracy"
)

In [None]:
def plot_avg_lr_and_accuracy_schedules(base_dir_template, seeds, dataset, algorithms):
    """
    For each algorithm, plot average learning rate and training accuracy schedules across seeds,
    grouped by initial learning rate. Plots appear in two stacked subplots with a shared legend.
    """
    for algorithm in algorithms:
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5), sharey=False)

        example_dir = base_dir_template.format(seeds[0])
        dataset_path = os.path.join(example_dir, dataset)
        lr_folders = [f for f in os.listdir(dataset_path) if f.startswith("lr")]

        lr_entries = []
        for folder in lr_folders:
            try:
                lr_str = folder[2:]
                lr_float = float(lr_str)
                lr_entries.append((lr_str, lr_float))
            except ValueError:
                continue
        lr_entries.sort(key=lambda x: x[1])

        for lr_str, _ in lr_entries:
            lr_histories = []
            acc_histories = []

            for seed in seeds:
                base_dir = base_dir_template.format(seed)
                pkl_path = os.path.join(base_dir, dataset, f"lr{lr_str}", f"{algorithm}.pkl")
                if os.path.isfile(pkl_path):
                    with open(pkl_path, "rb") as f:
                        data = pickle.load(f)
                        lr_history = data.get("lr_history")

                        # contains_nan = np.isnan(lr_history).any()
                        # if contains_nan:
                        #     print(lr_history)

                        acc_history = data.get("train_accuracies")
                        if lr_history is not None and acc_history is not None:
                            lr_histories.append(lr_history)
                            acc_histories.append(acc_history)

            if not lr_histories:
                continue

            # Compute mean and std
            lr_array = np.array(lr_histories)
            acc_array = np.array(acc_histories)
            mean_lr = np.mean(lr_array, axis=0)
            std_lr = np.std(lr_array, axis=0)
            mean_acc = np.mean(acc_array, axis=0)
            std_acc = np.std(acc_array, axis=0)
            steps = np.arange(len(mean_lr))

            label = f"lr={lr_str}"
            ax2.plot(steps, mean_lr, label=label)
            ax2.fill_between(steps, mean_lr - std_lr, mean_lr + std_lr, alpha=0.2)

            ax1.plot(steps, mean_acc, label=label)
            ax1.fill_between(steps, mean_acc - std_acc, mean_acc + std_acc, alpha=0.2)

        plt.suptitle(f"{algorithm} on {dataset}")

        # Format bottom plot (learning rate)
        ax2.set_xlabel("Epoch")
        ax2.set_yscale("log")
        ax2.set_ylabel("Learning Rate")
        ax2.grid(True, linestyle="--", alpha=0.6)

        # Format top plot (accuracy)
        ax1.set_xlabel("Epoch")
        ax1.set_ylabel("Train Accuracy")
        ax1.grid(True, linestyle="--", alpha=0.6)
        ax1.set_ylim(0, 100)
        ax1.legend(loc="upper right")

        plt.tight_layout()

        # Save best damping vs learning rate plot
        lr_plot_dir = "learning_curves"
        os.makedirs(lr_plot_dir, exist_ok=True)
        lr_filename = f"{dataset}_{algorithm}_training_acc_lr.pdf"
        lr_filepath = os.path.join(lr_plot_dir, lr_filename)
        plt.savefig(lr_filepath, format="pdf", bbox_inches="tight")

        plt.show()


In [None]:
plot_avg_lr_and_accuracy_schedules(base_dir_template="Experiments/selected_results/100_{}_EpochSeed",  #"Experiments/selected_results/damping_experiments/100_{}_1e-03_EpochSeedDamping",
                                   seeds=[0, 1, 2, 3, 4],
                                   dataset="cifar100",
                                   algorithms=["SGD", "SGD_CLARA", "SGD_CLARA_us", "Adam", "Adam_CLARA", "Adam_CLARA_us", "D-Adaptation"]
                                   )