In [1]:
import pandas as pd
import numpy as np
from scipy.stats import pearsonr, spearmanr, norm
import matplotlib.pyplot as plt
import seaborn as sns
import statsmodels.api as sm

# ----------------------------------------------------------------------
# Global Plotting Configuration
# ----------------------------------------------------------------------
sns.set_theme(style="whitegrid")
plt.rcParams["font.family"] = "serif"
plt.rcParams["figure.figsize"] = (10, 6)
plt.rcParams["figure.dpi"] = 100


# ----------------------------------------------------------------------
# 1. Data Loading and Aggregation
# ----------------------------------------------------------------------
def load_set_results(filepath: str, set_label: str) -> pd.DataFrame | None:
    """
    Load results for a particular model set and add basic derived columns.

    Expected columns (if present):
      - params: number of trainable parameters
      - train_error, test_error: classification error rates
      - gen_gap: (optional) generalisation gap; if absent, computed as test - train
    """
    try:
        df = pd.read_csv(filepath)
    except FileNotFoundError:
        print(f"[Warning] Could not find file: {filepath}. Skipping Set {set_label}.")
        return None

    df["set"] = set_label

    if "gen_gap" not in df.columns and {"train_error", "test_error"} <= set(df.columns):
        df["gen_gap"] = df["test_error"] - df["train_error"]

    if "params" in df.columns:
        df["log_params"] = np.log10(df["params"])

    return df


def aggregate_set_c_data(df_c_raw: pd.DataFrame | None) -> pd.DataFrame | None:
    """
    Pass-through for Set C results (single-run, per architecture).
    """
    if df_c_raw is None:
        return None

    if "id" not in df_c_raw.columns:
        raise KeyError("Set C results must contain an 'id' column for architecture grouping.")

    df_c = df_c_raw.copy()
    df_c["set"] = "C"
    if "params" in df_c.columns:
        df_c["log_params"] = np.log10(df_c["params"])

    print(f"\n[Set C Loaded] {len(df_c)} architectures (single-run results).")
    return df_c


def load_and_aggregate_all_sets() -> tuple[
    pd.DataFrame,
    pd.DataFrame | None,
    pd.DataFrame | None,
    pd.DataFrame | None,
    pd.DataFrame | None,
]:
    """
    Load all sets (A, B, C) and prepare:

      - df_all_mean: concatenated A, B, C
      - df_a: Set A
      - df_b: Set B
      - df_c: Set C (single-run, per architecture)
      - df_c_raw: Set C raw runs (identical to df_c here)
    """
    df_a = load_set_results("dissertation_results_set_a.csv", "A")
    df_b = load_set_results("dissertation_results_set_b.csv", "B")
    df_c_raw = load_set_results("dissertation_results_set_c.csv", "C")

    df_c = aggregate_set_c_data(df_c_raw)

    frames_mean = [d for d in (df_a, df_b, df_c) if d is not None]
    if not frames_mean:
        raise RuntimeError("No result files found for Sets A, B, or C.")

    df_all_mean = pd.concat(frames_mean, ignore_index=True)

    return df_all_mean, df_a, df_b, df_c, df_c_raw


# ----------------------------------------------------------------------
# 2. Statistical Analysis
# ----------------------------------------------------------------------
def fisher_ci_for_correlation(r: float, n: int, alpha: float = 0.05) -> tuple[float, float]:
    """
    Compute the (1 - alpha) confidence interval for a correlation coefficient
    (Pearson or Spearman) using Fisher's z-transform.
    """
    if n <= 3 or np.isclose(r, 1.0) or np.isclose(r, -1.0):
        return np.nan, np.nan

    z = np.arctanh(r)
    se_z = 1.0 / np.sqrt(n - 3)
    z_crit = norm.ppf(1 - alpha / 2.0)

    z_low = z - z_crit * se_z
    z_high = z + z_crit * se_z

    r_low = np.tanh(z_low)
    r_high = np.tanh(z_high)

    return r_low, r_high


def run_complexity_analysis(
    df: pd.DataFrame,
    label: str,
    metrics: list[str],
    output_csv: str | None = None,
) -> pd.DataFrame:
    """
    Perform correlation and simple linear regression analyses.

    Implements Methodology Section 4.5:
      - Pearson and Spearman correlations (with p-values)
      - 95% CI for Pearson r and Spearman rho (Fisher z)
      - Simple linear regression: R^2 + F-test p-value
      - 95% CI for R^2 derived from Pearson r CI (R^2 = r^2 in 1D regression),
        with lower bound set to 0 when the r-CI crosses zero.
    """
    print(f"\n--- Complexity Analysis: {label} (N={len(df)}) ---")

    rows: list[dict] = []

    for metric in metrics:
        if metric not in df.columns:
            print(f"[Info] Metric '{metric}' not found in DataFrame for {label}; skipping.")
            continue

        sub = df[["gen_gap", metric]].dropna()
        if sub.empty or len(sub) < 2:
            print(f"[Info] Not enough data for metric '{metric}' in {label} (N={len(sub)}); skipping.")
            continue

        y = sub["gen_gap"]
        x = sub[metric]
        n = len(sub)

        r_pearson, p_pearson = pearsonr(x, y)
        r_spearman, p_spearman = spearmanr(x, y)

        r_ci_low, r_ci_high = fisher_ci_for_correlation(r_pearson, n)
        rho_ci_low, rho_ci_high = fisher_ci_for_correlation(r_spearman, n)

        X_reg = sm.add_constant(x)
        model = sm.OLS(y, X_reg).fit()
        r_squared = model.rsquared
        f_pvalue = model.f_pvalue

        if not np.isnan(r_ci_low):
            if r_ci_low <= 0 <= r_ci_high:
                r2_ci_low = 0.0
                r2_ci_high = max(r_ci_low**2, r_ci_high**2)
            else:
                r2_ci_low = min(r_ci_low**2, r_ci_high**2)
                r2_ci_high = max(r_ci_low**2, r_ci_high**2)
        else:
            r2_ci_low = np.nan
            r2_ci_high = np.nan

        rows.append(
            {
                "Metric": metric,
                "Pearson_r": r_pearson,
                "Pearson_p": p_pearson,
                "Pearson_r_CI_low": r_ci_low,
                "Pearson_r_CI_high": r_ci_high,
                "Spearman_rho": r_spearman,
                "Spearman_p": p_spearman,
                "Spearman_rho_CI_low": rho_ci_low,
                "Spearman_rho_CI_high": rho_ci_high,
                "R_squared": r_squared,
                "R2_sig_p": f_pvalue,
                "R_squared_CI_low": r2_ci_low,
                "R_squared_CI_high": r2_ci_high,
                "N": n,
            }
        )

    if not rows:
        print(f"[Warning] No valid metrics for analysis in {label}.")
        return pd.DataFrame()

    stats_df = pd.DataFrame(rows)

    print(f"\n[Correlation and Regression Summary – {label}]")
    print(stats_df.set_index("Metric").round(4))

    if output_csv is not None:
        stats_df.to_csv(output_csv, index=False)

    return stats_df


def run_hypothesis_tests_set_c(df_c: pd.DataFrame) -> pd.DataFrame:
    """
    Specialised statistical validation for Set C to support H3 and H4.

    Metrics considered:
      - gscm_score (proposed metric)
      - l2_norm
      - spectral_norm
      - sharpness
    """
    metrics = ["gscm_score", "l2_norm", "spectral_norm", "sharpness"]
    stats_df = run_complexity_analysis(
        df_c,
        label="Set C (Core Experiment)",
        metrics=metrics,
        output_csv="table_5_1_complexity_metrics_summary_set_c.csv",
    )

    if stats_df.empty:
        return stats_df

    best_r2_idx = stats_df["R_squared"].idxmax()
    best_r2_row = stats_df.loc[best_r2_idx]

    print("\n--- Hypothesis-Oriented Summary for Set C (H3 & H4) ---")
    print(
        f"H4 (Predictive Power): '{best_r2_row['Metric']}' achieves the highest "
        f"R^2 (R^2 = {best_r2_row['R_squared']:.3f}, p = {best_r2_row['R2_sig_p']:.4f})."
    )

    return stats_df


def export_set_c_descriptives(
    df_c: pd.DataFrame,
    output_csv: str = "table_4_1_set_c_descriptive_stats.csv",
) -> None:
    """
    Export per-architecture results for Set C (single-run values).
    """
    columns = [
        "id",
        "params",
        "log_params",
        "train_error",
        "test_error",
        "gen_gap",
        "gscm_score",
        "l2_norm",
        "spectral_norm",
        "sharpness",
    ]
    cols_present = [c for c in columns if c in df_c.columns]
    df_c[cols_present].to_csv(output_csv, index=False)
    print(f"\n[Descriptives Exported] Set C descriptive stats saved to {output_csv}.")


# ----------------------------------------------------------------------
# 3. Visualisations (Figures 5.1–5.4)
# ----------------------------------------------------------------------
def plot_double_descent_set_c(
    df_c: pd.DataFrame,
    output_path: str = "figure_5_1_double_descent_set_c.png",
) -> None:
    """
    Figure 5.1: Empirical Double Descent Curve (Model Set C)

    - Train and test error vs log10(params)
    """
    if df_c is None or df_c.empty:
        print("[Warning] Set C DataFrame is empty; skipping Figure 5.1.")
        return

    df_sorted = df_c.sort_values("log_params")
    x = df_sorted["log_params"]

    plt.figure(figsize=(10, 6))

    y_test = df_sorted["test_error"]
    plt.plot(x, y_test, marker="o", linestyle="-", label="Test Error")

    y_train = df_sorted["train_error"]
    plt.plot(x, y_train, marker="x", linestyle="--", label="Train Error")

    plt.title("Figure 5.1: Empirical Double Descent Curve (Model Set C)")
    plt.xlabel("Model Complexity (log10 of Trainable Parameters)")
    plt.ylabel("Classification Error Rate")
    plt.legend()
    plt.grid(True, which="both", linestyle="--", alpha=0.7)
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()


def plot_complexity_vs_gen_gap_set_c(
    df_c: pd.DataFrame,
    output_path: str = "figure_5_2_complexity_vs_gen_gap_set_c.png",
) -> None:
    """
    Figure 5.2: Complexity Metrics versus Generalisation Gap (Model Set C)

    Panels:
      - L2 Norm vs GenGap
      - Spectral Norm vs GenGap
      - GSCM vs GenGap
    """
    if df_c is None or df_c.empty:
        print("[Warning] Set C DataFrame is empty; skipping Figure 5.2.")
        return

    metrics_to_plot = [
        ("l2_norm", r"$\ell_2$ Norm ($N$)"),
        ("spectral_norm", "Spectral Norm Proxy"),
        ("gscm_score", "GSCM (Proposed Metric)"),
    ]

    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    for ax, (metric_col, display_name) in zip(axes, metrics_to_plot):
        if metric_col not in df_c.columns or "gen_gap" not in df_c.columns:
            ax.set_title(f"{display_name}\nPearson r = N/A (missing data)")
            ax.axis("off")
            continue

        base = df_c[[metric_col, "gen_gap"]].dropna()
        if base.empty:
            ax.set_title(f"{display_name}\nPearson r = N/A (no data)")
            ax.axis("off")
            continue

        x = base[metric_col]
        y = base["gen_gap"]
        n = len(base)

        if n >= 2:
            sns.regplot(
                x=x,
                y=y,
                ax=ax,
                scatter_kws={"s": 40, "alpha": 0.7},
                line_kws={"linewidth": 2},
            )
        else:
            ax.scatter(x, y, s=40, alpha=0.7)

        if n >= 2:
            r_val, _ = pearsonr(x, y)
            title_suffix = f"Pearson r = {r_val:.3f}"
        else:
            title_suffix = "Pearson r = N/A (N < 2)"

        ax.set_title(f"{display_name}\n{title_suffix}")
        ax.set_xlabel(display_name)
        ax.set_ylabel("Generalisation Gap")
        ax.grid(True, which="both", linestyle="--", alpha=0.5)

    plt.tight_layout()
    plt.suptitle(
        "Figure 5.2: Complexity Metrics versus Generalisation Gap (Model Set C)",
        y=1.02,
        fontsize=16,
    )
    plt.savefig(output_path, bbox_inches="tight")
    plt.close()


def plot_gen_gap_vs_log_params_all_sets(
    df_all: pd.DataFrame,
    output_path: str = "figure_5_3_gen_gap_vs_log_params_all_sets.png",
) -> None:
    """
    Figure 5.3: Generalisation Gap vs Model Complexity (All Sets)
    """
    if df_all is None or df_all.empty:
        print("[Warning] Global DataFrame is empty; skipping Figure 5.3.")
        return

    plt.figure(figsize=(10, 6))
    sns.scatterplot(data=df_all, x="log_params", y="gen_gap", hue="set", style="set", s=60)

    plt.title("Figure 5.3: Generalisation Gap vs Model Complexity (All Sets)")
    plt.xlabel("Model Complexity (log10 of Trainable Parameters)")
    plt.ylabel("Generalisation Gap")
    plt.grid(True, which="both", linestyle="--", alpha=0.6)
    plt.legend(title="Model Set")
    plt.tight_layout()
    plt.savefig(output_path)
    plt.close()


def plot_gen_gap_vs_log_params_by_set(
    df_a: pd.DataFrame | None,
    df_b: pd.DataFrame | None,
    df_c: pd.DataFrame | None,
    output_path: str = "figure_5_4_gen_gap_vs_log_params_by_set.png",
) -> None:
    """
    Figure 5.4: Generalisation Gap vs Model Complexity (Per Set)
    """
    sets_data = [("A", df_a), ("B", df_b), ("C", df_c)]
    available = [(label, df) for (label, df) in sets_data if df is not None and not df.empty]

    if not available:
        print("[Warning] No per-set data available; skipping Figure 5.4.")
        return

    n_sets = len(available)
    fig, axes = plt.subplots(1, n_sets, figsize=(6 * n_sets, 5), sharey=True)

    if n_sets == 1:
        axes = [axes]

    for ax, (label, df) in zip(axes, available):
        sns.scatterplot(data=df, x="log_params", y="gen_gap", s=60, ax=ax)
        ax.set_title(f"Set {label}")
        ax.set_xlabel("log10(Trainable Parameters)")
        ax.set_ylabel("Generalisation Gap")
        ax.grid(True, which="both", linestyle="--", alpha=0.6)

    plt.tight_layout()
    plt.suptitle("Figure 5.4: Generalisation Gap vs Model Complexity (Per Set)", y=1.03, fontsize=16)
    plt.savefig(output_path)
    plt.close()


# ----------------------------------------------------------------------
# 4. Main orchestration
# ----------------------------------------------------------------------
def main() -> None:
    """
    Run the full analysis pipeline for Sets A, B, and C.
    """
    df_all_mean, df_a, df_b, df_c, df_c_raw = load_and_aggregate_all_sets()

    global_metrics = ["gscm_score", "l2_norm", "spectral_norm"]
    run_complexity_analysis(
        df_all_mean,
        label="All Sets (A + B + C)",
        metrics=global_metrics,
        output_csv="table_5_2_complexity_metrics_summary_all_sets.csv",
    )

    plot_gen_gap_vs_log_params_all_sets(df_all_mean)
    plot_gen_gap_vs_log_params_by_set(df_a, df_b, df_c)

    if df_c is not None and not df_c.empty:
        export_set_c_descriptives(df_c)
        run_hypothesis_tests_set_c(df_c)
        plot_double_descent_set_c(df_c)
        plot_complexity_vs_gen_gap_set_c(df_c)
    else:
        print("[Warning] Set C results not available or empty; skipping Set C-specific analysis and plots.")

    print("\nAnalysis complete.")


if __name__ == "__main__":
    main()



[Set C Loaded] 15 architectures (single-run results).

--- Complexity Analysis: All Sets (A + B + C) (N=18) ---

[Correlation and Regression Summary – All Sets (A + B + C)]
               Pearson_r  Pearson_p  Pearson_r_CI_low  Pearson_r_CI_high  \
Metric                                                                     
gscm_score        0.0860     0.7607           -0.4459             0.5730   
l2_norm           0.7935     0.0001            0.5189             0.9197   
spectral_norm     0.9316     0.0000            0.8224             0.9746   

               Spearman_rho  Spearman_p  Spearman_rho_CI_low  \
Metric                                                         
gscm_score           0.0500      0.8595              -0.4744   
l2_norm              0.1063      0.6746              -0.3794   
spectral_norm        0.7358      0.0005               0.4097   

               Spearman_rho_CI_high  R_squared  R2_sig_p  R_squared_CI_low  \
Metric                                        