In [None]:
import matplotlib.pyplot as plt 
import seaborn as sns
import pandas as pd 
import numpy as np
import os
import json
from collections import defaultdict

from scipy.stats import spearmanr, pearsonr

from regmixer.synthesize_mixture import calculate_priors
from regmixer.utils import config_from_path
from regmixer.eval.utils import (
    build_regression,
    get_output_dir,
    get_runs_from_api,
    mk_run_from_json,
    mk_run_metrics,
    mk_weights_from_config,
    mk_output_prefix,
    plot_correlation,
    simulate2,
    )


import pathlib
import wandb
import boto3


from regmixer.eval.constants import GroupedWandbMetrics


from regmixer.eval.law import ScalingLaw

In [None]:

bucket = 'ai2-llm'
prefix = 'evaluation/regmixer/dclm-datadelve-1xC-30m-206f164f-0114_step11400-hf/'

s3 = boto3.client('s3')

# Get list of all files under the prefix
paginator = s3.get_paginator('list_objects_v2')
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)

json_files = []
jsonl_files = []

for page in pages:
    for obj in page.get('Contents', []):
        key = obj['Key']
        if not key.endswith('/'):
            if key.endswith('.jsonl'):
                jsonl_files.append(key)
            elif key.endswith('.json'):
                json_files.append(key)

print(f"Found {len(json_files)} .json files and {len(jsonl_files)} .jsonl files.")



Found 78 .json files and 228 .jsonl files.


In [6]:
all_jsonl_data = []

for key in jsonl_files:
    if key.endswith("metrics-all.jsonl"):
        obj = s3.get_object(Bucket=bucket, Key=key)
        for line in obj['Body'].iter_lines():
            if line:
                try:
                    data = json.loads(line.decode('utf-8'))
                    all_jsonl_data.append(data)
                except json.JSONDecodeError as e:
                    print(f"Error decoding JSON line in {key}: {e}")

        print(f"Loaded JSONL: {key}")


Loaded JSONL: evaluation/regmixer/dclm-datadelve-1xC-30m-206f164f-0114_step11400-hf/arc_challenge_rc_olmes_full-arc_easy_rc_olm-2b05f9/metrics-all.jsonl
Loaded JSONL: evaluation/regmixer/dclm-datadelve-1xC-30m-206f164f-0114_step11400-hf/codex_humaneval_bpb-gsm8k_olmes-mbpp_bpb-mi-93c5e5/metrics-all.jsonl
Loaded JSONL: evaluation/regmixer/dclm-datadelve-1xC-30m-206f164f-0114_step11400-hf/copycolors_10way_mc_none/metrics-all.jsonl


In [None]:
offline_results = {}
for data in all_jsonl_data:
    print(data['task_name'])
    try:
        print(data['metrics']['bits_per_byte_corr'])
        offline_results[data['task_name']] = data['metrics']['bits_per_byte_corr']
    except KeyError:
        print(data['metrics'])

mmlu:rc::olmes
{'logits_per_token_corr_micro': -5.803046457116754, 'logits_per_token_corr_macro': -6.003525432412653, 'primary_score_micro': 0.24982196268337845, 'primary_score_macro': 0.24397789362373604, 'logits_per_char_corr_micro': -1.4047309685393543, 'logits_per_char_corr_macro': -1.553624989274837, 'acc_per_byte_micro': 0.24982196268337845, 'acc_per_byte_macro': 0.24390116003976514, 'bits_per_byte_corr_micro': 2.008864397102229, 'bits_per_byte_corr_macro': 2.217158337604244, 'acc_per_char_micro': 0.24982196268337845, 'acc_per_char_macro': 0.24397789362373604, 'sum_logits_corr_micro': -44.97763202594838, 'sum_logits_corr_macro': -43.55604781917607, 'acc_uncond_micro': 0.2590086882210511, 'acc_uncond_macro': 0.25957308084745895, 'acc_per_token_micro': 0.25373878364905283, 'acc_per_token_macro': 0.24901827502602275, 'acc_raw_micro': 0.23344252955419456, 'acc_raw_macro': 0.22929394894307456, 'primary_score': 0.24397789362373604, 'extra_metrics': {'no_answer_micro': 0.0, 'no_answer_m

In [16]:
config = "src/regmixer/config/dclm-datadelve-5xC-30m-dolma2tok.yaml"
launch_config = config_from_path(config)


priors = calculate_priors(
        source_configs=launch_config.sources,
        dtype=launch_config.dtype,
        use_cache=False,
    )

Counting source tokens: 100%|██████████| 24/24 [00:32<00:00,  1.35s/it]


In [18]:
BASE_CACHE_DIR = "cache/"

experiment_groups = ["206f164f"]


cache_path = pathlib.Path(BASE_CACHE_DIR) / f"{'_'.join(experiment_groups)}_runs_cache.json"
api = wandb.Api() 

workspace = "ai2-llm/regmixer"

num_samples=1

group_metrics= 'all_bpb' #'arc_easy_new' #'mmlu_bpb_new' #'val_loss' #'mmlu_bpb_new'
eval_metric_group = GroupedWandbMetrics[group_metrics]
eval_metric_group_name = group_metrics

run_instances = get_runs_from_api(
        api, workspace, experiment_groups, cache_path, False, num_samples, eval_metric_group
    )
    
run_ratios = [
        {"run": run.display_name, "index": idx, **mk_weights_from_config(run.config, priors)}
        for idx, run in enumerate(run_instances)
    ]

2025-04-08 23:25:43,568 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0009:hsbm57e5 with samples: (1, 21)
2025-04-08 23:25:43,571 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0003:k9k2j1ia with samples: (1, 21)
2025-04-08 23:25:43,574 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0006:lbimptkr with samples: (1, 21)
2025-04-08 23:25:43,576 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0008:j7tpz28p with samples: (1, 21)
2025-04-08 23:25:43,580 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0001:qevd2gzk with samples: (1, 21)
2025-04-08 23:25:43,583 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0007:92o2jfpg with samples: (1, 21)
2025-04-08 23:25:43,587 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-20

finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished


2025-04-08 23:25:43,776 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0033:gkfm7f2n with samples: (1, 21)
2025-04-08 23:25:43,780 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0039:c7miu1ei with samples: (1, 21)
2025-04-08 23:25:43,787 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0036:by67gl2y with samples: (1, 21)
2025-04-08 23:25:43,792 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0040:zq9jw5bj with samples: (1, 21)
2025-04-08 23:25:43,797 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0041:q5dhzgbl with samples: (1, 21)
2025-04-08 23:25:43,801 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0042:e7cnja9l with samples: (1, 21)
2025-04-08 23:25:43,806 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-20

finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished


2025-04-08 23:25:43,981 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0094:6nz4jmao with samples: (1, 21)
2025-04-08 23:25:43,985 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0092:oljat221 with samples: (1, 21)
2025-04-08 23:25:43,988 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0091:n52v6vgn with samples: (1, 21)
2025-04-08 23:25:43,992 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0096:b2cmlop0 with samples: (1, 21)
2025-04-08 23:25:43,995 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0095:dod45l0w with samples: (1, 21)
2025-04-08 23:25:43,998 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-206f164f-0099:c7vnd43j with samples: (1, 21)
2025-04-08 23:25:44,002 - regmixer.eval.utils - INFO - Collected RunInstance for dclm-datadelve-1xC-30m-20

finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished
finished


In [22]:
run_metrics = []
for idx, run in enumerate(run_instances):
    if len(run.samples) == 0:
        continue 
    entry =    {
        "run": run.display_name,
        "index": idx,
        **mk_run_metrics(
            history=run.samples,
            samples=num_samples,
            metrics=(eval_metric_group_name, eval_metric_group.value),
            average=False,
        ),
    }
    run_metrics.append(entry)


In [23]:
ratios = pd.DataFrame(run_ratios)
metrics = pd.DataFrame(run_metrics)
ratios = ratios[ratios['run'].isin(metrics.run)]