In [1]:
import os
import numpy as np
import torch
from transformers import AutoModelForCausalLM
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import ipywidgets as widgets
from IPython.display import display
from tqdm.notebook import tqdm

def load_checkpoint(checkpoint_path, step):
    path = os.path.join(checkpoint_path, f"checkpoint-{step}")
    return AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float32)

def is_bias_or_layernorm(param_name):
    return 'bias' in param_name.lower() or 'ln' in param_name.lower()

@torch.no_grad()
def precompute_differences(checkpoint_path, start_step=0, end_step=999, omit_bias_layernorm=True):
    print("Precomputing differences...")
    all_differences = []
    param_names = None
    param_shapes = None

    for step in tqdm(range(start_step, end_step)):
        model1 = load_checkpoint(checkpoint_path, step)
        model2 = load_checkpoint(checkpoint_path, step + 1)

        step_differences = {}
        
        if param_names is None:
            param_names = [name for name, _ in model1.named_parameters() 
                           if not (omit_bias_layernorm and is_bias_or_layernorm(name))]
            param_shapes = {name: param.shape for name, param in model1.named_parameters() 
                            if not (omit_bias_layernorm and is_bias_or_layernorm(name))}

        for (name1, param1), (name2, param2) in zip(model1.named_parameters(), model2.named_parameters()):
            assert name1 == name2, f"Parameter names do not match: {name1} vs {name2}"
            
            if omit_bias_layernorm and is_bias_or_layernorm(name1):
                continue
            
            diff = param2.data - param1.data
            diff_np = diff.cpu().numpy()
            step_differences[name1] = diff_np

        all_differences.append(step_differences)
        
        del model1, model2
        torch.cuda.empty_cache()

    return all_differences, param_names, param_shapes

In [12]:
@torch.no_grad()
def precompute_parameter_values(checkpoint_path, start_step=0, end_step=999, omit_bias_layernorm=False):
    print("Precomputing parameter values...")
    all_parameter_values = []
    param_names = None
    param_shapes = None

    for step in tqdm(range(start_step, end_step + 1)):
        model = load_checkpoint(checkpoint_path, step)

        step_values = {}
        
        if param_names is None:
            param_names = [name for name, _ in model.named_parameters() 
                           if not (omit_bias_layernorm and is_bias_or_layernorm(name))]
            param_shapes = {name: param.shape for name, param in model.named_parameters() 
                            if not (omit_bias_layernorm and is_bias_or_layernorm(name))}

        for name, param in model.named_parameters():
            if omit_bias_layernorm and is_bias_or_layernorm(name):
                continue
            
            param_np = param.cpu().numpy()
            step_values[name] = param_np

        all_parameter_values.append(step_values)
        
        del model
        torch.cuda.empty_cache()

    return all_parameter_values, param_names, param_shapes

In [13]:
def get_histogram_ranges(all_differences, param_names):
    all_diffs = np.concatenate([diff[name].flatten() for diff in all_differences for name in param_names])
    x_min, x_max = np.percentile(all_diffs, [0.1, 99.9])
    y_max = 0
    
    for diff in all_differences:
        all_diffs_flat = np.concatenate([diff[name].flatten() for name in param_names])
        hist, _ = np.histogram(all_diffs_flat, bins=50, range=(x_min, x_max))
        y_max = max(y_max, hist.max())
    
    return x_min, x_max, y_max

def create_interactive_plot_diffs(all_differences, param_names, param_shapes, interval, step):
    n_params = len(param_names)
    n_cols = 3
    n_rows = (n_params + n_cols - 1) // n_cols + 1  # +1 for the histogram

    fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=param_names + ["Distribution of All Parameter Differences"])
    fig_widget = go.FigureWidget(fig)

    all_diffs_flat = np.concatenate([all_differences[0][name].flatten() for name in param_names])
    vmin, vmax = np.percentile(all_diffs_flat, [1, 99])

    for idx, name in enumerate(param_names):
        row = idx // n_cols + 1
        col = idx % n_cols + 1
        
        diff = all_differences[0][name]
        
        if len(param_shapes[name]) == 1:  # For 1D tensors
            heatmap = go.Heatmap(z=[diff], colorscale='RdBu', zmin=vmin, zmax=vmax, zmid=0)
        else:
            heatmap = go.Heatmap(z=diff, colorscale='RdBu', zmin=vmin, zmax=vmax, zmid=0)
        
        fig_widget.add_trace(heatmap, row=row, col=col)
        fig_widget.update_xaxes(title_text=f"Shape: {param_shapes[name]}", row=row, col=col)
        fig_widget.update_yaxes(title_text=name, row=row, col=col)

    # Get fixed ranges for histogram
    x_min, x_max, y_max = get_histogram_ranges(all_differences, param_names)

    # Plot histogram of all differences
    fig_widget.add_trace(go.Histogram(x=all_diffs_flat, nbinsx=50, autobinx=False, xbins=dict(start=x_min, end=x_max, size=(x_max-x_min)/50)), row=n_rows, col=1)
    fig_widget.update_xaxes(title_text="Difference Value", range=[x_min, x_max], row=n_rows, col=1)
    fig_widget.update_yaxes(title_text="Frequency", range=[0, y_max * 1.1], row=n_rows, col=1)

    fig_widget.update_layout(height=300*n_rows, width=1200)

    def update_plot(step):
        diff_dict = all_differences[step]
        all_diffs_flat = np.concatenate([diff_dict[name].flatten() for name in param_names])

        for idx, name in enumerate(param_names):
            diff = diff_dict[name]
            if len(param_shapes[name]) == 1:
                fig_widget.data[idx].z = [diff]
            else:
                fig_widget.data[idx].z = diff

        fig_widget.data[-1].x = all_diffs_flat
        fig_widget.layout.title.text = f"Parameter Differences between Checkpoints {step} and {step+1}"

    step_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=len(all_differences)-1,
        step=1,
        description='Step:',
        continuous_update=False
    )

    play_button = widgets.Play(
        value=0,
        min=0,
        max=len(all_differences)-1,
        step=step,
        interval=interval,
        description="Play",
        disabled=False
    )

    widgets.jslink((play_button, 'value'), (step_slider, 'value'))

    def on_value_change(change):
        update_plot(change['new'])

    step_slider.observe(on_value_change, names='value')

    controls = widgets.HBox([play_button, step_slider])
    output = widgets.VBox([controls, fig_widget])
    display(output)

    # Initial plot
    update_plot(0)

In [17]:
def create_interactive_plot_params(all_parameter_values, param_names, param_shapes, interval, step):
    n_params = len(param_names)
    n_cols = 3
    n_rows = (n_params + n_cols - 1) // n_cols + 1  # +1 for the histogram

    fig = make_subplots(rows=n_rows, cols=n_cols, subplot_titles=param_names + ["Distribution of All Parameter Values"])
    fig_widget = go.FigureWidget(fig)

    all_values_flat = np.concatenate([all_parameter_values[0][name].flatten() for name in param_names])
    vmin, vmax = np.percentile(all_values_flat, [1, 99])

    for idx, name in enumerate(param_names):
        row = idx // n_cols + 1
        col = idx % n_cols + 1
        
        values = all_parameter_values[0][name]
        
        if len(param_shapes[name]) == 1:  # For 1D tensors
            heatmap = go.Heatmap(z=[values], colorscale='RdBu', zmin=vmin, zmax=vmax, zmid=0)
        else:
            heatmap = go.Heatmap(z=values, colorscale='RdBu', zmin=vmin, zmax=vmax, zmid=0)
        
        fig_widget.add_trace(heatmap, row=row, col=col)
        fig_widget.update_xaxes(title_text=f"Shape: {param_shapes[name]}", row=row, col=col)
        fig_widget.update_yaxes(title_text=name, row=row, col=col)

    # Get fixed ranges for histogram
    x_min, x_max, y_max = get_histogram_ranges(all_parameter_values, param_names)

    # Plot histogram of all parameter values
    fig_widget.add_trace(go.Histogram(x=all_values_flat, nbinsx=50, autobinx=False, xbins=dict(start=x_min, end=x_max, size=(x_max-x_min)/50)), row=n_rows, col=1)
    fig_widget.update_xaxes(title_text="Parameter Value", range=[x_min, x_max], row=n_rows, col=1)
    fig_widget.update_yaxes(title_text="Frequency", range=[0, y_max * 1.1], row=n_rows, col=1)

    fig_widget.update_layout(height=300*n_rows, width=1200)

    def update_plot(step):
        value_dict = all_parameter_values[step]
        all_values_flat = np.concatenate([value_dict[name].flatten() for name in param_names])

        for idx, name in enumerate(param_names):
            values = value_dict[name]
            if len(param_shapes[name]) == 1:
                fig_widget.data[idx].z = [values]
            else:
                fig_widget.data[idx].z = values

        fig_widget.data[-1].x = all_values_flat
        fig_widget.layout.title.text = f"Parameter Values at Checkpoint {step}"

    step_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=len(all_parameter_values)-1,
        step=1,
        description='Step:',
        continuous_update=False
    )

    play_button = widgets.Play(
        value=0,
        min=0,
        max=len(all_parameter_values)-1,
        step=step,
        interval=interval,
        description="Play",
        disabled=False
    )

    widgets.jslink((play_button, 'value'), (step_slider, 'value'))

    def on_value_change(change):
        update_plot(change['new'])

    step_slider.observe(on_value_change, names='value')

    controls = widgets.HBox([play_button, step_slider])
    output = widgets.VBox([controls, fig_widget])
    display(output)

    # Initial plot
    update_plot(0)

In [22]:
checkpoint_dir = '/home/hoyeon/circuit-analysis/checkpoints'
# exp_name = 'L1_H1_E16_wd1_len10_bsize1_code10_lr-4_pure-mixed'
exp_name = 'L1_H1_E16_len10_code10_pos5_bsize64_lr0.0001_wd0.1_pure-mixed'
checkpoint_path = os.path.join(checkpoint_dir, exp_name)

start_step = 0
end_step = 19999  # Adjust this based on your total number of checkpoints

# Precompute differences
all_differences, param_names, param_shapes = precompute_differences(checkpoint_path, start_step, end_step)

Precomputing differences...


  0%|          | 0/19999 [00:00<?, ?it/s]

In [23]:
all_parameter_values, param_names, param_shapes = precompute_parameter_values(checkpoint_path, start_step, end_step, omit_bias_layernorm=True)

Precomputing parameter values...


  0%|          | 0/20000 [00:00<?, ?it/s]

In [24]:
# Create interactive plot
create_interactive_plot_diffs(all_differences, param_names, param_shapes, interval=10, step=10)

VBox(children=(HBox(children=(Play(value=0, description='Play', interval=10, max=19998, step=10), IntSlider(va…

In [25]:
create_interactive_plot_params(all_parameter_values, param_names, param_shapes, interval=10, step=10)

VBox(children=(HBox(children=(Play(value=0, description='Play', interval=10, max=19999, step=10), IntSlider(va…