In [92]:
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [93]:
import matplotlib.pyplot as plt 
import seaborn as sns
import pandas as pd 
import numpy as np
import os
import json
from collections import defaultdict

from scipy.stats import spearmanr, pearsonr

from regmixer.synthesize_mixture import calculate_priors
from regmixer.utils import config_from_path
from regmixer.eval.utils import (
    build_regression,
    get_output_dir,
    get_runs_from_api,
    mk_run_from_json,
    mk_run_metrics,
    mk_weights_from_config,
    mk_output_prefix,
    plot_correlation,
    simulate2,
    )


import pathlib
import wandb
import boto3


from regmixer.eval.constants import GroupedWandbMetrics


from regmixer.eval.law import ScalingLaw

In [94]:
config = "src/regmixer/config/dclm-datadelve-5xC-30m-dolma2tok.yaml"
launch_config = config_from_path(config)


priors = calculate_priors(
        source_configs=launch_config.sources,
        dtype=launch_config.dtype,
        use_cache=False,
    )

Counting source tokens: 100%|██████████| 24/24 [00:26<00:00,  1.12s/it]


In [95]:
in_loop_final = {}
offline_final = {}

In [96]:
BASE_CACHE_DIR = "cache/"

#experiment_groups = ["f6600ba5"] #"206f164f", "4318c7a9", "f6600ba5"]
experiment_groups = ["206f164f", "4318c7a9"]

cache_path = pathlib.Path(BASE_CACHE_DIR) / f"{'_'.join(experiment_groups)}_runs_cache.json"
api = wandb.Api() 

workspace = "ai2-llm/regmixer"

num_samples=1

group_metrics= 'all_bpb' # 'hellaswag_v2' #'arc_easy_new' #'mmlu_bpb_new' #'val_loss' #'mmlu_bpb_new'
eval_metric_group = GroupedWandbMetrics[group_metrics]
eval_metric_group_name = group_metrics

run_instances = get_runs_from_api(
        api, workspace, experiment_groups, cache_path, False, num_samples, eval_metric_group
    )
    
run_ratios = [
        {"run": run.display_name, "index": idx, **mk_weights_from_config(run.config, priors)}
        for idx, run in enumerate(run_instances)
    ]


tasks = list(run_instances[0].samples.columns)
if len(tasks) == len(eval_metric_group.value):
    tasks = eval_metric_group.value


run_metrics = []
for idx, run in enumerate(run_instances):
    if len(run.samples) == 0:
        continue 
    entry =    {
        "run": run.display_name,
        "index": idx,
        **mk_run_metrics(
            history=run.samples,
            samples=num_samples,
            metrics=(eval_metric_group_name, tasks),
            average=False,
        ),
    }
    run_metrics.append(entry)

ratios = pd.DataFrame(run_ratios)
metrics = pd.DataFrame(run_metrics)
ratios = ratios[ratios['run'].isin(metrics.run)]

2025-04-10 02:24:12,789 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0009:hsbm57e5 with samples: (1, 21)
2025-04-10 02:24:12,792 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0003:k9k2j1ia with samples: (1, 21)
2025-04-10 02:24:12,795 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0006:lbimptkr with samples: (1, 21)
2025-04-10 02:24:12,797 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0008:j7tpz28p with samples: (1, 21)
2025-04-10 02:24:12,801 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0001:qevd2gzk with samples: (1, 21)
2025-04-10 02:24:12,805 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0007:92o2jfpg with samples: (1, 21)
2025-04-10 02:24:12,808 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-20

finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished


2025-04-10 02:24:12,992 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0042:e7cnja9l with samples: (1, 21)
2025-04-10 02:24:12,996 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0043:c7vtz3tu with samples: (1, 21)
2025-04-10 02:24:13,000 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0045:33fzr1sk with samples: (1, 21)
2025-04-10 02:24:13,008 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0044:lpoqrzf2 with samples: (1, 21)
2025-04-10 02:24:13,012 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0047:fubuc57y with samples: (1, 21)
2025-04-10 02:24:13,015 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0046:mel2x8fy with samples: (1, 21)
2025-04-10 02:24:13,021 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-20

finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished


2025-04-10 02:24:13,201 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0095:dod45l0w with samples: (1, 21)
2025-04-10 02:24:13,205 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0099:c7vnd43j with samples: (1, 21)
2025-04-10 02:24:13,211 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0122:zgbe4zed with samples: (1, 21)
2025-04-10 02:24:13,215 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0124:49kap6rm with samples: (1, 21)
2025-04-10 02:24:13,218 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0100:1bibcneo with samples: (1, 21)
2025-04-10 02:24:13,222 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0101:lelhivv4 with samples: (1, 21)
2025-04-10 02:24:13,227 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-20

finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished


2025-04-10 02:24:13,409 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-5xC-30m-subset-4318c7a9-0027:f0n5bn22 with samples: (1, 21)
2025-04-10 02:24:13,414 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-5xC-30m-subset-4318c7a9-0025:z6sjvy9n with samples: (1, 21)
2025-04-10 02:24:13,419 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-5xC-30m-subset-4318c7a9-0013:4mi0434d with samples: (1, 21)
2025-04-10 02:24:13,425 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-5xC-30m-subset-4318c7a9-0022:2428s0gr with samples: (1, 21)
2025-04-10 02:24:13,431 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-5xC-30m-subset-4318c7a9-0023:fkxevczj with samples: (1, 21)
2025-04-10 02:24:13,435 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-5xC-30m-subset-4318c7a9-0029:23c1kmvq with samples: (1, 21)
2025-04-10 02:24:13,440 - regmixer.eval.utils - INFO - Collected

finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished


In [97]:
def get_offline_results(display_name):
    bucket = 'ai2-llm'
    prefix = f'evaluation/regmixer/{display_name}_step11400-hf/'

    s3 = boto3.client('s3')

    # Get list of all files under the prefix
    paginator = s3.get_paginator('list_objects_v2')
    pages = paginator.paginate(Bucket=bucket, Prefix=prefix)

    json_files = []
    jsonl_files = []

    for page in pages:
        for obj in page.get('Contents', []):
            key = obj['Key']
            if not key.endswith('/'):
                if key.endswith('.jsonl'):
                    jsonl_files.append(key)
                elif key.endswith('.json'):
                    json_files.append(key)

    print(f"Found {len(json_files)} .json files and {len(jsonl_files)} .jsonl files.")

    all_jsonl_data = []

    for key in jsonl_files:
        if key.endswith("metrics-all.jsonl"):
            obj = s3.get_object(Bucket=bucket, Key=key)
            for line in obj['Body'].iter_lines():
                if line:
                    try:
                        data = json.loads(line.decode('utf-8'))
                        all_jsonl_data.append(data)
                    except json.JSONDecodeError as e:
                        print(f"Error decoding JSON line in {key}: {e}")

            print(f"Loaded JSONL: {key}")
            
    offline_results = {}
    for data in all_jsonl_data:

        if 'bits_per_byte_corr' in data['metrics']:
            offline_results[data['task_name']] = data['metrics']['bits_per_byte_corr']
        elif 'bits_per_byte_corr_macro' in data['metrics']:
            offline_results[data['task_name']] = data['metrics']['bits_per_byte_corr_macro']
        else:
            print(data['task_name'])
            print(data['metrics'].keys())

    return offline_results


In [98]:
def get_in_loop_results(metrics, display_name):
    in_loop_results = dict(sorted(metrics[metrics.run == display_name].to_dict().items()))
    return in_loop_results

In [99]:
display_names = ["dclm-datadelve-1xC-30m-206f164f-0114", "dclm-datadelve-1xC-30m-206f164f-0048", "dclm-datadelve-1xC-30m-206f164f-0099"]

In [128]:
display_names = ["dclm-datadelve-5xC-30m-subset-4318c7a9-0019", "dclm-datadelve-5xC-30m-subset-4318c7a9-0054", "dclm-datadelve-5xC-30m-subset-4318c7a9-0059"]

In [80]:
display_names = ["dclm-datadelve-1xC-30m-f6600ba5-0000", "dclm-datadelve-1xC-30m-f6600ba5-0001", "dclm-datadelve-1xC-30m-f6600ba5-0002", "dclm-datadelve-1xC-30m-f6600ba5-0003"]

In [129]:
in_loop_metrics = metrics.columns[2:].tolist()

In [130]:
all_in_loop_results = {}
all_offline_results = {}

for display_name in display_names:
    print(display_name)

    all_in_loop_results[display_name] = get_in_loop_results(metrics, display_name)
    all_offline_results[display_name] = get_offline_results(display_name)

    #in_loop_final[display_name] = np.array([list(value.values())[0] for name, value in in_loop_results.items() if name not in ['run', 'index']]).mean()
    #offline_final[display_name] = np.array(list(offline_results.values())).mean()

dclm-datadelve-5xC-30m-subset-4318c7a9-0019
Found 78 .json files and 228 .jsonl files.
Loaded JSONL: evaluation/regmixer/dclm-datadelve-5xC-30m-subset-4318c7a9-0019_step11400-hf/arc_challenge_rc_olmes_full-arc_easy_rc_olm-2b05f9/metrics-all.jsonl
Loaded JSONL: evaluation/regmixer/dclm-datadelve-5xC-30m-subset-4318c7a9-0019_step11400-hf/codex_humaneval_bpb-gsm8k_olmes-mbpp_bpb-mi-93c5e5/metrics-all.jsonl
Loaded JSONL: evaluation/regmixer/dclm-datadelve-5xC-30m-subset-4318c7a9-0019_step11400-hf/copycolors_10way_mc_none/metrics-all.jsonl
dclm-datadelve-5xC-30m-subset-4318c7a9-0054
Found 78 .json files and 228 .jsonl files.
Loaded JSONL: evaluation/regmixer/dclm-datadelve-5xC-30m-subset-4318c7a9-0054_step11400-hf/arc_challenge_rc_olmes_full-arc_easy_rc_olm-2b05f9/metrics-all.jsonl
Loaded JSONL: evaluation/regmixer/dclm-datadelve-5xC-30m-subset-4318c7a9-0054_step11400-hf/codex_humaneval_bpb-gsm8k_olmes-mbpp_bpb-mi-93c5e5/metrics-all.jsonl
Loaded JSONL: evaluation/regmixer/dclm-datadelve-5xC

In [134]:
for i in range(len(in_loop_metrics)):

    in_loop_task = in_loop_metrics[i] 
    print(in_loop_task)

    offline = [] 
    in_loop = []
    for display_name in display_names:
        #print(display_name)
        offline_results = all_offline_results[display_name]
        in_loop_results = all_in_loop_results[display_name]




        for offline_task, offline_result in offline_results.items():

            if ":" in offline_task:
                continue 

            if offline_task in in_loop_task:
                offline.append(offline_result)
                in_loop.append(list(in_loop_results[in_loop_task].values())[0])
                print(np.abs(offline_result - list(in_loop_results[in_loop_task].values())[0]))


    from scipy.stats import spearmanr, pearsonr
    print(spearmanr(offline, in_loop)[0])

eval/downstream/mmlu_social_sciences_test_rc_5shot (BPB)
nan
eval/downstream/mmlu_humanities_test_rc_5shot (BPB)
nan
eval/downstream/mmlu_other_test_rc_5shot (BPB)
nan
eval/downstream/mmlu_stem_test_rc_5shot (BPB)
nan
eval/downstream/winogrande_val_rc_5shot (BPB)
0.2628160412999778
0.26228690258027254
0.25649253527455174
1.0
eval/downstream/socialiqa_val_rc_5shot (BPB)
0.1422456055122876
0.17007637560975275
0.14881031327864336
1.0
eval/downstream/piqa_val_rc_5shot (BPB)
0.039283609312411816
0.0420669020028579
0.03862640084337521
1.0
eval/downstream/minerva_math_algebra_gold_bpb_0shot (BPB)
0.013625204508043165
0.015425533501968403
0.016213785940850656
1.0
eval/downstream/minerva_math_counting_and_probability_gold_bpb_0shot (BPB)
0.007724095210176252
0.008337260193857166
0.009223778499173108
1.0
eval/downstream/minerva_math_geometry_gold_bpb_0shot (BPB)
0.006049832652074372
0.007186449992383714
0.008792798906896326
1.0
eval/downstream/minerva_math_intermediate_algebra_gold_bpb_0shot (BP

In [132]:
offline

[1.8088391577004221, 1.9688668355057606, 1.8277997576281988]

In [133]:
in_loop

[2.0008933544158936, 2.167832136154175, 2.0140817165374756]

hellaswag eval/downstream/hellaswag_rc_5shot (BPB)
1.3073624167581064 1.3213847875595093


In [34]:
in_loop_final

{'dclm-datadelve-1xC-30m-f6600ba5-0000': 1.2401875257492065,
 'dclm-datadelve-1xC-30m-f6600ba5-0001': 1.2832578420639038,
 'dclm-datadelve-1xC-30m-f6600ba5-0002': 1.232410192489624,
 'dclm-datadelve-1xC-30m-f6600ba5-0003': 1.3213847875595093}

In [35]:
offline_final

{'dclm-datadelve-1xC-30m-f6600ba5-0000': 1.9530376928137323,
 'dclm-datadelve-1xC-30m-f6600ba5-0001': 1.9683715458215059,
 'dclm-datadelve-1xC-30m-f6600ba5-0002': 1.8260503961989123,
 'dclm-datadelve-1xC-30m-f6600ba5-0003': 2.2022111974618457}