# Scaling law analysis and visualization

This notebook has been written to run on the TU Ilmenau cluster with my specific setup.
However, it can also be used simply to visualize the scaling laws. 
A summary of all results is also stored in the directory `/research/scaling_data`.

In [None]:
from __future__ import annotations

import os
import re
import json
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Callable, Any
from collections import defaultdict
from scipy.optimize import curve_fit
from matplotlib.lines import Line2D
from matplotlib.colors import LogNorm
from matplotlib.ticker import FuncFormatter
from matplotlib.ticker import FixedLocator, ScalarFormatter
from tensorboard.backend.event_processing import event_accumulator

import torch_geometric

from optimetal.evaluation import Evaluator
from optimetal.data.loader import load_torch_data, create_dataloader

def init_empty_results() -> dict:
    """
    Initialize an empty 'results' structure for the scaling law study.
    """
    return {
        "data": {
            "optimate": {},
            "2b": {
                "variant1": {},
                "variant2": {},
            },
            "3b": {},
        },
        "parameter": {
            "2b": {
                "variant1": {},
                "variant2": {},
            },
            "3b": {},
        },
    }

def init_empty_results_grid() -> dict:
    """
    Initialize an empty 'results' structure for the scaling law grid study.
    """
    return {
        "2b": {},
        "3b": {},
    }

def extract_processed_dirs(results: dict) -> set[str]:
    """
    Recursively walk all leaf lists in results and collect their 'study_dir'.
    """
    processed = set()
    def _walk(obj: Any) -> None:
        if isinstance(obj, dict):
            for v in obj.values():
                _walk(v)
        elif isinstance(obj, list):
            for entry in obj:
                if "study_dir" in entry:
                    processed.add(entry["study_dir"])

    _walk(results)
    return processed

def load_tb_scalars(logdir: str) -> dict:
    """
    Load scalar values from tensorboard event files. This is useful
    when you want to look at training and validation loss curves.
    """
    ea = event_accumulator.EventAccumulator(
        logdir,
        size_guidance={event_accumulator.SCALARS: 0},
    )
    ea.Reload()
    tags = ea.Tags().get("scalars", [])
    tb_log = {}
    for tag in tags:
        events = ea.Scalars(tag)
        values = [e.value for e in events]
        tb_log[tag] = values
    return tb_log

def eval_model(
    best_model_path: str, 
    dataloader: torch_geometric.loader.DataLoader,
    device_index=0,
) -> dict:
    """
    Use the Evaluator class to gather metrics for a given model.
    """
    evaluator = Evaluator(
        best_model_path=best_model_path, 
        dataloader=dataloader, 
        device_index=device_index,
        turn_off_progress_bar=True,
    )
    num_parameter = evaluator.num_parameter
    evaluator.evaluate()
    metric_dict = {
        "mean_metrics": evaluator.mean_metrics,
        "median_metrics": evaluator.median_metrics,
        "std_metrics": evaluator.std_metrics,
        "drude_r2": evaluator.drude_r2,
    }
    return num_parameter, metric_dict
    
def process_one(
        study_path: str,
        study_dir: str,
        results: dict,
        dataloader: torch_geometric.loader.DataLoader,
) -> dict:
    """
    Load and evaluate a single 'study_dir', then insert its entry into the right place in 'results'.
    Input:
        study_path:     Path to the root directory containing study result subdirectories
        study_dir:      Name of the study directory in 'study_path'
        results:        Nested dict mapping study types and hyperparameters to dictionaries
        dataloader:     Used to evaluate the best model in the study directory
    Output:
        results:        Nested dict mapping study types and hyperparameters to dictionaries 
                        with 'best_val_loss', 'train_loss', and 'val_loss'
    """
    # path setup and checks
    study_dir_path = os.path.join(study_path, study_dir)
    val_loss_path = os.path.join(study_dir_path, "val_loss.txt")
    best_model_path = os.path.join(study_dir_path, "best_model.pt")
    if not os.path.exists(val_loss_path) or not os.path.exists(best_model_path):
        print(f"Skipping {study_dir_path:s}, probably still running")
        return
    # parse the seed (metadata)
    seed = int(re.search(r"seed(\d+)", study_dir).group(1))
    # load the data from the tensorboard log and validation loss file
    best_val_loss = float(np.loadtxt(val_loss_path))
    tb_log = load_tb_scalars(study_dir_path)
    val_loss = tb_log.get("val/loss", [])
    min_idx = np.argmin(val_loss)
    best_eps_loss = tb_log.get("val/eps", [])[min_idx]
    best_drude_loss = tb_log.get("val/drude", [])[min_idx]
    result_entry = {
        "study_dir": study_dir,
        "seed": seed,
        "val_loss": best_val_loss,
        "eps_loss": best_eps_loss,
        "drude_loss": best_drude_loss,
    }
    # use the "Evaluator" class to evaluate the best model and obtain more detailed metrics
    # (this can take some time...)
    num_parameter, metric_dict = eval_model(
        best_model_path=best_model_path,
        dataloader=dataloader,
    )
    result_entry["num_parameter"] = num_parameter
    result_entry["metric_dict"] = metric_dict
    # data scaling
    if "hestness" in study_dir:
        num_data = re.search(r"data(\d+)", study_dir).group(1)
        if "optimate" in study_dir:
            results["data"]["optimate"].setdefault(num_data, []).append(result_entry)
        elif "2b" in study_dir:
            if "variant1" in study_dir:
                results["data"]["2b"]["variant1"].setdefault(num_data, []).append(result_entry)
            elif "variant2" in study_dir:
                results["data"]["2b"]["variant2"].setdefault(num_data, []).append(result_entry)
        elif "3b" in study_dir:
            results["data"]["3b"].setdefault(num_data, []).append(result_entry)
    # parameter scaling
    elif "kaplan" in study_dir:
        width = re.search(r"width(\d+)", study_dir).group(1)
        if "2b" in study_dir:
            if "variant1" in study_dir:
                results["parameter"]["2b"]["variant1"].setdefault(width, []).append(result_entry)
            elif "variant2" in study_dir:
                results["parameter"]["2b"]["variant2"].setdefault(width, []).append(result_entry)
        elif "3b" in study_dir:
            results["parameter"]["3b"].setdefault(width, []).append(result_entry)

def process_one_grid(
        study_path: str,
        study_dir: str,
        results: dict,
        dataloader: torch_geometric.loader.DataLoader,
) -> dict:
    """
    Load and evaluate a single 'study_dir', then insert its entry into the right place in 'results'.
    This version is for the scaling law grid study, where the data and parameter scaling are combined.
    Input:
        study_path:     Path to the root directory containing study result subdirectories
        study_dir:      Name of the study directory in 'study_path'
        results:        Nested dict mapping study types and hyperparameters to dictionaries
        dataloader:     Used to evaluate the best model in the study directory
    Output:
        results:        Nested dict mapping study types and hyperparameters to dictionaries 
                        with 'best_val_loss', 'train_loss', and 'val_loss'
    """
    # path setup and checks
    study_dir_path = os.path.join(study_path, study_dir)
    val_loss_path = os.path.join(study_dir_path, "val_loss.txt")
    best_model_path = os.path.join(study_dir_path, "best_model.pt")
    if not os.path.exists(val_loss_path) or not os.path.exists(best_model_path):
        print(f"Skipping {study_dir_path:s}, probably still running")
        return
    # parse the seed (metadata)
    seed = int(re.search(r"seed(\d+)", study_dir).group(1))
    # load the data from the tensorboard log and validation loss file
    best_val_loss = float(np.loadtxt(val_loss_path))
    tb_log = load_tb_scalars(study_dir_path)
    val_loss = tb_log.get("val/loss", [])
    min_idx = np.argmin(val_loss)
    best_eps_loss = tb_log.get("val/eps", [])[min_idx]
    best_drude_loss = tb_log.get("val/drude", [])[min_idx]
    result_entry = {
        "study_dir": study_dir,
        "seed": seed,
        "val_loss": best_val_loss,
        "eps_loss": best_eps_loss,
        "drude_loss": best_drude_loss,
    }
    # use the "Evaluator" class to evaluate the best model and obtain more detailed metrics
    # (this can take some time...)
    num_parameter, metric_dict = eval_model(
        best_model_path=best_model_path,
        dataloader=dataloader,
    )
    result_entry["num_parameter"] = num_parameter
    result_entry["metric_dict"] = metric_dict
    # put the result entry into the right place in the results dictionary
    num_data = re.search(r"data(\d+)", study_dir).group(1)
    width = re.search(r"width(\d+)", study_dir).group(1)
    if "2b" in study_dir:
        results["2b"].setdefault(num_data, {}).setdefault(width, []).append(result_entry)
    elif "3b" in study_dir:
        results["3b"].setdefault(num_data, {}).setdefault(width, []).append(result_entry)

def _filter_nonempty(data: dict) -> dict:
    """
    Return only keys whose value list is non-empty.
    """
    return {k: v for k, v in data.items() if v}

def _regroup_by_entry_key(data: dict, x_key: str) -> dict:
    """
    Convert a dictionary that is keyed by hyperparameters into a dictionary that is keyed by entries.
    This is useful when you want the x-axis to display a value stored in each entry. For example, you could use 'num_parameter'.
    """
    regroup = {}
    for _, entries in data.items():
        for entry in entries:
            x_val = entry[x_key]
            regroup.setdefault(x_val, []).append(entry)
    return regroup

def plot_scaling_law(
    ax: plt.Axes, 
    data: dict, 
    label: str, 
    color: str,
    key: str = "val_loss", 
    x_from_entry: str | None = None,
    error_bars = False,
    ms: int = 4,
) -> None:
    """
    Plot one scaling law curve.
    """
    data = _filter_nonempty(data)
    if x_from_entry is not None:
        data = _regroup_by_entry_key(data, x_from_entry)
    if len(data) < 2:
        raise ValueError("Need at least two points for a fit")
    sort_idx = np.argsort([int(x) for x in data.keys()])
    x = np.array(sorted([int(x) for x in data.keys()]), dtype=float)
    y_mean = np.array(
        [np.mean([entry[key] for entry in data[k]]) for k in np.array(list(data.keys()))[sort_idx]], 
        dtype=float,
    )
    if error_bars:
        y_std = np.array(
            [np.std([entry[key] for entry in data[k]]) for k in np.array(list(data.keys()))[sort_idx]], 
            dtype=float,
        )
        ax.errorbar(
            x,
            y_mean,
            yerr=y_std,
            fmt="o",
            markersize=ms,
            markeredgecolor=color,
            markerfacecolor=color,
            ecolor=color,
            capsize=4,
            linestyle="none",
            label=label,
        ) 
    else:
        ax.plot(x, y_mean, "o", markersize=ms, label=label, color=color)

"""
Scaling law function block.
######################################################################################################
"""

def power_law(x: float, alpha: float, x0: float) -> float:
    return (x0 / x) ** alpha

def power_law_with_floor(x: float, alpha: float, x0: float, l_0: float) -> float:
    return l_0 + (x0 / x) ** alpha

def broken_power_law(x: float, alpha1: float, alpha2: float, xc: float) -> float:
    return ((xc/ x) ** (alpha2)) * ((1 + (xc / x)) ** (alpha1 - alpha2))

def broken_power_law_with_amp(x: float, alpha1: float, alpha2: float, xc: float, A: float) -> float:
    return A * ((xc/ x) ** (alpha2)) * ((1 + (xc / x)) ** (alpha1 - alpha2))

"""
######################################################################################################
"""

def calc_aicc(residuals: np.ndarray, k: int) -> tuple[float, float]:
    """
    https://en.wikipedia.org/wiki/Akaike_information_criterion#Comparison_with_least_squares
    """
    residuals = np.asarray(residuals, dtype=float)
    n = residuals.size
    rss = np.sum(residuals**2)
    sigma2_hat = rss / n
    aic = 2 * k + n * np.log(sigma2_hat)
    aicc = aic + ((2 * k ** 2 + 2 * k) / (n - k - 1))
    return aicc

def fit_scaling_law(
    data: dict, 
    func_type: str, # see below
    key: str = "best_val_loss",
    x_from_entry: str | None = None,
)-> tuple[float, float]:
    """
    Fit the specified power law function to the scaling data.
    """
    data = _filter_nonempty(data)
    if x_from_entry is not None:
        data = _regroup_by_entry_key(data, x_from_entry)
    if len(data) < 2:
        raise ValueError("Need at least two points for a fit")
    sort_idx = np.argsort([int(x) for x in data.keys()])
    x = np.array(sorted([int(x) for x in data.keys()]), dtype=float)
    y_mean = np.array(
        [np.mean([entry[key] for entry in data[k]]) for k in np.array(list(data.keys()))[sort_idx]], 
        dtype=float,
    )
    y_std = np.array(
        [np.std([entry[key] for entry in data[k]]) for k in np.array(list(data.keys()))[sort_idx]], 
        dtype=float,
    )
    # select power law function type
    if func_type == "simple":
        func = power_law
        alpha0 = 0.1 # typical starting value
        x0_0 = x[0] # characteristic scale
        p0 = [alpha0, x0_0]
        bounds = ([-np.inf, -np.inf], [np.inf, np.inf])
    elif func_type == "floor":
        func = power_law_with_floor
        alpha0 = 0.1 # typical starting value
        x0_0 = x[0] # characteristic scale
        l_00 = y_mean[-1] # smallest observed loss
        p0 = [alpha0, x0_0, l_00]
        bounds = ([-np.inf, -np.inf, 0.0], [np.inf, np.inf, np.inf])
    elif func_type == "broken":
        func = broken_power_law
        alpha0 = 0.1 # typical starting value
        xc_0 = x[0] # critical size, where scaling changes
        p0 = [alpha0, alpha0, xc_0]
        bounds = ([-np.inf, -np.inf, -np.inf], [np.inf, np.inf, np.inf])
    elif func_type == "broken_with_amp":
        func = broken_power_law_with_amp
        alpha0 = 0.1 # typical starting value
        xc_0 = x[0] # critical size, where scaling changes
        A0 = 1.0 # our validation loss is around 1.0
        p0 = [alpha0, alpha0, xc_0, A0]
        bounds = ([-np.inf, -np.inf, -np.inf, -np.inf], [np.inf, np.inf, np.inf, np.inf])
    else:
        raise ValueError("Unsupported power law type")
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        popt, pcov = curve_fit(
            f=func,
            xdata=x,
            ydata=y_mean,
            p0=p0,
            sigma=y_std,
            absolute_sigma=False, # default
            bounds=bounds,
            maxfev=10000,
        )
    y_fit = func(x, *popt)
    res = y_mean - y_fit
    aicc = calc_aicc(res, len(popt))
    return {
        "popt": popt, 
        "pcov": pcov, 
        "aicc": aicc,
    }
    
def get_power_law(func_type: str) -> Callable[..., float]:
    """
    Helper function to select power law functions from strings.
    """
    if func_type == "simple":
        func = power_law
    elif func_type == "floor":
        func = power_law_with_floor
    elif func_type == "broken":
        func = broken_power_law
    elif func_type == "broken_with_amp":
        func = broken_power_law_with_amp
    else:
        raise ValueError("Unsupported power law type")
    return func

def get_fit_func_str(func_type: str, X: str) -> str:
    """
    Helper function to get a latex equation of a power law functions.
    """
    if func_type == "simple":
        leg_str = rf"$L({X:s}) = \left({X:s}_0/{X:s}\right)^{{\alpha_{{{X:s}}}}}$"
    elif func_type == "floor":
        leg_str = rf"$L({X:s}) = L_\infty + \left({X:s}_0/{X:s}\right)^{{\alpha_{{{X:s}}}}}$"
    elif func_type == "broken":
        leg_str = (rf"$L({X:s}) = \left({X:s}_c/{X:s}\right)^{{\alpha_{{{X:s}, 2}}}}"
                   rf"\left(1+{X:s}_c/{X:s}\right)^{{\alpha_{{{X:s}, 1}}-\alpha_{{{X:s}, 2}}}}$")
    elif func_type == "broken_with_amp":
        leg_str = (rf"$L({X:s}) = A \cdot \left[\left({X:s}_c/{X:s}\right)^{{\alpha_{{{X:s}, 2}}}}"
                   rf"\left(1+{X:s}_c/{X:s}\right)^{{\alpha_{{{X:s}, 1}}-\alpha_{{{X:s}, 2}}}}\right]$")
    elif func_type == "kaplan":
        leg_str = (r"$L(N, D) = \left[\left(\frac{N_0}{N}\right)^{\frac{\alpha_N}{\alpha_D}} + "
                   r"\frac{D_0}{D}\right]^{\alpha_D}$")
    elif func_type == "global":
        leg_str = (r"$L(N, D) = \left[\left(\frac{N_0}{N}\right)^{\frac{\alpha_N}{\alpha_D}} + \left(\frac{D_c}{D}\right)"
                   r"\left(1 + \frac{D_c}{D}\right)^{\frac{\alpha_{D,1}}{\alpha_{D,2}} - 1}\right]^{\alpha_{D,2}}$")
    else:
        raise ValueError("Unsupported power law type")
    return leg_str

def format_val_err(val: float, err: float, digits: int = 2) -> str:
    """
    Helper function for string formatting, used for latex tables.
    """
    return f"${val:.{digits:d}f} \\pm {err:.{digits:d}f}$"

def format_val_err_log10(x: float, s: float | None = None, digits: int = 2) -> str:
    """
    Format large positive parameters with uncertainty in scientific notation.
    """
    if x <= 0:
        raise ValueError("x must be positive.")
    xl = np.log10(x)
    if s is None:
        return rf"$10^{{{xl:.{digits:d}f}}}$"
    if s <= 0:
        raise ValueError("s must be positive.")
    sl = np.log10(s)
    return rf"$10^{{{xl:.{digits:d}f}}} \pm 10^{{{sl:.{digits:d}f}}}$"

def build_rows_from_fits(fits: list, func_type: str, X: str = "D") -> dict:
    """ 
    Helper function for building a dictionary from the fit results of the single scaling laws.
    """
    rows = []
    for name, popt, pcov, aicc in fits:
        se = np.sqrt(np.diag(pcov))
        # default fields
        record = {
            "Model": name, 
            "Form": func_type,
            "alpha": None, 
            "alpha1": None, 
            "alpha2": None,
            f"{X:s}_0": None, 
            f"{X:s}_c": None,
            "A": None, 
            "L_0": None,
            "AICc": f"{aicc:.2f}",
        }
        # populate depending on scaling law function type
        if func_type == "simple":
            alpha, x0 = popt
            se_alpha, se_x0 = se
            record["alpha"] = format_val_err(alpha, se_alpha)
            record[f"{X:s}_0"] = format_val_err_log10(x0, se_x0, digits=2)
        elif func_type == "floor":
            alpha, x0, linf = popt
            se_alpha, se_x0, se_linf = se
            record["alpha"] = format_val_err(alpha, se_alpha)
            record[f"{X:s}_0"] = format_val_err_log10(x0, se_x0, digits=2)
            record["L_0"] = format_val_err(linf, se_linf)
        elif func_type == "broken":
            a1, a2, xc = popt
            se_a1, se_a2, se_xc = se
            record["alpha1"] = format_val_err(a1, se_a1)
            record["alpha2"] = format_val_err(a2, se_a2)
            record[f"{X:s}_c"] = format_val_err_log10(xc, se_xc, digits=2)
        elif func_type == "broken_with_amp":
            a1, a2, xc, A = popt
            se_a1, se_a2, se_xc, se_A = se
            record["alpha1"] = format_val_err(a1, se_a1)
            record["alpha2"] = format_val_err(a2, se_a2)
            record[f"{X:s}_c"] = format_val_err_log10(xc, se_xc, digits=2)
            record["A"] = format_val_err(A, se_A)
        else:
            raise ValueError(f"Unsupported form: {func_type:s}")
        rows.append(record)
    return rows

def latex_table_from_rows(rows: list[dict], X: str, fit_func_display: bool = False) -> str:
    """
    Helper function for building latex tables for the single scaling law fits.
    """
    cols = [
        "Model", 
        "Form",
        "alpha", 
        "alpha1", 
        "alpha2",
        f"{X:s}_0",
        f"{X:s}_c",
        "A", 
        "L_0", 
        "AICc",
    ]
    df = pd.DataFrame(rows, columns=cols)
    display_cols = [c for c in cols if c not in ("Model", "Form")]
    df[display_cols] = df[display_cols].where(pd.notnull(df[display_cols]), "---")
    keep_cols = [c for c in df.columns if not (df[c] == "---").all()]
    df = df[keep_cols]
    col_map = {
        "Model": "Model",
        "Form": "Form",
        "alpha": rf"$\alpha_{{{X:s}}}$",
        "alpha1": rf"$\alpha_{{{X:s},1}}$",
        "alpha2": rf"$\alpha_{{{X:s},2}}$",
        f"{X:s}_0": rf"${X:s}_0$",
        f"{X:s}_c": rf"${X:s}_c$",
        "A": r"$A$",
        "L_0": r"$L_\infty$",
        "AICc": "AICc",
    }
    df = df.rename(columns=col_map)
    # redundant information
    func_type = rows[0]["Form"]
    if "Form" in df.columns:
        df = df.drop(columns=["Form"])
    if fit_func_display:
        # determine parameter columns
        cols_order = list(df.columns)
        if "Model" not in cols_order or "AICc" not in cols_order:
            raise ValueError("Expected 'Model' and 'AICc' columns after processing")
        param_cols = [c for c in cols_order if c not in ("Model", "AICc")]
        # build a MultiIndex for columns
        # (top level: "" over Model, "Forms" over parameters, "" over AICc)
        top = []
        bottom = []
        for c in ["Model"] + param_cols + ["AICc"]:
            if c == "Model" or c == "AICc":
                top.append("")
                bottom.append(c)
            else:
                top.append(get_fit_func_str(func_type, X))
                bottom.append(c)

        df = df[["Model"] + param_cols + ["AICc"]]
        df.columns = pd.MultiIndex.from_arrays([top, bottom])
    latex = df.to_latex(
        index=False,
        escape=False,
        longtable=False,
        multicolumn=True,
        multicolumn_format="c",
        bold_rows=False,
        column_format=len(df.columns) * r"c@{\hspace{1em}}",
    )
    if fit_func_display:
        lines = latex.splitlines()
        for i, line in enumerate(lines):
            if line.strip().startswith("Model"):
                lines.insert(i, rf"\cmidrule(lr){{2-{len(param_cols) + 1}}}")
                break
        latex = "\n".join(lines)
    return latex

def grid_to_vectors(
    grid_dict: dict,
    metric_key: str = "val_loss",
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """
    Helper function that converts data for the scaling law L(D,N) into a more manageable format.
    """
    data_keys = sorted([int(x) for x in grid_dict.keys()])
    param_keys = sorted([int(x) for x in grid_dict[str(data_keys[0])].keys()])
    datapoints, widths, num_parameters, mus, stds = [], [], [], [], []
    num_d = len(data_keys)
    num_n = len(param_keys)
    for dk in data_keys:
        by_n = grid_dict[str(dk)]
        for wk in param_keys:
            entries = by_n[str(wk)]
            vals = np.array([e[metric_key] for e in entries])
            datapoints.append(dk)
            widths.append(wk)
            num_parameters.append(entries[0]["num_parameter"])
            mus.append(np.mean(vals))
            stds.append(np.std(vals))
    datapoints = np.array(datapoints).reshape(num_d, num_n)
    widths = np.array(widths).reshape(num_d, num_n)
    num_parameters = np.array(num_parameters).reshape(num_d, num_n)
    mus = np.array(mus).reshape(num_d, num_n)
    stds = np.array(stds).reshape(num_d, num_n)
    return datapoints, widths, num_parameters, mus, stds
    
def fmt_param(v: float, digits: int = 1) -> str:
    """
    Parameter count formatter function.
    """
    if not np.isfinite(v):
        return "---"
    v = float(v)
    if v >= 1_000_000:
        val, unit = v / 1_000_000.0, r"\,M"
    elif v >= 1_000:
        val, unit = v / 1_000.0, r"\,k"
    else:
        return f"{int(round(v)):d}"
    if np.isclose(val, np.round(val), rtol=0.0, atol=1e-12):
        s = f"{int(np.round(val)):d}"
    else:
        s = f"{val:.{digits:d}f}".rstrip("0").rstrip(".")
    return f"{s:s}{unit:s}"

def global_scaling_fit_kaplan(X, alpha_n, alpha_d1, alpha_d2, Nc, Dc):
    """
    Combination of broken data scaling and simple power law parameter scaling following the form of Kaplan et al.
    """
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        D, N = X
        return ((Nc / N) ** (alpha_n / alpha_d2) + (Dc / D) * (1 + Dc / D) ** (alpha_d1/alpha_d2 - 1)) ** alpha_d2
    
def global_scaling_fit_hoffmann(X, alpha_n, alpha_d1, alpha_d2, Nc, Dc, l_0):
    """
    Combination of broken data scaling and simple power law parameter scaling following the form of Hoffmann et al.
    """
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        D, N = X
        return l_0 + (Nc / N) ** (alpha_n) + ((Dc / D) ** (alpha_d2)) * ((1 + Dc / D) ** (alpha_d1 - alpha_d2))

# Load the results from the "line search"

In [None]:
# directory containing the scaling law study
study_path = "/scratch/magr4985/Scaling_Base"

# directory to save the results
output_dir = "./scaling_data"
os.makedirs(output_dir, exist_ok=True)

# directory where to store the figures
fig_dir = "./scaling_data/scaling"
os.makedirs(fig_dir, exist_ok=True)

In [None]:
# check the all training runs converged
if os.path.exists(study_path):
    # load all study directories, valdation loss files, and tensorboard logs
    study_dirs = [d for d in os.listdir(study_path) if os.path.isdir(os.path.join(study_path, d))]
    print(f"[INFO] Found {len(study_dirs):d} study directories in {study_path:s}")
    checks = []
    for study_dir in study_dirs:
        val_loss = None
        if os.path.exists(os.path.join(study_path, study_dir, "val_loss.txt")):
            val_loss = float(np.loadtxt(os.path.join(study_path, study_dir, "val_loss.txt")))
        checks.append([study_dir, val_loss, load_tb_scalars(os.path.join(study_path, study_dir))])
    # print the study directories with diverging gradients, i.e, where something went wrong
    print("\n[DIVERGING GRADIENTS CHECK]")
    for study_dir, val_loss, tb_log in checks:
        if np.max(tb_log["train/grad_norm"]) > 1e2:
            print("[DIVERGING GRADIENTS] ", study_dir, np.max(tb_log["train/grad_norm"]))
    # print study directories with training not finished
    print("\n[TRAINING FINISHED CHECK]")
    for study_dir, val_loss, tb_log in checks:
        if len(tb_log.get("val/loss", [])) < 500:
            print("[TRAINING NOT FINISHED] ", study_dir, len(tb_log.get("val/loss", [])))
    # seed-to-seed validation-loss consistency
    print("\n[SEED VARIANCE CHECK]")
    max_seed_variance_thr = 0.05
    _seed_suffix_re = re.compile(r'([_-])seed\d+$')
    _seed_extract_re = re.compile(r'seed(\d+)$')
    def _key_without_seed(name: str) -> str:
        return _seed_suffix_re.sub('', name)
    def _seed_or_inf(name: str) -> int:
        m = _seed_extract_re.search(name)
        return int(m.group(1)) if m else 10**9 # for nice sorting when seed missing
    groups = defaultdict(list)
    for study_dir, val_loss, _ in checks:
        if val_loss is not None and np.isfinite(val_loss):
            groups[_key_without_seed(study_dir)].append((study_dir, float(val_loss)))
    for cfg_key, items in groups.items():
        if len(items) < 2:
            continue # need at least two seeds to compare
        vals = [v for _, v in items]
        sorted_vals = sorted(vals)
        low = sorted_vals[0]
        mid = sorted_vals[len(sorted_vals) // 2]
        high = sorted_vals[-1]
        if low <= 0:
            continue # avoid divide-by-zero/undefined relative difference
        max_rel_spread = max((mid - low) / mid, (high - mid) / mid)
        if max_rel_spread > max_seed_variance_thr:
            print(f"[SEED VARIANCE > 5%] {cfg_key:s}: max(spread)={100*max_rel_spread:.2f}% (low={low:.6g}, mid={mid:.6g} , high={high:.6g})")
            for sd, v in sorted(items, key=lambda p: _seed_or_inf(p[0])):
                print(f"    {sd:s}: val_loss={v:.6g}")

In [None]:
"""
Incrementally load and evaluate all models in the scaling-law study, caching results to disk.

Evaluating every model can be time-consuming, so this process saves intermediate results
to a JSON file and, on subsequent runs, will only process any newly added models.
If execution is interrupted, you can simply rerun and it will resume from the last saved state.
"""

# path of the JSON file where results are stored
json_path = os.path.join(output_dir, "scaling_results.json")

if os.path.exists(study_path):
    # batch size (speeds up the process at bit)
    batch_size = 128 # adjust this according to your GPU memory
    # path to the validation dataset
    eval_path = "../graph/val.pt"
    # load the data on which we want to evaluate the models
    eval_data = load_torch_data(eval_path)
    dataloader = create_dataloader(
        eval_data, 
        num_data=-1, # use the whole dataset 
        batch_size=batch_size,
        shuffle=False, # do not shuffle the validation set
    )
    print(f"Loaded evaluation data from '{eval_path:s}'", flush=True)
    # load or initialize results
    if os.path.exists(json_path):
        with open(json_path, "r") as f:
            results = json.load(f)
        print(f"Loaded existing results ({len(extract_processed_dirs(results)):d} runs)")
    else:
        results = init_empty_results()
        print("Initialized new results store")
    # find what hass already been done
    processed = extract_processed_dirs(results)
    # scan for all study subdirectories
    all_dirs = sorted(d for d in os.listdir(study_path) if os.path.isdir(os.path.join(study_path, d)))
    # only process the new ones
    new_dirs = [d for d in all_dirs if d not in processed]
    save_every = 4  
    total = len(new_dirs)
    if total == 0:
        print("No new studies to process, everything is up to date")
    else:
        print(f"Processing {len(new_dirs)} new studies (saving every {save_every:d})")
        for idx, study_dir in enumerate(new_dirs, start=1):
            process_one(
                study_path=study_path,
                study_dir=study_dir, 
                results=results, 
                dataloader=dataloader,
            )
            # checkpoint every 'save_every' or on the very last one
            if (idx % save_every == 0) or (idx == total):
                with open(json_path, "w") as f:
                    json.dump(results, f, indent=4)
                print(f"    Checkpointed after {idx:d}/{total:d} runs")
        print(f"All {total:d} new runs appended and final JSON saved")
else:
    with open(json_path, "r") as f:
        results = json.load(f)
    print(f"Loaded results for {len(extract_processed_dirs(results)):d} models from JSON file")

# Scaling law plots: Single Power Law Tests (Data)

In [None]:
# plot with or without errorbars
error_bars = True

# helper variables for aicc table
scaling_law_labels = {
    "simple": r"Power Law",
    "floor": r"Power Law + Floor",
    "broken": r"Broken Power Law",
    "broken_with_amp": r"Broken power law + Amplitude",
}
model_order = ["OptiMetal2B (CGC)", "OptiMetal2B (TC)", "OptiMetal3B (TC)"]

# loop over all metrics and data scaling function types 
for metric_key in ["val_loss"]:
    
    # aicc table setup
    df = pd.DataFrame(
        index=[scaling_law_labels[k] for k in ["simple","floor","broken","broken_with_amp"]],
        columns=model_order, 
        dtype=float,
    )
    
    for func_type in ["simple", "floor", "broken", "broken_with_amp"]:
        
        # aicc table setup
        aicc_row = {m: np.nan for m in model_order}
        row_label = scaling_law_labels[func_type]
        
        # figure setup
        fig, ax = plt.subplots(figsize=(3.5, 3.0))

        # parameters and useful variable
        var = "D"
        func = get_power_law(func_type)
        fits = [] # to store (label, alpha, const, color) for the legend
        fit_params = []
        fname = f"{metric_key:s}_{func_type:s}_scaling_laws_data"
        
        # some fits can be unstable...
        try:

            # plot setup
            num_data = sorted([int(x) for x in results["data"]["2b"]["variant1"].keys()]) # tick positions for the x-axis
            x_fit = np.logspace(np.log10(2000), np.log10(200000), 100)

            # OptiMetal2B CGC
            color = "k"
            name = "OptiMetal2B (CGC)"
            data = results["data"]["2b"]["variant1"]
            plot_scaling_law(ax, data, name, key=metric_key, color=color, error_bars=error_bars)
            fit_dict = fit_scaling_law(data, func_type=func_type, key=metric_key)
            y_fit = func(x_fit, *fit_dict["popt"])
            ax.plot(x_fit, y_fit, "--", color=color, zorder=-1)
            fits.append([name, fit_dict["popt"], color])
            fit_params.append([name, *list(fit_dict.values())])
            aicc_row[name] = np.round((fit_dict["aicc"]), 2)

            # OptiMetal2B CGC
            color = "tab:orange"
            name = "OptiMetal2B (TC)"
            data = results["data"]["2b"]["variant2"]
            plot_scaling_law(ax, data, name, key=metric_key, color=color, error_bars=error_bars)
            fit_dict = fit_scaling_law(data, func_type=func_type, key=metric_key)
            y_fit = func(x_fit, *fit_dict["popt"])
            ax.plot(x_fit, y_fit, "--", color=color, zorder=-1)
            fits.append([name, fit_dict["popt"], color])
            fit_params.append([name, *list(fit_dict.values())])
            aicc_row[name] = np.round((fit_dict["aicc"]), 2)

            # OptiMetal3B
            color = "tab:blue"
            name = "OptiMetal3B (TC)"
            data = results["data"]["3b"]
            plot_scaling_law(ax, data, name, key=metric_key, color=color, error_bars=error_bars)
            fit_dict = fit_scaling_law(data, func_type=func_type, key=metric_key)
            y_fit = func(x_fit, *fit_dict["popt"])
            ax.plot(x_fit, y_fit, "--", color=color, zorder=-1)
            fits.append([name, fit_dict["popt"], color])
            fit_params.append([name, *list(fit_dict.values())])
            aicc_row[name] = np.round((fit_dict["aicc"]), 2)
                
            # rows for the latex table
            rows = build_rows_from_fits(fit_params, func_type=func_type, X=var)

            # log-log plot and axis ticks
            yticks = ax.get_yticks()
            ax.set_xlim(2000, 200000)
            ax.set_xscale("log")
            ax.set_yscale("log")
            ax.set_xticks(num_data)
            ax.xaxis.set_major_locator(FixedLocator(num_data))
            ax.xaxis.set_major_formatter(ScalarFormatter())
            ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: fmt_param(int(x))))
            ax.tick_params(axis="x", which="minor", length=0)
            ax.yaxis.set_minor_locator(FixedLocator(yticks))
            ax.yaxis.set_major_locator(FixedLocator(yticks))
            ax.yaxis.set_minor_formatter(ScalarFormatter())
            ax.yaxis.set_major_formatter(ScalarFormatter())
            ax.tick_params(axis="y", which="minor", length=0)

            # axis labels and legends
            ax.set_xlabel(r"$D$")
            ax.set_ylabel(r"$L_\mathrm{val}$")
            leg_models = ax.legend(loc="lower left", handletextpad=0.25, handlelength=1.25)
            fit_handle = [Line2D([], [], ls="--", color="0.5")]
            fit_label = [get_fit_func_str(func_type, var) for _ in fits]
            ax.legend(fit_handle, fit_label, title=r"$N \approx 10\mathrm{M}$ ($d_\mathrm{h}=256$)", loc="upper right", handletextpad=0.25, handlelength=1.25)
            ax.add_artist(leg_models)

            # save the figure
            fig.tight_layout()
            fig.align_labels()
            fig.savefig(os.path.join(fig_dir, fname + ".pdf"))
            fig.savefig(os.path.join(fig_dir, fname + ".svg"))

            # print and save the latex tables
            latex = latex_table_from_rows(rows, X="D", fit_func_display=False)
            print(f"\nData scaling: {metric_key:s} & {func_type:s}")
            print(latex)
            with open(os.path.join(fig_dir, fname + ".txt"), "w") as f:
                f.write(latex)
                
        except Exception as e:
            
            print(f"[WARN] {fname:s}: An error occurred, probably due to an unstable fit ({e})\n")
            with open(os.path.join(fig_dir, fname + ".txt"), "w") as f:
                f.write("Unstable fit...")
                
        finally:
            plt.close(fig)        
                
        # aicc table dataframe
        df.loc[row_label, model_order] = [aicc_row[m] for m in model_order]

    # latex table to compare the aicc of different scaling fit functions
    latex_aicc = df.to_latex(
        index=True,
        escape=False,
        multicolumn=True,
        multicolumn_format="c",
        bold_rows=False,
        float_format=lambda x: f"{int(round(x))}" if np.isclose(x, round(x)) else f"{x:.2f}",
        column_format="l" + len(df.columns) * r"c@{\hspace{1em}}",
    )
    print("\nAICc table:")
    print(latex_aicc)
    with open(os.path.join(fig_dir, f"{metric_key:s}_aicc_table_data.txt"), "w") as f:
        f.write(latex_aicc)

# Scaling law plots: Single Power Law Tests (Parameter)

In [None]:
# plot with or without errorbars
error_bars = True

# helper variables for aicc table
scaling_law_labels = {
    "simple": r"Power Law",
    "floor": r"Power Law + Floor",
    "broken": r"Broken Power Law",
    "broken_with_amp": r"Broken power law + Amplitude",
}
model_order = ["OptiMetal2B (CGC)", "OptiMetal2B (TC)", "OptiMetal3B (TC)"]

# loop over all metrics
for metric_key in ["val_loss"]:
    
    # aicc table setup
    df = pd.DataFrame(
        index=[scaling_law_labels[k] for k in ["simple","floor","broken","broken_with_amp"]],
        columns=model_order, 
        dtype=float,
    )
    
    for func_type in ["simple", "floor", "broken", "broken_with_amp"]:
        
        # aicc table setup
        aicc_row = {m: np.nan for m in model_order}
        row_label = scaling_law_labels[func_type]
        
        # figure setup
        fig, ax = plt.subplots(figsize=(3.5, 3.0))

        # parameters and useful variable
        var = "N"
        func = get_power_law(func_type)
        fits = [] # to store (label, alpha, const, color) for the legend
        fit_params = []
        fname = f"{metric_key:s}_{func_type:s}_scaling_laws_parameter"
        
        # some fits may be unstable...
        try:

            # plot setup
            num_parameter = [5e5, 1e6, 5e6, 1e7, 5e7, 1e8] # tick positions for the x-axis
            x_fit = np.logspace(np.log10(2e5), np.log10(2e8), 100)

            # OptiMetal2B CGC
            color = "k"
            name = "OptiMetal2B (CGC)"
            data = results["parameter"]["2b"]["variant1"]
            plot_scaling_law(ax, data, name, key=metric_key, color=color, error_bars=error_bars, x_from_entry="num_parameter")
            fit_dict = fit_scaling_law(data, func_type=func_type, key=metric_key, x_from_entry="num_parameter")
            y_fit = func(x_fit, *fit_dict["popt"])
            ax.plot(x_fit, y_fit, "--", color=color, zorder=-1)
            fits.append([name, fit_dict["popt"], color])
            fit_params.append([name, *list(fit_dict.values())])
            aicc_row[name] = np.round((fit_dict["aicc"]), 2)

            # OptiMetal2B CGC
            color = "tab:orange"
            name = "OptiMetal2B (TC)"
            data = results["parameter"]["2b"]["variant2"]
            plot_scaling_law(ax, data, name, key=metric_key, color=color, error_bars=error_bars, x_from_entry="num_parameter")
            fit_dict = fit_scaling_law(data, func_type=func_type, key=metric_key, x_from_entry="num_parameter")
            y_fit = func(x_fit, *fit_dict["popt"])
            ax.plot(x_fit, y_fit, "--", color=color, zorder=-1)
            fits.append([name, fit_dict["popt"], color])
            fit_params.append([name, *list(fit_dict.values())])
            aicc_row[name] = np.round((fit_dict["aicc"]), 2)

            # OptiMetal3B
            color = "tab:blue"
            name = "OptiMetal3B (TC)"
            data = results["parameter"]["3b"]
            plot_scaling_law(ax, data, name, key=metric_key, color=color, error_bars=error_bars, x_from_entry="num_parameter")
            fit_dict = fit_scaling_law(data, func_type=func_type, key=metric_key, x_from_entry="num_parameter")
            y_fit = func(x_fit, *fit_dict["popt"])
            ax.plot(x_fit, y_fit, "--", color=color, zorder=-1)
            fits.append([name, fit_dict["popt"], color])
            fit_params.append([name, *list(fit_dict.values())])
            aicc_row[name] = np.round((fit_dict["aicc"]), 2)
                
            # rows for the latex table
            rows = build_rows_from_fits(fit_params, func_type=func_type, X=var)

            # log-log plot and axis ticks
            yticks = ax.get_yticks()
            ax.set_xlim(2e5, 2e8)
            ax.set_xscale("log")
            ax.set_yscale("log")
            ax.set_xticks(num_parameter)
            ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: fmt_param(int(x))))
            ax.yaxis.set_minor_locator(FixedLocator(yticks))
            ax.yaxis.set_major_locator(FixedLocator(yticks))
            ax.yaxis.set_minor_formatter(ScalarFormatter())
            ax.yaxis.set_major_formatter(ScalarFormatter())
            ax.tick_params(axis="y", which="minor", length=0)

            # axis labels and legends
            ax.set_xlabel(r"$N$")
            ax.set_ylabel(r"$L_\mathrm{val}$")
            model_handles, model_labels = ax.get_legend_handles_labels()
            fit_handle = Line2D([], [], ls="--", color="0.5")
            fit_label = get_fit_func_str(func_type, var)
            handles = [fit_handle] + model_handles
            labels =  [fit_label] + model_labels
            ax.legend(handles, labels, title=r"$D=20000$", loc="upper right", handletextpad=0.25, handlelength=1.25)

            # save the figure
            fig.tight_layout()
            fig.align_labels()
            fig.savefig(os.path.join(fig_dir, fname + ".pdf"))
            fig.savefig(os.path.join(fig_dir, fname + ".svg"))

            # print and save the latex tables
            latex = latex_table_from_rows(rows, X="N", fit_func_display=False)
            print(f"\nParameter scaling: {metric_key:s} & {func_type:s}")
            print(latex)
            with open(os.path.join(fig_dir, fname + ".txt"), "w") as f:
                f.write(latex)
                
        except Exception as e:
            
            print(f"[WARN] {fname:s}: An error occurred, probably due to an unstable fit ({e})\n")
            with open(os.path.join(fig_dir, fname + ".txt"), "w") as f:
                f.write("Unstable fit...")
                
        finally:
            plt.close(fig)        
                
        # aicc table dataframe
        df.loc[row_label, model_order] = [aicc_row[m] for m in model_order]

    # latex table to compare the aicc of different scaling fit functions
    latex_aicc = df.to_latex(
        index=True,
        escape=False,
        multicolumn=True,
        multicolumn_format="c",
        bold_rows=False,
        float_format=lambda x: f"{int(round(x))}" if np.isclose(x, round(x)) else f"{x:.2f}",
        column_format="l" + len(df.columns) * r"c@{\hspace{1em}}",
    )
    print("\nAICc table:")
    print(latex_aicc)
    with open(os.path.join(fig_dir, f"{metric_key:s}_aicc_table_parameter.txt"), "w") as f:
        f.write(latex_aicc)

# Scaling law plots: Single Power Laws

In [None]:
# plot with or without errorbars
error_bars = True

# helper variables for aicc table
model_order = ["OptiMetal2B (CGC)", "OptiMetal2B (TC)", "OptiMetal3B (TC)"]

# loop over all metrics and function types 
for metric_key in ["val_loss"]:

    # figure setup
    fig, axes = plt.subplots(2, 1, figsize=(3.5, 5))
    fname = f"{metric_key:s}_scaling_laws"
    
    """
    Data scaling
    """

    # parameters and useful variable
    var = "D"
    data_func_type = "broken"
    func = get_power_law(data_func_type)
    fits = [] # to store (label, alpha, const, color) for the legend
    fit_params_data = []
    
    # some fits may be unstable...
    try:

        # plot setup
        ax = axes[0]
        num_data = sorted([int(x) for x in results["data"]["2b"]["variant1"].keys()]) # tick positions for the x-axis
        x_fit = np.logspace(np.log10(2000), np.log10(200000), 100)

        # OptiMetal2B CGC
        color = "k"
        name = "OptiMetal2B (CGC)"
        data = results["data"]["2b"]["variant1"]
        plot_scaling_law(ax, data, name, key=metric_key, color=color, error_bars=error_bars)
        fit_dict = fit_scaling_law(data, func_type=data_func_type, key=metric_key)
        y_fit = func(x_fit, *fit_dict["popt"])
        ax.plot(x_fit, y_fit, "--", color=color, zorder=-1)
        fits.append([name, fit_dict["popt"], color])
        fit_params_data.append([name, *list(fit_dict.values())])
        
        # add "unbroken" power-law fit for comparison
        data = _filter_nonempty(data)
        if len(data) < 2:
            raise ValueError("Need at least two points for a fit")
        sort_idx = np.argsort([int(x) for x in data.keys()])
        x = np.array(sorted([int(x) for x in data.keys()]), dtype=float)
        y_mean = np.array(
            [np.mean([entry[metric_key] for entry in data[k]]) for k in np.array(list(data.keys()))[sort_idx]], 
            dtype=float,
        )
        y_std = np.array(
            [np.std([entry[metric_key] for entry in data[k]]) for k in np.array(list(data.keys()))[sort_idx]], 
            dtype=float,
        )
        alpha_1, d0_1 = np.polyfit(np.log10(x[:3]), np.log10(y_mean[:3]), deg=1)
        alpha_2, d0_2 = np.polyfit(np.log10(x[-3:]), np.log10(y_mean[-3:]), deg=1)
        y_fit_unbroken_1 = 10**d0_1 * x_fit**alpha_1
        ax.plot(x_fit, y_fit_unbroken_1, ":", color="k", zorder=-1)
        y_fit_unbroken_2 = 10**d0_2 * x_fit**alpha_2
        ax.plot(x_fit, y_fit_unbroken_2, ":", color="k", zorder=-1, label="Asymptotic fit")

        # OptiMetal2B CGC
        color = "tab:orange"
        name = "OptiMetal2B (TC)"
        data = results["data"]["2b"]["variant2"]
        plot_scaling_law(ax, data, name, key=metric_key, color=color, error_bars=error_bars)
        fit_dict = fit_scaling_law(data, func_type=data_func_type, key=metric_key)
        y_fit = func(x_fit, *fit_dict["popt"])
        ax.plot(x_fit, y_fit, "--", color=color, zorder=-1)
        fits.append([name, fit_dict["popt"], color])
        fit_params_data.append([name, *list(fit_dict.values())])

        # OptiMetal3B
        color = "tab:blue"
        name = "OptiMetal3B (TC)"
        data = results["data"]["3b"]
        plot_scaling_law(ax, data, name, key=metric_key, color=color, error_bars=error_bars)
        fit_dict = fit_scaling_law(data, func_type=data_func_type, key=metric_key)
        y_fit = func(x_fit, *fit_dict["popt"])
        ax.plot(x_fit, y_fit, "--", color=color, zorder=-1)
        fits.append([name, fit_dict["popt"], color])
        fit_params_data.append([name, *list(fit_dict.values())])
            
        # rows for the latex table
        rows_data = build_rows_from_fits(fit_params_data, func_type=data_func_type, X=var)

        # log-log plot and axis ticks
        yticks = ax.get_yticks()
        ax.set_xlim(2000, 200000)
        ax.set_ylim(top=2.5)
        ax.set_xscale("log")
        ax.set_yscale("log")
        ax.set_xticks(num_data)
        ax.xaxis.set_major_locator(FixedLocator(num_data))
        ax.xaxis.set_major_formatter(ScalarFormatter())
        ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: fmt_param(int(x))))
        ax.tick_params(axis="x", which="minor", length=0)
        ax.yaxis.set_minor_locator(FixedLocator(yticks))
        ax.yaxis.set_major_locator(FixedLocator(yticks))
        ax.yaxis.set_minor_formatter(ScalarFormatter())
        ax.yaxis.set_major_formatter(ScalarFormatter())
        ax.tick_params(axis="y", which="minor", length=0)

        # axis labels and legends
        ax.set_xlabel(r"$D$")
        ax.set_ylabel(r"$L_\mathrm{val}$")
        leg_models = ax.legend(loc="lower left", handletextpad=0.25, handlelength=1.25)
        fit_handle = [Line2D([], [], ls="--", color="tab:gray")]
        fit_label = [get_fit_func_str(data_func_type, var)]
        leg = ax.legend(fit_handle, fit_label, title=r"$N \approx 10\mathrm{M}$ ($d_\mathrm{h}=256$)", loc="upper right", handletextpad=0.25, handlelength=1.25)
        ax.add_artist(leg_models)

        """
        Parameter scaling
        """

        # parameters and useful variable
        var = "N"
        param_func_type = "floor"
        func = get_power_law(param_func_type)
        fit_params_parameter = []

        # plot setup
        ax = axes[1]
        num_parameter = [5e5, 1e6, 5e6, 1e7, 5e7, 1e8] # tick positions for the x-axis
        x_fit = np.logspace(np.log10(2e5), np.log10(2e8), 100)

        # OptiMetal2B CGC
        color = "k"
        name = "OptiMetal2B (CGC)"
        data = results["parameter"]["2b"]["variant1"]
        plot_scaling_law(ax, data, None, key=metric_key, color=color, error_bars=error_bars, x_from_entry="num_parameter")
        fit_dict = fit_scaling_law(data, func_type=param_func_type, key=metric_key, x_from_entry="num_parameter")
        y_fit = func(x_fit, *fit_dict["popt"])
        ax.plot(x_fit, y_fit, "--", color=color, zorder=-1, label=get_fit_func_str(param_func_type, var))
        fit_params_parameter.append([name, *list(fit_dict.values())])

        # OptiMetal2B CGC
        color = "tab:orange"
        name = "OptiMetal2B (TC)"
        data = results["parameter"]["2b"]["variant2"]
        plot_scaling_law(ax, data, None, key=metric_key, color=color, error_bars=error_bars, x_from_entry="num_parameter")
        fit_dict = fit_scaling_law(data, func_type=param_func_type, key=metric_key, x_from_entry="num_parameter")
        y_fit = func(x_fit, *fit_dict["popt"])
        ax.plot(x_fit, y_fit, "--", color=color, zorder=-1)
        fit_params_parameter.append([name, *list(fit_dict.values())])

        # OptiMetal3B
        color = "tab:blue"
        name = "OptiMetal3B (TC)"
        data = results["parameter"]["3b"]
        plot_scaling_law(ax, data, None, key=metric_key, color=color, error_bars=error_bars, x_from_entry="num_parameter")
        fit_dict = fit_scaling_law(data, func_type=param_func_type, key=metric_key, x_from_entry="num_parameter")
        y_fit = func(x_fit, *fit_dict["popt"])
        ax.plot(x_fit, y_fit, "--", color=color, zorder=-1)
        fit_params_parameter.append([name, *list(fit_dict.values())])

        # rows for the latex table
        rows_params = build_rows_from_fits(fit_params_parameter, func_type=param_func_type, X=var)

        # log-log plot and axis ticks
        yticks = ax.get_yticks()
        ax.set_xlim(2e5, 2e8)
        ax.set_xscale("log")
        ax.set_yscale("log")
        ax.set_xticks(num_parameter)
        ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: fmt_param(int(x))))
        ax.yaxis.set_minor_locator(FixedLocator(yticks))
        ax.yaxis.set_major_locator(FixedLocator(yticks))
        ax.yaxis.set_minor_formatter(ScalarFormatter())
        ax.yaxis.set_major_formatter(ScalarFormatter())
        ax.tick_params(axis="y", which="minor", length=0)

        # axis labels and legends
        ax.set_xlabel(r"$N$")
        ax.set_ylabel(r"$L_\mathrm{val}$")
        fit_handle = [Line2D([], [], ls="--", color="tab:gray")]
        fit_label = [get_fit_func_str(param_func_type, var)]
        leg = ax.legend(fit_handle, fit_label, title=r"$D=20000$", loc="upper right", handletextpad=0.25, handlelength=1.25)

        # save the figure
        fig.tight_layout()
        fig.align_labels()
        fig.savefig(os.path.join(fig_dir, fname + ".pdf"))
        fig.savefig(os.path.join(fig_dir, fname + ".svg"))

        # print and save the latex tables
        latex_data = latex_table_from_rows(rows_data, X="D", fit_func_display=False)
        latex_params = latex_table_from_rows(rows_params, X="N", fit_func_display=False)
        print(f"\nData scaling:")
        print(latex_data)
        print(f"\nParameter scaling:")
        print(latex_data)
        with open(os.path.join(fig_dir, fname + ".txt"), "w") as f:
            f.write(latex_data)
            f.write("\n")
            f.write(latex_params)
            
    except Exception as e:
        
        print(f"[WARN] {fname:s}: An error occurred, probably due to an unstable fit ({e})\n")
        with open(os.path.join(fig_dir, fname + ".txt"), "w") as f:
            f.write("Unstable fit..."),
            
    finally:
        plt.close(fig)

# Load the results from the "grid search"

In [None]:
# directory containing the scaling law study
study_path = "/scratch/magr4985/Scaling_Grid"

# directory to save the results
output_dir = "./scaling_data"
os.makedirs(output_dir, exist_ok=True)

# figure directory
fig_dir = "./scaling_data/scaling_grid"
os.makedirs(fig_dir, exist_ok=True)

In [None]:
# check the all training runs converged
if os.path.exists(study_path):
    # load all study directories, valdation loss files, and tensorboard logs
    study_dirs = [d for d in os.listdir(study_path) if os.path.isdir(os.path.join(study_path, d))]
    print(f"[INFO] Found {len(study_dirs):d} study directories in {study_path:s}")
    checks = []
    for study_dir in study_dirs:
        val_loss = None
        if os.path.exists(os.path.join(study_path, study_dir, "val_loss.txt")):
            val_loss = float(np.loadtxt(os.path.join(study_path, study_dir, "val_loss.txt")))
        checks.append([study_dir, val_loss, load_tb_scalars(os.path.join(study_path, study_dir))])
    # print the study directories with diverging gradients, i.e, where something went wrong
    print("\n[DIVERGING GRADIENTS CHECK]")
    for study_dir, val_loss, tb_log in checks:
        if np.max(tb_log["train/grad_norm"]) > 1e2:
            print("[DIVERGING GRADIENTS] ", study_dir, np.max(tb_log["train/grad_norm"]))
    # print study directories with training not finished
    print("\n[TRAINING FINISHED CHECK]")
    for study_dir, val_loss, tb_log in checks:
        if len(tb_log.get("val/loss", [])) < 500:
            print("[TRAINING NOT FINISHED] ", study_dir, len(tb_log.get("val/loss", [])))
    # seed-to-seed validation-loss consistency
    print("\n[SEED VARIANCE CHECK]")
    max_seed_variance_thr = 0.05
    _seed_suffix_re = re.compile(r'([_-])seed\d+$')
    _seed_extract_re = re.compile(r'seed(\d+)$')
    def _key_without_seed(name: str) -> str:
        return _seed_suffix_re.sub('', name)
    def _seed_or_inf(name: str) -> int:
        m = _seed_extract_re.search(name)
        return int(m.group(1)) if m else 10**9 # for nice sorting when seed missing
    groups = defaultdict(list)
    for study_dir, val_loss, _ in checks:
        if val_loss is not None and np.isfinite(val_loss):
            groups[_key_without_seed(study_dir)].append((study_dir, float(val_loss)))
    for cfg_key, items in groups.items():
        if len(items) < 2:
            continue # need at least two seeds to compare
        vals = [v for _, v in items]
        sorted_vals = sorted(vals)
        low = sorted_vals[0]
        mid = sorted_vals[len(sorted_vals) // 2]
        high = sorted_vals[-1]
        if low <= 0:
            continue # avoid divide-by-zero/undefined relative difference
        max_rel_spread = max((mid - low) / mid, (high - mid) / mid)
        if max_rel_spread > max_seed_variance_thr:
            print(f"[SEED VARIANCE > 5%] {cfg_key:s}: max(spread)={100*max_rel_spread:.2f}% (low={low:.6g}, mid={mid:.6g} , high={high:.6g})")
            for sd, v in sorted(items, key=lambda p: _seed_or_inf(p[0])):
                print(f"    {sd:s}: val_loss={v:.6g}")

In [None]:
"""
Incrementally load and evaluate all models in the scaling-law study, caching results to disk.

Evaluating every model can be time-consuming, so this process saves intermediate results
to a JSON file and, on subsequent runs, will only process any newly added models.
If execution is interrupted, you can simply rerun and it will resume from the last saved state.
"""

# path of the JSON file where results are stored
json_path = os.path.join(output_dir, "scaling_grid_results.json")

if os.path.exists(study_path):
    # batch size (speeds up the process at bit)
    batch_size = 128 # adjust this according to your GPU memory
    # path to the validation dataset
    eval_path = "../graph/val.pt"
    # load the data on which we want to evaluate the models
    eval_data = load_torch_data(eval_path)
    dataloader = create_dataloader(
        eval_data, 
        num_data=-1, # use the whole dataset 
        batch_size=batch_size,
        shuffle=False, # do not shuffle the validation set
    )
    print(f"Loaded evaluation data from '{eval_path:s}'", flush=True)
    # load or initialize results
    if os.path.exists(json_path):
        with open(json_path, "r") as f:
            results = json.load(f)
        print(f"Loaded existing results ({len(extract_processed_dirs(results)):d} runs)")
    else:
        results = init_empty_results_grid()
        print("Initialized new results store")
    # find what hass already been done
    processed = extract_processed_dirs(results)
    # scan for all study subdirectories
    all_dirs = sorted(d for d in os.listdir(study_path) if os.path.isdir(os.path.join(study_path, d)))
    # only process the new ones
    new_dirs = [d for d in all_dirs if d not in processed]
    save_every = 4  
    total = len(new_dirs)
    if total == 0:
        print("No new studies to process, everything is up to date")
    else:
        print(f"Processing {len(new_dirs)} new studies (saving every {save_every:d})")
        for idx, study_dir in enumerate(new_dirs, start=1):
            process_one_grid(
                study_path=study_path,
                study_dir=study_dir, 
                results=results, 
                dataloader=dataloader,
            )
            # checkpoint every 'save_every' or on the very last one
            if (idx % save_every == 0) or (idx == total):
                with open(json_path, "w") as f:
                    json.dump(results, f, indent=4)
                print(f"    Checkpointed after {idx:d}/{total:d} runs")
        print(f"All {total:d} new runs appended and final JSON saved")
else:
    with open(json_path, "r") as f:
        results = json.load(f)
    print(f"Loaded results for {len(extract_processed_dirs(results)):d} models from JSON file")

# Scaling law grid maps

In [None]:
# color map name
cmap_name = "viridis_r"
cmap = plt.get_cmap(cmap_name)

# plot with or without errorbars
error_bars = True

# unified validation loss axis limits
lmin = 0.40
lmax = 1.85

# helper function for latex table
model_label_helper = lambda key: {"2b": "OptiMetal2B (TC)", "3b": "OptiMetal3B (TC)"}.get(key, str(key))

# loop over all metrics 
for metric_key in ["val_loss"]:
    
    # loop over all fit types
    for fit_type in ["kaplan", "hoffmann"]:
        
        # latex table row setup
        param_rows = []

        # loop over all models
        for model in results.keys():
            
            # process the data into a more manageable format
            datapoints, widths, num_parameters, mus, stds = grid_to_vectors(results[model], metric_key=metric_key)

            # fit the loss surface and store the results to create a latex table later one
            X = (datapoints.ravel(), num_parameters.ravel())
            if fit_type == "kaplan":
                func = global_scaling_fit_kaplan
                popt, pcov = curve_fit(
                    f=func,
                    xdata=X,
                    ydata=mus.ravel(),
                    p0=[0.5, 0.5, 0.5, 1e4, 1e4],
                    sigma=stds.ravel(),
                    absolute_sigma=True,
                    maxfev=100000,
                )
            elif fit_type == "hoffmann":
                func = global_scaling_fit_hoffmann
                popt, pcov = curve_fit(
                    f=func,
                    xdata=X,
                    ydata=mus.ravel(),
                    p0=[0.5, 0.5, 0.5, 1e4, 1e4, 0],
                    sigma=stds.ravel(),
                    absolute_sigma=True,
                    maxfev=100000,
                )
            else:
                raise ValueError("Unsupported fit type")
            perr = np.sqrt(np.diag(pcov))
            y_fit = func(X, *popt)
            res = mus.ravel() - y_fit
            aicc = calc_aicc(res, len(popt))
            if fit_type == "kaplan":
                row = {
                    "Model": model_label_helper(model),
                    r"$\alpha_N$": format_val_err(popt[0], perr[0]),
                    r"$\alpha_{D,1}$": format_val_err(popt[1], perr[1]),
                    r"$\alpha_{D,2}$": format_val_err(popt[2], perr[2]),
                    r"$N_c$": format_val_err_log10(popt[3], perr[3]),
                    r"$D_c$": format_val_err_log10(popt[4], perr[4]),
                    r"AICc": f"{aicc:.2f}",
                }
            elif fit_type == "hoffmann":
                row = {
                    "Model": model_label_helper(model),
                    r"$\alpha_N$": format_val_err(popt[0], perr[0]),
                    r"$\alpha_{D,1}$": format_val_err(popt[1], perr[1]),
                    r"$\alpha_{D,2}$": format_val_err(popt[2], perr[2]),
                    r"$N_c$": format_val_err_log10(popt[3], perr[3]),
                    r"$D_c$": format_val_err_log10(popt[4], perr[4]),
                    r"L_0": fmt_param(popt[5], perr[5]),
                    r"AICc": f"{aicc:.2f}",
                }
            else:
                raise ValueError("Unsupported fit type")
            param_rows.append(row)

            # setup the figure
            fig = plt.figure(figsize=(3.5, 6), constrained_layout=True)
            gs = fig.add_gridspec(3, 2, width_ratios=[1.0, 0.1])
            
            # plot the valdation loss over D for every N
            ax = fig.add_subplot(gs[1, 0]) 
            norm_n = LogNorm(vmin=np.min(num_parameters), vmax=np.max(num_parameters))
            for i in range(datapoints.shape[1]):
                if error_bars:
                    ax.errorbar(
                        datapoints[:, i],
                        mus[:, i],
                        yerr=stds[:, i],
                        fmt="o",
                        markersize=4,
                        markeredgecolor=cmap(norm_n(num_parameters[0, i])),
                        markerfacecolor=cmap(norm_n(num_parameters[0, i])),
                        ecolor=cmap(norm_n(num_parameters[0, i])),
                        capsize=4,
                        linestyle="none",
                    ) 
                else:
                    ax.plot(datapoints[:, i], mus[:, i], "o", ms=4, color=cmap(norm_n(num_parameters[0, i])))
                data_col = datapoints[:, i]
                dmin = np.min(data_col)
                dmax = np.max(data_col)
                xfit = np.logspace(np.log10(2000), np.log10(200000), 200)
                yfit = func((xfit, np.full_like(xfit, num_parameters[0, i])), *popt)
                ax.plot(xfit, yfit, "--", color=cmap(norm_n(num_parameters[0, i])), zorder=-1)
            ax.set_xlabel(r"$D$")
            ax.set_ylabel(r"$L_\mathrm{val}$")
            yticks = [0.50, 0.75, 1.00, 1.25, 1.50, 1.75]
            ax.set_xlim(2000, 200000)
            ax.set_xscale("log")
            ax.set_yscale("log")
            ax.set_xticks(np.unique(datapoints))
            ax.xaxis.set_major_locator(FixedLocator(np.unique(datapoints)))
            ax.xaxis.set_major_formatter(ScalarFormatter())
            ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: fmt_param(int(x))))
            ax.yaxis.set_minor_locator(FixedLocator(yticks))
            ax.yaxis.set_major_locator(FixedLocator(yticks))
            ax.yaxis.set_minor_formatter(ScalarFormatter())
            ax.yaxis.set_major_formatter(ScalarFormatter())
            ax.tick_params(axis="y", which="minor", length=0)
            sm_np = plt.cm.ScalarMappable(norm=norm_n, cmap=cmap)
            sm_np.set_array([])
            cbar_top1 = fig.colorbar(sm_np, ax=ax, pad=0.02, label=r"$N$")
            cbar_top1.set_ticks(num_parameters[0, :])
            cbar_top1.ax.set_ylim(np.min(num_parameters[0, :]), np.max(num_parameters[0, :]))
            cbar_top1.ax.yaxis.set_major_formatter(FuncFormatter(lambda x, _: fmt_param(x)))
            xlim = ax.get_xlim()
            if error_bars:
                ax.errorbar(
                    [1, 2],
                    [101, 102],
                    yerr=[0.1, 0.2],
                    fmt="o",
                    markersize=4,
                    markeredgecolor="k",
                    markerfacecolor="k",
                    ecolor="k",
                    capsize=4,
                    linestyle="none",
                    label="Data",
                ) 
            else:
                ax.plot([1, 2], [101, 102], "o", ms=4, color="k", label="Data")
            ax.plot([1, 2], [101, 102], "--", color="k", label="Fit")
            ax.legend(handlelength=1.25, loc="lower left")
            ax.set_xlim(xlim)
            ax.set_ylim(lmin, lmax)
            handles, labels = ax.get_legend_handles_labels()
            order = [labels.index("Data"), labels.index("Fit")] # desired ordering
            ax.legend(
                [handles[i] for i in order], [labels[i] for i in order], 
                handlelength=1.25, 
                loc="lower left",
            )
            
            # plot the valdation loss over N for every D
            ax = fig.add_subplot(gs[0, 0])  
            norm_d = LogNorm(vmin=np.min(datapoints), vmax=np.max(datapoints))
            for i in range(num_parameters.shape[0]):
                if error_bars:
                    ax.errorbar(
                        num_parameters[i, :],
                        mus[i, :],
                        yerr=stds[i, :],
                        fmt="o",
                        markersize=4,
                        markeredgecolor=cmap(norm_d(datapoints[i, 0])),
                        markerfacecolor=cmap(norm_d(datapoints[i, 0])),
                        ecolor=cmap(norm_d(datapoints[i, 0])),
                        capsize=4,
                        linestyle="none",
                        label=fmt_param(datapoints[i, 0]),
                    ) 
                else:
                    ax.plot(
                        num_parameters[i, :], 
                        mus[i, :], 
                        "o", 
                        label=fmt_param(datapoints[i, 0]), 
                        color=cmap(norm_d(datapoints[i, 0])),
                    )
                nmin = np.nanmin(num_parameters[i, :])
                nmax = np.nanmax(num_parameters[i, :])
                xfit = np.logspace(np.log10(2e5), np.log10(2e8), 200)
                yfit = func((np.full_like(xfit, datapoints[i, 0]), xfit), *popt)
                ax.plot(xfit, yfit, "--", color=cmap(norm_d(datapoints[i, 0])), zorder=-1)
            ax.set_xlabel(r"$N$")
            ax.set_ylabel(r"$L_\mathrm{val}$")
            ax.set_title(model_label_helper(model), pad=3)
            yticks = ax.get_yticks()
            if model == "2b":
                ax.set_xlim(2e5, 1e8)
            else:
                ax.set_xlim(2e5, 2e8)
            ax.set_xscale("log")
            ax.set_yscale("log")
            ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: fmt_param(int(x))))
            ax.set_xlim(200000, 200000000)
            ax.yaxis.set_minor_locator(FixedLocator(yticks))
            ax.yaxis.set_major_locator(FixedLocator(yticks))
            ax.yaxis.set_minor_formatter(ScalarFormatter())
            ax.yaxis.set_major_formatter(ScalarFormatter())
            ax.tick_params(axis="y", which="minor", length=0)
            ax.set_ylim(lmin, lmax)
            sm_d = plt.cm.ScalarMappable(norm=norm_d, cmap=cmap)
            sm_d.set_array([])
            cbar_top2 = fig.colorbar(sm_d, ax=ax, pad=0.02, label=r"$D$")
            cbar_top2.set_ticks(datapoints[:, 0])
            cbar_top2.ax.set_ylim(np.min(datapoints[:, 0]), np.max(datapoints[:, 0]))
            cbar_top2.ax.yaxis.set_major_formatter(FuncFormatter(lambda x, _: fmt_param(x)))
        
            # plot the full L(D,N)
            ax = fig.add_subplot(gs[2, 0]) 
            x_centers = datapoints[:, 0]
            y_centers = num_parameters[0, :]
            def _make_edges(centers: np.ndarray) -> np.ndarray:
                """
                Small helper function to find the edges of bins in log space.
                """
                logc   = np.log10(centers)
                strides = np.diff(logc) * 0.5
                edges  = np.concatenate(
                    [
                        [logc[0] - strides[0]], 
                        logc[:-1] + strides, 
                        [logc[-1] + strides[-1]]
                    ]
                )
                return 10**edges
            d_edges = _make_edges(x_centers)
            n_edges = _make_edges(y_centers)
            cmin = 0.45
            cmax = 1.77
            cstep = 0.33
            pcm = ax.pcolormesh(d_edges, n_edges, mus, cmap=cmap_name.rstrip("_r"), vmin=cmin, vmax=cmax, shading="auto")
            pcm.set_rasterized(True)
            ax.set_xlabel(r"$D$")
            ax.set_ylabel(r"$N$")
            ax.set_xscale("log")
            ax.set_yscale("log")
            ax.set_xlim(d_edges[0], d_edges[-1])
            ax.set_ylim(n_edges[0], n_edges[-1])
            ax.set_xticks(np.unique(datapoints))
            ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: fmt_param(int(x))))
            ax.yaxis.set_major_formatter(FuncFormatter(lambda x, _: fmt_param(int(x))))
            ax.tick_params(axis="x", which="minor", length=0)
            cbar = fig.colorbar(pcm, ax=ax, pad=0.02, label=r"$L_\mathrm{val}$") 
            cbar.ax.set_ylim([cmin, cmax])
            cbar.locator = FixedLocator(np.arange(cmin, cmax + cstep, cstep))
            cbar.formatter = ScalarFormatter() 
            cbar.update_ticks()
            
            # align labels and save the figure
            fig.align_labels()
            fig.savefig(os.path.join(fig_dir, f"{model:s}_{fit_type:s}_{metric_key:s}_scaling_map_v2.pdf"))
            fig.savefig(os.path.join(fig_dir, f"{model:s}_{fit_type:s}_{metric_key:s}_scaling_map_v2.svg"))
            
        # build and export the latex table for the fit parameters & aicc
        if param_rows:
            df = pd.DataFrame(param_rows)
            latex_params_global = df.to_latex(
                index=True,
                escape=False,
                longtable=False,
                multicolumn=True,
                multicolumn_format="c",
                bold_rows=False,
                column_format=len(df.columns) * r"c@{\hspace{1em}}",
            )
            print(f"Fit parameters: {fit_type:s} & {metric_key:s}")
            print(latex_params_global)
            with open(os.path.join(fig_dir, f"{fit_type:s}_{metric_key:s}_scaling_fits.txt"), "w") as f:
                f.write(latex_params_global)