In [1]:
# keep refreshing the env to update experiments.json
%load_ext autoreload
%autoreload 2

# make cells take up the whole width to display graphs better
from IPython.display import display, HTML
display(HTML("<style>:root { --jp-notebook-max-width: 100% !important; }</style>"))

import json
import numpy as np
import plotly.graph_objects as go
from IPython.display import display
import ipywidgets as widgets

### Prep data

In [2]:
experiment_filename = "rattled-1000/experiments_20241209_050645.json"

In [3]:
# Load the data
with open(f"results/{experiment_filename}", "r") as f:
    loaded_data = json.load(f)


detailed_runs = {}  # {dataset_size: [(params, best_val_loss, best_train_loss), ...]}
for ds_size_str, runs in loaded_data.items():
    ds_size = float(ds_size_str)
    run_losses = []
    for run in runs:
        val_losses = [step_data["val_loss"] for step_data in run["losses"].values()]
        train_losses = [step_data["train_loss"] for step_data in run["losses"].values()]
        best_val_loss = min(val_losses)
        best_train_loss = min(train_losses)
        num_params = run["config"]["num_params"]
        run_losses.append((num_params, best_val_loss, best_train_loss))
    detailed_runs[ds_size] = run_losses

# Create summary_results for scaling law plot
summary_val = {}  # best achievable validation loss per dataset size
for ds_size, runs in detailed_runs.items():
    best_val = min(l[1] for l in runs)
    summary_val[ds_size] = best_val

ds_sizes = np.array(list(summary_val.keys()))
val_losses = np.array(list(summary_val.values()))

In [4]:
ds_size_str

'1196'

In [5]:
ds_size_str

'1196'

### Create plots

In [6]:
############################################
# Create Plot-1 (Scaling Law for Validation Loss)
############################################
fig_plot1 = go.FigureWidget()
fig_plot1.add_trace(
    go.Scatter(
        x=ds_sizes, 
        y=val_losses, 
        mode='lines+markers', 
        name='Best Val Loss'
    )
)
# Baselines for Validation Loss
fig_plot1.add_trace(
    go.Scatter(
        x=ds_sizes, 
        y=[314.503]*len(ds_sizes),
        mode='lines',
        name='Naive Zero',
        line=dict(dash='dot', color='#DDD5C7')
    )
)
fig_plot1.add_trace(
    go.Scatter(
        x=ds_sizes, 
        y=[101.496]*len(ds_sizes),
        mode='lines',
        name='Naive Mean',
        line=dict(dash='dot', color='#3B7EA1')
    )
)
fig_plot1.add_trace(
    go.Scatter(
        x=ds_sizes, 
        y=[100.102]*len(ds_sizes),
        mode='lines',
        name='Naive k=1',
        line=dict(dash='dot', color='#C4820E')
    )
)
fig_plot1.update_layout(
    title="Scaling Law: Validation Loss vs Dataset Size",
    xaxis_title="Dataset Size",
    yaxis_title="Validation Loss",
    template="plotly_white",
    width=900,
    height=600
)
# Add buttons for toggling axis scales (reuse same layout as before)
xaxis_buttons = [
    dict(args=[{"xaxis.type": "linear"}], label="X-Linear", method="relayout"),
    dict(args=[{"xaxis.type": "log"}], label="X-Log", method="relayout")
]
yaxis_buttons = [
    dict(args=[{"yaxis.type": "linear"}], label="Y-Linear", method="relayout"),
    dict(args=[{"yaxis.type": "log"}], label="Y-Log", method="relayout")
]
fig_plot1.update_layout(
    margin=dict(r=150),
    updatemenus=[
        dict(type="buttons", direction="up", x=1.02, y=0.05, xanchor="left", yanchor="bottom",
             showactive=True, buttons=xaxis_buttons),
        dict(type="buttons", direction="up", x=1.145, y=0.05, xanchor="left", yanchor="bottom",
             showactive=True, buttons=yaxis_buttons)
    ]
)

############################################
# Create Plot-2 (Saturation Curves for Train Loss)
############################################
fig_train_loss = go.FigureWidget()
fig_train_loss.update_layout(
    title="Saturation Curves: Click a point in Plot-1 to view Train Loss curves",
    xaxis_title="Epoch",
    yaxis_title="Training Loss",
    template="plotly_white",
    width=720,
    height=480,
)
fig_train_loss.update_layout(
    margin=dict(r=150),
    updatemenus=[
        dict(type="buttons", direction="up", x=1.02, y=0.05, xanchor="left", yanchor="bottom",
             showactive=True, buttons=xaxis_buttons),
        dict(type="buttons", direction="up", x=1.145, y=0.05, xanchor="left", yanchor="bottom",
             showactive=True, buttons=yaxis_buttons)
    ]
)

############################################
# Create Plot-3 (Saturation Curves for Validation Loss)
############################################
fig_val_loss = go.FigureWidget()
fig_val_loss.update_layout(
    title="Saturation Curves: Click a point in Plot-1 to view Val Loss curves",
    xaxis_title="Epoch",
    yaxis_title="Validation Loss",
    template="plotly_white",
    width=720,
    height=480,
)
fig_val_loss.update_layout(
    margin=dict(r=150),
    updatemenus=[
        dict(type="buttons", direction="up", x=1.02, y=0.05, xanchor="left", yanchor="bottom",
             showactive=True, buttons=xaxis_buttons),
        dict(type="buttons", direction="up", x=1.145, y=0.05, xanchor="left", yanchor="bottom",
             showactive=True, buttons=yaxis_buttons)
    ]
)

# Callback function to update saturation curves based on clicked point in Plot-1
def update_saturation_curves(trace, points, selector):
    if points.point_inds:
        idx = points.point_inds[0]
        ds_size = ds_sizes[idx]
        
        # Clear current data
        fig_val_loss.data = []
        fig_train_loss.data = []
        
        # Get runs for the selected dataset size (convert key to int if needed)
        runs = loaded_data.get(str(int(ds_size)), [])
        
        for run in runs:
            epochs = []
            val_losses = []
            train_losses = []
            # Iterate over epochs (assuming keys are numeric strings)
            for epoch_str in sorted(run["losses"].keys(), key=lambda x: int(x)):
                epochs.append(int(epoch_str))
                val_losses.append(run["losses"][epoch_str]["val_loss"])
                train_losses.append(run["losses"][epoch_str]["train_loss"])
            
            params = run["config"]["num_params"]
            fig_val_loss.add_trace(
                go.Scatter(
                    x=epochs,
                    y=val_losses,
                    mode='lines+markers',
                    name=f"{params} params"
                )
            )
            fig_train_loss.add_trace(
                go.Scatter(
                    x=epochs,
                    y=train_losses,
                    mode='lines+markers',
                    name=f"{params} params"
                )
            )
        
        fig_val_loss.update_layout(title=f"Val Loss Curves for Dataset Size = {int(ds_size)}")
        fig_train_loss.update_layout(title=f"Train Loss Curves for Dataset Size = {int(ds_size)}")

# Attach the callback to the first trace (scaling law) of Plot-1
fig_plot1.data[0].on_click(update_saturation_curves)

# Arrange the plots in the desired layout
top_row = widgets.VBox([fig_plot1])  # Scaling Law on top
bottom_row = widgets.HBox([fig_train_loss, fig_val_loss])  # Training loss (left), Validation loss (right)
container = widgets.VBox([top_row, bottom_row])

# Display the updated layout
display(container)

VBox(children=(VBox(children=(FigureWidget({
    'data': [{'mode': 'lines+markers',
              'name': 'Bes…