# Fetch Experiment Results from W&B Logs

In [25]:
import wandb
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px
from tqdm import tqdm

from wandb_utils import get_wandb_project_table

In [26]:
data_dir = 'figure_data_tmp'

In [27]:
def get_project_run_histories(project_name, keys=None, groups=None, runs_filter=None, entity='Awni00', attr_cols=('group', 'name'), config_cols='all', retrieval_method='history', **kwargs):
    '''gets the log history of all runs in a project'''

    def get_run_history(run):
        if retrieval_method == 'scan':
            history_scan = run.scan_history(keys=keys)
            run_history_data = {key: [] for key in keys}
            for row in history_scan:
                for key in keys:
                    run_history_data[key].append(row[key])
            return pd.DataFrame(run_history_data)
        elif retrieval_method == 'history':
            return run.history(keys=keys, **kwargs)
        else:
            raise ValueError(f'invalid retrieval method: {retrieval_method}')

    api = wandb.Api(timeout=60)

    runs = api.runs(entity + "/" + project_name)
    if groups is not None:
        runs = [run for run in runs if run.group in groups]
    if runs_filter:
        runs = [run for run in runs if run.name in runs_filter]

    if config_cols == 'all':
        config_cols = set().union(*tuple(run.config.keys() for run in runs))

    run_history_dfs = []

    print(f'fetching run history for {len(runs)} runs in {project_name}')

    for run in tqdm(runs):
        run_history = get_run_history(run)

        for config_col in config_cols:
            run_history[config_col] = run.config.get(config_col, None)

        for attr_col in attr_cols:
            run_history[attr_col] = getattr(run, attr_col, None)

        run_history_dfs.append(run_history)

    runs_history_df = pd.concat(run_history_dfs, axis=0)

    runs_history_df = runs_history_df.reset_index(drop=True)

    return runs_history_df

## Relational Games

In [44]:
api = wandb.Api()
project = 'dual_attention--relational_games_learning_curves'
data_dir = 'figure_data'

In [45]:
config_cols = 'all' #['d_model', 'n_layers', 'symbol_retrieval', 'n_heads_rca', 'n_heads_sa', 'rca_type', 'rca_kwargs.symmetric_rels', 'n_heads', 'task', 'train_size', 'num_params']
attr_cols = ['name']

relgames_data = get_wandb_project_table(
    # project_name=project, entity='awni00', attr_cols=attr_cols, config_cols=config_cols, summary_cols='all')
    project_name=project, entity='dual-attention', attr_cols=attr_cols, config_cols=config_cols, summary_cols='all')

In [46]:
def process_groupname(group_name):
    task, model_name = group_name.split('__')
    return model_name

In [58]:
relgames_data['model_name'] = relgames_data['group'].apply(process_groupname)

In [59]:
trimmed_names = [
    'sa=0; ra=8; nr=8; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=',
    'sa=0; ra=4; nr=4; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=',
    'sa=0; ra=8; nr=8; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=pos',
    'sa=0; ra=4; nr=4; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=positiona',
    'sa=0; ra=4; nr=4; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=pos',
    'sa=0; ra=8; nr=8; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=positiona',
    'sa=0; ra=4; nr=4; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=positional',
    'sa=0; ra=8; nr=8; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=positional',
    'sa=0; ra=8; nr=8; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=positional_s',
    'sa=0; ra=4; nr=4; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=positional_s',
    'sa=4; ra=4; nr=8; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=',
    'sa=2; ra=2; nr=4; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=',
    'sa=0; ra=2; nr=2; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=',
    'sa=1; ra=1; nr=2; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=',
    'sa=4; ra=4; nr=8; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=pos',
    'sa=2; ra=2; nr=4; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=pos',
    'sa=1; ra=1; nr=2; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=pos',
    'sa=0; ra=2; nr=2; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=pos'
]

trim_length = len('sa=0; ra=8; nr=8; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=')
trimmed_name_map = {trimmed_name: trimmed_name[:trim_length]+'positional_symbols' for trimmed_name in trimmed_names}
trimmed_name_map

{'sa=0; ra=8; nr=8; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=': 'sa=0; ra=8; nr=8; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=positional_symbols',
 'sa=0; ra=4; nr=4; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=': 'sa=0; ra=4; nr=4; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=positional_symbols',
 'sa=0; ra=8; nr=8; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=pos': 'sa=0; ra=8; nr=8; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=positional_symbols',
 'sa=0; ra=4; nr=4; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=positiona': 'sa=0; ra=4; nr=4; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=positional_symbols',
 'sa=0; ra=4; nr=4; d=128; L=2; ra_type=relational_attention; sym_rel=Tr

In [60]:
relgames_data['model_name'] = relgames_data['model_name'].map(lambda x: trimmed_name_map.get(x, x))

In [61]:
relgames_data['model_name'].unique()

array(['sa=0; ra=8; nr=8; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=positional_symbols',
       'sa=0; ra=4; nr=4; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=positional_symbols',
       'sa=4; ra=4; nr=8; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=positional_symbols',
       'sa=2; ra=2; nr=4; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=positional_symbols',
       'sa=0; ra=2; nr=2; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=positional_symbols',
       'sa=1; ra=1; nr=2; d=128; L=2; ra_type=relational_attention; sym_rel=True; sym_attn=True; symbol_type=positional_symbols',
       'sa=4; d=144; L=2; sym_attn', 'sa=8; d=144; L=2; sym_attn',
       'sa=0; ra=2; nr=2; d=128; L=2; ra_type=relational_attention; sym_rel=True; symbol_type=positional_symbols',
       'sa=1; ra=1; nr=2; d=128; L=2; 

In [62]:
relgames_data[['task', 'model_name']].drop_duplicates().sort_values(by='task')

Unnamed: 0,task,model_name
923,1task_between,sa=4; d=144; L=2; sym_attn
3998,1task_between,sa=4; d=144; L=2
4004,1task_between,sa=2; d=144; L=2
4670,1task_between,sa=1; rca=1; d=128; L=2; rca_type=standard; sy...
4679,1task_between,sa=0; rca=2; d=128; L=2; rca_type=standard; sy...
...,...,...
4776,xoccurs,sa=0; rca=2; d=128; L=2; rca_type=standard; sy...
4765,xoccurs,sa=1; rca=1; d=128; L=2; rca_type=standard; sy...
913,xoccurs,sa=4; ra=4; nr=8; d=128; L=2; ra_type=relation...
6955,xoccurs,sa=0; rca=4; d=128; L=2; rca_type=disentangled...


In [51]:
# L, total_n_heads = 2, 2
# filter_ = (relgames_data['n_layers'] == L) & ((relgames_data['n_heads_rca'] + relgames_data['n_heads_sa'] == total_n_heads) | (relgames_data['n_heads'] == total_n_heads)) & (relgames_data['train_size'] <= 25_000)
# relgames_data = relgames_data[filter_]

In [64]:
relgames_data.to_csv(f'{data_dir}/relgames/relgames_data.csv', index=False)

## Math

In [86]:
api = wandb.Api()
projects = [project for project in api.projects('awni00') if 'dual_attention--math' in project.name]

In [87]:
config_cols = ['d_model', 'n_layers_enc', 'n_layers_dec', 'symbol_retrieval', 'encoder_kwargs', 'decoder_kwargs']
attr_cols = ['group', 'name']

project_dfs = []
for project in projects:
    task_name = project.name.split('--math--')[1]
    project_df = get_wandb_project_table(
        project_name=project.name, entity='awni00', attr_cols=attr_cols, config_cols=config_cols, summary_cols='all')
    project_df['task'] = task_name
    project_dfs.append(project_df)

math_eot_data = pd.concat(project_dfs)

In [88]:
math_eot_data.to_csv(f'{data_dir}/math/math_eot_data.csv', index=False)

In [89]:
metrics = None # ['epoch', 'train_teacher_forcing_acc', 'interpolate_teacher_forcing_acc', 'extrapolate_teacher_forcing_acc', 'train_loss', 'interpolate_loss', 'extrapolate_loss']
groups = None
# groups = [
#     'ee=8; ea=0; de=8; da=0; dc=8; el=2; dl=2',
#  'e_sa=8; e_rca=0; d_sa=8; d_rca=0; d_cross=8; d=128; rca_type=NA, symbol_type=NA; el=2; dl=2',
#  'e_sa=8; e_rca=0; d_sa=8; d_rca=0; d_cross=8; d=144; rca_type=NA, symbol_type=NA; el=2; dl=2',
#  'e_sa=4; e_rca=4; d_sa=8; d_rca=0; d_cross=8; d=128; rca_type=disentangled_v2, symbol_type=pos_relative; el=2; dl=2',
#  'e_sa=4; e_rca=4; d_sa=4; d_rca=4; d_cross=8; d=128; rca_type=disentangled_v2, symbol_type=pos_relative; el=2; dl=2',
#  'e_sa=4; e_rca=4; d_sa=8; d_rca=0; d_cross=8; rca_dis=True, el=2; dl=2',
#  'e_sa=4; e_rca=4; d_sa=4; d_rca=4; d_cross=8; rca_dis=True, el=2; dl=2'
#  ]

config_cols = 'all' # ['d_model', 'n_layers_enc', 'n_layers_dec', 'symbol_retrieval']
attr_cols = ['group', 'name']
project_dfs = []
for project in projects:
    task_name = project.name.split('--math--')[1]
    project_df = get_project_run_histories(
        project_name=project.name, entity='awni00', keys=metrics, groups=groups, attr_cols=attr_cols, config_cols=config_cols)
    project_df.to_csv(f'{data_dir}/math/run_history_{project.name}.csv')
    project_df['task'] = task_name
    project_dfs.append(project_df)

math_run_histories = pd.concat(project_dfs)

fetching run history for 18 runs in dual_attention--math--algebra__sequence_next_term


100%|██████████| 18/18 [00:08<00:00,  2.04it/s]


fetching run history for 19 runs in dual_attention--math--calculus__differentiate


100%|██████████| 19/19 [00:09<00:00,  2.11it/s]


fetching run history for 19 runs in dual_attention--math--polynomials__expand


100%|██████████| 19/19 [00:11<00:00,  1.68it/s]


fetching run history for 19 runs in dual_attention--math--polynomials__add


100%|██████████| 19/19 [00:11<00:00,  1.72it/s]


fetching run history for 19 runs in dual_attention--math--algebra__linear_1d


100%|██████████| 19/19 [00:08<00:00,  2.16it/s]


In [90]:
math_run_histories.to_csv(f'{data_dir}/math/run_history_all.csv', index=False)

## Language Modeling (Tiny Stories)

In [15]:
api = wandb.Api()
project = 'abstract_transformer--tiny_stories-LM'

metrics = ['iter', 'tokens', 'train/loss', 'val/loss', 'train/perplexity', 'val/perplexity']

groups = None

config_cols = 'all' # ['d_model', 'n_layers', 'symbol_retrieval', 'rca_type', 'symmetric_rels']
attr_cols = ['group', 'name']
project_df = get_project_run_histories(
    project_name=project, entity='awni00', keys=metrics, groups=groups, attr_cols=attr_cols, config_cols=config_cols)

fetching run history for 57 runs in abstract_transformer--tiny_stories-LM


100%|██████████| 57/57 [00:29<00:00,  1.96it/s]


In [16]:
project_df.to_csv(f'{data_dir}/tiny_stories/run_histories.csv', index=False)

## Language Modeling (Fineweb)

In [96]:
data_dir = 'figure_data'

api = wandb.Api()
project = 'fineweb'

metrics = ['step', 'tokens', 'loss/train', 'loss/val', 'norm', 'lr']

groups = None

config_cols = 'all' # ['d_model', 'n_layers', 'symbol_retrieval', 'rca_type', 'symmetric_rels']
attr_cols = ['group', 'name']
project_df = get_project_run_histories(
    project_name=project, entity='awni00', x_axis='_step', samples=20_000, keys=None, groups=groups, attr_cols=attr_cols, config_cols=config_cols, retrieval_method='history')

fetching run history for 49 runs in fineweb


100%|██████████| 49/49 [02:04<00:00,  2.55s/it]


In [97]:
project_df.to_csv(f'{data_dir}/fineweb/run_histories.csv', index=False)

## Vision

In [17]:
api = wandb.Api()
project = 'abstract_transformer--Vision-IMAGENET'

groups = None

run_filter = [
    'sa=16; d=1024; L=24__2024_05_15_16_38_09',
    'sa=10; rca=6; d=1024; L=24; rca_type=disentangled_v2; sym_rel=True; symbol_type=pos_relative__2024_05_15_18_13_54']

config_cols = ['d_model', 'n_layers', 'symbol_retrieval', 'rca_type', 'symmetric_rels']
attr_cols = ['group', 'name']
project_df = get_project_run_histories(
    project_name=project, entity='awni00', keys=None, groups=groups, runs_filter=run_filter, attr_cols=attr_cols, config_cols=config_cols)

fetching run history for 2 runs in abstract_transformer--Vision-IMAGENET


100%|██████████| 2/2 [00:00<00:00,  2.57it/s]


In [18]:
project_df.to_csv(f'{data_dir}/imagenet/run_histories.csv', index=False)