In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import json
from collections import defaultdict

from regmixer.utils import config_from_path
from regmixer.eval.utils import (
    calculate_priors_with_manual,
    )

In [None]:
# replace with your swarm config
config = "src/regmixer/config/superswarm_conditional_dclm_stackedu.yaml"
launch_config = config_from_path(config)


# priors[1][0] will contain a dictionary with the prior at the leaf-level.
priors = calculate_priors_with_manual(
        source_configs=launch_config.sources,
        dtype=launch_config.dtype,
        use_cache=True,
        manual_prior=launch_config.manual_prior if hasattr(launch_config, 'manual_prior') else None,
        fixed_source_weights=launch_config.fixed_source_weights if hasattr(launch_config, 'fixed_source_weights') else None,
    )

# Format proposed mixes

Note: this notebook currently just supports the natural distribution and mixes produced by regmixer. Doesn't support other hardcoded mixes yet (but we can easily add that - load in the mix from yaml)

In [None]:
def get_source_mix(mix):
    # extracts source-level mix from a leaf-level mix
    source_level = defaultdict(float)
    for source, weight in mix.items():
        if source.startswith("dclm:"):
            source_level["dclm"] += weight
        elif source.startswith("s2pdf:"):
            source_level["s2pdf"] += weight
        elif source.startswith("stack-edu:"):
            source_level["stack-edu"] += weight
        elif source.startswith("pes2o:"):
            source_level["pes2o"] += weight
        else:
            source_level[source] += weight
    return source_level

def get_topic_mix(source, mix):
    # extracts normalized topic-level mix for the source specified
    topic_weights = {k: v for k, v in mix.items() if k.startswith(source + ":")}
    total_weight = sum(topic_weights.values())
    normalized_weights = {k: v / total_weight for k, v in topic_weights.items()}

    return normalized_weights

In [None]:
def plot_mixes(mix_dict, desc):
    # constructs a bar chart of all mixes
    # mix_dict maps from a mix name to a mix, where mix = dictionary of domains:weights
    # desc is the title of the plot
    mix_dict = dict(sorted(mix_dict.items()))

    # Step 2: Collect all unique keys
    all_keys = sorted(set().union(*[d.keys() for d in mix_dict.values()]))

    # Step 3: Extract values for each dict, aligning with `all_keys`
    values_matrix = [
        [d.get(k, 0) for k in all_keys]
        for d in mix_dict.values()
    ]

    # Step 4: Plot grouped bars with annotations
    n_dicts = len(mix_dict)
    x = np.arange(len(all_keys))
    bar_width = 0.8 / n_dicts

    fig, ax = plt.subplots(figsize=(20, 8))

    for i, (label, values) in enumerate(zip(mix_dict.keys(), values_matrix)):
        offset = (i - (n_dicts - 1) / 2) * bar_width
        bars = ax.bar(x + offset, values, width=bar_width, label=label)

        # Annotate each bar
        for bar in bars:
            height = bar.get_height()
            if height > 0:
                ax.text(
                    bar.get_x() + bar.get_width() / 2,
                    height + 0.01,               # use this only for vertical position
                    f"{height:.3f}",             # display the correct value
                    ha='center',
                    va='bottom',
                    fontsize=10,
                    rotation=90
                )


    # Step 5: Format the plot
    ax.set_xticks(x)
    ax.set_xticklabels(all_keys, rotation=90, fontsize=16)
    ax.set_ylabel("Value")
    ax.set_title(desc, fontsize=20)
    ax.legend(title="Source", fontsize=16)
    plt.tight_layout()
    plt.show()


Load in data from rmc-eval output

In [None]:
path = "output/0cb55cb5/REPLACE_THIS/olmo3_offline_tasks_log_linear_reg_1_samples_optimal.json"
with open(path, "r") as f:
    unconstrained_mix = json.load(f)

unconstrained_mix = {m['domain']: m['weight'] for m in unconstrained_mix}

In [None]:
# this folder contains all the proposed mixes where we set --repetition-factor
folder = "output/0cb55cb5/?????"
runs = os.listdir(folder)

sweep_repetition_factor = {}
for run in runs:
    config_path = os.path.join(folder, run, "config.json")
    mix_path = os.path.join(folder, run, "olmo3_offline_tasks_log_linear_reg_1_samples_optimal.json")

    with open(config_path, "r") as f:
        config = json.load(f)

    if 'repetition_factor' in config:
        repetition_factor = config['repetition_factor']
    else:
        repetition_factor = 1

    print(mix_path, repetition_factor)

    with open(mix_path, "r") as f:
        mix = json.load(f)

    sweep_repetition_factor[str(repetition_factor)] = {m['domain']: m['weight'] for m in mix}


In [None]:
all_pes2o_mixes = {}
for rep, mix in sweep_repetition_factor.items():
    dclm_mix = get_topic_mix("pes2o", mix)
    all_pes2o_mixes[rep] = dclm_mix

all_pes2o_mixes['unconstrained'] = get_topic_mix("pes2o", unconstrained_mix)
all_pes2o_mixes['manual prior'] = get_topic_mix("pes2o", priors[1][0])
all_s2pdf_mixes = {}

for rep, mix in sweep_repetition_factor.items():
    dclm_mix = get_topic_mix("s2pdf", mix)
    all_s2pdf_mixes[rep] = dclm_mix

all_s2pdf_mixes['unconstrained'] = get_topic_mix("s2pdf", unconstrained_mix)
all_s2pdf_mixes['manual prior'] = get_topic_mix("s2pdf", priors[1][0])

In [None]:
plot_mixes(all_s2pdf_mixes, "s2pdf")
plot_mixes(all_pes2o_mixes, "pes2o")