In [1]:
# keep refreshing the env to update experiments.json
%load_ext autoreload
%autoreload 2

# make cells take up the whole width to display graphs better
from IPython.display import display, HTML, Markdown
display(HTML("<style>:root { --jp-notebook-max-width: 100% !important; }</style>"))

import json
import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from IPython.display import display
import ipywidgets as widgets
from scipy.optimize import curve_fit
from math import ceil

In [2]:
# Define the experiment filenames as a list
# You can modify this list to include any number of experiment files
experiment_filenames = [
    # "merged.json",                          # Transformer batch_size=8
    "experiments_20250129_162556.json"      # FCN batch_size=64
]

### Helper funcs

In [3]:
def load_and_process_data(filename):
    """
    Load and process the experimental data from a JSON file.

    Args:
        filename (str): The name of the JSON file.

    Returns:
        summary_val (dict): Best achievable validation loss per dataset size.
        detailed_runs (dict): Detailed runs data structured by dataset size.
        loaded_data (dict): The raw loaded JSON data.
    """
    with open(f"results/{filename}", "r") as f:
        loaded_data = json.load(f)

    detailed_runs = {}  # {dataset_size: [(params, best_val_loss, best_train_loss), ...]}
    for ds_size_str, runs in loaded_data.items():
        ds_size = float(ds_size_str)
        run_losses = []
        for run in runs:
            val_losses = [step_data["val_loss"] for step_data in run["losses"].values()]
            train_losses = [step_data["train_loss"] for step_data in run["losses"].values()]
            best_val_loss = min(val_losses)
            best_train_loss = min(train_losses)
            num_params = run["config"]["num_params"]
            run_losses.append((num_params, best_val_loss, best_train_loss, run))  # Store run data for callbacks
        detailed_runs[ds_size] = run_losses

    # Create summary_results for scaling law plot
    summary_val = {}  # best achievable validation loss per dataset size
    for ds_size, runs in detailed_runs.items():
        best_val = min(l[1] for l in runs)
        summary_val[ds_size] = best_val

    return summary_val, detailed_runs, loaded_data

def prepare_scaling_data(summary_val):
    """
    Prepare and sort the scaling data.

    Args:
        summary_val (dict): Best achievable validation loss per dataset size.

    Returns:
        sorted_ds_sizes (np.array): Sorted dataset sizes.
        sorted_val_losses (np.array): Validation losses sorted accordingly.
    """
    ds_sizes = np.array(list(summary_val.keys()))
    val_losses = np.array(list(summary_val.values()))
    sorted_indices = np.argsort(ds_sizes)
    sorted_ds_sizes = ds_sizes[sorted_indices]
    sorted_val_losses = val_losses[sorted_indices]
    return sorted_ds_sizes, sorted_val_losses

def calculate_scaling_law(ds_sizes, val_losses):
    """
    Calculate the scaling law parameters and fitted values.

    Args:
        ds_sizes (np.array): Array of dataset sizes.
        val_losses (np.array): Array of validation losses.

    Returns:
        fitted_vals (np.array): Fitted validation loss values based on scaling law.
        a (float): Scaling law parameter a.
        b (float): Scaling law parameter b.
    """
    # Define the power-law function
    def power_law(x, a, b):
        return a * x ** (-b)

    # Fit the power-law curve to the data
    popt, _ = curve_fit(power_law, ds_sizes, val_losses, p0=(1, 1))  # Initial guess for a and b

    # Extract parameters
    a, b = popt

    fitted_vals = power_law(ds_sizes, a, b)
    return fitted_vals, a, b

### Prep data

In [4]:
# ------------------ buggy ------------------ 
# all FCN models 50 epochs
# experiment_filename = "experiments_20250111_214659.json"

# all Transformer models 50 epochs                          ********
# experiment_filename = "experiments_20250112_062438.json"

# ------------------ bugfix ------------------ 
# all Transformer models 50 epochs  
# experiment_filename = "experiments_20250123_155336.json"

# ------------------ all loss equally scaled ------------------ 
# experiment_filename = "experiments_20250123_162306.json"


# ------------------ low batch size (8) on 0.1 fraction ------------------ 
# experiment_filename = "experiments_20250123_163741.json"

# ------------------ low batch size (8) ------------------ 
# 0.01 0.02 0.04 0.08
# experiment_filename = "experiments_20250124_102758.json"
# 0.1 0.2 0.4 0.8 1.0
# experiment_filename = "experiments_20250123_165735.json"
# experiment_filename_1 = "merged.json"

In [5]:
# Initialize dictionaries to store data for all experiments
all_summary_vals = {}
all_detailed_runs = {}
all_loaded_data = {}
all_sorted_ds_sizes = {}
all_sorted_val_losses = {}
all_fitted_vals = {}
all_scaling_params = {}

# Load and process each experiment file
for filename in experiment_filenames:
    summary_val, detailed_runs, loaded_data = load_and_process_data(filename)
    sorted_ds_sizes, sorted_val_losses = prepare_scaling_data(summary_val)
    # fitted_vals, a, b = calculate_scaling_law(sorted_ds_sizes, sorted_val_losses)
    
    # Store data in dictionaries
    all_summary_vals[filename] = summary_val
    all_detailed_runs[filename] = detailed_runs
    all_loaded_data[filename] = loaded_data
    all_sorted_ds_sizes[filename] = sorted_ds_sizes
    all_sorted_val_losses[filename] = sorted_val_losses
    # all_fitted_vals[filename] = fitted_vals
    # all_scaling_params[filename] = (a, b)

### Create plots

In [6]:
# Create Plot-1 (Scaling Law for Validation Loss)
fig_plot1 = go.FigureWidget()

# Define color and marker styles for different experiments
color_palette = px.colors.qualitative.Plotly
marker_symbols = ['circle', 'square', 'diamond', 'cross', 'triangle-up', 'triangle-down', 'star']

# To keep track of which colors and symbols are used
color_map = {}
marker_map = {}

for idx, filename in enumerate(experiment_filenames):
    color = color_palette[idx % len(color_palette)]
    marker = marker_symbols[idx % len(marker_symbols)]
    color_map[filename] = color
    marker_map[filename] = marker
    
    # Extract data
    sorted_ds_sizes = all_sorted_ds_sizes[filename]
    sorted_val_losses = all_sorted_val_losses[filename]
    # fitted_vals = all_fitted_vals[filename]
    # a, b = all_scaling_params[filename]
    
    # Original Data Trace
    fig_plot1.add_trace(
        go.Scatter(
            x=sorted_ds_sizes, 
            y=sorted_val_losses, 
            mode='lines+markers', 
            name=f'Best Val Loss<br>{filename}',
            marker=dict(symbol=marker, size=8, color=color),
            customdata=[{'experiment': filename, 'dataset_size': ds} for ds in sorted_ds_sizes],
            hovertemplate='Dataset Size: %{x}<br>Val Loss: %{y}<extra></extra>',
        )
    )
    
    # # Scaling Law Trace
    # fig_plot1.add_trace(
    #     go.Scatter(
    #         x=sorted_ds_sizes, 
    #         y=fitted_vals, 
    #         mode='lines', 
    #         name=f'y = {a:.2f}x^(-{b:.2f})',
    #         line=dict(dash='dash', color=color),
    #         hovertemplate='Scaling Law: y = {:.2f}x^(-{:.2f})<extra></extra>'.format(a, b),
    #     )
    # )

fig_plot1.update_layout(
    title="Scaling Law: Validation Loss vs Dataset Size",
    xaxis_title="Dataset Size",
    yaxis_title="Validation Loss",
    template="plotly_white",
    width=900,
    height=600,
    legend=dict(x=1.02, y=1)
)

# Add buttons for toggling axis scales
xaxis_buttons = [
    dict(args=[{"xaxis.type": "linear"}], label="X-Linear", method="relayout"),
    dict(args=[{"xaxis.type": "log"}], label="X-Log", method="relayout")
]
yaxis_buttons = [
    dict(args=[{"yaxis.type": "linear"}], label="Y-Linear", method="relayout"),
    dict(args=[{"yaxis.type": "log"}], label="Y-Log", method="relayout")
]
fig_plot1.update_layout(
    margin=dict(r=150),
    updatemenus=[
        dict(type="buttons", direction="up", x=1.02, y=0.05, xanchor="left", yanchor="bottom",
             showactive=True, buttons=xaxis_buttons, name="xaxis_buttons"),
        dict(type="buttons", direction="up", x=1.145, y=0.05, xanchor="left", yanchor="bottom",
             showactive=True, buttons=yaxis_buttons, name="yaxis_buttons")
    ]
)

# Create Plot-2 (Saturation Curves for Train Loss)
fig_train_loss = go.FigureWidget()
fig_train_loss.update_layout(
    title="Saturation Curves: Click a point in Plot-1 to view Train Loss curves",
    xaxis_title="Epoch",
    yaxis_title="Training Loss",
    template="plotly_white",
    width=720,
    height=480,
)
fig_train_loss.update_layout(
    margin=dict(r=150),
    updatemenus=[
        dict(type="buttons", direction="up", x=1.02, y=0.05, xanchor="left", yanchor="bottom",
             showactive=True, buttons=xaxis_buttons),
        dict(type="buttons", direction="up", x=1.145, y=0.05, xanchor="left", yanchor="bottom",
             showactive=True, buttons=yaxis_buttons)
    ]
)

# Create Plot-3 (Saturation Curves for Validation Loss)
fig_val_loss = go.FigureWidget()
fig_val_loss.update_layout(
    title="Saturation Curves: Click a point in Plot-1 to view Val Loss curves",
    xaxis_title="Epoch",
    yaxis_title="Validation Loss",
    template="plotly_white",
    width=720,
    height=480,
)
fig_val_loss.update_layout(
    margin=dict(r=150),
    updatemenus=[
        dict(type="buttons", direction="up", x=1.02, y=0.05, xanchor="left", yanchor="bottom",
             showactive=True, buttons=yaxis_buttons),
        dict(type="buttons", direction="up", x=1.145, y=0.05, xanchor="left", yanchor="bottom",
             showactive=True, buttons=yaxis_buttons)
    ]
)

# Callback function to update saturation curves based on clicked point in Plot-1
def update_saturation_curves(trace, points, selector):
    if points.point_inds:
        idx = points.point_inds[0]
        # Retrieve custom data to identify the experiment and dataset size
        clicked_data = trace.customdata[idx]
        filename = clicked_data['experiment']
        ds_size = clicked_data['dataset_size']
        
        # Retrieve the corresponding detailed runs and loaded data
        detailed_runs = all_detailed_runs[filename]
        loaded_data = all_loaded_data[filename]
        
        # Clear existing data in saturation plots
        fig_val_loss.data = []
        fig_train_loss.data = []
        
        # Retrieve all runs for the selected dataset size
        runs = detailed_runs.get(ds_size, [])
        
        # Iterate through each run and plot the saturation curves
        for run in runs:
            num_params, best_val_loss, best_train_loss, run_data = run
            batch_size = run_data["config"]["batch_size"]
            dataset_size = run_data["config"]["dataset_size"]
            dataset_size = int(dataset_size)  # Ensure it's an integer

            # Calculate number of batches per epoch
            num_batches = ceil(dataset_size / batch_size)
            
            # Group the logged steps by epoch
            epoch_dict = {}
            for step_str, step_data in run_data["losses"].items():
                step_int = int(step_str)
                epoch_idx = step_int // num_batches  # integer division
                if epoch_idx not in epoch_dict:
                    epoch_dict[epoch_idx] = {"train": [], "val": []}
                epoch_dict[epoch_idx]["train"].append(step_data["train_loss"])
                epoch_dict[epoch_idx]["val"].append(step_data["val_loss"])
            
            # Build arrays of (epoch, train_loss, val_loss) by taking the last logged or average in each epoch
            epochs = sorted(epoch_dict.keys())
            epoch_train = []
            epoch_val = []
            for e in epochs:
                # You can choose to take the last or average; here we take the last
                epoch_train.append(epoch_dict[e]["train"][-1])
                epoch_val.append(epoch_dict[e]["val"][-1])
            
            # Update Train Loss Plot
            fig_train_loss.add_trace(
                go.Scatter(
                    x=epochs,
                    y=epoch_train,
                    mode='lines+markers',
                    name=f"{num_params} params"
                )
            )
            
            # Update Validation Loss Plot
            fig_val_loss.add_trace(
                go.Scatter(
                    x=epochs,
                    y=epoch_val,
                    mode='lines+markers',
                    name=f"{num_params} params"
                )
            )
        
        # Update titles to reflect the selected dataset size and experiment
        fig_train_loss.update_layout(title=f"Train Loss Curves for Dataset Size = {int(ds_size)}<br>({filename})")
        fig_val_loss.update_layout(title=f"Val Loss Curves for Dataset Size = {int(ds_size)}<br>({filename})")

# Attach the callback to the data traces (only the "Best Val Loss" traces)
for trace in fig_plot1.data:
    if 'Best Val Loss' in trace.name:
        trace.on_click(update_saturation_curves)

# Arrange the plots in the desired layout
top_row = widgets.VBox([fig_plot1])  # Scaling Law on top
bottom_row = widgets.HBox([fig_train_loss, fig_val_loss])  # Training loss (left), Validation loss (right)
container = widgets.VBox([top_row, bottom_row])

# Display the updated layout
display(container)

# Optionally, display the scaling law parameters for all experiments
# for filename in experiment_filenames:
#     a, b = all_scaling_params[filename]
#     print(f"Scaling Law Parameters for {filename}:\na = {a:.4f}\nb = {b:.4f}\n")


VBox(children=(VBox(children=(FigureWidget({
    'data': [{'customdata': [{'dataset_size': 3.0, 'experiment': …

In [7]:
scaling_law_md = f"""
### Scaling Law


$y = {a:.4f} \\cdot x^{{-{b:.4f}}}$

Where:
- $a$: The scaling constant, representing the loss when $x = 1$.
- $b$: The scaling exponent, describing how the loss decreases with dataset size.
"""

display(Markdown(scaling_law_md))

NameError: name 'a' is not defined