In [None]:
# keep refreshing the env to update experiments.json
%load_ext autoreload
%autoreload 2

# make cells take up the whole width to display graphs better
from IPython.display import display, HTML, Markdown
display(HTML("<style>:root { --jp-notebook-max-width: 100% !important; }</style>"))

import json
import numpy as np
import plotly.graph_objects as go
from IPython.display import display
import ipywidgets as widgets
from scipy.optimize import curve_fit

### Prep data

In [None]:
# all FCN models 50 epochs
experiment_filename = "experiments_20250111_214659.json"

# all Transformer models 50 epochs
# experiment_filename = "experiments_20250112_062438.json"

In [None]:
# Load the data
with open(f"results/{experiment_filename}", "r") as f:
    loaded_data = json.load(f)


detailed_runs = {}  # {dataset_size: [(params, best_val_loss, best_train_loss), ...]}
for ds_size_str, runs in loaded_data.items():
    ds_size = float(ds_size_str)
    run_losses = []
    for run in runs:
        val_losses = [step_data["val_loss"] for step_data in run["losses"].values()]
        train_losses = [step_data["train_loss"] for step_data in run["losses"].values()]
        best_val_loss = min(val_losses)
        best_train_loss = min(train_losses)
        num_params = run["config"]["num_params"]
        run_losses.append((num_params, best_val_loss, best_train_loss))
    detailed_runs[ds_size] = run_losses

# Create summary_results for scaling law plot
summary_val = {}  # best achievable validation loss per dataset size
for ds_size, runs in detailed_runs.items():
    best_val = min(l[1] for l in runs)
    summary_val[ds_size] = best_val

ds_sizes = np.array(list(summary_val.keys()))
val_losses = np.array(list(summary_val.values()))

In [None]:
def calculate_scaling_law(ds_sizes, val_losses):
    """
    Calculate the scaling law parameters and plot the relationship.

    Args:
        ds_sizes (np.array): Array of dataset sizes.
        val_losses (np.array): Array of validation losses.

    Returns:
        tuple: Fitted values, and scaling law parameters a and b.
    """

    # Define the power-law function
    def power_law(x, a, b):
        return a * x ** (-b)

    # Fit the power-law curve to the data
    popt, _ = curve_fit(power_law, ds_sizes, val_losses, p0=(1, 1))  # Initial guess for a and b

    # Extract parameters
    a, b = popt

    fitted_vals = power_law(ds_sizes, a, b)
    return fitted_vals, a, b

# Sort the data to ensure proper plotting
sorted_indices = np.argsort(ds_sizes)
sorted_ds_sizes = ds_sizes[sorted_indices]
sorted_val_losses = val_losses[sorted_indices]

# Calculate the scaling law
fitted_vals, a, b = calculate_scaling_law(sorted_ds_sizes, sorted_val_losses)

### Create plots

In [None]:
# Create Plot-1 (Scaling Law for Validation Loss)
fig_plot1 = go.FigureWidget()
# Original Data Trace
fig_plot1.add_trace(
    go.Scatter(
        x=sorted_ds_sizes, 
        y=sorted_val_losses, 
        mode='lines+markers', 
        name='Best Val Loss'
    )
)
# Scaling Law Trace
fig_plot1.add_trace(
    go.Scatter(
        x=sorted_ds_sizes, 
        y=fitted_vals, 
        mode='lines', 
        name=f'Scaling Law: y = {a:.2f}x^(-{b:.2f})',
        line=dict(dash='dash', color='red')
    )
)

fig_plot1.update_layout(
    title="Scaling Law: Validation Loss vs Dataset Size",
    xaxis_title="Dataset Size",
    yaxis_title="Validation Loss",
    template="plotly_white",
    width=900,
    height=600,
    legend=dict(x=1.02, y=1)
)

# Add buttons for toggling axis scales
xaxis_buttons = [
    dict(args=[{"xaxis.type": "linear"}], label="X-Linear", method="relayout"),
    dict(args=[{"xaxis.type": "log"}], label="X-Log", method="relayout")
]
yaxis_buttons = [
    dict(args=[{"yaxis.type": "linear"}], label="Y-Linear", method="relayout"),
    dict(args=[{"yaxis.type": "log"}], label="Y-Log", method="relayout")
]
fig_plot1.update_layout(
    margin=dict(r=150),
    updatemenus=[
        dict(type="buttons", direction="up", x=1.02, y=0.05, xanchor="left", yanchor="bottom",
             showactive=True, buttons=xaxis_buttons),
        dict(type="buttons", direction="up", x=1.145, y=0.05, xanchor="left", yanchor="bottom",
             showactive=True, buttons=yaxis_buttons)
    ]
)

# Create Plot-2 (Saturation Curves for Train Loss)
fig_train_loss = go.FigureWidget()
fig_train_loss.update_layout(
    title="Saturation Curves: Click a point in Plot-1 to view Train Loss curves",
    xaxis_title="Epoch",
    yaxis_title="Training Loss",
    template="plotly_white",
    width=720,
    height=480,
)
fig_train_loss.update_layout(
    margin=dict(r=150),
    updatemenus=[
        dict(type="buttons", direction="up", x=1.02, y=0.05, xanchor="left", yanchor="bottom",
             showactive=True, buttons=xaxis_buttons),
        dict(type="buttons", direction="up", x=1.145, y=0.05, xanchor="left", yanchor="bottom",
             showactive=True, buttons=yaxis_buttons)
    ]
)

# Create Plot-3 (Saturation Curves for Validation Loss)
fig_val_loss = go.FigureWidget()
fig_val_loss.update_layout(
    title="Saturation Curves: Click a point in Plot-1 to view Val Loss curves",
    xaxis_title="Epoch",
    yaxis_title="Validation Loss",
    template="plotly_white",
    width=720,
    height=480,
)
fig_val_loss.update_layout(
    margin=dict(r=150),
    updatemenus=[
        dict(type="buttons", direction="up", x=1.02, y=0.05, xanchor="left", yanchor="bottom",
             showactive=True, buttons=yaxis_buttons),
        dict(type="buttons", direction="up", x=1.145, y=0.05, xanchor="left", yanchor="bottom",
             showactive=True, buttons=yaxis_buttons)
    ]
)

# Callback function to update saturation curves based on clicked point in Plot-1
def update_saturation_curves(trace, points, selector):
    if points.point_inds:
        idx = points.point_inds[0]
        ds_size = sorted_ds_sizes[idx]
        
        fig_val_loss.data = []
        fig_train_loss.data = []
        
        runs = loaded_data.get(str(int(ds_size)), [])
        
        for run in runs:
            # Retrieve relevant config info
            batch_size = run["config"]["batch_size"]
            dataset_size = run["config"]["dataset_size"]  # might need str->int conversion
            # If dataset_size is stored as float, ensure int conversion
            dataset_size = int(dataset_size)

            from math import ceil
            # Approx number of batches per epoch
            num_batches = ceil(dataset_size / batch_size)
            
            # Group the logged steps by epoch
            epoch_dict = {}
            for step_str, step_data in run["losses"].items():
                step_int = int(step_str)
                epoch_idx = step_int // num_batches  # integer division
                if epoch_idx not in epoch_dict:
                    epoch_dict[epoch_idx] = {"train": [], "val": []}
                epoch_dict[epoch_idx]["train"].append(step_data["train_loss"])
                epoch_dict[epoch_idx]["val"].append(step_data["val_loss"])
            
            # Build arrays of (epoch, train_loss, val_loss) by taking the last or avg in each epoch
            epochs = sorted(epoch_dict.keys())
            epoch_train = []
            epoch_val = []
            for e in epochs:
                # e.g. take the last logged train/val in that epoch
                epoch_train.append(epoch_dict[e]["train"][-1])
                epoch_val.append(epoch_dict[e]["val"][-1])
            
            params = run["config"]["num_params"]
            fig_train_loss.add_trace(
                go.Scatter(
                    x=epochs,
                    y=epoch_train,
                    mode='lines+markers',
                    name=f"{params} params"
                )
            )
            fig_val_loss.add_trace(
                go.Scatter(
                    x=epochs,
                    y=epoch_val,
                    mode='lines+markers',
                    name=f"{params} params"
                )
            )
        
        fig_train_loss.update_layout(title=f"Train Loss Curves for Dataset Size = {int(ds_size)}")
        fig_val_loss.update_layout(title=f"Val Loss Curves for Dataset Size = {int(ds_size)}")

# Attach the callback to the first trace (scaling law) of Plot-1
fig_plot1.data[0].on_click(update_saturation_curves)

# Arrange the plots in the desired layout
top_row = widgets.VBox([fig_plot1])  # Scaling Law on top
bottom_row = widgets.HBox([fig_train_loss, fig_val_loss])  # Training loss (left), Validation loss (right)
container = widgets.VBox([top_row, bottom_row])

# Display the updated layout
display(container)

# Optionally, display the scaling law parameters
print(f"Scaling Law Parameters:\na = {a:.4f}\nb = {b:.4f}")

In [None]:
scaling_law_md = f"""
### Scaling Law


$y = {a:.4f} \\cdot x^{{-{b:.4f}}}$

Where:
- $a$: The scaling constant, representing the loss when $x = 1$.
- $b$: The scaling exponent, describing how the loss decreases with dataset size.
"""

display(Markdown(scaling_law_md))