# Here is the file for the plots
This notebook summarizes and visualizes the behavior of various metrics under different matrix factorization settings.
The plots cover how accuracy, reconstruction error, and correlation metrics evolve depending on the sampling rate `p`, the number of latent dimensions `d`, the scaling factor `s`, and selection strategies.
Some sections also analyze the distribution of per-row metrics and the influence of outliers or structured noise.

No comments are made on this notebook directly. Please refer to the pdf for more details.

## Summary

- [Plots vs $s$](#plots-vs-s)
  - [Plots vs $s$ by $p$](#plots-vs-s-by-p)
  - [Plots vs $s$ by $k$](#plots-vs-s-by-k)
- [Plots explaining the difference in Reconstruction Error Scaled per Row and the Pearson coefficient](#plots-explaining-the-difference-in-reconstruction-error-scaled-per-row-and-the-pearson-coefficient)
  - [Comparison of the distribution of the values per row](#comparison-of-the-distribution-of-the-values-per-row)
  - [Histogram of the distribution of the $\alpha_u$](#histogram-of-the-distribution-of-the-alpha_u)
  - [Effect of the outlier and of the noise on the metrics](#effect-of-the-outlier-and-of-the-noise-on-the-metrics)
  - [Plots for $p \cdot k = \text{const}$](#plots-for-p-cdot-k-const)
- [Plots vs $p$](#plots-vs-p)
  - [Plots for $p \cdot s = \text{const}$](#plots-for-p-cdot-s--const)
  - [Plots for $p$ vs $d$](#plots-for-p-vs-d)
- [Plots for Strategies](#plots-for-strategies)
  - [Plots for strategies vs $s$](#plots-for-strategies-vs-s)
  - [Plots for strategies vs $p$](#plots-for-strategies-vs-p)
- [Plots for Ground Truth Analysis](#plots-for-ground-truth-analysis)
  - [Plots vs $p$](#plots-vs-p-1)
  - [Plots vs $d$](#plots-vs-d)


# Plots vs $s$

## Plots vs $s$ by $p$

In [None]:
import pickle
from visualization import plot_metrics_vs_param, display_experiment_indices

# === Load results from latest experiment ===
file_path_1 = "Data_final/scan_K1_fixedLR_varS_varP_full_4.pkl"
with open(file_path_1, "rb") as f:
    results = pickle.load(f)

# === Plot: Accuracy vs s (grouped by p), over full range, for all weight_decay values ===
plot_metrics_vs_param(
    results=results,
    param_x="s",
    metrics=["accuracy"],
    group_by="p",
    log_scale_x=True,
    save_path="Results_final/accuracy_vs_s_by_p_full",
    max_overall=True,
    font_scale=1.5
)

# === Filter: only results with weight_decay = 5e-6, then plot reconstruction error vs s ===
result_4 = [exp for exp in results if exp["params"]["weight_decay"] == 5e-6]
plot_metrics_vs_param(
    results=result_4,
    param_x="s",
    metrics=["reconstruction_errors"],
    group_by="p",
    log_scale_x=True,
    log_scale_y=True,
    save_path="Results_final/reconstruction_error_vs_s_by_p_full",
    max_overall=True,
    font_scale=1.5
)

# === Filter: results in reasonable s range (0.1 ≤ s ≤ 100), for clarity in plots ===
results_2 = [exp for exp in results if 0.1 <= exp["params"]["s"] <= 100]

# === Plot: Accuracy vs s for filtered s-range ===
plot_metrics_vs_param(
    results=results_2,
    param_x="s",
    metrics=["accuracy"],
    group_by="p",
    log_scale_x=True,
    save_path="Results_final/accuracy_vs_s_by_p",
    max_overall=True,
    font_scale=1.5
)

# === Plot: Reconstruction error vs s for filtered s-range ===
plot_metrics_vs_param(
    results=results_2,
    param_x="s",
    metrics=["reconstruction_errors"],
    group_by="p",
    log_scale_x=True,
    log_scale_y=True,
    save_path="Results_final/reconstruction_error_vs_s_by_p",
    max_overall=True,
    font_scale=1.5
)

# === Filter: same s-range and weight_decay = 5e-6 ===
results_5 = [exp for exp in results if exp["params"]["weight_decay"] == 5e-6 and 0.1 <= exp["params"]["s"] <= 100]

# === Plot: Reconstruction error vs s with weight_decay fixed ===
plot_metrics_vs_param(
    results=results_5,
    param_x="s",
    metrics=["reconstruction_errors"],
    group_by="p",
    log_scale_x=True,
    log_scale_y=True,
    save_path="Results_final/reconstruction_error_vs_s_by_p_wd_fixed",
    max_overall=True,
    font_scale=1.5
)

# === Add 'final_loss' metric if missing (last validation loss) ===
for exp in results:
    if "final_loss" not in exp["results"]:
        exp["results"]["final_loss"] = exp["results"]["val_losses"][-1]

# === Plot: Final validation loss vs s for all experiments ===
plot_metrics_vs_param(
    results=results,
    param_x="s",
    metrics=["final_loss"],
    group_by="p",
    log_scale_x=True,
    save_path="Results_final/final_loss_vs_s_by_p_full",
    max_overall=True,
    font_scale=1.5
)


In [None]:
import pickle
import copy
from collections import defaultdict
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import sem
from matplotlib import colors as mcolors
from visualization import (
    plot_metrics_vs_param,
    plot_optimal_param_vs_x,
    smart_formatter
)

# === Load experiment results ===
file_path_1 = "Data_final/scan_K1_fixedLR_varS_varP_full_4.pkl"
with open(file_path_1, "rb") as f:
    results = pickle.load(f)

# === Create two separate copies of results to simulate wd fixed vs optimized ===
results_one_wd = [copy.deepcopy(exp) for exp in results if exp["params"]["weight_decay"] == 5e-6 and exp["params"]["p"] in [0.1, 0.2, 0.5]]
results_other = [copy.deepcopy(exp) for exp in results if exp["params"]["p"] in [0.1, 0.2, 0.5]]

# === Assign labels for grouping in plots ===
for exp in results_one_wd:
    exp["params"]["wd"] = "$10^{-5}$"
for exp in results_other:
    exp["params"]["wd"] = "Optimized"

# === Merge both sets for potential group plotting ===
results_3 = results_one_wd + results_other

# === Compute best reconstruction error per s for each p (optimized wd only) ===
grouped_by_p = defaultdict(list)
for exp in results_other:
    p = exp["params"]["p"]
    grouped_by_p[p].append(exp)

p_to_s_best = {}
for p, exps in grouped_by_p.items():
    s_to_errors = defaultdict(list)
    for exp in exps:
        s = exp["params"]["s"]
        s_to_errors[s].append(exp["results"]["reconstruction_errors"])

    s_best_vals = {}
    for s, error_lists in s_to_errors.items():
        best_mean = float("inf")
        best_sem = None
        for errs in error_lists:
            mean_err = np.mean(errs)
            if mean_err < best_mean:
                best_mean = mean_err
                best_sem = sem(errs) if len(errs) > 1 else 0.0
        s_best_vals[s] = (best_mean, best_sem)
    p_to_s_best[p] = s_best_vals

# === Plot optimized vs fixed weight_decay (errorbars) ===
fig, ax = plt.subplots(figsize=(9, 6))
font_scale = 1.5
colors = plt.cm.viridis(np.linspace(0, 1, len(p_to_s_best)))
p_to_color = {p: colors[i] for i, p in enumerate(sorted(p_to_s_best))}

def shift_color(color, factor=0.85):
    """Darken or lighten a color by a factor."""
    r, g, b, a = mcolors.to_rgba(color)
    return (min(r * factor, 1), min(g * factor, 1), min(b * factor, 1), a)

# Optimized curves
for p, s_dict in sorted(p_to_s_best.items()):
    s_sorted = sorted(s_dict)
    means = [s_dict[s][0] for s in s_sorted]
    sems = [s_dict[s][1] for s in s_sorted]
    ax.errorbar(
        s_sorted, means, yerr=sems,
        label=rf"Optimized ($p$={p})",
        capsize=4, marker='o', linestyle='--',
        color=p_to_color[p]
    )

# Fixed wd=1e-5 curves
grouped_fixed_by_p = defaultdict(list)
for exp in results_one_wd:
    grouped_fixed_by_p[exp["params"]["p"]].append(exp)

for p, exps in sorted(grouped_fixed_by_p.items()):
    s_to_errors = defaultdict(list)
    for exp in exps:
        s = exp["params"]["s"]
        s_to_errors[s].append(exp["results"]["reconstruction_errors"])

    s_sorted = sorted(s_to_errors)
    means = [np.mean(sum(s_to_errors[s], [])) for s in s_sorted]
    sems = [sem(sum(s_to_errors[s], [])) if len(s_to_errors[s]) > 1 else 0.0 for s in s_sorted]

    ax.errorbar(
        s_sorted, means, yerr=sems,
        label=rf"$w_d=10^{{-5}}$ ($p$={p})",
        capsize=4, linestyle='-',
        color=shift_color(p_to_color[p], 0.7)
    )

# === Plot formatting ===
ax.set_xlabel(r"$s$", fontsize=14 * font_scale)
ax.set_ylabel(r"Reconstruction Error (scaled)", fontsize=14 * font_scale)
ax.set_xscale("log")
ax.set_yscale("log")
ax.grid(True, linestyle="--", alpha=0.5)
ax.legend(fontsize=11 * font_scale)
ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda val, _: smart_formatter(val)))
ax.tick_params(axis="both", labelsize=12 * font_scale)

plt.tight_layout()
plt.savefig("Results_final/reconstruction_error_vs_s_wd_scan.png", dpi=300)
plt.show()

# === Additional analysis and plots ===

# 1. Optimal weight_decay vs s
plot_optimal_param_vs_x(
    results=results,
    param_x="s",
    metric="reconstruction_errors",
    parameter="weight_decay",
    group_by="p",
    save_path="Results_final/optimal_param_vs_s_groupp_reconstruction",
    font_scale=2,
    log_scale_x=True,
    log_scale_y=True,
)

# 2. Reconstruction error per row
plot_metrics_vs_param(
    results=results,
    param_x="s",
    metrics=["reconstruction_error_scaled_per_row"],
    group_by="p",
    log_scale_x=True,
    save_path="Results_final/reconstruction_error_scaled_per_row_vs_s_by_p",
    max_overall=True,
    font_scale=1.5,
    fill_between=True,
)

# 3. Other related metrics
plot_metrics_vs_param(
    results=results,
    param_x="s",
    metrics=["reconstruction_errors"],
    group_by="p",
    log_scale_x=True,
    save_path="Results_final/reconstruction_errors_vs_s_by_p",
    max_overall=True,
    font_scale=1.5,
    fill_between=True,
)

plot_metrics_vs_param(
    results=results,
    param_x="s",
    metrics=["pearson_corr"],
    group_by="p",
    log_scale_x=True,
    save_path="Results_final/pearson_correlation_vs_s_by_p",
    max_overall=True,
    font_scale=1.5,
    fill_between=True,
)

plot_metrics_vs_param(
    results=results,
    param_x="s",
    metrics=["spearman_corr"],
    group_by="p",
    log_scale_x=True,
    save_path="Results_final/spearman_correlation_vs_s_by_p",
    max_overall=True,
    font_scale=1.5,
    fill_between=True,
)

plot_metrics_vs_param(
    results=results,
    param_x="s",
    metrics=["reconstruction_error_scaled"],
    group_by="p",
    log_scale_x=True,
    save_path="Results_final/reconstruction_error_scaled_vs_s_by_p",
    max_overall=True,
    font_scale=1.5,
    fill_between=True,
)

# 4. Plot α vs s with reference 1/s line
def plot_alpha_vs_s(results, s_min=0.1, s_max=1e5, save_path="Results_final/alpha_vs_s_group_p"):
    filtered = [
        exp for exp in results
        if s_min < exp['params'].get('s') < s_max and exp['params'].get('weight_decay') == 5e-6
    ]

    plot_metrics_vs_param(
        results=filtered,
        param_x="s",
        metrics=["alpha"],
        group_by="p",
        log_scale_x=True,
        log_scale_y=True,
        save_path=save_path,
        sub_plot=False,
        font_scale=2,
        show_plot=False,
        max_overall=True,
        fill_between=True,
    )

    fig = plt.gcf()
    axes = fig.get_axes()

    for ax in axes:
        x_vals = ax.get_lines()[0].get_xdata()
        y_vals = 1 / np.array(x_vals)
        ax.plot(x_vals, y_vals, 'k--', label=r"$1/s$")
        ax.legend(fontsize=12 * 1.5)
        ax.set_xlabel(r"$s$", fontsize=12 * 1.5)
        ax.set_ylabel(r"$\alpha$", fontsize=12 * 1.5)

    plt.savefig(save_path + "_.png", dpi=300)
    plt.show()

# Call α vs s with two ranges
plot_alpha_vs_s(results, s_min=0.1, s_max=1e5)
plot_alpha_vs_s(results, s_min=0.5, s_max=50.1, save_path="Results_final/alpha_vs_s_group_p_small")


## Plots vs $s$ by $k$

In [None]:
import pickle
from visualization import plot_metrics_vs_param, plot_optimal_param_vs_x

# === Load experiment results ===
with open("Data_final/scan_K_logspaceS_wdScan_p0.2_centered_soft_label_True_2.pkl", "rb") as f:
    results = pickle.load(f)

# === Subset: Select specific weight_decay values for visualization ===
results_1 = [exp for exp in results if exp['params']['weight_decay'] in [1e-6, 1e-4]]

# === Plot 1: Accuracy vs s, split by weight_decay, grouped by K ===
plot_metrics_vs_param(
    results=results_1,
    param_x="s",
    metrics=["accuracy"],
    group_by="K",
    split_by="weight_decay",
    log_scale_x=True,
    save_path="Results_final/accuracy_vs_s_by_wd_groupK",
    sub_plot=True,
    font_scale=2,
)

# === Plot 2: Reconstruction error vs s, split by weight_decay, grouped by K ===
plot_metrics_vs_param(
    results=results_1,
    param_x="s",
    metrics=["reconstruction_errors"],
    group_by="K",
    split_by="weight_decay",
    log_scale_x=True,
    log_scale_y=True,
    save_path="Results_final/reconstruction_vs_s_by_wd_groupK",
    sub_plot=True,
    font_scale=2
)

# === Plot 3: Best reconstruction error vs s across all wd, grouped by K ===
plot_metrics_vs_param(
    results=results,
    param_x="s",
    metrics=["reconstruction_errors"],
    max_overall=True,
    group_by="K",
    log_scale_x=True,
    log_scale_y=False,
    save_path="Results_final/reconstruction_vs_s_groupK_optimal",
    sub_plot=True,
    font_scale=2,
)

# === Subset: Select specific K values for optimal param scan ===
results_1 = [exp for exp in results if exp['params']['K'] in [1, 2, 50]]

# === Plot 4: Optimal weight_decay vs s to minimize reconstruction error, grouped by K ===
plot_optimal_param_vs_x(
    results=results_1,
    param_x="s",
    metric="reconstruction_errors",
    parameter="weight_decay",
    group_by="K",
    save_path="Results_final/optimal_param_vs_s_groupK_reconstruction",
    font_scale=2,
    log_scale_x=True,
    log_scale_y=True,
)

# === Plot 5: Optimal weight_decay vs s to maximize accuracy, grouped by K ===
plot_optimal_param_vs_x(
    results=results_1,
    param_x="s",
    metric="accuracy",
    parameter="weight_decay",
    group_by="K",
    save_path="Results_final/optimal_param_vs_s_groupK_accuracy",
    font_scale=2,
    log_scale_x=True,
    log_scale_y=True,
)

# === Subset: Focus on s ∈ [0.1, 100] and weight_decay = 1e-5 ===
results_3 = [
    exp for exp in results
    if 0.1 <= exp['params']['s'] <= 100 and exp['params']['weight_decay'] == 1e-5
]

# === Plot 6: Accuracy vs s (filtered), grouped by K ===
plot_metrics_vs_param(
    results=results_3,
    param_x="s",
    metrics=["accuracy"],
    max_overall=True,
    group_by="K",
    log_scale_x=True,
    save_path="Results_final/accuracy_vs_s_groupK",
    sub_plot=True,
    font_scale=2,
)

# === Plot 7: Reconstruction error vs s (filtered), grouped by K ===
plot_metrics_vs_param(
    results=results_3,
    param_x="s",
    metrics=["reconstruction_errors"],
    max_overall=True,
    group_by="K",
    log_scale_x=True,
    log_scale_y=True,
    save_path="Results_final/reconstruction_vs_s_groupK",
    sub_plot=True,
    font_scale=2,
)


In [None]:
import pickle
import copy
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from scipy.stats import sem
from matplotlib import colors as mcolors
from visualization import plot_metrics_vs_param, smart_formatter

# === Load experiment results ===
file_path = "Data_final/scan_K_logspaceS_wdScan_p0.2_centered_soft_label_True_2.pkl"
with open(file_path, "rb") as f:
    results = pickle.load(f)

# === Filter: select experiments with p = 0.2 and K ∈ {1, 10, 50}, s ∈ [0.01, 100] ===
results_p_02 = [
    exp for exp in results
    if exp["params"]["p"] == 0.2
    and exp["params"]["K"] in [1, 10, 50]
    and 0.01 <= exp["params"]["s"] <= 100
]

# === Separate two groups: fixed weight decay vs optimized ===
results_fixed_wd = [copy.deepcopy(exp) for exp in results_p_02 if exp["params"]["weight_decay"] == 1e-5]
results_optimized = [copy.deepcopy(exp) for exp in results_p_02]

# === Rename groups for legend labels ===
for exp in results_fixed_wd:
    exp["params"]["wd"] = "$10^{-5}$"
for exp in results_optimized:
    exp["params"]["wd"] = "Optimized"

# === Combine both sets for alternative plot if needed ===
results_combined = results_fixed_wd + results_optimized

# Optional:
# plot_metrics_vs_param(
#     results=results_combined,
#     param_x="s",
#     metrics=["reconstruction_errors"],
#     max_overall=True,
#     group_by=["wd", "K"],
#     log_scale_x=True,
#     log_scale_y=True,
#     save_path="Results_final/reconstruction_error_vs_s_wd_by_K",
#     font_scale=1.5,
# )

# === STEP 1: Group optimized experiments by K ===
grouped_by_K = defaultdict(list)
for exp in results_optimized:
    grouped_by_K[exp["params"]["K"]].append(exp)

# === STEP 2: Compute best mean ± SEM per s for each K ===
K_to_s_best = {}
for K, exps in grouped_by_K.items():
    s_to_errors = defaultdict(list)
    for exp in exps:
        s = exp["params"]["s"]
        s_to_errors[s].append(exp["results"]["reconstruction_errors"])

    s_best_vals = {}
    for s, error_lists in s_to_errors.items():
        best_mean = float("inf")
        best_sem = None
        for errs in error_lists:
            mean_err = np.mean(errs)
            if mean_err < best_mean:
                best_mean = mean_err
                best_sem = sem(errs) if len(errs) > 1 else 0.0
        s_best_vals[s] = (best_mean, best_sem)
    K_to_s_best[K] = s_best_vals

# === STEP 3: Plot curves (Optimized vs Fixed wd) with error bars ===
fig, ax = plt.subplots(figsize=(9, 6))
font_scale = 1.5

colors = plt.cm.viridis(np.linspace(0, 1, len(K_to_s_best)))
K_to_color = {K: colors[i] for i, K in enumerate(sorted(K_to_s_best))}

def shift_color(color, factor=0.85):
    """Darken or lighten a given RGBA color."""
    r, g, b, a = mcolors.to_rgba(color)
    return (min(r * factor, 1), min(g * factor, 1), min(b * factor, 1), a)

# === Optimized curves ===
for K, s_dict in sorted(K_to_s_best.items()):
    s_sorted = sorted(s_dict)
    means = [s_dict[s][0] for s in s_sorted]
    sems = [s_dict[s][1] for s in s_sorted]
    base_color = K_to_color[K]
    ax.errorbar(
        s_sorted, means, yerr=sems,
        label=f"Optimized (K={K})",
        capsize=4,
        marker='o',
        linestyle='--',
        color=base_color
    )

# === Fixed wd = 1e-5 curves ===
grouped_fixed_by_K = defaultdict(list)
for exp in results_fixed_wd:
    grouped_fixed_by_K[exp["params"]["K"]].append(exp)

for K, exps in sorted(grouped_fixed_by_K.items()):
    s_to_errors = defaultdict(list)
    for exp in exps:
        s = exp["params"]["s"]
        s_to_errors[s].append(exp["results"]["reconstruction_errors"])

    s_sorted = sorted(s_to_errors)
    means = [np.mean(sum(s_to_errors[s], [])) for s in s_sorted]
    sems = [sem(sum(s_to_errors[s], [])) if len(s_to_errors[s]) > 1 else 0.0 for s in s_sorted]
    shifted_color = shift_color(K_to_color[K], 0.7)

    ax.errorbar(
        s_sorted, means, yerr=sems,
        label=rf"$w_d=10^{{-5}}$ (K={K})",
        capsize=4,
        linestyle='-',
        color=shifted_color
    )

# === Format axes, grid, and legend ===
ax.set_xlabel(r"$s$", fontsize=14 * font_scale)
ax.set_ylabel(r"Reconstruction Error (scaled)", fontsize=14 * font_scale)
ax.set_xscale("log")
ax.set_yscale("log")
ax.grid(True, linestyle="--", alpha=0.5)
ax.legend(fontsize=11 * font_scale)

ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda val, _: smart_formatter(val)))
ax.tick_params(axis="both", labelsize=12 * font_scale)

plt.tight_layout()
plt.savefig("Results_final/reconstruction_error_vs_s_wd_by_K.png", dpi=300)
plt.show()


In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from visualization import plot_metrics_vs_param
from visualization import smart_formatter

# === Load experiment results ===
file_path = "Data_final/scan_K_logspaceS_wdScan_p0.2_centered_soft_label_True_2.pkl"
with open(file_path, "rb") as f:
    results = pickle.load(f)

# === Plot: reconstruction_error_scaled vs s (split by weight_decay, grouped by K) ===
results_2 = [exp for exp in results if exp['params']['weight_decay'] in [1e-6, 1e-4]]
plot_metrics_vs_param(
    results=results_2,
    param_x="s",
    metrics=["reconstruction_error_scaled"],
    group_by="K",
    split_by="weight_decay",
    max_overall=True,
    log_scale_x=True,
    log_scale_y=False,
    save_path="Results_final/reconstruction_scaled_vs_s_by_wd_groupK",
    sub_plot=True,
    font_scale=2
)

# === Plot: reconstruction_error_scaled vs s (grouped by K, across all weight_decay) ===
plot_metrics_vs_param(
    results=results,
    param_x="s",
    metrics=["reconstruction_error_scaled"],
    max_overall=True,
    group_by="K",
    log_scale_x=True,
    log_scale_y=False,
    save_path="Results_final/reconstruction_scaled_vs_s_groupK",
    sub_plot=True,
    font_scale=2,
)

# === Define reusable function to plot alpha vs s and overlay 1/s ===
def plot_alpha_vs_s(results, s_min=-1, s_max=1e5, save_path="Results_final/alpha_vs_s_groupK"):
    filtered = [
        exp for exp in results
        if s_min < exp['params'].get('s') < s_max and exp['params'].get('weight_decay') in [1e-5, 5e-5, 1e-4, 5e-4]
    ]

    plot_metrics_vs_param(
        results=filtered,
        param_x="s",
        metrics=["alpha"],
        group_by="K",
        log_scale_x=True,
        log_scale_y=True,
        split_by="weight_decay",
        save_path=save_path,
        sub_plot=True,
        font_scale=2,
        show_plot=False,
    )

    # Overlay 1/s reference line on each subplot
    fig = plt.gcf()
    axes = fig.get_axes()
    for ax in axes:
        x_vals = ax.get_lines()[0].get_xdata()
        y_vals = 1 / np.array(x_vals)
        ax.plot(x_vals, y_vals, 'k--', label=r"$1/s$")
        ax.legend(fontsize=12 * 2)
    plt.savefig(save_path + ".png", dpi=300)
    plt.show()

# === Call alpha vs s function at different scales ===
plot_alpha_vs_s(results, s_min=0.15, s_max=1e5)
plot_alpha_vs_s(results, s_min=0.15, s_max=10.1, save_path="Results_final/alpha_vs_s_groupK_small")

# === Correlation metrics filtering ===
results_2 = [
    exp for exp in results
    if 0.009 < exp['params'].get('s') < 10 and exp['params'].get('weight_decay') in [1e-5, 5e-5, 1e-4, 5e-4]
]
results_3 = [exp for exp in results if 0.009 < exp['params'].get('s') < 10]

# === Plot: Pearson correlation vs s ===
plot_metrics_vs_param(
    results=results_2,
    param_x="s",
    metrics=["pearson_corr"],
    group_by="K",
    split_by="weight_decay",
    log_scale_x=True,
    save_path="Results_final/pearson_corr_vs_s_by_wd_groupK",
    sub_plot=True,
    max_overall=True,
    font_scale=2,
    fill_between=True,
)

plot_metrics_vs_param(
    results=results_3,
    param_x="s",
    metrics=["pearson_corr"],
    group_by="K",
    max_overall=True,
    log_scale_x=True,
    save_path="Results_final/pearson_corr_vs_s_groupK",
    sub_plot=True,
    font_scale=2,
    fill_between=True,
)

# === Plot: Spearman correlation vs s ===
plot_metrics_vs_param(
    results=results_2,
    param_x="s",
    metrics=["spearman_corr"],
    group_by="K",
    split_by="weight_decay",
    log_scale_x=True,
    save_path="Results_final/spearman_corr_vs_s_by_wd_groupK",
    sub_plot=True,
    font_scale=2,
    max_overall=True,
    fill_between=True,
)

plot_metrics_vs_param(
    results=results_3,
    param_x="s",
    metrics=["spearman_corr"],
    group_by="K",
    max_overall=True,
    log_scale_x=True,
    save_path="Results_final/spearman_corr_vs_s_groupK",
    sub_plot=True,
    font_scale=2,
    fill_between=True,
)

# === Plot: Row-wise scaled reconstruction error vs s (grouped by K) ===
plot_metrics_vs_param(
    results=results,
    param_x="s",
    metrics=["reconstruction_error_scaled_per_row"],
    group_by="K",
    log_scale_x=True,
    save_path="Results_final/reconstruction_error_scaled_per_row_vs_s_groupK",
    max_overall=True,
    font_scale=1.5,
    # ylim=(0, 0.4)  # Uncomment to set custom y-axis limits
)


# Plots explaining the difference in Reconstruction Error Scaled per Row and the Pearson coefficient 
(for high $s$)

## Comparison of the distribution of the values per row

In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from random import choice

# === Load experiment results ===
with open("Data_final/scan_K_logspaceS_wdScan_p0.2_centered_soft_label_True_2.pkl", "rb") as f:
    results = pickle.load(f)

def plot_sampled_comparison_aligned(UVT_row, X_row, title=None, save_path=None, font_scale=1.5, real_index=None):
    """
    Plot a single row from UVT and X, sorted by X values, using dual y-axes.

    Args:
        UVT_row (array-like): Predicted values
        X_row (array-like): Ground-truth values
        title (str): Optional title
        save_path (str): Optional save path for the figure
        font_scale (float): Scaling factor for font sizes
        real_index (int): Optional, for reference
    """
    UVT_row = np.array(UVT_row)
    X_row = np.array(X_row)
    sort_idx = np.argsort(X_row)
    x = np.arange(len(X_row))

    fig, ax1 = plt.subplots(figsize=(8, 5))

    color1 = 'tab:red'
    ax1.set_ylabel(r'$UV^\top$', color=color1, fontsize=12*font_scale)
    ax1.plot(x, UVT_row[sort_idx], color=color1, label=r'$UV^\top$')
    ax1.tick_params(axis='y', labelcolor=color1, labelsize=12*font_scale)
    ax1.tick_params(axis='x', labelsize=12*font_scale)

    ax2 = ax1.twinx()
    color2 = 'tab:blue'
    ax2.set_ylabel(r'$X$', color=color2, fontsize=12*font_scale)
    ax2.plot(x, X_row[sort_idx], color=color2, linestyle='--', label=r'$X$')
    ax2.tick_params(axis='y', labelcolor=color2, labelsize=12*font_scale)

    fig.suptitle(title if title else "$UV^\\top$ vs $X$ (sorted)", fontsize=14*font_scale)
    ax1.set_xlabel("Sorted Index", fontsize=12*font_scale)
    fig.tight_layout()
    ax1.grid(True, linestyle="--", alpha=0.5, which='both')

    if save_path:
        plt.savefig(save_path, bbox_inches="tight")
    plt.show()

def find_closest_index_by_s(results, s_target):
    """
    Find the index of the experiment with s closest to s_target.

    Args:
        results (list): List of experiment dictionaries
        s_target (float): Target s value

    Returns:
        int: Index of the closest experiment
    """
    min_dist = float('inf')
    closest_idx = -1
    for i, res in enumerate(results):
        s_val = res["params"].get("s")
        if s_val is not None:
            dist = abs(s_val - s_target)
            if dist < min_dist:
                min_dist = dist
                closest_idx = i
    return closest_idx

# === Plot comparison for a few s values ===
s_targets = [0.1, 5, 100]
i_indices = [find_closest_index_by_s(results, s_target) for s_target in s_targets]
print(f"Closest indices found for s={s_targets}: {i_indices}")

rep = [0]  # Repetition indices
for k, i in enumerate(i_indices):
    for reps in rep:
        sampled_UVT_rows = results[i]["results"]["sampled_UVT_rows"][reps]
        sampled_X_rows = results[i]["results"]["sampled_X_rows"][reps]
        row_id = choice(range(len(sampled_UVT_rows)))

        UVT_row = sampled_UVT_rows[row_id]
        X_row = sampled_X_rows[row_id]

        plot_sampled_comparison_aligned(
            UVT_row,
            X_row,
            title=f"s = {s_targets[k]}",
            real_index=i,
            save_path=f"Results_final/sample_comparison_s_{s_targets[k]}.png"
        )


## Histrogram of the distribution of the $\alpha_u$

In [None]:
import pickle
import matplotlib.pyplot as plt
import numpy as np

# === Load results ===
file_path_1 = "Data_final/scan_K_logspaceS_wdScan_p0.2_centered_soft_label_True_2.pkl"
with open(file_path_1, "rb") as f:
    results = pickle.load(f)

# Filter experiments for K = 1
results = [exp for exp in results if exp['params'].get('K') == 1]

# Target values of s for which we want histograms
target_s_values = [0.1, 5, 10, 100]

# Find experiments with closest s to each target
selected_experiments = []
selected_s_values = []

for s_target in target_s_values:
    closest_exp = min(results, key=lambda exp: abs(exp["params"]["s"] - s_target))
    s_val = closest_exp["params"]["s"]
    if s_val not in selected_s_values:  # Avoid duplicates
        selected_experiments.append(closest_exp)
        selected_s_values.append(s_val)

print("Selected s values:", selected_s_values)

# === Create 2x2 subplot layout ===
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
axes = axes.flatten()

for idx, s_val in enumerate(selected_s_values):
    # Get corresponding experiment and per-row slopes
    exp = next(exp for exp in results if exp["params"]["s"] == s_val)
    slopes = exp["results"]["alpha_per_row"]
    all_slopes = [v for sublist in slopes for v in sublist] if isinstance(slopes[0], list) else slopes

    # Compute statistics
    the_mean = np.mean(all_slopes)
    alpha_total = exp["results"]["alpha"]
    the_alpha = np.mean(alpha_total)

    # === Plot histogram of per-row slopes with vertical lines ===
    font_scale = 0.7
    ax = axes[idx]
    ax.hist(all_slopes, bins=50, alpha=0.7, color='blue')
    ax.axvline(the_mean, color='red', linestyle='--', label=fr"Mean of slopes = {the_mean:.5f}")
    ax.axvline(the_alpha, color='black', linestyle='--', label=fr"Global Alpha = {the_alpha:.5f}")
    ax.set_title(rf"$s = {s_val}$", fontsize=22 * font_scale)
    ax.set_xlabel(r"$\alpha_u$", fontsize=25 * font_scale)
    ax.set_ylabel("Count", fontsize=20 * font_scale)
    ax.grid(True, linestyle='--', alpha=0.5)
    ax.tick_params(axis='both', labelsize=12 * 2 * font_scale)
    ax.legend(fontsize=20 * font_scale)

# Final layout and saving
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.savefig("Results_final/Slopes_histogram_4_different_s.png", dpi=300, bbox_inches='tight')
plt.show()


## Effect of the outlier and of the noise on the metrics

In [None]:
import numpy as np
from scipy.stats import pearsonr
import matplotlib.pyplot as plt

def compute_metrics(x, y):
    """
    Computes the Pearson correlation and normalized reconstruction error between vectors x and y.
    """
    pearson, _ = pearsonr(x, y)
    reconstruction_error = np.linalg.norm(x - y) / np.linalg.norm(x)
    return pearson, reconstruction_error

def plot_outlier_impact(font_scale=1.5):
    """
    Plots how increasing the magnitude of a single outlier affects both Pearson correlation
    and normalized reconstruction error.
    """
    # === Set font sizes globally ===
    plt.rcParams.update({
        'axes.titlesize': 14 * font_scale,
        'axes.labelsize': 12 * font_scale,
        'xtick.labelsize': 10 * font_scale,
        'ytick.labelsize': 10 * font_scale,
        'legend.fontsize': 10 * font_scale,
        'figure.titlesize': 15 * font_scale
    })

    # === Generate reference vectors ===
    np.random.seed(0)
    x = np.linspace(0, 10, 100)
    y_base = x + np.random.normal(0, 0.5, size=x.shape)

    # === Define outlier magnitudes ===
    outlier_magnitudes = np.linspace(0, 100, 200)
    pearsons = []
    errors = []

    # === Evaluate metrics as the outlier grows ===
    for magnitude in outlier_magnitudes:
        y = y_base.copy()
        y[-1] += magnitude  # Inject outlier
        p, err = compute_metrics(x, y)
        pearsons.append(p)
        errors.append(err)

    # === Create the plot ===
    fig, ax1 = plt.subplots(figsize=(8, 5))

    color1 = 'tab:blue'
    ax1.set_xlabel("Outlier magnitude")
    ax1.set_ylabel("Pearson correlation", color=color1)
    ax1.plot(outlier_magnitudes, pearsons, color=color1, label="Pearson")
    ax1.tick_params(axis='y', labelcolor=color1)
    ax1.set_ylim(0.0, 1.01)
    ax1.grid(True, linestyle="--", alpha=0.3)

    ax2 = ax1.twinx()
    color2 = 'tab:red'
    ax2.set_ylabel("Reconstruction error", color=color2)
    ax2.plot(outlier_magnitudes, errors, color=color2, linestyle='--', label="Reconstruction Error")
    ax2.tick_params(axis='y', labelcolor=color2)
    ax2.set_ylim(0, 1.05)
    ax2.invert_yaxis()  # Optional: inverted for visual emphasis

    fig.suptitle("Effect of Outlier Magnitude on Pearson and Reconstruction Error")
    fig.tight_layout()
    plt.savefig("Results_final/outlier_impact_pearson_reconstruction.png", bbox_inches="tight")
    plt.show()

# === Run the plot ===
plot_outlier_impact()


# Plots for $p\cdot k=const$

In [None]:
from visualization import plot_metrics_vs_param
import pickle

# === Load results ===
file_path = "Data_final/scan_pK_constant_Final_s_wd_sweep.pkl"
with open(file_path, "rb") as f:
    all_results = pickle.load(f)

# === Add product pxK to params for grouping ===
for exp in all_results:
    p = exp['params']['p']
    K = exp['params']['K']
    exp['params']['pxK'] = round(p * K, 4)

# === Filter experiments by value of s ===
results_s1 = [exp for exp in all_results if exp['params'].get('s') == 1]
results_s3 = [exp for exp in all_results if exp['params'].get('s') == 3]
results_s5 = [exp for exp in all_results if exp['params'].get('s') == 5]
results_s8 = [exp for exp in all_results if exp['params'].get('s') == 8]

# === Plot for s = 1 ===
plot_metrics_vs_param(
    results=results_s1,
    param_x="K",
    metrics=["accuracy"],
    group_by="pxK",
    save_path="Results_final/accuracy_vs_K_grouped_by_pxK_s1",
    max_overall=True,
    grid=True,
    sub_plot=False,
    show_plot=True,
    font_scale=1.5,
    GT_plot=False
)
plot_metrics_vs_param(
    results=results_s1,
    param_x="K",
    metrics=["reconstruction_errors"],
    group_by="pxK",
    save_path="Results_final/reconstruction_errors_vs_K_grouped_by_pxK_s1",
    max_overall=True,
    grid=True,
    sub_plot=False,
    show_plot=True,
    font_scale=1.5
)

# === Plot for s = 3 ===
plot_metrics_vs_param(
    results=results_s3,
    param_x="K",
    metrics=["accuracy"],
    group_by="pxK",
    save_path="Results_final/accuracy_vs_K_grouped_by_pxK_s3",
    max_overall=True,
    grid=True,
    sub_plot=False,
    show_plot=True,
    font_scale=1.5,
    GT_plot=False
)
plot_metrics_vs_param(
    results=results_s3,
    param_x="K",
    metrics=["reconstruction_errors"],
    group_by="pxK",
    save_path="Results_final/reconstruction_errors_vs_K_grouped_by_pxK_s3",
    max_overall=True,
    grid=True,
    sub_plot=False,
    show_plot=True,
    font_scale=1.5
)

# === Plot for s = 5 ===
plot_metrics_vs_param(
    results=results_s5,
    param_x="K",
    metrics=["accuracy"],
    group_by="pxK",
    save_path="Results_final/accuracy_vs_K_grouped_by_pxK_s5",
    max_overall=True,
    grid=True,
    sub_plot=False,
    show_plot=True,
    font_scale=1.5,
    GT_plot=False
)
plot_metrics_vs_param(
    results=results_s5,
    param_x="K",
    metrics=["reconstruction_errors"],
    group_by="pxK",
    save_path="Results_final/reconstruction_errors_vs_K_grouped_by_pxK_s5",
    max_overall=True,
    grid=True,
    sub_plot=False,
    show_plot=True,
    font_scale=1.5
)
plot_metrics_vs_param(
    results=results_s5,
    param_x="K",
    metrics=["reconstruction_error_scaled_per_row"],
    group_by="pxK",
    save_path="Results_final/reconstruction_errors_scaled_per_row_vs_K_grouped_by_pxK_s5",
    max_overall=True,
    grid=True,
    sub_plot=False,
    show_plot=True,
    font_scale=1.5
)

# === Plot for s = 8 ===
plot_metrics_vs_param(
    results=results_s8,
    param_x="K",
    metrics=["accuracy"],
    group_by="pxK",
    save_path="Results_final/accuracy_vs_K_grouped_by_pxK_s8",
    max_overall=True,
    grid=True,
    sub_plot=False,
    show_plot=True,
    font_scale=1.5,
    GT_plot=False
)
plot_metrics_vs_param(
    results=results_s8,
    param_x="K",
    metrics=["reconstruction_errors"],
    group_by="pxK",
    save_path="Results_final/reconstruction_errors_vs_K_grouped_by_pxK_s8",
    max_overall=True,
    grid=True,
    sub_plot=False,
    show_plot=True,
    font_scale=1.5
)
plot_metrics_vs_param(
    results=results_s8,
    param_x="K",
    metrics=["reconstruction_error_scaled_per_row"],
    group_by="pxK",
    save_path="Results_final/reconstruction_errors_scaled_per_row_vs_K_grouped_by_pxK_s8",
    max_overall=True,
    grid=True,
    sub_plot=False,
    show_plot=True,
    font_scale=1.5
)

# === Example of other metrics you may want to plot ===
# plot_metrics_vs_param(
#     results=results_s8,
#     param_x="K",
#     metrics=["pearson_corr"],
#     group_by="pxK",
#     save_path="Results_final/pearson_corr_vs_K_grouped_by_pxK_s8",
#     max_overall=True,
#     grid=True,
#     sub_plot=False,
#     show_plot=True,
#     font_scale=1.5
# )


# Plots vs $p$

In [None]:
from visualization import plot_metrics_vs_param
import matplotlib.pyplot as plt
import pickle

# === Load results ===
filename = "Data_final/scan_pK_Final.pkl"
with open(filename, "rb") as f:
    results = pickle.load(f)

# === Plot 1: Accuracy vs p, grouped by K (all results) ===
plot_metrics_vs_param(
    results=results,
    param_x="p",
    metrics=["accuracy"],
    group_by="K",
    title="Accuracy vs p grouped by K",
    save_path="Results_final/accuracy_vs_p_grouped_by_K",
    log_scale_x=True,
    grid=True,
    max_overall=True,
    sub_plot=False,
    show_plot=True,
    font_scale=1.5,
    fill_between=True,
)

# === Filter for selected K values and p range ===
results_k = [
    exp for exp in results 
    if exp['params']['K'] in [1, 3, 10] and 0.05 <= exp['params']['p'] <= 0.5
]

# === Plot 2: Accuracy vs p, for K = 1, 3, 10 ===
plot_metrics_vs_param(
    results=results_k,
    param_x="p",
    metrics=["accuracy"],
    group_by="K",
    title="Accuracy vs p (K = 1, 3, 10)",
    log_scale_x=True,
    grid=True,
    max_overall=True,
    sub_plot=True,
    show_plot=False,
    ylim=(0.49, 0.85),
    font_scale=1.3,
    fill_between=True,
)

# === Add legend and save manually ===
plt.legend(
    loc="lower right",
    fontsize=12 * 1.3,
    title_fontsize=12 * 1.3,
)
plt.savefig("Results_final/accuracy_vs_p_K1_to_10.png", dpi=300, bbox_inches='tight')
plt.show()

# === Plot 3: Reconstruction Error vs p, for K = 1, 3, 10 ===
plot_metrics_vs_param(
    results=results_k,
    param_x="p",
    metrics=["reconstruction_errors"],
    group_by="K",
    title="Reconstruction Error vs p (K = 1, 3, 10)",
    save_path="Results_final/reconstruction_errors_vs_p_grouped_by_K",
    log_scale_x=True,
    grid=True,
    max_overall=True,
    sub_plot=True,
    show_plot=True,
    font_scale=1.3,
    fill_between=True,
)


## Plots for $p\cdot s = const$

In [None]:
import pickle
import matplotlib.pyplot as plt
from visualization import plot_metrics_vs_param

# === Load results from both experiment files ===
with open("Data_final/scan_ps_constant_Final_2.pkl", "rb") as f:
    results = pickle.load(f)
with open("Data_final/scan_ps_constant_Final.pkl", "rb") as f:
    results_2 = pickle.load(f)

# Merge both result lists
results = results + results_2
print("✅ Results loaded successfully!")

# === Preprocess: compute p*s and filter relevant experiments ===
for exp in results:
    p = exp['params']['p']
    s = exp['params']['s']
    exp['params']['p*s'] = round(p * s, 2)  # Add p*s product to params

# Filter to retain only experiments with specific p*s values and s in [1, 9]
target_pxs = [0.9, 0.7, 0.6, 0.55, 0.5, 0.35, 0.25]
results = [
    exp for exp in results 
    if exp['params']['p*s'] in target_pxs and 1 <= exp['params']['s'] <= 9
]

# === Plot 1: Accuracy vs s, grouped by p*s ===
plot_metrics_vs_param(
    results,
    param_x="s",
    metrics=["accuracy"],
    group_by="p*s",
    save_path="Results_final/accuracy_vs_s_grouped_by_pxs_2",
    max_overall=True,
    grid=True,
    sub_plot=False,
    font_scale=1.5,
    show_plot=False,
    fill_between=True,
)
plt.legend(loc='upper left', fontsize=12 * 1.3)
plt.xlim(1, 9.5)
plt.savefig("Results_3/accuracy_vs_s_grouped_by_pxs_2.png", dpi=300, bbox_inches='tight')
plt.show()

# === Plot 2: Reconstruction Error vs s, grouped by p*s ===
plot_metrics_vs_param(
    results,
    param_x="s",
    metrics=["reconstruction_errors"],
    group_by="p*s",
    save_path="Results_final/reconstruction_error_vs_s_grouped_by_pxs_2",
    max_overall=True,
    grid=True,
    sub_plot=False,
    font_scale=1.5,
    fill_between=True,
)

# === Plot 3: Scaled Reconstruction Error per row vs s, grouped by p*s ===
plot_metrics_vs_param(
    results,
    param_x="s",
    metrics=["reconstruction_error_scaled_per_row"],
    group_by="p*s",
    save_path="Results_final/reconstruction_error_scaled_per_row_vs_s_grouped_by_pxs_2",
    max_overall=True,
    grid=True,
    sub_plot=False,
    font_scale=1.5,
)


# Plots for $p$ vs $d$

In [None]:
import pickle
from visualization import plot_all_heatmaps

# === Load experiment results ===
with open("./Data_final/p_d_1.pkl", "rb") as f:
    results = pickle.load(f)

print("✅ Results successfully loaded from './Data_final/p_d_1.pkl'.")

# === Plot heatmap: Accuracy as a function of (p, d) ===
plot_all_heatmaps(
    results,
    save_path="./Results_final/p_d_accuracy_heatmap",
    result_metric="accuracy",
    param_x="p",
    param_y="d",
    fig_size=(10, 5),
    font_scale=1.3
)


# Plots for Strategies

## Plots for strategies vs $s$

In [None]:
import pickle
from visualization import plot_metrics_vs_param

# === Define all strategies to include ===
strategies = ["random", "proximity", "margin", "variance", "popularity", "top_k", "svd"]

# === Load all experiment results from strategy scans ===
results = []
for strategy in strategies:
    file_path = f"Data_strategies/run_vs_s_K1_{strategy}_wd_sweep.pkl"
    with open(file_path, "rb") as f:
        results.extend(pickle.load(f))

# === Define strategy groups for comparison ===
group_1 = {"random", "proximity", "svd", "margin", "top_k"}
group_2 = {"random", "popularity"}

# === Define the metrics to be plotted ===
metrics = [
    "gt_accuracy",
    "accuracy",
    "reconstruction_error_scaled",
    "pearson_corr",
    "spearman_corr"
]

# === Loop over all metrics and generate plots per group ===
for metric in metrics:

    # === Filter results for group 1 ===
    results_group1 = [
        r for r in results
        if r["params"]["strategy"] in group_1 and
           (r["params"].get("weight_decay") == 1e-5 if "alpha" in metric else True)
    ]
    if metric == "reconstruction_errors":
        results_group1 = [r for r in results_group1 if r["params"].get("s", 0) > 0.2]

    # === Plot for group 1 ===
    plot_metrics_vs_param(
        results=results_group1,
        param_x="s",
        metrics=[metric],
        group_by="strategy",
        log_scale_x=True,
        log_scale_y=("loss" in metric or "alpha" in metric),
        sub_plot=True,
        font_scale=1.6,
        max_overall=True,
        save_path=f"Results_strategies/{metric}_vs_s_group_strategy_set1",
        use_color_gradient=False,
        GT_plot=False,
        fill_between=True,
    )

    # === Filter results for group 2 ===
    results_group2 = [
        r for r in results
        if r["params"]["strategy"] in group_2 and
           (r["params"].get("weight_decay") == 1e-5 if "alpha" in metric else True)
    ]
    if metric == "reconstruction_errors":
        results_group2 = [r for r in results_group2 if r["params"].get("s", 0) > 0.2]

    # === Plot for group 2 ===
    plot_metrics_vs_param(
        results=results_group2,
        param_x="s",
        metrics=[metric],
        group_by="strategy",
        log_scale_x=True,
        log_scale_y=("loss" in metric or "alpha" in metric),
        sub_plot=True,
        font_scale=1.6,
        max_overall=True,
        save_path=f"Results_strategies/{metric}_vs_s_group_strategy_set2",
        use_color_gradient=False,
        GT_plot=False,
        fill_between=True,
    )


## Plots for strategies vs $p$

In [None]:
import pickle
from visualization import plot_metrics_vs_param

# === Load results for scans over p ===
strategies = ["random", "proximity", "margin", "variance", "popularity", "top_k", "svd"]
results_p = []

# Load both versions of each strategy's run (original and _2)
for strategy in strategies:
    with open(f"Data_strategies/run_vs_p_{strategy}.pkl", "rb") as f:
        results_p.extend(pickle.load(f))
    with open(f"Data_strategies/run_vs_p_{strategy}_2.pkl", "rb") as f:
        results_p.extend(pickle.load(f))

# === Define strategy groups for comparison ===
group_1 = {"random", "proximity", "svd"}
group_2 = {"random", "popularity"}

# === Define metrics to plot ===
metrics = ["accuracy", "reconstruction_error_scaled", "pearson_corr", "spearman_corr"]

# === Generate plots for each metric and strategy group ===
for metric in metrics:
    # Filter results by strategy group
    results_group1 = [r for r in results_p if r["params"]["strategy"] in group_1]
    results_group2 = [r for r in results_p if r["params"]["strategy"] in group_2]

    # Plot for group 1
    plot_metrics_vs_param(
        results=results_group1,
        param_x="p",
        metrics=[metric],
        group_by="strategy",
        log_scale_x=True,
        log_scale_y=("loss" in metric or "alpha" in metric),
        sub_plot=True,
        font_scale=1.6,
        max_overall=True,
        save_path=f"Results_strategies/{metric}_vs_p_group_strategy_set1",
        use_color_gradient=False,
        GT_plot=False,
        fill_between=True,
    )

    # Plot for group 2
    plot_metrics_vs_param(
        results=results_group2,
        param_x="p",
        metrics=[metric],
        group_by="strategy",
        log_scale_x=True,
        log_scale_y=("loss" in metric or "alpha" in metric),
        sub_plot=True,
        font_scale=1.6,
        max_overall=True,
        save_path=f"Results_strategies/{metric}_vs_p_group_strategy_set2",
        use_color_gradient=False,
        GT_plot=False,
        fill_between=True,
    )


# Plots for Ground Truth Analysis

## Plots vs $p$

In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import sem
from visualization import plot_metrics_vs_param

# === Load the experiment results ===
with open("Data_final/gt_scan_s5_Ksweep_pSweep_n1000.pkl", "rb") as f:
    results = pickle.load(f)
print("✅ Results successfully loaded from Data_final/gt_scan_s5_Ksweep_pSweep_n1000.pkl")

# === Filter experiments for specific K values ===
K_values = [1, 4, 9]
results_subset = [res for res in results if res['params']['K'] in K_values]

# === Plot: GT accuracy vs p, grouped by K ===
plot_metrics_vs_param(
    results=results_subset,
    param_x="p",
    metrics=["gt_accuracy"],
    group_by="K",
    sub_plot=True,
    log_scale_x=True,
    font_scale=1.5,
    save_path="Results_final/gt_accuracy_vs_K"
)

# === Function to aggregate accuracy SEM by a parameter ===
def aggregate_by_param(results, param_key):
    """
    Aggregates mean accuracy and SEM by the given parameter.
    """
    param_values = sorted(set(res['params'][param_key] for res in results))
    means, errors = [], []
    for val in param_values:
        accs = [np.mean(res['results']['gt_accuracy']) for res in results if res['params'][param_key] == val]
        means.append(np.mean(accs))
        errors.append(sem(accs))
    return param_values, means, errors

# === Plot: GT error (SEM on accuracy) vs p in log-scale ===
p_vals, _, p_sems = aggregate_by_param(results, 'p')
font_scale = 1.5

plt.figure(figsize=(7, 5))
plt.plot(p_vals, p_sems, 'd-', label="SEM of GT Accuracy")
plt.title("GT Error vs $p$", fontsize=14 * font_scale)
plt.xlabel("$p$", fontsize=12 * font_scale)
plt.ylabel("Error on Accuracy", fontsize=12 * font_scale)
plt.xscale('log')
plt.xticks(fontsize=10 * font_scale)
plt.yticks(fontsize=10 * font_scale)
plt.grid(True, linestyle="--", alpha=0.5)
plt.tight_layout()
plt.savefig("Results_final/gt_error_vs_p.png", dpi=300)
plt.show()


## Plots vs $d$

In [None]:
import pickle
from visualization import plot_metrics_vs_param

# === Load the experiment results ===
with open("Data_final/scan_d_s_gt.pkl", "rb") as f:
    results = pickle.load(f)
print("✅ Results successfully loaded from Data_final/scan_d_s_gt.pkl")

# === Plot: GT accuracy vs d, grouped by s ===
plot_metrics_vs_param(
    results=results,
    param_x="d",
    metrics=["gt_accuracy"],
    group_by="s",
    save_path="Results_final/gt_accuracy_d_vs_s.png",
    ylim=(0.5, 1),
    font_scale=1.5,
    # log_scale_x=True  # Uncomment if log scale is desired
)
