In [2]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import LinearSegmentedColormap

# ======================================================
# MODEL DIRECTORIES (DEDUPLICATED)
# ======================================================
#model_dirs = list(set([
    #r"C:\Users\Gilles\Documents\school\Ma2\Thesis\VSC\IRIS\Own_code\results\ClassicalModel_NoRelu_Bias",
    #r"C:\Users\Gilles\Documents\school\Ma2\Thesis\VSC\IRIS\Own_code\results\ClassicalModel_NoRelu_Nobias",
    #r"C:\Users\Gilles\Documents\school\Ma2\Thesis\VSC\IRIS\Own_code\results\ClassicalModel_Relu_bias",
    #r"C:\Users\Gilles\Documents\school\Ma2\Thesis\VSC\IRIS\Own_code\results\ClassicalModel_ReLu_NoBias",
    #r"C:\Users\Gilles\Documents\school\Ma2\Thesis\VSC\IRIS\Own_code\results\OrthoModel_NoRelu_bias",
    #r"C:\Users\Gilles\Documents\school\Ma2\Thesis\VSC\IRIS\Own_code\results\OrthoModel_NoRelu_NoBias",
    #r"C:\Users\Gilles\Documents\school\Ma2\Thesis\VSC\IRIS\Own_code\results\OrthoModel_Relu_bias",
    #r"C:\Users\Gilles\Documents\school\Ma2\Thesis\VSC\IRIS\Own_code\results\OrthoModel_Relu_Nobias"
#]))

model_dirs = list(set([r"C:\Users\Gilles\Documents\school\Ma2\Thesis\VSC\IRIS\Own_code\results\OrthoModel_Relu_bias\2nd run"
]))

# ======================================================
# METRICS
# ======================================================
metrics = [
    "final_train_loss",
    "final_valid_loss",
    "final_train_accuracy",
    "final_valid_accuracy",
    "epochs_to_final_train",
    "epochs_to_final_valid"
]

higher_is_better = {
    "final_train_accuracy",
    "final_valid_accuracy"
}

# ======================================================
# GLOBAL COLORBAR LIMITS (ROBUST)
# ======================================================
def compute_global_limits(model_dirs, metrics, q_low=5, q_high=95):
    values = {m: [] for m in metrics}

    for d in model_dirs:
        for f in os.listdir(d):
            if f.endswith(".csv"):
                df = pd.read_csv(os.path.join(d, f))
                for m in metrics:
                    values[m].extend(df[m].dropna().values)

    limits = {}
    for m, v in values.items():
        v = np.array(v)
        limits[m] = (
            np.percentile(v, q_low),
            np.percentile(v, q_high)
        )
    return limits

global_limits = compute_global_limits(model_dirs, metrics)

# ======================================================
# LIGHT GRAYSCALE COLORMAP (NO DARK TILES)
# ======================================================
base_gray = cm.get_cmap("gray")
light_gray = LinearSegmentedColormap.from_list(
    "light_gray",
    base_gray(np.linspace(0.18, 0.95, 256))  # <- no dark gray
)

# ======================================================
# ARCHITECTURE GRID
# ======================================================
layer_vals = [1, 2, 3, 4]
neuron_vals = [1, 2, 3, 4]

# ======================================================
# MAIN LOOP — PER MODEL, PER METRIC
# ======================================================
for model_dir in model_dirs:

    model_name = os.path.basename(model_dir)
    save_dir = os.path.join(model_dir, "graphs_global_scale")
    os.makedirs(save_dir, exist_ok=True)

    # --- LOAD DATA ---
    df = pd.concat(
        [
            pd.read_csv(os.path.join(model_dir, f))
            for f in os.listdir(model_dir)
            if f.endswith(".csv")
        ],
        ignore_index=True
    )

    df["hidden_sizes"] = df["hidden_sizes"].astype(str)

    summary = {}
    for key, g in df.groupby("hidden_sizes"):
        summary[key] = {}
        for m in metrics:
            summary[key][f"{m}_mean"] = g[m].mean()
            summary[key][f"{m}_std"]  = g[m].std()

    summary_df = pd.DataFrame(summary).T

    # ==================================================
    # PLOTTING
    # ==================================================
    for metric in metrics:

        fig, ax = plt.subplots(figsize=(6.8, 6))

        ax.set_xticks(range(4))
        ax.set_yticks(range(4))
        ax.set_xticklabels(neuron_vals)
        ax.set_yticklabels(layer_vals)
        ax.set_xlabel("Neurons per layer")
        ax.set_ylabel("Number of layers")
        ax.set_title(f"{model_name}\n{metric.replace('_', ' ').title()}")

        data = np.full((4, 4), np.nan)

        for i, n in enumerate(neuron_vals):
            for j, l in enumerate(layer_vals):
                arch = str([n] * l)
                if arch in summary_df.index:
                    data[j, i] = summary_df.loc[arch, f"{metric}_mean"]

        vmin, vmax = global_limits[metric]

        if metric in higher_is_better:
            norm = plt.Normalize(vmin, vmax)
        else:
            norm = plt.Normalize(vmax, vmin)

        im = ax.imshow(
            data,
            origin="lower",
            cmap=light_gray,
            norm=norm,
            extent=(-0.5, 3.5, -0.5, 3.5),
            interpolation="nearest"
        )

        # --- CELL ANNOTATIONS (BLACK TEXT) ---
        for i, n in enumerate(neuron_vals):
            for j, l in enumerate(layer_vals):
                arch = str([n] * l)
                if arch in summary_df.index:
                    mean = summary_df.loc[arch, f"{metric}_mean"]
                    std  = summary_df.loc[arch, f"{metric}_std"]
                    ax.text(
                        i, j,
                        f"{mean:.4f}\n±{std:.4f}",
                        ha="center", va="center",
                        fontsize=9, color="black"
                    )

                ax.add_patch(
                    plt.Rectangle((i-0.5, j-0.5), 1, 1,
                                  fill=False, linewidth=0.8)
                )

        # --- COLORBAR (GLOBAL SCALE, PUSHED RIGHT) ---
        cbar = plt.colorbar(
            im,
            ax=ax,
            fraction=0.035,
            pad=0.22
        )
        cbar.set_label("Global performance scale")

        ax.set_xlim(-0.5, 3.5)
        ax.set_ylim(-0.5, 3.5)

        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, f"{metric}.png"), dpi=300)
        plt.close(fig)

print("✅ All models plotted with shared global color scales.")


  base_gray = cm.get_cmap("gray")


✅ All models plotted with shared global color scales.
