# Generating figures for paper

In [31]:
import pandas as pd
import numpy as np
import scipy
import plotly.express as px
import matplotlib.pyplot as plt

save_dir = 'plotly_figs'
theme = 'plotly_white'

In [32]:
our_model_name = 'Dual-Attn Transformer'# 'Orthrus' # 'AbstractTransformer' # Orthrus

In [33]:
cmap = plt.cm.tab20

In [34]:
def convert_to_plotly_color(color):
    r, g, b, a = color
    r, g, b, a = int(r*255), int(g*255), int(b*255), a
    # return f'rgba({r}, {g}, {b}, {a})'
    return f'rgb({r}, {g}, {b})'

## Relational Games

In [35]:
relgames_data = pd.read_csv('figure_data/relgames/relgames_data.csv')

def process_groupname(group_name):
    task, model_name = group_name.split('__')
    return model_name


L, total_n_heads = 2, 2
filter_ = (relgames_data['n_layers'] == L) & ((relgames_data['n_heads_rca'] + relgames_data['n_heads_sa'] == total_n_heads) | (relgames_data['n_heads'] == total_n_heads)) & (relgames_data['train_size'] <= 25_000)
figure_data = relgames_data[filter_]

figure_data['Model'] = figure_data['group'].apply(process_groupname)

figure_data.rename(columns={'train_size': 'Training Set Size', 'test/acc_in_distribution': 'Generalization Accuracy', 'task': 'Task'}, inplace=True)

In [36]:
def parse_rel_symmetry(group_name):
    if 'sym_rel=True' in group_name:
        return True
    elif 'sym_rel=False' in group_name:
        return False
    else:
        return None

figure_data['Symmetric RA'] = figure_data['group'].apply(parse_rel_symmetry)

In [37]:
models = [
    'sa=2; d=128; L=2',
    'sa=1; rca=1; d=128; L=2; rca_type=disentangled_v2; sym_rel=True; symbol_type=pos_sym_retriever',
    'sa=0; rca=2; d=128; L=2; rca_type=disentangled_v2; sym_rel=True; symbol_type=pos_sym_retriever',
    'sa=1; rca=1; d=128; L=2; rca_type=disentangled_v2; sym_rel=False; symbol_type=pos_sym_retriever',
    'sa=0; rca=2; d=128; L=2; rca_type=disentangled_v2; sym_rel=False; symbol_type=pos_sym_retriever',
    'sa=1; rca=1; d=128; L=2; rca_type=standard; sym_rel=False; symbol_type=pos_sym_retriever',
    'sa=0; rca=2; d=128; L=2; rca_type=standard; sym_rel=False; symbol_type=pos_sym_retriever',
    ]
model_name_map = {
    'sa=2; d=128; L=2': '$\\text{Transformer}\ (n_h^{sa}=2, n_h^{ra}=0)$',
    'sa=1; rca=1; d=128; L=2; rca_type=disentangled_v2; sym_rel=True; symbol_type=pos_sym_retriever': f'$\\text{{{our_model_name}}}\\ (n_h^{{sa}}=1, n_h^{{ra}}=1)$',
    'sa=0; rca=2; d=128; L=2; rca_type=disentangled_v2; sym_rel=True; symbol_type=pos_sym_retriever': f'$\\text{{{our_model_name}}}\\ (n_h^{{sa}}=0, n_h^{{ra}}=2)$',
    'sa=1; rca=1; d=128; L=2; rca_type=disentangled_v2; sym_rel=False; symbol_type=pos_sym_retriever': f'{our_model_name} [asymmetric] ($n_h^{{sa}}=1, n_h^{{ra}}=1)$',
    'sa=0; rca=2; d=128; L=2; rca_type=disentangled_v2; sym_rel=False; symbol_type=pos_sym_retriever': f'{our_model_name} [asymmetric] ($n_h^{{sa}}=0, n_h^{{ra}}=2)$',
    'sa=1; rca=1; d=128; L=2; rca_type=standard; sym_rel=False; symbol_type=pos_sym_retriever': "Abstractor's RCA ($n_h^{{sa}}=1, n_h^{{ra}}=1)$",
    'sa=0; rca=2; d=128; L=2; rca_type=standard; sym_rel=False; symbol_type=pos_sym_retriever': "Abstractor's RCA ($n_h^{{sa}}=0, n_h^{{ra}}=2)$",
    }
models = [model_name_map[m] for m in models]

tasks = ['same', 'occurs', 'xoccurs', '1task_between', '1task_match_patt']
# task_name_map = {
#     '1task_between': r'$\texttt{between}$', '1task_match_patt': r'$\texttt{match pattern}$',
#     'same': r'$\texttt{same}$', 'occurs': r'$\texttt{occurs}$', 'xoccurs': r'$\texttt{xoccurs}$'}
task_name_map = {
    '1task_between': 'between', '1task_match_patt': 'match pattern',
    'same': 'same', 'occurs': 'occurs', 'xoccurs': 'xoccurs'}
tasks = [task_name_map[t] for t in tasks]

figure_data['Model'] = pd.Categorical(figure_data['Model'].map(model_name_map), models, ordered=True)
figure_data['Task'] = pd.Categorical(figure_data['Task'].map(task_name_map), tasks, ordered=True)


color_map_ = {
    f'$\\text{{{our_model_name}}}\\ (n_h^{{sa}}=1, n_h^{{ra}}=1)$': cmap(8), # purple
    f'$\\text{{{our_model_name}}}\\ (n_h^{{sa}}=0, n_h^{{ra}}=2)$': cmap(0), # blue
    f'{our_model_name} [asymmetric] ($n_h^{{sa}}=1, n_h^{{ra}}=1)$': cmap(8), # purple
    f'{our_model_name} [asymmetric] ($n_h^{{sa}}=0, n_h^{{ra}}=2)$': cmap(0), # blue
    '$\\text{Transformer}\ (n_h^{sa}=2, n_h^{ra}=0)$': cmap(6), # red
    "Abstractor's RCA ($n_h^{{sa}}=1, n_h^{{ra}}=1)$": cmap(8),
    "Abstractor's RCA ($n_h^{{sa}}=0, n_h^{{ra}}=2)$": cmap(0),
    # 'Transformer+': cmap(4),
    }

color_map = {k: convert_to_plotly_color(color_map_[k]) for k in color_map_}

In [38]:
metric = 'Generalization Accuracy'
figure_data_ = figure_data[figure_data['Symmetric RA']!=False]
figure_data_ = figure_data_.groupby(['Model', 'Task', 'Training Set Size'])[metric].aggregate(['mean', 'std', 'count', 'sem']).reset_index()

In [39]:
import plotly.graph_objects as go

# Assuming figure_data_ is your DataFrame with columns 'Training Set Size', 'mean', 'sem', 'Model', 'Task'

# Get unique tasks and models
tasks = figure_data_['Task'].unique()
models = figure_data_['Model'].unique()

# Create frames for each task
frames = []
for task in tasks:
    frame_data = []
    task_filter = figure_data_['Task'] == task
    task_data = figure_data_[task_filter].dropna()
    yrange = [task_data['mean'].min(), task_data['mean'].max()]
    xrange = [task_data['Training Set Size'].min(), task_data['Training Set Size'].max()]
    eps_y = 0.025 * (yrange[1] - yrange[0])
    eps_x = 0.025 * (xrange[1] - xrange[0])
    yrange = [yrange[0] - eps_y, yrange[1] + eps_y]
    xrange = [xrange[0] - eps_x, xrange[1] + eps_x]

    for model in models:
        model_filter = task_data['Model'] == model
        model_data = task_data[model_filter]
        frame_data.append(go.Scatter(
            x=model_data['Training Set Size'], 
            y=model_data['mean'], 
            error_y=dict(type='data', array=model_data['sem'], visible=True),
            mode='lines',
            name=model,
            line=dict(color=color_map[model]))) # change color based on model
    frames.append(go.Frame(data=frame_data, 
                           name=str(task),
                           layout=dict(xaxis=dict(range=xrange),
                                       yaxis=dict(range=yrange))
                                       ))

# Create steps for the slider
steps = [dict(method='animate',
              args=[[frame['name']]],  # frame name to be shown
              label=frame['name']) for frame in frames]

# Create base frame
fig = go.Figure(
    data=frames[0]['data'],
    layout=go.Layout(
        title='Relational Games Learning Curves',
        xaxis=dict(title='Training Set Size'),
        yaxis=dict(title='Generalization Accuracy'),
        height=600,
        width=1000,
        sliders=[dict(steps=steps)],  # add the slider
        legend_title="Model",
        legend=dict(yanchor="bottom", y=0.01, xanchor="right", x=0.99,
                    title_font_family="Times New Roman", #bgcolor='LightSteelBlue',
                    bordercolor="Black", borderwidth=1),
        template='plotly_white'
#         legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="center", x=0.5)
    ),
    frames=frames
)

fig.show()

In [40]:
import plotly.subplots as sp
import plotly.graph_objects as go

# Assuming figure_data is your DataFrame with columns 'Epoch', 'Training Accuracy', 'Validation Accuracy', 'Model'
# and color_map is a dictionary mapping 'Model' values to colors

# Get unique models
models = figure_data['Model'].unique()

# Create subplots with horizontal spacing
fig = sp.make_subplots(rows=1, cols=2, subplot_titles=("Training Accuracy", "Validation Accuracy"), horizontal_spacing=0.1)  # adjust horizontal_spacing as needed

# Add line plots to the subplots
for model in models:
    model_data = figure_data[figure_data['Model'] == model]
    model_data_tr = model_data[['Epoch', 'Training Accuracy']].dropna()
    model_data_val = model_data[['Epoch', 'Validation Accuracy']].dropna()
    fig.add_trace(go.Scatter(
        x=model_data_tr['Epoch'], y=model_data_tr['Training Accuracy'],
        mode='lines', line=dict(color=color_map[model]), name=model, showlegend=False
        ), 
        row=1, col=1)
    fig.add_trace(go.Scatter(
        x=model_data_val['Epoch'], y=model_data_val['Validation Accuracy'],
        mode='lines', line=dict(color=color_map[model]), 
        name=model), 
        row=1, col=2)

# Update layout
fig.update_layout(
    height=600, width=1000, title_text="Image Classification Training Curves",
    updatemenus=[
        dict(
            buttons=list([
                dict(
                    args=["visible", [True, True]],
                    label="Both",
                    method="update"
                ),
                dict(
                    args=["visible", [True, False]],
                    label="Training Accuracy",
                    method="update"
                ),
                dict(
                    args=["visible", [False, True]],
                    label="Validation Accuracy",
                    method="update"
                )
            ]),
            direction="down",
            pad={"r": 10, "t": 10},
            showactive=True,
            x=0.1,
            xanchor="left",
            y=1.1,
            yanchor="top"
        ),
    ]
)

fig.update_xaxes(title_text="Epoch", row=1, col=1)
fig.update_yaxes(title_text="Training Accuracy", row=1, col=1)
fig.update_xaxes(title_text="Epoch", row=1, col=2)
fig.update_yaxes(title_text="Validation Accuracy", row=1, col=2)

# Show the figure
fig.show()

# Save the figure
# fig.write_image(f'{save_dir}/imagenet/imagenet_acc_curves.pdf')

KeyError: "None of [Index(['Epoch', 'Training Accuracy'], dtype='object')] are in the [columns]"

In [104]:
fig_json = fig.to_json()
# save fig_json to file
with open(f'{save_dir}/relgames_learning_curves.json', 'w') as f:
    f.write(fig_json)

## Math

In [245]:
figure_data = pd.read_csv('figure_data/math/run_history_all.csv')

In [246]:
# bar plot figure
color_map_ = {
    f'{our_model_name} (config 1)': cmap(8), # purple
    f'{our_model_name} (config 2)': cmap(0), # blue
    # 'Transformer': cmap(7), # lighter red
    'Transformer': cmap(6), # red
    # 'Transformer+': cmap(6), # red # NOTE: removed Transformer+ for now
    # 'Transformer+': cmap(4),
    }

color_map = {k: convert_to_plotly_color(color_map_[k]) for k in color_map_}

models = {
    'e_sa=8; e_rca=0; d_sa=8; d_rca=0; d_cross=8; d=128; rca_type=NA, symbol_type=NA; el=2; dl=2': 'Transformer',
    # 'ee=8; ea=0; de=8; da=0; dc=8; el=2; dl=2': 'Transformer',
    # 'e_sa=8; e_rca=0; d_sa=8; d_rca=0; d_cross=8; d=144; rca_type=NA, symbol_type=NA; el=2; dl=2': 'Transformer+', # NOTE: removed Transformer+ for now
    'e_sa=4; e_rca=4; d_sa=8; d_rca=0; d_cross=8; d=128; rca_type=disentangled_v2, symbol_type=pos_relative; el=2; dl=2': f'{our_model_name} (config 1)',
    'e_sa=4; e_rca=4; d_sa=4; d_rca=4; d_cross=8; d=128; rca_type=disentangled_v2, symbol_type=pos_relative; el=2; dl=2': f'{our_model_name} (config 2)',
    # 'e_sa=4; e_rca=4; d_sa=8; d_rca=0; d_cross=8; rca_dis=True, el=2; dl=2': 'AbstractTransformer (v1; OGRCA)',
    # 'e_sa=4; e_rca=4; d_sa=4; d_rca=4; d_cross=8; rca_dis=True, el=2; dl=2': 'AbstractTransformer (v2; OGRCA)',
    }

In [247]:
figure_data.rename(columns={
    'epoch': 'Epoch', 'interpolate_teacher_forcing_acc': 'Accuracy (Interpolation)', 'extrapolate_teacher_forcing_acc': 'Accuracy (Extrapolation)', 
    'train_teacher_forcing_acc': 'Accuracy (Training)', 'task': 'Task', 'group': 'Model'}, inplace=True)

def format_task(task):
    string = task.replace('__', r'\_\_')
    string = r"$\texttt{" + string + r"}$"
    return string
# figure_data['Task'] = pd.Categorical(figure_data['Task'].map(format_task))
figure_data = figure_data[figure_data['Model'].isin(models.keys())]
figure_data['Model'] = pd.Categorical(figure_data['Model'].map(models), models.values(), ordered=True)

In [248]:
# TODO: need to run more trials for some, etc.
figure_data.groupby(['Model', 'Task'])['Accuracy (Training)'].aggregate('count')

Model                             Task                       
Transformer                       algebra__linear_1d             400
                                  algebra__sequence_next_term    400
                                  calculus__differentiate        400
                                  polynomials__add               400
                                  polynomials__expand            400
Dual-Attn Transformer (config 1)  algebra__linear_1d             964
                                  algebra__sequence_next_term    952
                                  calculus__differentiate        940
                                  polynomials__add               976
                                  polynomials__expand            988
Dual-Attn Transformer (config 2)  algebra__linear_1d             955
                                  algebra__sequence_next_term    940
                                  calculus__differentiate        970
                                  polynom

In [249]:
metric = 'Accuracy (Interpolation)'
figure_data_ = figure_data.groupby(['Model', 'Task', 'Epoch'])[metric].aggregate(['mean', 'std', 'count', 'sem']).reset_index()

In [251]:
import plotly.graph_objects as go

# Get unique tasks and models
tasks = figure_data_['Task'].unique()
models = figure_data_['Model'].unique()

# Create frames for each task
frames = []
for task in tasks:
    frame_data = []
    task_filter = figure_data_['Task'] == task
    task_data = figure_data_[task_filter].dropna()
    yrange = [task_data['mean'].min(), task_data['mean'].max()]
    # xrange = [task_data['Epoch'].min(), task_data['Training Set Size'].max()]
    eps_y = 0.025 * (yrange[1] - yrange[0])
    # eps_x = 0.025 * (xrange[1] - xrange[0])
    yrange = [yrange[0] - eps_y, yrange[1] + eps_y]
    # xrange = [xrange[0] - eps_x, xrange[1] + eps_x]

    for model in models:
        model_filter = task_data['Model'] == model
        model_data = task_data[model_filter]
        frame_data.append(go.Scatter(
            x=model_data['Epoch'], 
            y=model_data['mean'],
            error_y=dict(type='data', array=model_data['sem'], visible=True),
            mode='lines',
            name=model,
            line=dict(color=color_map[model]))) # change color based on model
    frames.append(go.Frame(data=frame_data, 
                           name=str(task),
                           layout=dict(yaxis=dict(range=yrange))
                                       ))

# Create steps for the slider
steps = [dict(method='animate',
              args=[[frame['name']]],  # frame name to be shown
              label=frame['name']) for frame in frames]

# Create base frame
fig = go.Figure(
    data=frames[0]['data'],
    layout=go.Layout(
        title='Mathematical Problem Solving Training Curves',
        xaxis=dict(title='Epoch'),
        yaxis=dict(title='Accuracy'),
        height=600,
        width=1000,
        sliders=[dict(steps=steps)],  # add the slider
        font_family="Computer Modern",
        template=theme,
        legend_title="Model",
        legend=dict(yanchor="bottom", y=0.01, xanchor="right", x=0.99,
                    title_font_family="Times New Roman", #bgcolor='LightSteelBlue',
                    bordercolor="Black", borderwidth=1),
    ),
    frames=frames
)

fig.show()

In [136]:
fig_json = fig.to_json()
# save fig_json to file
with open(f'{save_dir}/math_training_curves.json', 'w') as f:
    f.write(fig_json)

## Language Modeling: Tiny Stories

In [41]:
figure_data = pd.read_csv('figure_data/tiny_stories/run_histories.csv')
figure_data.head()

Unnamed: 0,val/perplexity,mfu,tokens,Generated Samples,lr,val/loss,train/loss,_timestamp,_step,train/perplexity,...,n_layers,wandb_log,weight_decay,vocab_source,out_dir,pos_enc_type,beta1,sym_attn_n_symbols,group,name
0,39238.144531,-100.0,0.0,,0.001,10.577385,10.578004,1716068000.0,0,39261.359375,...,6,True,0.1,llama2,out/sa=4; rca=4; d=128; L=6; rca_type=disentan...,RoPE,0.9,512.0,,sa=4; rca=4; d=128; L=6; rca_type=disentangled...
1,5.612392,3.820743,262144000.0,,0.001,1.724343,1.72705,1716069000.0,2000,5.626907,...,6,True,0.1,llama2,out/sa=4; rca=4; d=128; L=6; rca_type=disentan...,RoPE,0.9,512.0,,sa=4; rca=4; d=128; L=6; rca_type=disentangled...
2,4.960096,3.820697,524288000.0,,0.001,1.600834,1.604139,1716071000.0,4000,4.975849,...,6,True,0.1,llama2,out/sa=4; rca=4; d=128; L=6; rca_type=disentan...,RoPE,0.9,512.0,,sa=4; rca=4; d=128; L=6; rca_type=disentangled...
3,4.70429,3.820555,786432000.0,,0.001,1.547904,1.55054,1716072000.0,6000,4.716177,...,6,True,0.1,llama2,out/sa=4; rca=4; d=128; L=6; rca_type=disentan...,RoPE,0.9,512.0,,sa=4; rca=4; d=128; L=6; rca_type=disentangled...
4,4.540239,3.820479,1048576000.0,,0.001,1.512442,1.515324,1716074000.0,8000,4.552852,...,6,True,0.1,llama2,out/sa=4; rca=4; d=128; L=6; rca_type=disentan...,RoPE,0.9,512.0,,sa=4; rca=4; d=128; L=6; rca_type=disentangled...


In [42]:
# bar plot figure
color_map_ = {
    '$\\text{Transformer}\ (n_h^{sa}=8, n_h^{ra}=0)$': cmap(6), # red
    f'$\\text{{{our_model_name}}}\ (n_h^{{sa}}=6, n_h^{{ra}}=2)$': cmap(8), # purple
    f'$\\text{{{our_model_name}}}\ (n_h^{{sa}}=4, n_h^{{ra}}=4)$': cmap(0), # blue
    # 'Transformer+': cmap(4),
    }

color_map = {k: convert_to_plotly_color(color_map_[k]) for k in color_map_}
models = color_map_.keys()

In [43]:
def get_model_name(row):
    if row.rca == 0:
        return f'$\\text{{Transformer}}\ (n_h^{{sa}}={row.sa}, n_h^{{ra}}={row.rca})$'
    else:
        # return f"AbstractTransformer [{row['Symbol Type']}, symm={row['Symmetric RA']}] ($n_h^{{sa}} = {row.sa}, n_h^{{ra}}={row.rca}$)"
        return f"$\\text{{{our_model_name}}}\ (n_h^{{sa}}={row.sa}, n_h^{{ra}}={row.rca})$"

In [44]:
figure_data = figure_data.rename(columns={
    'symbol_type': 'Symbol Type', 'symmetric_rels': 'Symmetric RA', 
    'val/loss': 'Validation Loss', 'val/perplexity': 'Validation Perplexity', 'tokens': 'Tokens'})

figure_data.loc[figure_data['rca']==0, 'Symbol Type'] = 'NA'
figure_data.loc[figure_data['rca']==0, 'Symmetric RA'] = 'NA'

figure_data['Model'] = pd.Categorical(figure_data.apply(get_model_name, axis=1), models, ordered=True)
sym_map = {'sym_attn': 'Symbolic Attention', 'pos_relative': 'Position-Relative Symbols'}
figure_data['Symbol Type'] = pd.Categorical(figure_data['Symbol Type'].map(sym_map), sym_map.values(), ordered=True)

In [45]:
def filter_data(figure_data, d_models=None, layers=None, filter_first_step=False, filter_transformer=False, symbol_types=None, symmetry=None, rca_types=None):
    filter_ = ~figure_data.index.isna()
    if d_models is not None:
        filter_ = filter_ & (figure_data['d_model'].isin(d_models))
    if layers is not None:
        filter_ = filter_ & (figure_data['n_layers'].isin(layers))
    if filter_transformer:
        filter_ = filter_ & (figure_data['rca'] > 0)
    if filter_first_step:
        filter_ = filter_ & (figure_data['_step'] > 0)
    if symbol_types is not None:
        symbol_types = [sym_map[s] for s in symbol_types]
        filter_ = filter_ & ((figure_data['rca'] == 0) | figure_data['Symbol Type'].isin(symbol_types))
    if symmetry is not None:
        filter_ = filter_ & ((figure_data['rca'] == 0) | figure_data['Symmetric RA'].isin(symmetry))
    if rca_types is not None:
        filter_ = filter_ & ((figure_data['rca'] == 0) | (figure_data['rca_type'].isin(rca_types)))
    filtered_data = figure_data.copy()[filter_]

    if filter_transformer:
        filtered_data['Model'] = filtered_data['Model'].cat.remove_unused_categories()

    return filtered_data

### All Plots & Ablations

In [46]:
import plotly.graph_objects as go

# Get unique tasks and models
layers = [4, 5, 6]
d = 64
metric = 'Validation Loss'
# metric = 'Validation Perplexity'
# models = figure_data_['Model'].unique()

fig_data = filter_data(figure_data, d_models=[d], layers=layers, filter_first_step=True,
    symbol_types=('sym_attn',), symmetry=(False,), rca_types=('disentangled_v2',))
yrange_global = [fig_data[metric].min(), fig_data[metric].max()]

# Create frames for each task
frames = []
for l in layers:
    frame_data = []
    ax_data = filter_data(figure_data, d_models=[d], layers=[l], filter_first_step=True,
        symbol_types=('sym_attn',), symmetry=(False,), rca_types=('disentangled_v2',))
    ax_data = ax_data.groupby(['Model', 'Tokens'])[metric].aggregate(['mean', 'std', 'count', 'sem']).reset_index()
    models = ax_data['Model'].unique()

    yrange = [ax_data['mean'].min(), ax_data['mean'].max()]

    yrange = yrange_global
    # xrange = [task_data['Epoch'].min(), task_data['Training Set Size'].max()]
    eps_y = 0.025 * (yrange[1] - yrange[0])
    # eps_x = 0.025 * (xrange[1] - xrange[0])
    yrange = [yrange[0] - eps_y, yrange[1] + eps_y]
    # xrange = [xrange[0] - eps_x, xrange[1] + eps_x]

    for model in models:
        model_filter = ax_data['Model'] == model
        model_data = ax_data[model_filter]
        frame_data.append(go.Scatter(
            x=model_data['Tokens'], 
            y=model_data['mean'],
            mode='lines',
            name=model,
            line=dict(color=color_map[model]))) # change color based on model
    frames.append(go.Frame(data=frame_data, 
                           name=f'L = {l}',
                           layout=dict(yaxis=dict(range=yrange))
                                       ))

# Create steps for the slider
steps = [dict(method='animate',
              args=[[frame['name']]],  # frame name to be shown
              label=frame['name']) for frame in frames]

# Create base frame
fig = go.Figure(
    data=frames[0]['data'],
    layout=go.Layout(
        title='Language Modeling Training Curves',
        xaxis=dict(title='Tokens'),
        yaxis=dict(title=metric),
        height=600,
        width=1000,
        sliders=[dict(steps=steps)],  # add the slider
                font_family="Computer Modern",
        template=theme,
        legend_title="Model",
        legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99,
                    title_font_family="Times New Roman", #bgcolor='LightSteelBlue',
                    bordercolor="Black", borderwidth=1),
    ),
    frames=frames
)

fig.show()

In [47]:
fig_json = fig.to_json()
# save fig_json to file
with open(f'{save_dir}/language_modeling_training_curves.json', 'w') as f:
    f.write(fig_json)

## Vision

In [227]:
figure_data = pd.read_csv('figure_data/imagenet/run_histories.csv')
figure_data.dropna(subset=['train/acc_epoch', 'val/loss'], inplace=True, how='all') # drop step rows and keep epoch rows
figure_data.head()

Unnamed: 0,train/loss_epoch,train/acc_step,trainer/global_step,val/top4_acc,_step,val/top3_acc,val/acc,val/loss,val/top2_acc,_runtime,...,val/top7_acc,val/top8_acc,train/acc_epoch,d_model,n_layers,symbol_retrieval,rca_type,symmetric_rels,group,name
6,,,312,0.126362,6,0.105609,0.047035,5.619035,0.079808,7318.16537,...,0.177324,0.191186,,1024,24,pos_relative,disentangled_v2,,,sa=10; rca=6; d=1024; L=24; rca_type=disentang...
7,6.408384,,312,,7,,,,,7322.367173,...,,,0.021343,1024,24,pos_relative,disentangled_v2,,,sa=10; rca=6; d=1024; L=24; rca_type=disentang...
14,,,625,0.208734,14,0.177083,0.092708,5.096831,0.141026,14620.223963,...,0.277804,0.297035,,1024,24,pos_relative,disentangled_v2,,,sa=10; rca=6; d=1024; L=24; rca_type=disentang...
15,5.503854,,625,,15,,,,,14624.3225,...,,,0.062984,1024,24,pos_relative,disentangled_v2,,,sa=10; rca=6; d=1024; L=24; rca_type=disentang...
22,,,938,0.290545,22,0.257212,0.141186,4.607257,0.211538,21924.534104,...,0.369631,0.390625,,1024,24,pos_relative,disentangled_v2,,,sa=10; rca=6; d=1024; L=24; rca_type=disentang...


In [228]:
model_name_map = {
    'sa=16; d=1024; L=24__2024_05_15_16_38_09': '$\\text{Transformer}\\ (n_h^{sa}=16, n_h^{ra}=0)$',
    'sa=10; rca=6; d=1024; L=24; rca_type=disentangled_v2; sym_rel=True; symbol_type=pos_relative__2024_05_15_18_13_54': f'$\\text{{{our_model_name}}}\\ (n_h^{{sa}}=10, n_h^{{ra}}=6)$'
    }

figure_data['Model'] = pd.Categorical(figure_data['name'].map(model_name_map), model_name_map.values(), ordered=True)

color_map_ = {
    '$\\text{Transformer}\\ (n_h^{sa}=16, n_h^{ra}=0)$': cmap(6), # red
    f'$\\text{{{our_model_name}}}\\ (n_h^{{sa}}=10, n_h^{{ra}}=6)$': cmap(8), # purple
    f'$\\text{{{our_model_name}}}\\ (n_h^{{sa}}=10, n_h^{{ra}}=6)$': cmap(0), # blue
    }

color_map = {k: convert_to_plotly_color(color_map_[k]) for k in color_map_}

In [229]:
df_ = figure_data[['epoch', 'train/acc_epoch', 'Model']].dropna().pivot(index='epoch', columns='Model', values='train/acc_epoch')
trainacc_diff = abs(df_.iloc[:,0] - df_.iloc[:,1])
print(f'train/acc mean difference: {trainacc_diff.mean():.2%}')
print(f'train/acc max difference: {trainacc_diff.max():.2%}')
print(f'train/acc end difference: {abs(max(df_.iloc[:,0]) - max(df_.iloc[:,1])):.2%}')
print()
df_ = figure_data[['epoch', 'val/acc', 'Model']].dropna().pivot(index='epoch', columns='Model', values='val/acc')
valacc_diff = abs(df_.iloc[:,0] - df_.iloc[:,1])
print(f'val/acc mean difference: {valacc_diff.mean():.2%}')
print(f'val/acc max difference: {valacc_diff.max():.2%}')
print(f'val/acc end difference: {abs(max(df_.iloc[:,0]) - max(df_.iloc[:,1])):.2%}')

train/acc mean difference: 5.01%
train/acc max difference: 10.08%
train/acc end difference: 2.89%

val/acc mean difference: 4.39%
val/acc max difference: 9.98%
val/acc end difference: 1.46%


In [230]:
figure_data.rename(columns={'train/acc_epoch': 'Training Accuracy', 'val/acc': 'Validation Accuracy',
    'val/loss': 'Validation Loss', 'train/loss_epoch': 'Training Loss', 'epoch': 'Epoch'}, inplace=True)

In [238]:
import plotly.subplots as sp
import plotly.graph_objects as go

# Assuming figure_data is your DataFrame with columns 'Epoch', 'Training Accuracy', 'Validation Accuracy', 'Model'
# and color_map_ is a dictionary mapping 'Model' values to colors

# Get unique models
models = figure_data['Model'].unique()

# Create subplots
fig = sp.make_subplots(rows=1, cols=2, subplot_titles=("Training Accuracy", "Validation Accuracy"), horizontal_spacing=0.1)

# Add line plots to the subplots
for model in models:
    model_data = figure_data[figure_data['Model'] == model]
    model_data_tr = model_data[['Epoch', 'Training Accuracy']].dropna()
    model_data_val = model_data[['Epoch', 'Validation Accuracy']].dropna()
    fig.add_trace(go.Scatter(
        x=model_data_tr['Epoch'], y=model_data_tr['Training Accuracy'],
        mode='lines', line=dict(color=color_map[model]), name=model, showlegend=False
        ), 
        row=1, col=1)
    fig.add_trace(go.Scatter(
        x=model_data_val['Epoch'], y=model_data_val['Validation Accuracy'],
        mode='lines', line=dict(color=color_map[model]), 
        name=model), 
        row=1, col=2)

# Update layout
fig.update_layout(
    height=600, width=1000, title_text="Image Classification Training Curves",
    template=theme,
    font_family="Computer Modern",
    legend_title="Model",
    legend=dict(yanchor="bottom", y=0.01, xanchor="right", x=0.99,
#                     title_font_family="Computer Modern", #bgcolor='LightSteelBlue',
                    bordercolor="Black", borderwidth=1),)

fig.update_xaxes(title_text="Epoch", row=1, col=1)
# fig.update_yaxes(title_text="Training Accuracy", row=1, col=1)
fig.update_xaxes(title_text="Epoch", row=1, col=2)
# fig.update_yaxes(title_text="Validation Accuracy", row=1, col=2)


# Show the figure
fig.show()

In [203]:
fig_json = fig.to_json()
# save fig_json to file
with open(f'{save_dir}/imagenet_training_curves.json', 'w') as f:
    f.write(fig_json)