In [1]:
import json
import numpy as np
import plotly.graph_objects as go
from IPython.display import display
import ipywidgets as widgets

### Prep data

In [2]:
experiment_filename = "rattled-1000/experiments_20241209_050645.json"

In [3]:
# Load the data
with open(f"results/{experiment_filename}", "r") as f:
    loaded_data = json.load(f)

detailed_runs = {}  # {dataset_size: [(params, best_val_loss), ...]}
for ds_size_str, runs in loaded_data.items():
    ds_size = float(ds_size_str)
    run_losses = []
    for run in runs:
        val_losses = [step_data["val_loss"] for step_data in run["losses"].values()]
        best_val_loss = min(val_losses)
        num_params = run["config"]["num_params"]
        run_losses.append((num_params, best_val_loss))
    detailed_runs[ds_size] = run_losses

# summary_results for plot-1: the minimal loss at each dataset size
summary_results = {}
for ds_size, runs in detailed_runs.items():
    best_loss = min(l[1] for l in runs)
    summary_results[ds_size] = best_loss

ds_sizes = np.array(list(summary_results.keys()))
losses = np.array(list(summary_results.values()))

### Create plots

In [4]:
############################################
# Create Plot-1 (Scaling Law)
############################################
fig_plot1 = go.FigureWidget()

fig_plot1.add_trace(
    go.Scatter(
        x=ds_sizes, 
        y=losses, 
        mode='lines+markers', 
        name='Best Achievable Loss'
    )
)

# Naive Zero Baseline: y=314.503
fig_plot1.add_trace(
    go.Scatter(
        x=ds_sizes,
        y=[314.503]*len(ds_sizes),
        mode='lines',
        name='Naive Zero',
        line=dict(dash='dot', color='#DDD5C7'),
    )
)

# Naive Mean Baseline: y=101.496
fig_plot1.add_trace(
    go.Scatter(
        x=ds_sizes,
        y=[101.496]*len(ds_sizes),
        mode='lines',
        name='Naive Mean',
        line=dict(dash='dot', color='#3B7EA1'),
    )
)

# Naive k=1 Baseline: y=100.102
fig_plot1.add_trace(
    go.Scatter(
        x=ds_sizes,
        y=[100.102]*len(ds_sizes),
        mode='lines',
        name='Naive k=1',
        line=dict(dash='dot', color='#C4820E'),
    )
)


fig_plot1.update_layout(
    title="Scaling Law: Loss vs Dataset Size",
    xaxis_title="Dataset Size",
    yaxis_title="Loss",
    template="plotly_white",
    # Set plot width and height to be the same for a square aspect ratio
    width=900,
    height=600,
    xaxis=dict(),
    yaxis=dict(
        # range=[float(np.min(losses)), float(np.max(losses))],
        # scaleanchor="x",
        scaleratio=1
    )
)

# Add buttons to toggle x-axis scale
xaxis_buttons = [
    dict(
        args=[{"xaxis.type": "linear"}],
        label="X-Linear",
        method="relayout"
    ),
    dict(
        args=[{"xaxis.type": "log"}],
        label="X-Log",
        method="relayout"
    )
]

# Add buttons to toggle y-axis scale
yaxis_buttons = [
    dict(
        args=[{"yaxis.type": "linear"}],
        label="Y-Linear",
        method="relayout"
    ),
    dict(
        args=[{"yaxis.type": "log"}],
        label="Y-Log",
        method="relayout"
    )
]

fig_plot1.update_layout(
    margin=dict(r=150),
    updatemenus=[
        dict(
            type="buttons",
            direction="up",
            x=1.02,
            y=0.05,
            xanchor="left",
            yanchor="bottom",
            showactive=True,
            buttons=xaxis_buttons
        ),
        dict(
            type="buttons",
            direction="up",
            x=1.145,
            y=0.05,
            xanchor="left",
            yanchor="bottom",
            showactive=True,
            buttons=yaxis_buttons
        )
    ]
)


############################################
# Create Plot-2 (Saturation Curves)
############################################
fig_plot2 = go.FigureWidget()
fig_plot2.update_layout(
    title="Saturation Curves: Click a point in Plot-1 to view",
    xaxis_title="Validation Steps",
    yaxis_title="Loss",
    template="plotly_white",
    width=900,
    height=600,
)

# Optional: add log scale toggle for plot-2 as well
xaxis_buttons_2 = [
    dict(
        args=[{"xaxis.type": "linear"}],
        label="X-Linear",
        method="relayout"
    ),
    dict(
        args=[{"xaxis.type": "log"}],
        label="X-Log",
        method="relayout"
    )
]

yaxis_buttons_2 = [
    dict(
        args=[{"yaxis.type": "linear"}],
        label="Y-Linear",
        method="relayout"
    ),
    dict(
        args=[{"yaxis.type": "log"}],
        label="Y-Log",
        method="relayout"
    )
]

fig_plot2.update_layout(
    margin=dict(r=150),
    updatemenus=[
        dict(
            type="buttons",
            direction="up",
            x=1.02,
            y=0.05,
            xanchor="left",
            yanchor="bottom",
            showactive=True,
            buttons=xaxis_buttons
        ),
        dict(
            type="buttons",
            direction="up",
            x=1.145,
            y=0.05,
            xanchor="left",
            yanchor="bottom",
            showactive=True,
            buttons=yaxis_buttons
        )
    ]
)

############################################
# Callback: When user clicks on a data point in Plot-1,
# update Plot-2 to show saturation curves for that dataset size.
############################################
def update_plot2(trace, points, selector):
    if points.point_inds:
        idx = points.point_inds[0]
        ds_size = ds_sizes[idx]
        
        # Clear current data in fig_plot2
        fig_plot2.data = []
        
        # Loop over each run for the chosen dataset size and plot them individually
        for run in loaded_data[str(int(ds_size))]:
            val_losses = [step_data["val_loss"] for step_data in run["losses"].values()]
            params = run["config"]["num_params"]
            
            # If you have multiple steps, use their indices as the x-axis. 
            # If you want to plot by model size directly (like a saturation curve), 
            # ensure you have data that associates each val_loss with a parameter count.
            # For simplicity, here we just use the training steps as x-axis:
            steps = list(range(len(val_losses)))

            fig_plot2.add_trace(
                go.Scatter(
                    x=steps,
                    y=val_losses,
                    mode='lines+markers',
                    name=f"{params} params"
                )
            )
        
        fig_plot2.update_layout(
            title=f"Saturation Curves for Dataset Size = {int(ds_size)}"
        )


# Attach the callback to the first trace of fig_plot1
fig_plot1.data[0].on_click(update_plot2)

############################################
# Display the two figures side by side
############################################
container = widgets.VBox([fig_plot1, fig_plot2])
display(container)

VBox(children=(FigureWidget({
    'data': [{'mode': 'lines+markers',
              'name': 'Best Achievable Lo…