# Generating figures for paper

In [1]:
import pandas as pd
import numpy as np
import scipy
import seaborn as sns
import plotly.express as px
import matplotlib.pyplot as plt
import plotly.io

save_dir = 'plotly_figs'
theme = 'plotly_white'

In [2]:
# our_model_name = 'Dual-Attn Transformer'# 'Orthrus' # 'AbstractTransformer' # Orthrus
our_model_name = 'DAT'
transformer_name = 'Transformer'

In [3]:
cmap = plt.cm.tab20

In [4]:
def convert_to_plotly_color(color):
    r, g, b, a = color
    r, g, b, a = int(r*255), int(g*255), int(b*255), a
    # return f'rgba({r}, {g}, {b}, {a})'
    return f'rgb({r}, {g}, {b})'

## Relational Games

In [5]:
our_model_name = r'$\textit{DAT}$' # 'Dual-Attn Transformer'# 'Orthrus' # 'AbstractTransformer' # Orthrus
transformer_name = 'Transformer'

In [6]:
# relgames_data = pd.read_csv('figure_data/relgames/relgames_data.csv')

relgames_data1 = pd.read_csv('figure_data/relgames/relgames_data.csv')
relgames_data2 = pd.read_csv('figure_data_tmp/relgames/relgames_data.csv')
relgames_data = pd.concat([relgames_data1, relgames_data2], ignore_index=True)

def process_groupname(group_name):
    task, model_name = group_name.split('__')
    return model_name


# L, total_n_heads = 2, 2
# filter_ = (relgames_data['n_layers'] == L) & ((relgames_data['n_heads_rca'] + relgames_data['n_heads_sa'] == total_n_heads) | (relgames_data['n_heads'] == total_n_heads)) & (relgames_data['train_size'] <= 25_000)
# figure_data = relgames_data[filter_]
filter_ = (relgames_data['task'] != '1task_match_patt') | (relgames_data['train_size'] <= 25_000) & (relgames_data['train_size'] > 2500)
figure_data = relgames_data[filter_]

figure_data['Model'] = figure_data['group'].apply(process_groupname)

figure_data.rename(columns={'train_size': 'Training Set Size', 'test/acc_in_distribution': 'Generalization Accuracy', 'task': 'Task'}, inplace=True)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  figure_data['Model'] = figure_data['group'].apply(process_groupname)
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  figure_data.rename(columns={'train_size': 'Training Set Size', 'test/acc_in_distribution': 'Generalization Accuracy', 'task': 'Task'}, inplace=True)


In [7]:
figure_data.Model.value_counts()

sa=2; d=128; L=2                                                                                             746
sa=1; rca=1; d=128; L=2; rca_type=disentangled_v2; sym_rel=True; symbol_type=pos_sym_retriever               491
sa=0; rca=2; d=128; L=2; rca_type=disentangled_v2; sym_rel=True; symbol_type=pos_sym_retriever               484
sa=0; ra=8; nr=16; d=128; L=2; ra_type=relational_attention; sym_rel=True; symbol_type=positional_symbols    463
sa=8; d=128; L=2                                                                                             296
sa=4; d=128; L=2                                                                                             275
sa=4; d=144; L=2                                                                                             247
sa=8; d=144; L=2                                                                                             245
sa=1; rca=1; d=128; L=2; rca_type=standard; sym_rel=False; symbol_type=pos_sym_retriever        

In [8]:
def parse_rel_symmetry(group_name):
    if 'sym_rel=True' in group_name:
        return True
    elif 'sym_rel=False' in group_name:
        return False
    else:
        return None

figure_data.loc[:, 'Symmetric RA'] = figure_data['group'].apply(parse_rel_symmetry)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  figure_data.loc[:, 'Symmetric RA'] = figure_data['group'].apply(parse_rel_symmetry)


In [9]:
def get_name(row):

    if row.baseline_model in ['abstractor', 'predinet', 'corelnet']:
        name_map = {'abstractor': 'Abstractor', 'predinet': 'PrediNet', 'corelnet': 'CoRelNet'} # FIXME (swap to CoRelNetSoftmax)
        name = name_map[row.baseline_model]
    elif not np.isnan(row.n_heads):
        name = f'{transformer_name} ($n_h^{{sa}}={int(row.n_heads)}, n_h^{{ra}}=0$)'
    else:
        if row.rca_type == 'standard':
            name = "Abstractor's RCA"
        elif row.rca_type == 'disentangled_v2':
            name = our_model_name

        if not row['Symmetric RA']:
            name += ' [Asymmetric]'

        name += f' ($n_h^{{sa}}={int(row.n_heads_sa)}, n_h^{{ra}}={int(row.n_heads_rca)}$)'

    name += f' [{int(row.num_params//1000)}K]'
    return name


In [10]:
models = [
    'sa=2; d=128; L=2',
    # 'sa=2; d=128; L=2',
    'sa=2; d=144; L=2',
    # 'sa=4; d=128; L=2',
    # 'sa=4; d=144; L=2',
    'sa=8; d=128; L=2',
    'sa=8; d=144; L=2',
    'sa=1; rca=1; d=128; L=2; rca_type=standard; sym_rel=False; symbol_type=pos_sym_retriever',
    'sa=0; rca=2; d=128; L=2; rca_type=standard; sym_rel=False; symbol_type=pos_sym_retriever',
    'sa=0; rca=2; d=128; L=2; rca_type=disentangled_v2; sym_rel=False; symbol_type=pos_sym_retriever',
    'sa=1; rca=1; d=128; L=2; rca_type=disentangled_v2; sym_rel=False; symbol_type=pos_sym_retriever',
    'sa=1; rca=1; d=128; L=2; rca_type=disentangled_v2; sym_rel=True; symbol_type=pos_sym_retriever',
    'sa=0; rca=2; d=128; L=2; rca_type=disentangled_v2; sym_rel=True; symbol_type=pos_sym_retriever',
    'sa=0; rca=4; d=128; L=2; rca_type=disentangled_v2; sym_rel=True; symbol_type=pos_sym_retriever',
    'sa=2; rca=2; d=128; L=2; rca_type=disentangled_v2; sym_rel=True; symbol_type=pos_sym_retriever',
    'sa=4; rca=4; d=128; L=2; rca_type=disentangled_v2; sym_rel=True; symbol_type=pos_sym_retriever',
    'sa=4; rca=4; d=128; L=1; rca_type=disentangled_v2; sym_rel=True; symbol_type=pos_sym_retriever',
    'sa=0; rca=8; d=128; L=1; rca_type=disentangled_v2; sym_rel=True; symbol_type=pos_sym_retriever',
    'sa=0; rca=8; d=128; L=2; rca_type=disentangled_v2; sym_rel=True; symbol_type=pos_sym_retriever',
    # 'sa=1; ra=1; d=128; L=2; ra_type=rel_attn; sym_rel=True; symbol_type=positional_symbols',
    # 'sa=0; ra=2; d=128; L=2; ra_type=rel_attn; sym_rel=True; symbol_type=positional_symbols',
    # 'sa=1; ra=1; d=128; L=2; ra_type=rel_attn; sym_rel=False; symbol_type=positional_symbols',
    # 'sa=0; ra=2; d=128; L=2; ra_type=rel_attn; sym_rel=False; symbol_type=positional_symbols',
    # 'sa=1; ra=1; d=128; L=2; ra_type=rca; sym_rel=False; symbol_type=positional_symbols',
    # 'sa=0; ra=2; d=128; L=2; ra_type=rca; sym_rel=False; symbol_type=positional_symbols',
    'corelnet',
    'predinet',
    'abstractor',
    ]
# model_name_map = {
#     'sa=2; d=128; L=2': 'Transformer ($n_h^{sa}=2, n_h^{ra}=0$)',
#     'sa=1; rca=1; d=128; L=2; rca_type=disentangled_v2; sym_rel=True; symbol_type=pos_sym_retriever': f'{our_model_name} ($n_h^{{sa}}=1, n_h^{{ra}}=1$)',
#     'sa=0; rca=2; d=128; L=2; rca_type=disentangled_v2; sym_rel=True; symbol_type=pos_sym_retriever': f'{our_model_name} ($n_h^{{sa}}=0, n_h^{{ra}}=2$)',
#     'sa=1; rca=1; d=128; L=2; rca_type=disentangled_v2; sym_rel=False; symbol_type=pos_sym_retriever': f'{our_model_name} [asymmetric] ($n_h^{{sa}}=1, n_h^{{ra}}=1$)',
#     'sa=0; rca=2; d=128; L=2; rca_type=disentangled_v2; sym_rel=False; symbol_type=pos_sym_retriever': f'{our_model_name} [asymmetric] ($n_h^{{sa}}=0, n_h^{{ra}}=2$)',
#     'sa=1; rca=1; d=128; L=2; rca_type=standard; sym_rel=False; symbol_type=pos_sym_retriever': "Abstractor's RCA ($n_h^{{sa}}=1, n_h^{{ra}}=1$)",
#     'sa=0; rca=2; d=128; L=2; rca_type=standard; sym_rel=False; symbol_type=pos_sym_retriever': "Abstractor's RCA ($n_h^{{sa}}=0, n_h^{{ra}}=2$)",
#     }
# models = [model_name_map[m] for m in models]

tasks = ['same', 'occurs', 'xoccurs', '1task_between', '1task_match_patt']
task_name_map = {
    '1task_between': r'$\texttt{between}$', '1task_match_patt': r'$\texttt{match pattern}$',
    'same': r'$\texttt{same}$', 'occurs': r'$\texttt{occurs}$', 'xoccurs': r'$\texttt{xoccurs}$'}
tasks = [task_name_map[t] for t in tasks]

figure_data = figure_data[figure_data.model_name.isin(models)]

# figure_data['Model'] = pd.Categorical(figure_data['Model'].map(model_name_map), models, ordered=True)

figure_data['Task'] = pd.Categorical(figure_data['Task'].map(task_name_map), tasks, ordered=True)

In [11]:

# bar plot figure
# color_map_ = {
#     f'{our_model_name} ($n_h^{{sa}}=1, n_h^{{ra}}=1$)': cmap(8), # purple
#     f'{our_model_name} ($n_h^{{sa}}=0, n_h^{{ra}}=2$)': cmap(0), # blue
#     f'{our_model_name} [asymmetric] ($n_h^{{sa}}=1, n_h^{{ra}}=1$)': cmap(8), # purple
#     f'{our_model_name} [asymmetric] ($n_h^{{sa}}=0, n_h^{{ra}}=2$)': cmap(0), # blue
#     'Transformer ($n_h^{sa}=2, n_h^{ra}=0$)': cmap(6), # red
#     "Abstractor's RCA ($n_h^{{sa}}=1, n_h^{{ra}}=1$)": cmap(8),
#     "Abstractor's RCA ($n_h^{{sa}}=0, n_h^{{ra}}=2$)": cmap(0),
#     # 'Transformer+': cmap(4),
#     }

cmap = sns.color_palette('tab20c', as_cmap=True)
color_map_ = {
    f'{transformer_name} ($n_h^{{sa}}=8, n_h^{{ra}}=0$) [481K]': cmap(4),
    # f'{transformer_name} ($n_h^{{sa}}=4, n_h^{{ra}}=0$) [481K]': cmap(),
    f'{transformer_name} ($n_h^{{sa}}=2, n_h^{{ra}}=0$) [481K]': cmap(5),
    f'{transformer_name} ($n_h^{{sa}}=8, n_h^{{ra}}=0$) [386K]': cmap(6),
    f'{transformer_name} ($n_h^{{sa}}=2, n_h^{{ra}}=0$) [386K]': cmap(7),
    # f'{transformer_name} ($n_h^{{sa}}=4, n_h^{{ra}}=0$) [386K]': cmap(),

    "Abstractor's RCA [Asymmetric] ($n_h^{sa}=1, n_h^{ra}=1$) [371K]": cmap(12),
    "Abstractor's RCA [Asymmetric] ($n_h^{sa}=0, n_h^{ra}=2$) [388K]": cmap(0),

    f'{our_model_name} [Asymmetric] ($n_h^{{sa}}=0, n_h^{{ra}}=2$) [454K]': cmap(0),
    f'{our_model_name} [Asymmetric] ($n_h^{{sa}}=1, n_h^{{ra}}=1$) [437K]': cmap(12),
    f'{our_model_name} ($n_h^{{sa}}=1, n_h^{{ra}}=1$) [404K]': cmap(12),
    f'{our_model_name} ($n_h^{{sa}}=0, n_h^{{ra}}=2$) [421K]': cmap(0),
    # 'Dual-Attn Transformer [Asymmetric] ($n_h^{sa}=0, n_h^{ra}=2$) [454K]': cmap(0),
    # 'Dual-Attn Transformer [Asymmetric] ($n_h^{sa}=1, n_h^{ra}=1$) [437K]': cmap(12),
    # 'Dual-Attn Transformer ($n_h^{sa}=1, n_h^{ra}=1$) [404K]': cmap(12),
    # 'Dual-Attn Transformer ($n_h^{sa}=0, n_h^{ra}=2$) [421K]': cmap(0),
    # 'Dual-Attn Transformer [Asymmetric] ($n_h^{sa}=0, n_h^{ra}=4$) [421K]': cmap(),
    # 'Dual-Attn Transformer [Asymmetric] ($n_h^{sa}=2, n_h^{ra}=2$) [404K]': cmap(),
    # 'Dual-Attn Transformer [Asymmetric] ($n_h^{sa}=4, n_h^{ra}=4$) [405K]': cmap(),
    # 'Dual-Attn Transformer [Asymmetric] ($n_h^{sa}=4, n_h^{ra}=4$) [232K]': cmap(),
    # 'Dual-Attn Transformer [Asymmetric] ($n_h^{sa}=0, n_h^{ra}=8$) [241K]': cmap(),
    # 'Dual-Attn Transformer [Asymmetric] ($n_h^{sa}=0, n_h^{ra}=8$) [423K]': cmap()
    'Abstractor [469K]': cmap(8),
    'CoRelNet [215K]': cmap(16),
    'PrediNet [376K]': cmap(12),
}

color_map = {k: convert_to_plotly_color(v) for k, v in color_map_.items()}

model_order = [
    f'{our_model_name} ($n_h^{{sa}}=0, n_h^{{ra}}=2$) [421K]',
    f'{our_model_name} ($n_h^{{sa}}=1, n_h^{{ra}}=1$) [404K]',
    f'{our_model_name} [Asymmetric] ($n_h^{{sa}}=0, n_h^{{ra}}=2$) [454K]',
    f'{our_model_name} [Asymmetric] ($n_h^{{sa}}=1, n_h^{{ra}}=1$) [437K]',
    "Abstractor's RCA [Asymmetric] ($n_h^{sa}=1, n_h^{ra}=1$) [371K]",
    "Abstractor's RCA [Asymmetric] ($n_h^{sa}=0, n_h^{ra}=2$) [388K]",
    f'{transformer_name} ($n_h^{{sa}}=8, n_h^{{ra}}=0$) [481K]',
    f'{transformer_name} ($n_h^{{sa}}=2, n_h^{{ra}}=0$) [481K]',
    f'{transformer_name} ($n_h^{{sa}}=8, n_h^{{ra}}=0$) [386K]',
    f'{transformer_name} ($n_h^{{sa}}=2, n_h^{{ra}}=0$) [386K]',
    'PrediNet [376K]',
    'CoRelNet [215K]',
    'Abstractor [469K]',
    ]

figure_data['Model'] = pd.Categorical(figure_data.apply(get_name, axis=1), model_order, ordered=True)

In [12]:
list(figure_data.Model.unique())

['Transformer ($n_h^{sa}=8, n_h^{ra}=0$) [481K]',
 'Transformer ($n_h^{sa}=2, n_h^{ra}=0$) [481K]',
 'Transformer ($n_h^{sa}=8, n_h^{ra}=0$) [386K]',
 'Transformer ($n_h^{sa}=2, n_h^{ra}=0$) [386K]',
 "Abstractor's RCA [Asymmetric] ($n_h^{sa}=1, n_h^{ra}=1$) [371K]",
 "Abstractor's RCA [Asymmetric] ($n_h^{sa}=0, n_h^{ra}=2$) [388K]",
 '$\\textit{DAT}$ [Asymmetric] ($n_h^{sa}=0, n_h^{ra}=2$) [454K]',
 '$\\textit{DAT}$ [Asymmetric] ($n_h^{sa}=1, n_h^{ra}=1$) [437K]',
 '$\\textit{DAT}$ ($n_h^{sa}=1, n_h^{ra}=1$) [404K]',
 '$\\textit{DAT}$ ($n_h^{sa}=0, n_h^{ra}=2$) [421K]',
 nan,
 'Abstractor [469K]',
 'PrediNet [376K]',
 'CoRelNet [215K]']

In [13]:
sns.color_palette('muted')

In [14]:
metric = 'Generalization Accuracy'
model_filter = [f'{our_model_name} ($n_h^{{sa}}=0, n_h^{{ra}}=2$) [421K]', f'{our_model_name} ($n_h^{{sa}}=1, n_h^{{ra}}=1$) [404K]', f'{transformer_name} ($n_h^{{sa}}=2, n_h^{{ra}}=0$) [481K]']

cmap_ = sns.color_palette('muted')
cmap = lambda x: (*cmap_[x], 1)

color_map = {
    f'{our_model_name} ($n_h^{{sa}}=0, n_h^{{ra}}=2$) [421K]': convert_to_plotly_color(cmap(0)),
    f'{our_model_name} ($n_h^{{sa}}=1, n_h^{{ra}}=1$) [404K]': convert_to_plotly_color(cmap(4)),
    f'{transformer_name} ($n_h^{{sa}}=2, n_h^{{ra}}=0$) [481K]': convert_to_plotly_color(cmap(3)),
}

# need everything in one latex expression
name_map = {
    f'{our_model_name} ($n_h^{{sa}}=0, n_h^{{ra}}=2$) [421K]': r'$\textit{DAT}\ (n_h^{sa}=0, n_h^{ra}=2)$',
    f'{our_model_name} ($n_h^{{sa}}=1, n_h^{{ra}}=1$) [404K]': r'$\textit{DAT}\ (n_h^{sa}=1, n_h^{ra}=1)$',
    f'{transformer_name} ($n_h^{{sa}}=2, n_h^{{ra}}=0$) [481K]': r'$\mathrm{Transformer}\ (n_h^{sa}=2, n_h^{ra}=0)$',
}
figure_data_ = figure_data[figure_data['Model'].isin(model_filter)]
figure_data_['Model'] = pd.Categorical(figure_data_['Model'], model_filter, ordered=True)
figure_data_['Model'] = figure_data_['Model'].map(name_map)
# figure_data_['Model'] = figure_data_['Model'].apply(lambda x: x.split(' [')[0])
color_map = {name_map[k]: v for k, v in color_map.items()}
# color_map = {k.split(' [')[0]: v for k, v in color_map.items()}

# figure_data_ = figure_data[figure_data['Symmetric RA']!=False]
figure_data_ = figure_data_.groupby(['Model', 'Task', 'Training Set Size'])[metric].aggregate(['mean', 'std', 'count', 'sem']).reset_index()


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  figure_data_['Model'] = pd.Categorical(figure_data_['Model'], model_filter, ordered=True)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  figure_data_['Model'] = figure_data_['Model'].map(name_map)


In [15]:
fig = px.line(figure_data_,
    x='Training Set Size', y='mean', color='Model', error_y='sem', facet_col='Task',
    facet_col_wrap=3, color_discrete_map=color_map, facet_col_spacing=0.1, facet_row_spacing=0.3,
    title='Relational Games: Data Efficiency in Relational Learning',
    template=theme)
    # width=1200, height=600, 

fig.update_xaxes(matches=None, showticklabels=True, title_text='Training Set Size')
fig.update_yaxes(title_text='Accuracy', matches=None, showticklabels=True)

for axis in fig.layout:
    if axis.startswith('yaxis') or axis.startswith('xaxis'):
        fig.layout[axis].showticklabels = True

fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[1]))

fig.update_xaxes(range=[100, 3000], tickvals=[500, 1e3, 1.5e3, 2e3, 2.5e3])
fig.update_xaxes(range=[2000, 30000], tickvals=[5e3, 10e3, 15e3, 20e3, 25e3], row=1, col=2)

fig.update_layout(
    legend=dict(
        x=0.85,  # Horizontal position (0 to 1)
        y=0.35,  # Vertical position (0 to 1)
        xanchor='center',  # Anchor point for x position
        yanchor='top',  # Anchor point for y position
        orientation='v',  # Horizontal orientation
        title=dict(
            text='Model',  # Title for the legend
            side='top'  # Position the title at the top and center it
        )
    )
)

fig.show()

In [16]:
# save to html
fig_html = plotly.offline.plot(fig, include_plotlyjs=False, output_type='div')
with open(f'{save_dir}/relgames_learning_curves.html', 'w') as f:
    f.write(fig_html)

fig_json = fig.to_json()
# save fig_json to file
with open(f'{save_dir}/relgames_learning_curves.json', 'w') as f:
    f.write(fig_json)

## Math

In [17]:
our_model_name = r'$\textit{DAT}$'
transformer_name = r'$\mathrm{Transformer}$'

figure_data = pd.read_csv('figure_data/math/run_history_all.csv')


Columns (16,24) have mixed types. Specify dtype option on import or set low_memory=False.



In [18]:
figure_data['task'] = figure_data.task.map(lambda x: {'algebra__linear_1': 'algebra__linear_1d'}.get(x, x))

In [19]:
import re
def parse_model_name(group_name):
    # format: enc_sa=8; enc_ra=0; dec_sa=8; dec_ra=0; dec_cross=8; d=128; ra_type=NA, symbol_type=NA; el=2; dl=2
    # extract enc_sa, enc_ra, d, el, etc.
    regex_exprs = [r'enc_sa=(\d+)', r'enc_ra=(\d+)', r'dec_sa=(\d+)', r'dec_ra=(\d+)', r'dec_cross=(\d+)', r'd=(\d+)', r'ra_type=(\w+)', r'symbol_type=(\w+)', r'el=(\d+)', r'dl=(\d+)']
    inferred_config_vals = []

    for expr in regex_exprs:
        match = re.search(expr, group_name)
        if match:
            inferred_config_vals.append(match.group(1))
        else:
            inferred_config_vals.append(None)

    return tuple(inferred_config_vals)

figure_data['enc_sa'], figure_data['enc_ra'], figure_data['dec_sa'], figure_data['dec_ra'], figure_data['dec_cross'], figure_data['d'], figure_data['ra_type'], figure_data['symbol_type'], figure_data['el'], figure_data['dl'] = zip(*figure_data['group'].apply(parse_model_name))

In [20]:
models = {
    'enc_sa=8; enc_ra=0; dec_sa=8; dec_ra=0; dec_cross=8; d=128; ra_type=NA, symbol_type=NA; el=2; dl=2': f'{transformer_name} [692K]',
    # 'ee=8; ea=0; de=8; da=0; dc=8; el=2; dl=2': 'Transformer',
    'enc_sa=8; enc_ra=0; dec_sa=8; dec_ra=0; dec_cross=8; d=144; ra_type=NA, symbol_type=NA; el=2; dl=2': f'{transformer_name} [871K]',
    'enc_sa=8; enc_ra=0; dec_sa=8; dec_ra=0; d_cross=8; d=144; el=3; dl=3': f'{transformer_name} [1.3M]',
    'enc_sa=8; enc_ra=0; dec_sa=8; dec_ra=0; d_cross=8; d=144; el=4; dl=4': f'{transformer_name} [1.7M]',
    'enc_sa=4; enc_ra=4; dec_sa=8; dec_ra=0; dec_cross=8; d=128; ra_type=rel_attn, symbol_type=position_relative; el=2; dl=2': f'{our_model_name} [783K]',
    'enc_sa=4; enc_ra=4; dec_sa=4; dec_ra=4; dec_cross=8; d=128; ra_type=rel_attn, symbol_type=position_relative; el=2; dl=2': f'{our_model_name} [832K]',
    'enc_sa=4; enc_ra=4; dec_sa=4; dec_ra=4; d_cross=8; d=128; el=4; dl=4': f'{our_model_name} [1.46M]',
    'enc_sa=4; enc_ra=4; dec_sa=8; dec_ra=0; d_cross=8; d=128; el=4; dl=4': f'{our_model_name} [1.43M]',
    'enc_sa=4; enc_ra=4; dec_sa=4; dec_ra=4; d_cross=8; d=128; el=3; dl=3': f'{our_model_name} [1.11M]',
    'enc_sa=4; enc_ra=4; dec_sa=8; dec_ra=0; d_cross=8; d=128; el=3; dl=3': f'{our_model_name} [1.09M]',
    # 'Abstractor - L=1, d=128, h=8': 'Abstractor [816K]',
    # 'Abstractor - L=2, d=128, h=8': 'Abstractor [1.54M]',
    # 'e_sa=4; e_rca=4; d_sa=8; d_rca=0; d_cross=8; rca_dis=True, el=2; dl=2': 'AbstractTransformer (v1; OGRCA)',
    # 'e_sa=4; e_rca=4; d_sa=4; d_rca=4; d_cross=8; rca_dis=True, el=2; dl=2': 'AbstractTransformer (v2; OGRCA)',
    }

In [21]:
figure_data.rename(columns={
    'epoch': 'Epoch', 'interpolate_teacher_forcing_acc': 'Accuracy',
    'train_teacher_forcing_acc': 'Accuracy (Training)', 'task': 'Task', 'group': 'Model'}, inplace=True)

def format_task(task):
    string = task.replace('_', r'\_')
    string = r"$\mathtt{" + string + r"}$"
    return string
figure_data['Task'] = pd.Categorical(figure_data['Task'].map(format_task))
figure_data = figure_data[figure_data['Model'].isin(models.keys())]
figure_data['Model'] = pd.Categorical(figure_data['Model'].map(models), models.values(), ordered=True)

In [22]:
cmap = sns.color_palette('tab20c', as_cmap=True)
color_map_ = {
    f'{transformer_name} [692K]': cmap(5), # lighter red
    f'{our_model_name} [783K]': cmap(2),
    # 'Abstractor [816K]': cmap(9),
    f'{transformer_name} [871K]': cmap(6), # red
    f'{our_model_name} [1.09M]': cmap(1),
    f'{transformer_name} [1.3M]': cmap(5),
    f'{our_model_name} [1.43M]': cmap(0),
    # 'Abstractor [1.54M]': cmap(8),
    f'{transformer_name} [1.7M]': cmap(4),
    # f'{our_model_name} [832K]': cmap(2),
    # f'{our_model_name} [1.11M]': cmap(1),
    # f'{our_model_name} [1.46M]': cmap(0),
    # 'Transformer': cmap(6), # red
    # 'Transformer+': cmap(4),
    }
color_map = {k: convert_to_plotly_color(color_map_[k]) for k in color_map_}

model_selection = list(color_map_.keys())

model_filter = figure_data['Model'].isin(model_selection)
# figure_data = figure_data[model_filter]

figure_data.Model.value_counts()

$\textit{DAT}$ [783K]            5775
$\textit{DAT}$ [832K]            5730
$\mathrm{Transformer}$ [1.3M]    4560
$\mathrm{Transformer}$ [1.7M]    4560
$\mathrm{Transformer}$ [871K]    3002
$\textit{DAT}$ [1.11M]           2960
$\textit{DAT}$ [1.09M]           2959
$\textit{DAT}$ [1.43M]           2873
$\textit{DAT}$ [1.46M]           2739
$\mathrm{Transformer}$ [692K]    2400
Name: Model, dtype: int64

In [23]:
figure_data[r'$d_{\text{model}}$'] = figure_data['d']
figure_data[r'Encoder $n_h^{sa}$'] = figure_data['enc_sa']
figure_data[r'Encoder $n_h^{ra}$'] = figure_data['enc_ra']
figure_data[r'Decoder $n_h^{sa}$'] = figure_data['dec_sa']
figure_data[r'Decoder $n_h^{ra}$'] = figure_data['dec_ra']
figure_data['# Layers'] = figure_data['n_layers_dec']
config_cols = ['# Layers', r'$d_{\text{model}}$', 'Encoder $n_h^{sa}$', 'Encoder $n_h^{ra}$', 'Decoder $n_h^{sa}$', 'Decoder $n_h^{ra}$']

depth_comparison_table = figure_data[figure_data.Epoch == figure_data.Epoch.max()][['Task', 'Model'] + config_cols + ['Accuracy']].dropna()

model_remove = [f'{our_model_name} [783K]', f'{our_model_name} [1.43M]', f'{our_model_name} [1.09M]']
depth_comparison_table = depth_comparison_table[~depth_comparison_table['Model'].isin(model_remove)]

# depth_comparison_table['Model'] = pd.Categorical(depth_comparison_table['Model'], model_selection, ordered=True)

depth_comparison_table['Parameter Count'] = depth_comparison_table.Model.apply(lambda x: x.split('[')[1][:-1].strip())
def parse_param_count(param_ct_str):
    number = float(param_ct_str[:-1])
    unit = param_ct_str[-1]
    unit_map = dict(K=1e3, M=1e6)
    return int(number * unit_map[unit])
depth_comparison_table['Parameter Count'] = depth_comparison_table['Parameter Count'].apply(parse_param_count).astype(int)


depth_comparison_table['Model'] = depth_comparison_table['Model'].apply(lambda x: x.split('[')[0].strip())

# depth_comparison_table.groupby(['Task', 'Model']).aggregate(['mean', 'sem'])
depth_comparison_table = depth_comparison_table.groupby(['Task', 'Model', 'Parameter Count'] + config_cols).aggregate(['mean', 'sem', 'count'])
depth_comparison_table[[('Accuracy', 'mean'), ('Accuracy', 'sem')]] = depth_comparison_table[[('Accuracy', 'mean'), ('Accuracy', 'sem')]] * 100 # convert to percentage
depth_comparison_table = depth_comparison_table.dropna()
depth_comparison_table = depth_comparison_table.reset_index()
depth_comparison_table.columns = [''.join(col).strip() for col in depth_comparison_table.columns]

In [24]:
# figure_data['# Layers'] = figure_data['n_layers_dec']
depth_comparison_table = figure_data[figure_data.Epoch == figure_data.Epoch.max()][['Task', 'Model', 'Accuracy']].dropna()
# depth_comparison_table['Model'] = pd.Categorical(depth_comparison_table['Model'], model_selection, ordered=True)

# keepe only DAT models with E=4/4, D=4/4 not E=4/4, D=8/0
# model_remove = [f'{our_model_name} [783K]', f'{our_model_name} [1.43M]', f'{our_model_name} [1.09M]']
# depth_comparison_table = depth_comparison_table[~depth_comparison_table['Model'].isin(model_remove)]

depth_comparison_table['Parameter Count'] = depth_comparison_table.Model.apply(lambda x: x.split('[')[1][:-1].strip())
def parse_param_count(param_ct_str):
    number = float(param_ct_str[:-1])
    unit = param_ct_str[-1]
    unit_map = dict(K=1e3, M=1e6)
    return int(number * unit_map[unit])
depth_comparison_table['Parameter Count'] = depth_comparison_table['Parameter Count'].apply(parse_param_count).astype(int)

model_remove = [f'{our_model_name} [783K]', f'{our_model_name} [1.43M]', f'{our_model_name} [1.09M]']
depth_comparison_table = depth_comparison_table[~depth_comparison_table['Model'].isin(model_remove)]

depth_comparison_table['Model'] = depth_comparison_table['Model'].apply(lambda x: x.split('[')[0].strip())
print(depth_comparison_table.Model.unique())
agg_depth_comparison_table = depth_comparison_table.groupby(['Task', 'Model', 'Parameter Count'])['Accuracy'].aggregate(
    ['mean', 'sem', 'count']).dropna().reset_index()
agg_depth_comparison_table.sort_values(by=['Task', 'Model', 'Parameter Count'], inplace=True)
agg_depth_comparison_table

['$\\textit{DAT}$' '$\\mathrm{Transformer}$']


Unnamed: 0,Task,Model,Parameter Count,mean,sem,count
0,$\mathtt{algebra\_\_linear\_1d}$,$\mathrm{Transformer}$,692000,0.624775,0.01141305,4
1,$\mathtt{algebra\_\_linear\_1d}$,$\mathrm{Transformer}$,871000,0.63993,0.01471048,5
2,$\mathtt{algebra\_\_linear\_1d}$,$\mathrm{Transformer}$,1300000,0.569788,0.02300507,6
3,$\mathtt{algebra\_\_linear\_1d}$,$\mathrm{Transformer}$,1700000,0.532451,0.0106581,6
4,$\mathtt{algebra\_\_linear\_1d}$,$\textit{DAT}$,832000,0.66331,0.01651395,7
5,$\mathtt{algebra\_\_linear\_1d}$,$\textit{DAT}$,1110000,0.762244,0.03537042,4
6,$\mathtt{algebra\_\_linear\_1d}$,$\textit{DAT}$,1460000,0.753228,0.05024529,4
7,$\mathtt{algebra\_\_sequence\_next\_term}$,$\mathrm{Transformer}$,692000,0.910546,0.002026369,4
8,$\mathtt{algebra\_\_sequence\_next\_term}$,$\mathrm{Transformer}$,871000,0.914477,0.00242464,5
9,$\mathtt{algebra\_\_sequence\_next\_term}$,$\mathrm{Transformer}$,1300000,0.96099,0.005069484,8


In [25]:
cmap_ = sns.color_palette('muted')
cmap = lambda x: (*cmap_[x], 1)

In [26]:
# agg_depth_comparison_table = depth_comparison_table[['Task', 'Model', 'Parameter Count', 'Accuracy']].groupby(
    # ['Task', 'Model', 'Parameter Count']).aggregate(['mean', 'sem', 'count']).dropna().reset_index()

color_map = {
    transformer_name: convert_to_plotly_color(cmap(3)),
    our_model_name: convert_to_plotly_color(cmap(0)),
}

agg_depth_comparison_table['Model'] = pd.Categorical(agg_depth_comparison_table['Model'], [transformer_name, our_model_name], ordered=True)

fig = px.line(agg_depth_comparison_table, x='Parameter Count', y='mean', color='Model',
    facet_col='Task', facet_col_wrap=3, facet_row_spacing=0.3, facet_col_spacing=0.1, error_y='sem', hover_data=['count'],
    title='Mathematical Problem Solving', color_discrete_map=color_map,
    template=theme)

fig.update_yaxes(matches=None)
for axis in fig.layout:
    if axis.startswith('yaxis') or axis.startswith('xaxis'):
        fig.layout[axis].showticklabels = True

fig.update_xaxes(tickvals=[750e3, 1e6, 1.25e6, 1.75e6], title_text='Parameter Count')
fig.update_yaxes(title_text='Accuracy')


fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[1]))

fig.update_layout(
    legend=dict(
        x=0.85,  # Horizontal position (0 to 1)
        y=0.25,  # Vertical position (0 to 1)
        xanchor='center',  # Anchor point for x position
        yanchor='top',  # Anchor point for y position
        orientation='v',  # Horizontal orientation
        # title_text='Model'  # Title for the legend
        title=dict(
            text='Model',  # Title for the legend
            side='top'  # Position the title at the top and center it
        )
    )
)

fig.show()

In [27]:
# save to html
fig_html = plotly.offline.plot(fig, include_plotlyjs=False, output_type='div')
with open(f'{save_dir}/math_accuracy_scaling.html', 'w') as f:
    f.write(fig_html)

# save fig
fig_json = fig.to_json()
# save fig_json to file
with open(f'{save_dir}/math_accuracy_scaling.json', 'w') as f:
    f.write(fig_json)

## Language Modeling: Fineweb

In [28]:
figure_data = pd.read_csv('figure_data/fineweb/run_histories.csv')
figure_data.head()


Columns (500,502,503) have mixed types. Specify dtype option on import or set low_memory=False.



Unnamed: 0,grad_norms/blocks.23.ff_block.linear1.weight,grad_norms/blocks.21.norm2.bias,grad_norms/blocks.20.norm1.weight,grad_norms/blocks.4.dual_attn.self_attention.wk.weight,grad_norms/blocks.14.dual_attn.self_attention.wv.weight,grad_norms/blocks.17.ff_block.linear1.bias,grad_norms/blocks.16.dual_attn.self_attention.wk.weight,grad_norms/blocks.16.dual_attn.relational_attention.wq_attn.weight,grad_norms/blocks.4.dual_attn.relational_attention.wk_attn.weight,grad_norms/blocks.17.dual_attn.self_attention.wv.weight,...,grad_norms/symbol_retrievers.8.q_proj.weight,grad_norms/symbol_retrievers.15.q_proj.bias,grad_norms/symbol_retrievers.23.q_proj.weight,grad_norms/symbol_retrievers.22.q_proj.weight,grad_norms/symbol_retrievers.4.template_features,grad_norms/symbol_retrievers.20.q_proj.weight,grad_norms/symbol_retrievers.18.template_features,grad_norms/symbol_retriever.q_proj.weight,grad_norms/symbol_retriever.q_proj.bias,grad_norms/symbol_retriever.template_features
0,0.016342,0.000575,0.000398,0.001477,0.012219,0.00032,0.002339,0.002593,0.003096,0.010029,...,,,,,,,,,,
1,,,,,,,,,,,...,,,,,,,,,,
2,,,,,,,,,,,...,,,,,,,,,,
3,,,,,,,,,,,...,,,,,,,,,,
4,,,,,,,,,,,...,,,,,,,,,,


In [29]:
# NOTE: resumption of DAT-sa16-ra16-nr128-ns2048-sh16-nkvh8-1.27B_2024_07_31_08_52_58 is ongoing TODO: add when completed
name_map = {
    # 1.3B Scale
    'T-sa32-1.3B_2024_07_11_19_05_56': 'Transformer - 1.31B',
    'T-sa32-1.3B_2024_07_11_19_05_56_resumed_2024_07_26_18_41_46': 'Transformer - 1.31B',
    # 'DAT-sa16-ra16-nr64-ns512-sh16-nkvh8-1.27B_2024_07_24_09_41_58': f'{our_model_name} - 1.27B',
    'DAT-sa16-ra16-nr128-ns2048-sh16-nkvh8-1.27B_2024_07_31_08_52_58': f'{our_model_name} - 1.27B',
    'DAT-sa16-ra16-nr128-ns2048-sh16-nkvh8-1.27B_2024_07_31_08_52_58_resumed_2024_08_19_15_26_57': f'{our_model_name} - 1.27B',

    # 'DAT-sa16-ra16-nr64-ns2048-sh8-nkvh8-1.27B_2024_07_28_00_48_29': f'{our_model_name} - 1.27B',
    # 'DAT-sa16-ra16-nr64-ns2048-sh8-1.37B_2024_07_22_18_31_43': f'{our_model_name} - 1.37B',

    # 750M scale
    'T-sa24-757M_2024_08_22_19_01_41': 'Transformer - 757M',
    'DAT-sa12-ra12-nr64-ns1024-sh8-nkvh6-734M_2024_08_21_07_48_32': f'{our_model_name} - 734M',
    'DAT-sa12-ra12-nr64-ns1024-sh8-nkvh6-734M_2024_08_21_07_48_32_resumed_2024_08_23_16_14_12': f'{our_model_name} - 734M',

    # 350M scale
    'T-350M_2024_07_09_17_25_58': 'Transformer - 353M', # TODO check exact param count
    # 'DAT-sa8-ra8-ns1024-sh8-nkvh4-343M_2024_07_19_13_50_14': f'{our_model_name} - 343M',
    # 'DAT-sa8-ra8-ns1024-sh8-nkvh4-343M_2024_07_19_13_50_14_resumed_2024_07_26_18_49_04': f'{our_model_name} - 343M',
    'DAT-sa8-ra8-nr64-ns1024-sh8-nkvh4-343M_2024_07_30_13_58_00': f'{our_model_name} - 343M',
    'DAT-sa8-ra8-nr64-ns1024-sh8-nkvh4-343M_2024_07_30_13_58_00_resumed_2024_08_14_19_34_08': f'{our_model_name} - 343M',
    # 'DAT-ra8sa8nr32-ns1024sh8-368M_2024_07_15_18_38_39': f'{our_model_name} - 368M', # TODO: decide precisely which models want here
}

In [30]:
cmap = plt.cm.tab20

color_map_ = {
    'Transformer - 1.31B': cmap(6), # red
    f'{our_model_name} - 1.27B': cmap(0), # blue
    # f'{our_model_name} - 1.37B': cmap(8), # purple

    'Transformer - 757M': cmap(6), # lighter red
    f'{our_model_name} - 734M': cmap(0),

    'Transformer - 353M': cmap(6), # red
    f'{our_model_name} - 343M': cmap(0), # blue
    # f'{our_model_name} - 368M': cmap(8), # purple
    }
color_map = {k: convert_to_plotly_color(v) for k, v in color_map_.items()}

scale_map = {
    'Transformer - 1.31B': '1.3B Scale',
    f'{our_model_name} - 1.27B': '1.3B Scale',

    'Transformer - 757M': '750M Scale',
    f'{our_model_name} - 734M': '750M Scale',

    'Transformer - 353M': '350M Scale',
    f'{our_model_name} - 343M': '350M Scale',
    }


# models = color_map_.keys()

In [31]:
figure_data = figure_data[figure_data.name.isin(name_map.keys())]
figure_data = figure_data[['name', 'loss/val', 'tokens']].dropna()

figure_data = figure_data.rename(columns={
    # 'symbol_type': 'Symbol Type', 'symmetric_rels': 'Symmetric RA', 
    'loss/val': 'Validation Loss', 'tokens': 'Tokens'})

figure_data['Model'] = pd.Categorical(figure_data['name'].map(name_map), list(set(name_map.values())), ordered=True)

figure_data['Scale'] = figure_data['Model'].map(scale_map)

figure_data['Perplexity'] = figure_data['Validation Loss'].apply(np.exp)

# figure_data = figure_data[['Model', 'Tokens', 'Perplexity']].dropna()

In [32]:
import plotly.graph_objects as go

# Get unique tasks and models
model_scales = ['350M Scale', '750M Scale', '1.3B Scale']

metric = 'Perplexity'

# Create frames for each task
frames = []
for scale in model_scales:
    frame_data = []
    scale_filter = figure_data['Scale'] == scale
    ax_data = figure_data[scale_filter]

    for model in ax_data.Model.unique():
        model_filter = ax_data['Model'] == model
        model_data = ax_data[model_filter]
        model_data.sort_values('Tokens', inplace=True)

        frame_data.append(go.Scatter(
            x=model_data['Tokens'],
            y=model_data['Perplexity'],
            mode='lines',
            name=model,
            line=dict(color=color_map[model])
            ))
    yrange = (12.5, 30)

    frames.append(
        go.Frame(
        data=frame_data,
        name=f'{scale}',
        layout=dict(yaxis=dict(range=yrange))
        ))

# Create steps for the slider
steps = [dict(method='animate',
              args=[[frame['name']]],  # frame name to be shown
              label=frame['name']) for frame in frames]

# Create base frame
fig = go.Figure(
    data=frames[0]['data'],
    layout=go.Layout(
        title='Language Modeling Training Curves',
        xaxis=dict(title='Tokens'), #, type='log', range=(9, 10)),
        yaxis=dict(title=metric, range=(12.5, 30)),
        height=600,
        width=1000,
        sliders=[dict(steps=steps)],  # add the slider
                font_family="Computer Modern",
        template=theme,
        legend_title="Model",
        legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99,
                    title_font_family="Times New Roman", #bgcolor='LightSteelBlue',
                    bordercolor="Black", borderwidth=1),
    ),
    frames=frames
)

toggle_scale_button = dict(
    type="buttons",
    direction="left",
    buttons=list([
        dict(
            args=[{"xaxis.type": "linear", 'xaxis.range': [0, 10e9]}],
            label="Linear Scale",
            method="relayout"
        ),
        dict(
            args=[{"xaxis.type": "log", 'xaxis.range': [9, 10]}],
            label="Logarithmic Scale",
            method="relayout"
        )
    ]),
    pad={"r": 10, "t": 10},
    showactive=True,
    x=1,
    xanchor="right",
    y=1.2,
    yanchor="top"
)

# Add the toggle button to the layout
fig.update_layout(
    updatemenus=[toggle_scale_button]
)

fig.show()



A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



In [33]:
import plotly.offline

fig_html = plotly.offline.plot(fig, include_plotlyjs=False, output_type='div')
with open(f'{save_dir}/language_modeling_fineweb_training_curves.html', 'w') as f:
    f.write(fig_html)

In [34]:
def get_first_tokens_below_perp(line, perp):
    return min([row.Tokens for _, row in line.iterrows() if row.Perplexity < perp], default=np.nan)

def parse_param_count(param_ct_str):
    number = float(param_ct_str[:-1])
    unit = param_ct_str[-1]
    unit_map = dict(K=1e3, M=1e6, B=1e9)
    return int(number * unit_map[unit])

df = pd.DataFrame()

perplexities = np.linspace(15, 30, 6)

for model in figure_data.Model.unique():
    line = figure_data[figure_data.Model == model].sort_values('Tokens')
    tokens = [get_first_tokens_below_perp(line, p) for p in perplexities]
    df_ = pd.DataFrame({'Perplexity': perplexities, 'Tokens': tokens, #'Model': [model]*len(perplexities), 
        'Param Count': [parse_param_count(model.split('-')[1].strip())] * len(perplexities), 'Model': [model.split('-')[0]]*len(perplexities)})
    df = pd.concat([df, df_])
df = df.reset_index(drop=True)
df['Param Count'] = df['Param Count'].astype(float)
df['Perplexity'] = df['Perplexity'].astype(float)
df.sort_values('Param Count', inplace=True)

In [35]:
import plotly.express as px
import plotly.graph_objects as go

fig = px.line(df, x='Param Count', y='Tokens', color='Perplexity', line_dash='Model',
    # width=1200, height=600, 
    template=theme,
    color_discrete_sequence=px.colors.sequential.haline_r[3:],
    title='Language Modeling Scaling Laws')

fig.update_yaxes(type='log')

fig.update_xaxes(range=[300e6, 1.35e9], tickvals=[350e6, 750e6, 1.3e9], ticktext=['350M', '750M', '1.3B'], title_text='Parameter Count')
fig.update_yaxes(tickvals=[1e9, 2e9, 5e9, 10e9], ticktext=['1B', '2B', '5B', '10B'])

def compress_legend(fig):
   group1_base, group2_base  = fig.data[0].name.split(",")
   lines_marker_name = []
   for i, trace in enumerate(fig.data):
       part1,part2 = trace.name.split(',')
       if part1 == group1_base:
           lines_marker_name.append({"line": trace.line.to_plotly_json(), "marker": trace.marker.to_plotly_json(), "mode": trace.mode, "name": part2.lstrip(" ")})
       if part2 != group2_base:
           trace['name'] = ''
           trace['showlegend']=False
       else:
           trace['name'] = part1

   for lmn in lines_marker_name:
       lmn["line"]["color"] = "black"
       lmn["marker"]["color"] = "black"
       fig.add_trace(go.Scatter(y=[None], **lmn))
   fig.update_layout(legend_title_text='Perplexity', 
                     legend_itemclick=False,
                     legend_itemdoubleclick= False)

compress_legend(fig)

fig.show()

In [36]:
# save to html
fig_html = plotly.offline.plot(fig, include_plotlyjs=False, output_type='div')
with open(f'{save_dir}/language_modeling_scaling_laws.html', 'w') as f:
    f.write(fig_html)

# save json
fig_json = fig.to_json()
with open(f'{save_dir}/language_modeling_scaling_laws.json', 'w') as f:
    f.write(fig_json)

## Language Modeling: Tiny Stories

In [41]:
figure_data = pd.read_csv('figure_data/tiny_stories/run_histories.csv')
figure_data.head()

Unnamed: 0,val/perplexity,mfu,tokens,Generated Samples,lr,val/loss,train/loss,_timestamp,_step,train/perplexity,...,n_layers,wandb_log,weight_decay,vocab_source,out_dir,pos_enc_type,beta1,sym_attn_n_symbols,group,name
0,39238.144531,-100.0,0.0,,0.001,10.577385,10.578004,1716068000.0,0,39261.359375,...,6,True,0.1,llama2,out/sa=4; rca=4; d=128; L=6; rca_type=disentan...,RoPE,0.9,512.0,,sa=4; rca=4; d=128; L=6; rca_type=disentangled...
1,5.612392,3.820743,262144000.0,,0.001,1.724343,1.72705,1716069000.0,2000,5.626907,...,6,True,0.1,llama2,out/sa=4; rca=4; d=128; L=6; rca_type=disentan...,RoPE,0.9,512.0,,sa=4; rca=4; d=128; L=6; rca_type=disentangled...
2,4.960096,3.820697,524288000.0,,0.001,1.600834,1.604139,1716071000.0,4000,4.975849,...,6,True,0.1,llama2,out/sa=4; rca=4; d=128; L=6; rca_type=disentan...,RoPE,0.9,512.0,,sa=4; rca=4; d=128; L=6; rca_type=disentangled...
3,4.70429,3.820555,786432000.0,,0.001,1.547904,1.55054,1716072000.0,6000,4.716177,...,6,True,0.1,llama2,out/sa=4; rca=4; d=128; L=6; rca_type=disentan...,RoPE,0.9,512.0,,sa=4; rca=4; d=128; L=6; rca_type=disentangled...
4,4.540239,3.820479,1048576000.0,,0.001,1.512442,1.515324,1716074000.0,8000,4.552852,...,6,True,0.1,llama2,out/sa=4; rca=4; d=128; L=6; rca_type=disentan...,RoPE,0.9,512.0,,sa=4; rca=4; d=128; L=6; rca_type=disentangled...


In [42]:
# bar plot figure
color_map_ = {
    '$\\text{Transformer}\ (n_h^{sa}=8, n_h^{ra}=0)$': cmap(6), # red
    f'$\\text{{{our_model_name}}}\ (n_h^{{sa}}=6, n_h^{{ra}}=2)$': cmap(8), # purple
    f'$\\text{{{our_model_name}}}\ (n_h^{{sa}}=4, n_h^{{ra}}=4)$': cmap(0), # blue
    # 'Transformer+': cmap(4),
    }

color_map = {k: convert_to_plotly_color(color_map_[k]) for k in color_map_}
models = color_map_.keys()

In [43]:
def get_model_name(row):
    if row.rca == 0:
        return f'$\\text{{Transformer}}\ (n_h^{{sa}}={row.sa}, n_h^{{ra}}={row.rca})$'
    else:
        # return f"AbstractTransformer [{row['Symbol Type']}, symm={row['Symmetric RA']}] ($n_h^{{sa}} = {row.sa}, n_h^{{ra}}={row.rca}$)"
        return f"$\\text{{{our_model_name}}}\ (n_h^{{sa}}={row.sa}, n_h^{{ra}}={row.rca})$"

In [44]:
figure_data = figure_data.rename(columns={
    'symbol_type': 'Symbol Type', 'symmetric_rels': 'Symmetric RA', 
    'val/loss': 'Validation Loss', 'val/perplexity': 'Validation Perplexity', 'tokens': 'Tokens'})

figure_data.loc[figure_data['rca']==0, 'Symbol Type'] = 'NA'
figure_data.loc[figure_data['rca']==0, 'Symmetric RA'] = 'NA'

figure_data['Model'] = pd.Categorical(figure_data.apply(get_model_name, axis=1), models, ordered=True)
sym_map = {'sym_attn': 'Symbolic Attention', 'pos_relative': 'Position-Relative Symbols'}
figure_data['Symbol Type'] = pd.Categorical(figure_data['Symbol Type'].map(sym_map), sym_map.values(), ordered=True)

In [45]:
def filter_data(figure_data, d_models=None, layers=None, filter_first_step=False, filter_transformer=False, symbol_types=None, symmetry=None, rca_types=None):
    filter_ = ~figure_data.index.isna()
    if d_models is not None:
        filter_ = filter_ & (figure_data['d_model'].isin(d_models))
    if layers is not None:
        filter_ = filter_ & (figure_data['n_layers'].isin(layers))
    if filter_transformer:
        filter_ = filter_ & (figure_data['rca'] > 0)
    if filter_first_step:
        filter_ = filter_ & (figure_data['_step'] > 0)
    if symbol_types is not None:
        symbol_types = [sym_map[s] for s in symbol_types]
        filter_ = filter_ & ((figure_data['rca'] == 0) | figure_data['Symbol Type'].isin(symbol_types))
    if symmetry is not None:
        filter_ = filter_ & ((figure_data['rca'] == 0) | figure_data['Symmetric RA'].isin(symmetry))
    if rca_types is not None:
        filter_ = filter_ & ((figure_data['rca'] == 0) | (figure_data['rca_type'].isin(rca_types)))
    filtered_data = figure_data.copy()[filter_]

    if filter_transformer:
        filtered_data['Model'] = filtered_data['Model'].cat.remove_unused_categories()

    return filtered_data

### All Plots & Ablations

In [46]:
import plotly.graph_objects as go

# Get unique tasks and models
model_scales = [4, 5, 6]
d = 64
metric = 'Validation Loss'
# metric = 'Validation Perplexity'
# models = figure_data_['Model'].unique()

fig_data = filter_data(figure_data, d_models=[d], layers=model_scales, filter_first_step=True,
    symbol_types=('sym_attn',), symmetry=(False,), rca_types=('disentangled_v2',))
yrange_global = [fig_data[metric].min(), fig_data[metric].max()]

# Create frames for each task
frames = []
for scale in model_scales:
    frame_data = []
    ax_data = filter_data(figure_data, d_models=[d], layers=[scale], filter_first_step=True,
        symbol_types=('sym_attn',), symmetry=(False,), rca_types=('disentangled_v2',))
    ax_data = ax_data.groupby(['Model', 'Tokens'])[metric].aggregate(['mean', 'std', 'count', 'sem']).reset_index()
    models = ax_data['Model'].unique()

    yrange = [ax_data['mean'].min(), ax_data['mean'].max()]

    yrange = yrange_global
    # xrange = [task_data['Epoch'].min(), task_data['Training Set Size'].max()]
    eps_y = 0.025 * (yrange[1] - yrange[0])
    # eps_x = 0.025 * (xrange[1] - xrange[0])
    yrange = [yrange[0] - eps_y, yrange[1] + eps_y]
    # xrange = [xrange[0] - eps_x, xrange[1] + eps_x]

    for model in models:
        model_filter = ax_data['Model'] == model
        model_data = ax_data[model_filter]
        frame_data.append(go.Scatter(
            x=model_data['Tokens'], 
            y=model_data['mean'],
            mode='lines',
            name=model,
            line=dict(color=color_map[model]))) # change color based on model
    frames.append(go.Frame(data=frame_data, 
                           name=f'L = {scale}',
                           layout=dict(yaxis=dict(range=yrange))
                                       ))

# Create steps for the slider
steps = [dict(method='animate',
              args=[[frame['name']]],  # frame name to be shown
              label=frame['name']) for frame in frames]

# Create base frame
fig = go.Figure(
    data=frames[0]['data'],
    layout=go.Layout(
        title='Language Modeling Training Curves',
        xaxis=dict(title='Tokens'),
        yaxis=dict(title=metric),
        height=600,
        width=1000,
        sliders=[dict(steps=steps)],  # add the slider
                font_family="Computer Modern",
        template=theme,
        legend_title="Model",
        legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99,
                    title_font_family="Times New Roman", #bgcolor='LightSteelBlue',
                    bordercolor="Black", borderwidth=1),
    ),
    frames=frames
)

fig.show()

In [47]:
fig_json = fig.to_json()
# save fig_json to file
with open(f'{save_dir}/language_modeling_training_curves.json', 'w') as f:
    f.write(fig_json)

## Vision

In [227]:
figure_data = pd.read_csv('figure_data/imagenet/run_histories.csv')
figure_data.dropna(subset=['train/acc_epoch', 'val/loss'], inplace=True, how='all') # drop step rows and keep epoch rows
figure_data.head()

Unnamed: 0,train/loss_epoch,train/acc_step,trainer/global_step,val/top4_acc,_step,val/top3_acc,val/acc,val/loss,val/top2_acc,_runtime,...,val/top7_acc,val/top8_acc,train/acc_epoch,d_model,n_layers,symbol_retrieval,rca_type,symmetric_rels,group,name
6,,,312,0.126362,6,0.105609,0.047035,5.619035,0.079808,7318.16537,...,0.177324,0.191186,,1024,24,pos_relative,disentangled_v2,,,sa=10; rca=6; d=1024; L=24; rca_type=disentang...
7,6.408384,,312,,7,,,,,7322.367173,...,,,0.021343,1024,24,pos_relative,disentangled_v2,,,sa=10; rca=6; d=1024; L=24; rca_type=disentang...
14,,,625,0.208734,14,0.177083,0.092708,5.096831,0.141026,14620.223963,...,0.277804,0.297035,,1024,24,pos_relative,disentangled_v2,,,sa=10; rca=6; d=1024; L=24; rca_type=disentang...
15,5.503854,,625,,15,,,,,14624.3225,...,,,0.062984,1024,24,pos_relative,disentangled_v2,,,sa=10; rca=6; d=1024; L=24; rca_type=disentang...
22,,,938,0.290545,22,0.257212,0.141186,4.607257,0.211538,21924.534104,...,0.369631,0.390625,,1024,24,pos_relative,disentangled_v2,,,sa=10; rca=6; d=1024; L=24; rca_type=disentang...


In [228]:
model_name_map = {
    'sa=16; d=1024; L=24__2024_05_15_16_38_09': '$\\text{Transformer}\\ (n_h^{sa}=16, n_h^{ra}=0)$',
    'sa=10; rca=6; d=1024; L=24; rca_type=disentangled_v2; sym_rel=True; symbol_type=pos_relative__2024_05_15_18_13_54': f'$\\text{{{our_model_name}}}\\ (n_h^{{sa}}=10, n_h^{{ra}}=6)$'
    }

figure_data['Model'] = pd.Categorical(figure_data['name'].map(model_name_map), model_name_map.values(), ordered=True)

color_map_ = {
    '$\\text{Transformer}\\ (n_h^{sa}=16, n_h^{ra}=0)$': cmap(6), # red
    f'$\\text{{{our_model_name}}}\\ (n_h^{{sa}}=10, n_h^{{ra}}=6)$': cmap(8), # purple
    f'$\\text{{{our_model_name}}}\\ (n_h^{{sa}}=10, n_h^{{ra}}=6)$': cmap(0), # blue
    }

color_map = {k: convert_to_plotly_color(color_map_[k]) for k in color_map_}

In [229]:
df_ = figure_data[['epoch', 'train/acc_epoch', 'Model']].dropna().pivot(index='epoch', columns='Model', values='train/acc_epoch')
trainacc_diff = abs(df_.iloc[:,0] - df_.iloc[:,1])
print(f'train/acc mean difference: {trainacc_diff.mean():.2%}')
print(f'train/acc max difference: {trainacc_diff.max():.2%}')
print(f'train/acc end difference: {abs(max(df_.iloc[:,0]) - max(df_.iloc[:,1])):.2%}')
print()
df_ = figure_data[['epoch', 'val/acc', 'Model']].dropna().pivot(index='epoch', columns='Model', values='val/acc')
valacc_diff = abs(df_.iloc[:,0] - df_.iloc[:,1])
print(f'val/acc mean difference: {valacc_diff.mean():.2%}')
print(f'val/acc max difference: {valacc_diff.max():.2%}')
print(f'val/acc end difference: {abs(max(df_.iloc[:,0]) - max(df_.iloc[:,1])):.2%}')

train/acc mean difference: 5.01%
train/acc max difference: 10.08%
train/acc end difference: 2.89%

val/acc mean difference: 4.39%
val/acc max difference: 9.98%
val/acc end difference: 1.46%


In [230]:
figure_data.rename(columns={'train/acc_epoch': 'Training Accuracy', 'val/acc': 'Validation Accuracy',
    'val/loss': 'Validation Loss', 'train/loss_epoch': 'Training Loss', 'epoch': 'Epoch'}, inplace=True)

In [238]:
import plotly.subplots as sp
import plotly.graph_objects as go

# Assuming figure_data is your DataFrame with columns 'Epoch', 'Training Accuracy', 'Validation Accuracy', 'Model'
# and color_map_ is a dictionary mapping 'Model' values to colors

# Get unique models
models = figure_data['Model'].unique()

# Create subplots
fig = sp.make_subplots(rows=1, cols=2, subplot_titles=("Training Accuracy", "Validation Accuracy"), horizontal_spacing=0.1)

# Add line plots to the subplots
for model in models:
    model_data = figure_data[figure_data['Model'] == model]
    model_data_tr = model_data[['Epoch', 'Training Accuracy']].dropna()
    model_data_val = model_data[['Epoch', 'Validation Accuracy']].dropna()
    fig.add_trace(go.Scatter(
        x=model_data_tr['Epoch'], y=model_data_tr['Training Accuracy'],
        mode='lines', line=dict(color=color_map[model]), name=model, showlegend=False
        ), 
        row=1, col=1)
    fig.add_trace(go.Scatter(
        x=model_data_val['Epoch'], y=model_data_val['Validation Accuracy'],
        mode='lines', line=dict(color=color_map[model]), 
        name=model), 
        row=1, col=2)

# Update layout
fig.update_layout(
    height=600, width=1000, title_text="Image Classification Training Curves",
    template=theme,
    font_family="Computer Modern",
    legend_title="Model",
    legend=dict(yanchor="bottom", y=0.01, xanchor="right", x=0.99,
#                     title_font_family="Computer Modern", #bgcolor='LightSteelBlue',
                    bordercolor="Black", borderwidth=1),)

fig.update_xaxes(title_text="Epoch", row=1, col=1)
# fig.update_yaxes(title_text="Training Accuracy", row=1, col=1)
fig.update_xaxes(title_text="Epoch", row=1, col=2)
# fig.update_yaxes(title_text="Validation Accuracy", row=1, col=2)


# Show the figure
fig.show()

In [203]:
fig_json = fig.to_json()
# save fig_json to file
with open(f'{save_dir}/imagenet_training_curves.json', 'w') as f:
    f.write(fig_json)