In [3]:
import pandas as pd
import wandb
import utils
from tqdm import tqdm

# Relational Games

## Get OoD Generalization data

In [4]:
api = wandb.Api()
projects = [project for project in api.projects('awni00') if 'relational_games-' in project.name and 'old' not in project.name]

In [5]:
summary_cols = ['stripes_acc', 'hexos_acc', 'acc']
config_cols = ['trial']
attr_cols = ['group', 'name']

project_dfs = []
for project in projects:
    task_name = project.name.split('-')[1]
    project_df = utils.get_wandb_project_table(
        project_name=project.name, entity='awni00', attr_cols=attr_cols, config_cols=config_cols, summary_cols=summary_cols)
    project_df['task'] = task_name
    project_dfs.append(project_df)

projects_df = pd.concat(project_dfs)

In [6]:
projects_df.to_csv('figure_data/relational_games/end-of-training-accuracy.csv')

## Get training curve data

In [2]:
api = wandb.Api()
projects = [project for project in api.projects('awni00') if 'relational_games-' in project.name and 'old' not in project.name]

In [45]:
projects[-1].name

'relational_games-1task_between'

In [46]:
config_cols = ['trial']
attr_cols = ['group', 'name']

project_dfs = []
for project in tqdm(projects[-1:]):
    task_name = project.name.split('-')[1]
    project_df = utils.get_project_run_histories(
        project_name=project.name, entity='awni00', attr_cols=attr_cols, config_cols=config_cols)
    project_df.to_csv(f'figure_data/relational_games/run_history_{project.name}.csv')
    project_df['task'] = task_name
    project_dfs.append(project_df)

projects_df = pd.concat(project_dfs)
projects_df.reset_index(inplace=True, drop=True)
projects_df.to_csv(f'figure_data/relational_games/project_run_histories.csv')

  0%|          | 0/1 [00:00<?, ?it/s]




trial = 0 has no history. Skipping...
trial = 0 has no history. Skipping...


 45%|████▍     | 42/94 [07:03<04:58,  5.74s/it][A


trial = 0 has no history. Skipping...
trial = 0 has no history. Skipping...


 47%|████▋     | 44/94 [07:04<02:43,  3.26s/it][A

trial = 0 has no history. Skipping...
trial = 0 has no history. Skipping...




trial = 0 has no history. Skipping...
trial = 0 has no history. Skipping...




trial = 0 has no history. Skipping...
trial = 0 has no history. Skipping...





trial = 0 has no history. Skipping...
trial = 0 has no history. Skipping...


 55%|█████▌    | 52/94 [07:05<00:20,  2.05it/s][A

trial = 0 has no history. Skipping...
trial = 0 has no history. Skipping...




trial = 0 has no history. Skipping...
trial = 0 has no history. Skipping...




trial = 3 has no history. Skipping...
trial = 4 has no history. Skipping...




trial = 1 has no history. Skipping...
trial = 2 has no history. Skipping...
trial = 0 has no history. Skipping...




trial = 3 has no history. Skipping...
trial = 4 has no history. Skipping...




trial = 1 has no history. Skipping...
trial = 2 has no history. Skipping...




trial = 0 has no history. Skipping...


100%|██████████| 1/1 [11:41<00:00, 701.88s/it]


In [47]:
# get csv files in directory
import glob
csv_files = glob.glob('figure_data/relational_games/run_history*.csv')

In [48]:
project_dfs = []
for csv_file in csv_files:
    task_name = csv_file.split('-')[-1].split('.')[0]
    project_df = pd.read_csv(csv_file, index_col=0)
    project_df['task'] = task_name
    project_dfs.append(project_df)

projects_df = pd.concat(project_dfs)


# Contains 'SET'

## Get OoD Generalization data

In [4]:
api = wandb.Api()

summary_cols = 'all'
config_cols = ['trial']
attr_cols = ['group', 'name']

project_df = utils.get_wandb_project_table(
    project_name='relconvnet-contains_set', entity='awni00', attr_cols=attr_cols, config_cols=config_cols, summary_cols=summary_cols)

In [5]:
project_df.to_csv('figure_data/contains_set/end-of-training-accuracy.csv')

## Get training curve data

In [6]:
api = wandb.Api()

In [7]:
config_cols = ['trial']
attr_cols = ['group', 'name']

project_df = utils.get_project_run_histories(
    project_name='relconvnet-contains_set', entity='awni00', attr_cols=attr_cols, config_cols=config_cols)

                                               

In [9]:
project_df.to_csv('figure_data/contains_set/run_history.csv')