# Fetch Experiment Results from W&B Logs

In [1]:
import wandb
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px
from tqdm import tqdm

from wandb_utils import get_wandb_project_table

In [9]:
data_dir = 'figure_data'

## Relational Games

In [10]:
api = wandb.Api()
project = 'abstract_transformer--relational_games_learning_curves'

In [11]:
config_cols = ['d_model', 'n_layers', 'symbol_retrieval', 'n_heads_rca', 'n_heads_sa', 'rca_type', 'rca_kwargs.symmetric_rels', 'n_heads', 'task', 'train_size']
attr_cols = ['group', 'name']

relgames_data = get_wandb_project_table(
    project_name=project, entity='awni00', attr_cols=attr_cols, config_cols=config_cols, summary_cols='all')

In [12]:
def process_groupname(group_name):
    task, model_name = group_name.split('__')
    return model_name

In [13]:
relgames_data['model_name'] = relgames_data['group'].apply(process_groupname)

In [14]:
L, total_n_heads = 2, 2
filter_ = (relgames_data['n_layers'] == L) & ((relgames_data['n_heads_rca'] + relgames_data['n_heads_sa'] == total_n_heads) | (relgames_data['n_heads'] == total_n_heads)) & (relgames_data['train_size'] <= 25_000)
relgames_data = relgames_data[filter_]

In [28]:
relgames_data.to_csv(f'{data_dir}/relgames_data.csv', index=False)

## Math

In [22]:
api = wandb.Api()
projects = [project for project in api.projects('awni00') if 'abstract_transformer--math' in project.name]

In [23]:
config_cols = ['d_model', 'n_layers_enc', 'n_layers_dec', 'symbol_retrieval', 'encoder_kwargs', 'decoder_kwargs']
attr_cols = ['group', 'name']

project_dfs = []
for project in projects:
    task_name = project.name.split('--math--')[1]
    project_df = get_wandb_project_table(
        project_name=project.name, entity='awni00', attr_cols=attr_cols, config_cols=config_cols, summary_cols='all')
    project_df['task'] = task_name
    project_dfs.append(project_df)

math_eot_data = pd.concat(project_dfs)

In [30]:
math_eot_data.to_csv(f'{data_dir}/math_eot_data.csv', index=False)

In [31]:
def get_project_run_histories(project_name, keys, groups=None, entity='Awni00', attr_cols=('group', 'name'), config_cols='all'):
    '''gets the log history of all runs in a project'''

    def get_run_history(run):
        # history_scan = run.scan_history(keys=keys)
        # run_history_data = {key: [] for key in keys}
        # for row in history_scan:
        #     for key in keys:
        #         run_history_data[key].append(row[key])
        # return pd.DataFrame(run_history_data)
        return run.history(x_axis='epoch')

    api = wandb.Api(timeout=60)

    runs = api.runs(entity + "/" + project_name)
    if groups is not None:
        runs = [run for run in runs if run.group in groups]

    if config_cols == 'all':
        config_cols = set().union(*tuple(run.config.keys() for run in runs))

    run_history_dfs = []

    print(f'fetching run history for {len(runs)} runs in {project_name}')

    for run in tqdm(runs):
        run_history = get_run_history(run)

        for config_col in config_cols:
            run_history[config_col] = run.config.get(config_col, None)

        for attr_col in attr_cols:
            run_history[attr_col] = getattr(run, attr_col, None)

        run_history_dfs.append(run_history)

    runs_history_df = pd.concat(run_history_dfs, axis=0)

    runs_history_df = runs_history_df.reset_index(drop=True)

    return runs_history_df

In [33]:
metrics = ['epoch', 'train_teacher_forcing_acc', 'interpolate_teacher_forcing_acc', 'extrapolate_teacher_forcing_acc', 'train_loss', 'interpolate_loss', 'extrapolate_loss']
groups = ['ee=8; ea=0; de=8; da=0; dc=8; el=2; dl=2',
 'e_sa=8; e_rca=0; d_sa=8; d_rca=0; d_cross=8; d=144; rca_type=NA, symbol_type=NA; el=2; dl=2',
 'e_sa=4; e_rca=4; d_sa=8; d_rca=0; d_cross=8; d=128; rca_type=disentangled_v2, symbol_type=pos_relative; el=2; dl=2',
 'e_sa=4; e_rca=4; d_sa=4; d_rca=4; d_cross=8; d=128; rca_type=disentangled_v2, symbol_type=pos_relative; el=2; dl=2',
 'e_sa=4; e_rca=4; d_sa=8; d_rca=0; d_cross=8; rca_dis=True, el=2; dl=2',
 'e_sa=4; e_rca=4; d_sa=4; d_rca=4; d_cross=8; rca_dis=True, el=2; dl=2'
 ]

config_cols = ['d_model', 'n_layers_enc', 'n_layers_dec', 'symbol_retrieval']
attr_cols = ['group', 'name']
project_dfs = []
for project in projects:
    task_name = project.name.split('--math--')[1]
    project_df = get_project_run_histories(
        project_name=project.name, entity='awni00', keys=metrics, groups=groups, attr_cols=attr_cols, config_cols=config_cols)
    project_df.to_csv(f'{data_dir}/run_history_{project.name}.csv')
    project_df['task'] = task_name
    project_dfs.append(project_df)

math_run_histories = pd.concat(project_dfs)

fetching run history for 28 runs in abstract_transformer--math--algebra__linear_1d


100%|██████████| 28/28 [00:11<00:00,  2.42it/s]


fetching run history for 28 runs in abstract_transformer--math--algebra__sequence_next_term


100%|██████████| 28/28 [00:12<00:00,  2.33it/s]


fetching run history for 29 runs in abstract_transformer--math--calculus__differentiate


100%|██████████| 29/29 [00:11<00:00,  2.63it/s]


fetching run history for 28 runs in abstract_transformer--math--polynomials__expand


100%|██████████| 28/28 [00:10<00:00,  2.73it/s]


fetching run history for 28 runs in abstract_transformer--math--polynomials__add


100%|██████████| 28/28 [00:09<00:00,  2.82it/s]


In [35]:
math_run_histories.to_csv(f'{data_dir}/run_history_all.csv', index=False)

## Vision

## Language Modeling