In [1]:
import wandb
api = wandb.Api()
all_runs = api.runs(
    "chaosarium/multi",
    {"$and": [
        {"tags": {"$in": ["report1"]}},
    ]},
)


In [74]:
import pandas as pd
from itertools import chain
import numpy as np

In [113]:
def proc_run(run):
    config = run.config
    summary = run.summary
    tags = run.tags
    return {
        # 'run': run,
        'run_id': run.id,
        'finetuned_model': config['finetuned_model'],
        'lang': config['lang'],
        'max_val_en_mc_acc': summary['val/en_mc_acc']['max'],
        'max_val_lang_mc_acc': summary['val/lang_mc_acc']['max'],
        'max_test_lang_mc_acc': summary['test/lang_mc_acc']['max'],
        'max_mlm_perplexity': summary['val/mlm_perplexity']['max'],
        'test_mc_acc': summary['test/mc_acc'],
    }

all_runs_df = pd.DataFrame([proc_run(run) for run in all_runs])

In [114]:
from pandasql import sqldf
pysqldf = lambda q: sqldf(q, globals())

In [123]:
# select top 5 seeds by highest max_val_lang_mc_acc, tested at epoch where (val_lang_mc_acc + 0.5*val_en_mc_acc) is maximized
res = pysqldf("""
SELECT finetuned_model, lang, AVG(test_mc_acc) as 'AVG(ACC)', GROUP_CONCAT(test_mc_acc) as test_mc_acc_s FROM (
    SELECT *, RANK() OVER (PARTITION BY finetuned_model ORDER BY max_val_lang_mc_acc DESC) as rank_in_finetuned
    FROM all_runs_df
    WHERE max_mlm_perplexity <= 1000
)
WHERE rank_in_finetuned <= 5
GROUP BY finetuned_model
""")
res['test_mc_acc_s'] = res['test_mc_acc_s'].map(lambda x: np.std(list(map(float, x.split(',')))))
res = res.rename({'test_mc_acc_s': 'STD(ACC)'}, axis=1)
res

Unnamed: 0,finetuned_model,lang,AVG(ACC),STD(ACC)
0,jv-1-0.003-20000-4106092417-lora,jv,0.55,0.008036
1,jv-1-0.005-10000-4106092417-lora,jv,0.5625,0.016137
2,kn-1-0.003-20000-4106092417-lora,kn,0.535714,0.014411
3,kn-1-0.005-10000-4106092417-lora,kn,0.533835,0.015307
4,su-1-0.003-20000-4106092417-lora,su,0.570833,0.056728
5,su-1-0.005-10000-4106092417-lora,su,0.577,0.067868
6,sw-1-0.003-20000-4106092417-lora,sw,0.52459,0.011098
7,sw-1-0.005-10000-4106092417-lora,sw,0.53051,0.013516
8,yo-2-0.005-10000-4106092417-lora,yo,0.51406,0.014276


In [125]:
# select top 5 seeds by highest max_val_lang_mc_acc, stop training by oracle (magically select the best epoch to stop training)
res = pysqldf("""
SELECT finetuned_model, lang, AVG(max_test_lang_mc_acc) as 'AVG(ACC)', GROUP_CONCAT(max_test_lang_mc_acc) as max_test_lang_mc_acc_s FROM (
    SELECT *, RANK() OVER (PARTITION BY finetuned_model ORDER BY max_val_lang_mc_acc DESC) as rank_in_finetuned
    FROM all_runs_df
    WHERE max_mlm_perplexity <= 1000
)
WHERE max_val_lang_mc_acc <= 5
GROUP BY finetuned_model
""")
res['max_test_lang_mc_acc_s'] = res['max_test_lang_mc_acc_s'].map(lambda x: np.std(list(map(float, x.split(',')))))
res = res.rename({'max_test_lang_mc_acc_s': 'STD(ACC)'}, axis=1)
res

Unnamed: 0,finetuned_model,lang,AVG(ACC),STD(ACC)
0,jv-1-0.003-20000-4106092417-lora,jv,0.59475,0.017046
1,jv-1-0.005-10000-4106092417-lora,jv,0.594875,0.016854
2,kn-1-0.003-20000-4106092417-lora,kn,0.556955,0.011385
3,kn-1-0.005-10000-4106092417-lora,kn,0.563214,0.013849
4,su-1-0.003-20000-4106092417-lora,su,0.595,0.034479
5,su-1-0.005-10000-4106092417-lora,su,0.587875,0.029878
6,sw-1-0.003-20000-4106092417-lora,sw,0.559699,0.013267
7,sw-1-0.005-10000-4106092417-lora,sw,0.56392,0.008669
8,yo-2-0.005-10000-4106092417-lora,yo,0.550593,0.012225
