In [30]:
import wandb

# Initialize API
api = wandb.Api()

# Replace with your project name and run ID
project_name = "regmixer"
run_id = "pt9bj68g"

# Fetch the run
run = api.run(f"{project_name}/{run_id}")

# Get all metric names that start with "eval/downstream"
eval_downstream_metrics = [k for k in run.summary.keys() if k.startswith("eval/downstream")]

# Print the metric names
print(eval_downstream_metrics)


['eval/downstream/arc_challenge (BPB)', 'eval/downstream/arc_challenge (CE loss)', 'eval/downstream/arc_challenge (length-normalized accuracy)', 'eval/downstream/arc_challenge (log soft loss)', 'eval/downstream/arc_challenge (soft loss)', 'eval/downstream/arc_challenge_mc_5shot (BPB)', 'eval/downstream/arc_challenge_mc_5shot (CE loss)', 'eval/downstream/arc_challenge_mc_5shot (accuracy)', 'eval/downstream/arc_challenge_mc_5shot (log soft loss)', 'eval/downstream/arc_challenge_mc_5shot (soft loss)', 'eval/downstream/arc_challenge_mc_5shot_bpb (BPB)', 'eval/downstream/arc_challenge_rc_5shot (BPB)', 'eval/downstream/arc_challenge_rc_5shot (CE loss)', 'eval/downstream/arc_challenge_rc_5shot (length-normalized accuracy)', 'eval/downstream/arc_challenge_rc_5shot (log soft loss)', 'eval/downstream/arc_challenge_rc_5shot (soft loss)', 'eval/downstream/arc_challenge_rc_5shot_bpb (BPB)', 'eval/downstream/arc_easy (BPB)', 'eval/downstream/arc_easy (CE loss)', 'eval/downstream/arc_easy (accuracy)'

In [33]:
{k: run.summary[k] for k in eval_downstream_metrics}

{'eval/downstream/arc_challenge (BPB)': 2.252821922302246,
 'eval/downstream/arc_challenge (CE loss)': 1.5615724325180054,
 'eval/downstream/arc_challenge (length-normalized accuracy)': 0.2040133774280548,
 'eval/downstream/arc_challenge (log soft loss)': -1.416715145111084,
 'eval/downstream/arc_challenge (soft loss)': 0.2532850503921509,
 'eval/downstream/arc_challenge_mc_5shot (BPB)': 4.161789417266846,
 'eval/downstream/arc_challenge_mc_5shot (CE loss)': 2.8847720623016357,
 'eval/downstream/arc_challenge_mc_5shot (accuracy)': 0.21739129722118378,
 'eval/downstream/arc_challenge_mc_5shot (log soft loss)': -1.6992669105529785,
 'eval/downstream/arc_challenge_mc_5shot (soft loss)': 0.23509830236434937,
 'eval/downstream/arc_challenge_mc_5shot_bpb (BPB)': 4.161789417266846,
 'eval/downstream/arc_challenge_rc_5shot (BPB)': 2.071514368057251,
 'eval/downstream/arc_challenge_rc_5shot (CE loss)': 1.4361021518707275,
 'eval/downstream/arc_challenge_rc_5shot (length-normalized accuracy)': 0

In [8]:
eval_tasks = [metric.replace("eval/downstream/", "").split("(")[0].strip() for metric in eval_downstream_metrics]

In [10]:
import numpy as np
eval_tasks = np.unique(np.array(eval_tasks))

In [12]:
# For training runs where we don't expect the model to acquire MC (e.g., 1B-5xC, short 7B training runs)
tasks_small_compute = [
# OLMES Core 9(-ish) RC
"arc_challenge_test_rc_5shot",
"arc_easy_test_rc_5shot",
"hellaswag_rc_5shot", # 1K subset of HellaSwag
"winogrande_val_rc_5shot", # Helpful after 750M-5xC scale
"csqa_val_rc_5shot",
"piqa_val_rc_5shot",
"socialiqa_val_rc_5shot",

# Too noisy to be worth tracking
# "boolq_val_rc_5shot",
# "openbookqa_test_rc_5shot",

# MMLU RC
"mmlu_stem_val_rc_5shot",
"mmlu_humanities_val_rc_5shot",
"mmlu_social_sciences_val_rc_5shot",
"mmlu_other_val_rc_5shot",
"mmlu_stem_test_rc_5shot",
"mmlu_humanities_test_rc_5shot",
"mmlu_social_sciences_test_rc_5shot",
"mmlu_other_test_rc_5shot",

# Gen tasks BPB
"gsm8k_gold_bpb_5shot",
"minerva_math_algebra_gold_bpb_0shot",
"minerva_math_counting_and_probability_gold_bpb_0shot",
"minerva_math_geometry_gold_bpb_0shot",
"minerva_math_intermediate_algebra_gold_bpb_0shot",
"minerva_math_number_theory_gold_bpb_0shot",
"minerva_math_prealgebra_gold_bpb_0shot",
"minerva_math_precalculus_gold_bpb_0shot",
"codex_humaneval_gold_bpb_0shot",
"codex_mbpp_gold_bpb_0shot",

# Sanity check for MCQA ability
"copycolors_10way",
]

# For training runs where we expect the model to acquire MC
tasks_large_compute = [
# OLMES Core 9(-ish) MC
"arc_challenge_test_mc_5shot",
"arc_easy_test_mc_5shot",
"hellaswag_rc_5shot", # 1K subset of HellaSwag
"csqa_val_mc_5shot",
"piqa_val_mc_5shot",
"socialiqa_val_mc_5shot",
"winogrande_val_rc_5shot",

# Too noisy to be worth tracking
# "boolq_val_mc_5shot",
# "openbookqa_test_mc_5shot",

# MMLU MC BPB
"mmlu_stem_val_mc_5shot",
"mmlu_humanities_val_mc_5shot",
"mmlu_social_sciences_val_mc_5shot",
"mmlu_other_val_mc_5shot",
"mmlu_stem_test_mc_5shot",
"mmlu_humanities_test_mc_5shot",
"mmlu_social_sciences_test_mc_5shot",
"mmlu_other_test_mc_5shot",

# Gen tasks BPB
"gsm8k_gold_bpb_5shot",
"minerva_math_algebra_gold_bpb_0shot",
"minerva_math_counting_and_probability_gold_bpb_0shot",
"minerva_math_geometry_gold_bpb_0shot",
"minerva_math_intermediate_algebra_gold_bpb_0shot",
"minerva_math_number_theory_gold_bpb_0shot",
"minerva_math_prealgebra_gold_bpb_0shot",
"minerva_math_precalculus_gold_bpb_0shot",
"codex_humaneval_gold_bpb_0shot",
"codex_mbpp_gold_bpb_0shot",

# Sanity check for MCQA ability
"copycolors_10way",
]

In [17]:
tasks_small_compute = [task.replace("_test", "").replace("_val", "") for task in tasks_small_compute]

In [24]:
sorted(tasks_small_compute)

['arc_challenge_rc_5shot',
 'arc_easy_rc_5shot',
 'codex_humaneval_gold_bpb_0shot',
 'codex_mbpp_gold_bpb_0shot',
 'copycolors_10way',
 'csqa_rc_5shot',
 'gsm8k_gold_bpb_5shot',
 'hellaswag_rc_5shot',
 'minerva_math_algebra_gold_bpb_0shot',
 'minerva_math_counting_and_probability_gold_bpb_0shot',
 'minerva_math_geometry_gold_bpb_0shot',
 'minerva_math_intermediate_algebra_gold_bpb_0shot',
 'minerva_math_number_theory_gold_bpb_0shot',
 'minerva_math_prealgebra_gold_bpb_0shot',
 'minerva_math_precalculus_gold_bpb_0shot',
 'mmlu_humanities_rc_5shot',
 'mmlu_humanities_rc_5shot',
 'mmlu_other_rc_5shot',
 'mmlu_other_rc_5shot',
 'mmlu_social_sciences_rc_5shot',
 'mmlu_social_sciences_rc_5shot',
 'mmlu_stem_rc_5shot',
 'mmlu_stem_rc_5shot',
 'piqa_rc_5shot',
 'socialiqa_rc_5shot',
 'winogrande_rc_5shot']

In [25]:
sorted(eval_tasks)

['arc_challenge',
 'arc_challenge_mc_5shot',
 'arc_challenge_mc_5shot_bpb',
 'arc_challenge_rc_5shot',
 'arc_challenge_rc_5shot_bpb',
 'arc_easy',
 'arc_easy_mc_5shot',
 'arc_easy_mc_5shot_bpb',
 'arc_easy_ppl',
 'arc_easy_rc_5shot',
 'arc_easy_rc_5shot_bpb',
 'basic_arithmetic',
 'boolq',
 'boolq_mc_5shot',
 'boolq_mc_5shot_bpb',
 'boolq_rc_5shot',
 'boolq_rc_5shot_bpb',
 'commonsense_qa',
 'copa',
 'csqa_mc_5shot',
 'csqa_mc_5shot_bpb',
 'csqa_rc_5shot',
 'csqa_rc_5shot_bpb',
 'hellaswag',
 'hellaswag_mc_5shot',
 'hellaswag_mc_5shot_bpb',
 'hellaswag_rc_5shot',
 'hellaswag_rc_5shot_bpb',
 'mmlu_humanities_bpb',
 'mmlu_humanities_mc_5shot',
 'mmlu_humanities_mc_5shot_test',
 'mmlu_humanities_var',
 'mmlu_humanities_var_bpb',
 'mmlu_other_bpb',
 'mmlu_other_mc_5shot',
 'mmlu_other_mc_5shot_test',
 'mmlu_other_var',
 'mmlu_other_var_bpb',
 'mmlu_social_sciences_bpb',
 'mmlu_social_sciences_mc_5shot',
 'mmlu_social_sciences_mc_5shot_test',
 'mmlu_social_sciences_var',
 'mmlu_social_scien

In [23]:
sorted(set(tasks_small_compute).intersection(set(eval_tasks.tolist())))

['arc_challenge_rc_5shot',
 'arc_easy_rc_5shot',
 'csqa_rc_5shot',
 'hellaswag_rc_5shot',
 'piqa_rc_5shot',
 'socialiqa_rc_5shot',
 'winogrande_rc_5shot']