# Parsing results for recurrent models

In [1]:
import wandb
import plotly.express as px
import numpy as np
import pandas as pd

import pprint
pp = pprint.PrettyPrinter(indent=4)

import sys, os; sys.path.append(os.path.abspath('../..'))
from utils.utils import AttributeDict

In [2]:
api = wandb.Api()

# check differences in config across runs in group
wandb_project = "neural-program-induction/sorting-recurrent-resstream-exploration"
runs = api.runs(wandb_project)

groups = {}
for run in runs:
    group = run.group
    if group not in groups:
        groups[group] = []
    groups[group].append(run)

def find_differences_in_dict(dict1, dict2):
    for key, value in dict1.items():
        if key not in dict2:
            print(f'  {key}: {value}')
        elif isinstance(value, dict):
            find_differences_in_dict(value, dict2[key])
        elif value != dict2[key]:
            print(f'  {key}: {value}')

for group, runs in groups.items():
    print(f'[# {len(runs)}] Group: {group}')
    config = runs[0].config
    for run in runs:
        find_differences_in_dict(config, run.config)
    print('-'*80)

[# 5] Group: L2H4D256_rotary_IRTrue_WTFalse-False_discinterm-NA_prepostdiscnorm-False-True - progressive - MaxVal64-TrainLen16RandLen-BOSEOS
  Total Mult-Adds: 1214784
  total_params: 1215040
  num_params: 1215040
  trainable_params: 1215040
  num_trainable_params: 1215040
  Estimated total size (MB): 5.424592
  Forward/backward pass size  (MB): 0.565248
  Params size (MB): 4.859136
  norm_method: none
  seed: 556468681
  experiment_run_name: 2024-12-24-18:51:37
  Total Mult-Adds: 1214784
  total_params: 1215040
  num_params: 1215040
  trainable_params: 1215040
  num_trainable_params: 1215040
  Estimated total size (MB): 5.424592
  Forward/backward pass size  (MB): 0.565248
  Params size (MB): 4.859136
  norm_method: none
  seed: 556468681
  experiment_run_name: 2024-12-24-18:51:37
  data_config: {'max_value': 64, 'train_min_sequence_length': 2, 'val_sequence_length': 16, 'ood_test_sequence_lengths': [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], 'train_random_sequence_l

In [3]:
runs = api.runs(wandb_project)

# find config keys that are non-constant
def flatten_dict(dd, separator='.', prefix=''):
    return { prefix + separator + k if prefix else k : v
             for kk, vv in dd.items()
             for k, v in flatten_dict(vv, separator, kk).items()
             } if isinstance(dd, dict) else { prefix : dd }

config_key_vals = dict()

for run in runs:
    run_config = flatten_dict(run.config)

    for key, value in run_config.items():
        if key not in config_key_vals:
            config_key_vals[key] = set()
        if isinstance(value, list):
            value = tuple(value)
        config_key_vals[key].add(value)

In [4]:
# find unique config key values (i.e. non-constant) to use for creating interface
unique_config_key_vals = {key: value for key, value in config_key_vals.items() if len(value) > 1}

# remove model_summary keys
for k in list(unique_config_key_vals.keys()):
    if 'model_summary' in k:
        del unique_config_key_vals[k]

# remove train_config.seed, train_config.experiment_run_name, train_config.experiment_group
for k in ['train_config.seed', 'train_config.experiment_run_name', 'train_config.experiment_group']:
    if k in unique_config_key_vals:
        del unique_config_key_vals[k]

pp.pprint(unique_config_key_vals)

{   'model_config.intermediate_discretization.discrete_intermediate': {   False,
                                                                          True},
    'model_config.intermediate_discretization.discretize_map': {   'gumbel-softmax',
                                                                   'hard',
                                                                   'sigmoid',
                                                                   'softmax'},
    'model_config.norm_config.norm_method': {   'hypersphere-interpolation',
                                                'none',
                                                'post-norm',
                                                'pre-norm'},
    'model_config.postdisc_norm': {False, True},
    'model_config.predisc_norm': {False, True}}


In [5]:
def filter_runs(config_filter, runs=None):
    if runs is None:
        runs = api.runs(wandb_project)
    # get runs that match config_filter
    filtered_runs = []
    for run in runs:
        if config_pattern_match(flatten_dict(config_filter), flatten_dict(run.config)):
            filtered_runs.append(run)

    return filtered_runs

def config_pattern_match(pattern, config):
    # if discrete_intermediate is False, ignore discretize_map from pattern
    if not pattern['model_config.intermediate_discretization.discrete_intermediate']:
        del pattern['model_config.intermediate_discretization.discretize_map']

    for key, value in pattern.items():
        if not isinstance(value, dict):
            if key not in config or value != config[key]:
                return False
        else:
            if not config_pattern_match(value, config[key]):
                return False

    return True

def print_run_config_diffs(runs):
    for run in runs:
        print(f'Run: {run.id}')
        find_differences_in_dict(runs[0].config, run.config)
        print('-'*80)

def get_run_plot(pattern):
    print('Searching for runs with the following config pattern:')
    pp.pprint(pattern)

    filtered_runs = filter_runs(pattern)
    print(f'Found {len(filtered_runs)} matching runs; will plot last')
    if len(filtered_runs) == 0:
        print('No matching runs found')
        return

    if len(filtered_runs) > 1:
        print('Differences in config:')
        print_run_config_diffs(filtered_runs)

    run = filtered_runs[-1]

    table = table = api.artifact(f'{wandb_project}/run-{run.id}-testood_eval:latest').get('test/ood_eval')
    data = pd.DataFrame(table.data, columns=table.columns)

    figs = create_plots(data, AttributeDict(run.config))
    return figs

In [6]:
def create_plots(test_df, config):
    figs = {}

    # color scale for lenths
    color_scale = px.colors.sample_colorscale('Viridis', len(test_df.L.unique()))

    # n_iters vs sequence_acc, color = L
    fig = px.line(test_df, x='n_iters', y='sequence_acc', color='L', title='Sequence Accuracy', labels={'sequence_acc': 'Sequence Accuracy', 'n_iters': 'n_iters'}, color_discrete_sequence=color_scale)
    fig.add_vline(x=config.train_config.train_max_n_iters, line_dash='dash', line_color='black', annotation_text='train_max_n_iters', annotation_position='top right')
    figs['sequence_acc'] = fig

    # n_iters vs per_token_acc, color = L
    fig = px.line(test_df, x='n_iters', y='per_token_acc', color='L', title='Token-wise Accuracy', labels={'per_token_acc': 'Token-wise Accuracy', 'n_iters': 'n_iters'}, color_discrete_sequence=color_scale)
    fig.add_vline(x=config.train_config.train_max_n_iters, line_dash='dash', line_color='black', annotation_text='train_max_n_iters', annotation_position='top right')
    figs['per_token_acc'] = fig

    # heatmap of n_iters vs L vs sequence_acc
    heatmap = test_df.pivot(index='n_iters', columns='L')['sequence_acc']
    fig = px.imshow(heatmap, x=heatmap.columns, y=heatmap.index, title='Sequence Accuracy', zmin=0, zmax=1, origin='lower', color_continuous_scale='Hot')
    fig.add_hline(y=config.train_config.train_max_n_iters, line_dash='dash', line_color='black', annotation_text='train_max_n_iters', annotation_position='bottom left')
    fig.add_vline(x=config.data_config.train_sequence_length, line_dash='dash', line_color='black', annotation_text='train_max_seq_len', annotation_position='bottom right')
    figs['sequence_acc_heatmap'] = fig

    # heatmap of n_iters vs L vs per_token_acc
    heatmap = test_df.pivot(index='n_iters', columns='L')['per_token_acc']
    fig = px.imshow(heatmap, x=heatmap.columns, y=heatmap.index, title='Token-wise Accuracy', zmin=0, zmax=1, origin='lower', color_continuous_scale='Hot')
    fig.add_hline(y=config.train_config.train_max_n_iters, line_dash='dash', line_color='black', annotation_text='train_max_n_iters', annotation_position='bottom left')
    fig.add_vline(x=config.data_config.train_sequence_length, line_dash='dash', line_color='black', annotation_text='train_max_seq_len', annotation_position='bottom right')
    figs['per_token_acc_heatmap'] = fig

    # max_{n_iters} sequence_acc vs L
    max_acc = test_df.groupby('L')['sequence_acc'].max().reset_index()
    fig = px.line(max_acc, x='L', y='sequence_acc', title='Sequence Accuracy vs Length (Max over `n_iters`)', labels={'sequence_acc': 'Max Sequence Accuracy', 'L': 'L'})
    fig.update_traces(fill='tozeroy') # fill area under curve
    fig.add_vline(x=config.data_config.train_sequence_length, line_dash='dash', line_color='black', annotation_text='train_max_seq_len', annotation_position='bottom right')
    avg = max_acc['sequence_acc'].mean()
    fig.add_hline(y=avg, line_dash='dash', line_color='black', annotation_text=f'avg={avg:.2f}', annotation_position='top right')
    figs['max_sequence_acc'] = fig

    # max_{n_iters} per_token_acc vs L
    max_acc = test_df.groupby('L')['per_token_acc'].max().reset_index()
    fig = px.line(max_acc, x='L', y='per_token_acc', title='Token-wise Accuracy vs Length (Max over `n_iters`)', labels={'per_token_acc': 'Max Token-wise Accuracy', 'L': 'L'})
    fig.update_traces(fill='tozeroy') # fill area under curve
    fig.add_vline(x=config.data_config.train_sequence_length, line_dash='dash', line_color='black', annotation_text='train_max_seq_len', annotation_position='bottom right')
    avg = max_acc['per_token_acc'].mean()
    fig.add_hline(y=avg, line_dash='dash', line_color='black', annotation_text=f'avg={avg:.2f}', annotation_position='top right')
    figs['max_per_token_acc'] = fig

    # heatmap of n_iters vs L vs emb_norms
    heatmap = test_df.pivot(index='n_iters', columns='L')['emb_norms']
    fig = px.imshow(heatmap, x=heatmap.columns, y=heatmap.index, title='Embedding Norms', origin='lower', color_continuous_scale='Hot')
    fig.add_hline(y=config.train_config.train_max_n_iters, line_dash='dash', line_color='black', annotation_text='train_max_n_iters', annotation_position='bottom left')
    fig.add_vline(x=config.data_config.train_sequence_length, line_dash='dash', line_color='black', annotation_text='train_max_seq_len', annotation_position='bottom right')
    figs['emb_norms_heatmap'] = fig

    # heatmap of n_iters vs L vs delta_norms
    heatmap = test_df.pivot(index='n_iters', columns='L')['delta_norms']
    fig = px.imshow(heatmap, x=heatmap.columns, y=heatmap.index, title='Delta Norms', origin='lower', color_continuous_scale='Hot')
    fig.add_hline(y=config.train_config.train_max_n_iters, line_dash='dash', line_color='black', annotation_text='train_max_n_iters', annotation_position='bottom left')
    fig.add_vline(x=config.data_config.train_sequence_length, line_dash='dash', line_color='black', annotation_text='train_max_seq_len', annotation_position='bottom right')
    figs['delta_norms_heatmap'] = fig

    # heatmap of n_iters vs L vs logits_states_softmax_entropy
    if 'logits_states_softmax_entropy' in test_df.columns:
        heatmap = test_df.pivot(index='n_iters', columns='L')['logits_states_softmax_entropy']
        fig = px.imshow(heatmap, x=heatmap.columns, y=heatmap.index, title='Logits States Softmax Entropy', origin='lower', color_continuous_scale='Hot')
        fig.add_hline(y=config.train_config.train_max_n_iters, line_dash='dash', line_color='black', annotation_text='train_max_n_iters', annotation_position='bottom left')
        fig.add_vline(x=config.data_config.train_sequence_length, line_dash='dash', line_color='black', annotation_text='train_max_seq_len', annotation_position='bottom right')
        figs['logits_states_softmax_entropy_heatmap'] = fig

    return figs

In [7]:
# interactive plot
import ipywidgets as widgets
from IPython.display import display
import plotly.express as px
import pandas as pd


# Create interactive widgets for each field in the pattern
config_widgets = []

for k in unique_config_key_vals.keys():
    config_widget = widgets.ToggleButtons(
        options=unique_config_key_vals[k],
        description=k,
        disabled=False,
    )
    config_widgets.append(config_widget)

plot_type = widgets.ToggleButtons(
    options=['sequence_acc', 'per_token_acc', 'sequence_acc_heatmap', 'per_token_acc_heatmap', 'max_sequence_acc', 'max_per_token_acc', 'emb_norms_heatmap', 'delta_norms_heatmap', 'logits_states_softmax_entropy_heatmap'],
    description='Plot Type',
    disabled=False,
)

output = widgets.Output()

def on_button_click(b):
    with output:
        output.clear_output()

        pattern = {config_widget.description: config_widget.value for config_widget in config_widgets}

        figs = get_run_plot(pattern)
        if figs is not None:
            figs[plot_type.value].show()

button = widgets.Button(
    description='Generate Plot',
    disabled=False,
    button_style='',
    tooltip='Click to generate plot',
    icon='check'
)

button.on_click(on_button_click)

# Display widgets
display(*config_widgets, plot_type, button, output)

ToggleButtons(description='model_config.postdisc_norm', options=(False, True), value=False)

ToggleButtons(description='model_config.intermediate_discretization.discrete_intermediate', options=(False, Tr…

ToggleButtons(description='model_config.predisc_norm', options=(False, True), value=False)

ToggleButtons(description='model_config.norm_config.norm_method', options=('none', 'hypersphere-interpolation'…

ToggleButtons(description='model_config.intermediate_discretization.discretize_map', options=('softmax', 'hard…

ToggleButtons(description='Plot Type', options=('sequence_acc', 'per_token_acc', 'sequence_acc_heatmap', 'per_…

Button(description='Generate Plot', icon='check', style=ButtonStyle(), tooltip='Click to generate plot')

Output()

[34m[1mwandb[0m:   1 of 1 files downloaded.  


[34m[1mwandb[0m:   1 of 1 files downloaded.  


[34m[1mwandb[0m:   1 of 1 files downloaded.  


[34m[1mwandb[0m:   1 of 1 files downloaded.  


[34m[1mwandb[0m:   1 of 1 files downloaded.  
