In [2]:
import pandas as pd
import wandb
import utils
from tqdm import tqdm

# Relational Games

## Get OoD Generalization data

In [2]:
api = wandb.Api()
projects = [project for project in api.projects('awni00') if 'relational_games-' in project.name]
projects

[<Project awni00/relational_games-same>,
 <Project awni00/relational_games-xoccurs>,
 <Project awni00/relational_games-occurs>,
 <Project awni00/relational_games-1task_match_patt>,
 <Project awni00/relational_games-1task_between>]

In [3]:
summary_cols = ['stripes_acc', 'hexos_acc', 'acc']
config_cols = ['trial']
attr_cols = ['group', 'name']

project_dfs = []
for project in projects:
    task_name = project.name.split('-')[1]
    project_df = utils.get_wandb_project_table(
        project_name=project.name, entity='awni00', attr_cols=attr_cols, config_cols=config_cols, summary_cols=summary_cols)
    project_df['task'] = task_name
    project_dfs.append(project_df)

projects_df = pd.concat(project_dfs)

In [4]:
projects_df.to_csv('figure_data/relational_games/end-of-training-accuracy.csv')

## Get training curve data

In [27]:
api = wandb.Api()
projects = [project for project in api.projects('awni00') if 'relational_games-' in project.name]

In [28]:
config_cols = ['trial']
attr_cols = ['group', 'name']

project_dfs = []
for project in tqdm(projects):
    task_name = project.name.split('-')[1]
    project_df = utils.get_project_run_histories(
        project_name=project.name, entity='awni00', attr_cols=attr_cols, config_cols=config_cols)
    project_df.to_csv(f'figure_data/relational_games/run_history_{project.name}.csv')
    project_df['task'] = task_name
    project_dfs.append(project_df)

projects_df = pd.concat(project_dfs)
projects_df.reset_index(inplace=True, drop=True)
projects_df.to_csv(f'figure_data/relational_games/project_run_histories.csv')

100%|██████████| 5/5 [52:45<00:00, 633.08s/it]


In [29]:
# get csv files in directory
import glob
csv_files = glob.glob('figure_data/relational_games/run_history*.csv')

In [30]:
project_dfs = []
for csv_file in csv_files:
    task_name = csv_file.split('-')[-1].split('.')[0]
    project_df = pd.read_csv(csv_file, index_col=0)
    project_df['task'] = task_name
    project_dfs.append(project_df)

projects_df = pd.concat(project_dfs)

## Relational Games Group Attention Exploration

In [9]:
api = wandb.Api()

summary_cols = 'all'
config_cols = 'all'
attr_cols = ['group', 'name']

project_df = utils.get_wandb_project_table(
    project_name='relgames_groupattn_param_search', entity='awni00', attr_cols=attr_cols, config_cols=config_cols, summary_cols=summary_cols)
project_df.to_csv('figure_data/relational_games/group_attn/end-of-training-accuracy.csv')

In [11]:
project_df.columns

Index(['group', 'name', 'epoch/learning_rate', 'epoch/loss',
       'group_attn_scores', 'epoch/val_loss', 'hexos_acc', '_runtime',
       'epoch/group_attn_entropy', 'group_attn_entropy', '_step',
       'epoch/val_acc', 'stripes_loss', 'epoch/epoch',
       'stripes_group_attn_entropy', 'loss', 'acc', 'hexos_loss', '_wandb',
       'hexos_group_attn_entropy', '_timestamp',
       'epoch/val_group_attn_entropy', 'epoch/acc', 'stripes_acc',
       'learning_rate', 'entropy_reg', 'run_name', 'wandb_project_name',
       'n_groups', 'n_filters', 'seed', 'symmetric_inner_prod', 'two_layer',
       'train_size', 'mdipr_proj_dim', 'group_attn_key', 'ignore_gpu_assert',
       'mdipr_symmetric', 'group_attn_key_dim', 'train_split', 'test_size',
       'task', 'n_epochs', 'val_size', 'graphlet_size', 'batch_size',
       'test_split_size', 'mdipr_rel_dim', 'early_stopping',
       'entropy_reg_scale'],
      dtype='object')

In [25]:
config_cols = 'all'
attr_cols = ['group', 'name']

project_df = utils.get_project_run_histories(
    project_name='relgames_groupattn_param_search', entity='awni00', attr_cols=attr_cols, config_cols=config_cols)

                                             

In [26]:
project_df.to_csv('figure_data/relational_games/group_attn/run_history.csv', index=False)

# Contains 'SET'

## Get End-of-Training Metrics data

In [83]:
api = wandb.Api()

summary_cols = 'all'
config_cols = ['trial']
attr_cols = ['group', 'name']

project_df = utils.get_wandb_project_table(
    project_name='relconvnet-contains_set', entity='awni00', attr_cols=attr_cols, config_cols=config_cols, summary_cols=summary_cols)

In [84]:
project_df.to_csv('figure_data/contains_set/end-of-training-accuracy.csv', index=False)

In [85]:
project_df.groupby(['group'])['acc'].mean()

group
abstractor                                                          0.506610
cnn                                                                 0.506173
corelnet                                                            0.562577
gat                                                                 0.517181
gat - L=1 - opt_adamw - weight_dec_0.032 - lr_sched_cosine          0.674640
gat - L=1 - opt_adamw - weight_dec_0.032 - lr_sched_none            0.652443
gcn                                                                 0.594727
gcn - L=2 - opt_adamw - weight_dec_1.024 - lr_sched_none            0.635339
gin                                                                 0.589660
gin - L=2 - opt_adamw - weight_dec_0.032 - lr_sched_none            0.593313
gru                                                                 0.592978
lstm                                                                0.602315
lstm - L=2 - opt_adamw - weight_dec_1.024 - lr_sched_none           0.

## Get training curve data

In [86]:
api = wandb.Api()

In [87]:
config_cols = ['trial']
attr_cols = ['group', 'name']

project_df = utils.get_project_run_histories(
    project_name='relconvnet-contains_set', entity='awni00', attr_cols=attr_cols, config_cols=config_cols)

                                                 

In [88]:
project_df.to_csv('figure_data/contains_set/run_history.csv')

# Contains 'SET' (Hyperparameter Sweep)

## Get End-of-Training Metrics data

In [96]:
import wandb

In [97]:
api = wandb.Api()

summary_cols = 'all'
config_cols = 'all' #['trial']
attr_cols = ['name']

project_df = utils.get_wandb_project_table(
    project_name='relconvnet-contains_set-hyperparamsweep', entity='awni00', attr_cols=attr_cols, config_cols=config_cols, summary_cols=summary_cols)

In [98]:
project_df.to_csv('figure_data/contains_set/hyperparam_sweep_end-of-training-accuracy.csv', index=False)

In [99]:
project_df.groupby(['group'])['acc'].mean()

group
cnn - L=-1 - opt_adamw - weight_dec_0.0 - lr_sched_none            0.500000
cnn - L=-1 - opt_adamw - weight_dec_0.002 - lr_sched_none          0.500000
cnn - L=-1 - opt_adamw - weight_dec_0.004 - lr_sched_none               NaN
cnn - L=-1 - opt_adamw - weight_dec_0.008 - lr_sched_none               NaN
cnn - L=-1 - opt_adamw - weight_dec_0.032 - lr_sched_none          0.500000
                                                                     ...   
transformer - L=6 - opt_adam - weight_dec_0.0 - lr_sched_none      0.585031
transformer - L=6 - opt_adamw - weight_dec_0.04 - lr_sched_none    0.570628
transformer - L=8 - opt_adam - weight_dec_0.0                      0.564043
transformer - L=8 - opt_adam - weight_dec_0.0 - lr_sched_none      0.566358
transformer - L=8 - opt_adamw - weight_dec_0.04 - lr_sched_none    0.574228
Name: acc, Length: 363, dtype: float64

## Get training curve data

In [100]:
api = wandb.Api()

In [101]:
def get_project_run_histories(project_name, entity='Awni00', attr_cols=('group', 'name'), config_cols='all'):
    '''gets the log history of all runs in a project'''

    def get_run_history(run):
        # history_scan = run.scan_history()
        # try:
        #     keys = history_scan.next().keys()
        # except StopIteration:
        #     print(f'{run.group}-{run.name} has no history. Skipping...')
        #     return pd.DataFrame()
        # run_history_data = {key: [] for key in keys}
        # for row in history_scan:
        #     for key in keys:
        #         run_history_data[key].append(row[key])
        # return pd.DataFrame(run_history_data)
        return run.history()

    api = wandb.Api(timeout=60)

    runs = api.runs(entity + "/" + project_name)

    if config_cols == 'all':
        config_cols = set().union(*tuple(run.config.keys() for run in runs))

    run_history_dfs = []

    for run in tqdm(runs, leave=False):
        run_history = get_run_history(run)

        for config_col in config_cols:
            run_history[config_col] = run.config.get(config_col, None)

        for attr_col in attr_cols:
            run_history[attr_col] = getattr(run, attr_col, None)

        run_history_dfs.append(run_history)

    runs_history_df = pd.concat(run_history_dfs, axis=0)

    runs_history_df = runs_history_df.reset_index(drop=True)

    return runs_history_df

In [102]:
config_cols = 'all' #['trial']
attr_cols = ['name']

# project_df = utils.get_project_run_histories(
#     project_name='relconvnet-contains_set-hyperparamsweep', entity='awni00', attr_cols=attr_cols, config_cols=config_cols)
project_df = get_project_run_histories(
    project_name='relconvnet-contains_set-hyperparamsweep', entity='awni00', attr_cols=attr_cols, config_cols=config_cols)

                                                   

In [103]:
project_df.to_csv('figure_data/contains_set/hyperparam_sweep_run_history.csv')