In [1]:
# Requires Python 3.10 or higher
# "plotnine>=0.13.6",
# "polars>=1.5.0,<1.7.0",
# "srsly>=2.4.8",
from pathlib import Path
import polars as pl

import plotnine as pn
import srsly
import json
import numpy as np

In [2]:
path = Path("data/pythia-results")  # these are available at https://huggingface.co/datasets/EleutherAI/polypythias-evals
plot_path = Path("plots")
plot_path.mkdir(exist_ok=True)

In [3]:
fl_df = (
    pl.DataFrame({"filepath": [str(i) for i in path.rglob("*.json")]})
    .with_columns(
        timestamp=pl.col("filepath").str.extract(r"results_(\d{4}-\d{2}-\d{2}T\d{2}-\d{2}-\d{2}.\d{6})").str.to_datetime(r"%Y-%m-%dT%H-%M-%S%.f"),
        seed=pl.col("filepath").str.extract(r"seed(\d)").cast(pl.Int64),
        model_size=pl.col("filepath").str.extract(r"pythia-(\d+m)").str.to_uppercase().cast(pl.Enum(["14M", "31M", "70M", "160M", "410M"])),
        step=pl.col("filepath").str.extract(r"step(\d+)").cast(pl.Int64),
    )
)

In [4]:
res = []
for filepath in path.rglob("*.json"):
    with filepath.open("r") as fl:
        file = json.load(fl)
        for r in file["results"].values():
            res.append({"filepath": str(filepath), **r})

In [5]:
df = (
    pl.from_dicts(res)
    .unpivot(index=["filepath", "alias"], variable_name="metric")
    .drop_nulls()
    .filter(pl.col("value").str.contains(r"\d+"))
    .with_columns(
        pl.col("value").cast(pl.Float64),
        pl.col("metric").str.replace(",none", "")
    )
    .rename({"alias": "dataset"})
    # Bring metadata
    .join(fl_df, on="filepath")
    # If duplicates, pick the latest based on timestamp
    .with_columns(max_timestamp=pl.col("timestamp").max().over(["dataset", "metric", "seed", "model_size", "step"]))
    .filter(pl.col("timestamp") == pl.col("max_timestamp")) 
    .drop("max_timestamp")
)

In [None]:
check = df.group_by(["dataset", "metric"]).agg(
    pl.col("model_size").n_unique(),
    pl.col("seed").n_unique(),
    pl.col("step").n_unique(),
)

print(f"{check['seed'].unique().item() == 10=}")
print(f"{check['model_size'].unique().item() == 5=}")

print("\nSome datasets have less than 27 steps")
print(f"{check['step'].unique() == 27=}")
print(check.filter(pl.col("step")<27)["dataset"].unique())

In [9]:
dataset_map = {
    'arc_challenge': 'ARC (Challenge)',
    'arc_easy': "ARC (Easy)",
    'bbq': 'BBQ',
    'bbq_ambig': 'bbq_ambig',
    'bbq_disambig': 'bbq_disambig',
    'blimp_anaphor_gender_agreement': 'BLiMP (Gender Agreement)',
    'crows_pairs_english': 'CrowS-Pairs',
    'crows_pairs_english_age': 'CrowS-Pairs (Age)',
    'crows_pairs_english_autre': 'crows_pairs_english_autre',
    'crows_pairs_english_disability': 'crows_pairs_english_disability',
    'crows_pairs_english_gender': 'CrowS-Pairs (Gender)',
    'crows_pairs_english_nationality': 'crows_pairs_english_nationality',
    'crows_pairs_english_physical_appearance': 'crows_pairs_english_physical_appearance',
    'crows_pairs_english_race_color': 'crows_pairs_english_race_color',
    'crows_pairs_english_religion': 'crows_pairs_english_religion',
    'crows_pairs_english_sexual_orientation': 'crows_pairs_english_sexual_orientation',
    'crows_pairs_english_socioeconomic': 'crows_pairs_english_socioeconomic',
    'lambada_openai': 'lambada_openai',
    'logiqa': 'logiqa',
    'piqa': 'piqa',
    'realtoxicityprompts_tiny': 'realtoxicityprompts_tiny',
    'sciq': 'SciQ',
    'simple_cooccurrence_bias': 'Simple Co-occurrence Bias',
    'winogrande': 'winogrande',
    'wsc': 'wsc',
}

cmap = {
    "14M": "#003049",
    "31M": "#2a9d8f",
    "70M": "#d62828",
    "160M": "#9b5de5",
    "410M": "#ff9f1c",
}

### Performance

In [10]:
dd = (
    pl.concat([
        (
            pl.read_csv("data/self_consistency.tsv", separator="\t")
            .rename({"kappa_w_last_step": "value", "benchmark": "dataset", "size": "model_size"})
            .select(["step", "seed", "value", "dataset", "model_size"])
            .with_columns(metric=pl.lit("Self-Consistency"))
        ),
        (
            pl.read_csv("data/inter_seed.tsv", separator="\t")
            .rename({"kappa_w_0": "value", "benchmark": "dataset", "size": "model_size"})
            .select(["step", "seed", "value", "dataset", "model_size"])
            .with_columns(metric=pl.lit("Inter-Seed Agreement"))
        ),

    ])
    .with_columns(
        model_size=(
            pl.col("model_size").map_elements(lambda x: f"{x}M", return_dtype=pl.String)
            .cast(pl.Enum(["14M", "31M", "70M", "160M", "410M"]))
        ),
    )
)

In [11]:
plot_dataset = ["arc_easy", "sciq"]
pdata = (
    pl.concat([
        dd.filter(pl.col("dataset").is_in(plot_dataset)),
        (
            df
            .select(dd.columns)
            .filter((pl.col("dataset").is_in(plot_dataset)) & (pl.col("metric") == "acc"))
            .with_columns(pl.col("metric").replace({"acc": "Accuracy"}))
        )
    ])
    .group_by(["dataset", "model_size", "step", "metric"])
    .agg(
        median=pl.col("value").median(),
        q25=pl.col("value").quantile(.25),
        q75=pl.col("value").quantile(.75),
    )
    .with_columns(pl.col("dataset").cast(pl.String).replace(dataset_map))
)

In [None]:
p = (
    pn.ggplot(pdata, pn.aes("step", "median", fill="model_size"))
    + pn.geom_line(pn.aes(colour="model_size"))
    + pn.geom_point(pn.aes(colour="model_size"), size=.8)
    + pn.geom_ribbon(pn.aes(ymin="q25", ymax="q75"), alpha=0.3)
    + pn.facet_wrap(["dataset", "metric"], scales="free_y")
    + pn.scale_x_log10(
        breaks=[1, 10, 100, 1000, 10_000, 100_000],
        labels=lambda x: [f"$10^{np.log10(v):.0f}$" if v > 0 else "0" for v in x]
    )
    + pn.scale_y_continuous(labels=lambda x: [f"{v:.2f}" if v > 0 else "0" for v in x])
    + pn.scale_colour_manual(cmap)
    + pn.scale_fill_manual(cmap)
    + pn.labs(x="Checkpoints Across Training", y="", colour="", fill="")
    + pn.theme_bw(base_size=12)
    + pn.theme(
        plot_margin=0.005, 
        plot_background=None, 
        legend_box_spacing=0.005, 
        legend_box_margin=0, 
        figure_size=(8, 4), 
        legend_position="top"
    )
    + pn.guides(colour=pn.guide_legend(nrow=1))
)
p.show()
p.save(plot_path / "performance.pdf")

### Gender bias

In [17]:
plot_dataset = ["blimp_anaphor_gender_agreement", "crows_pairs_english_gender", "simple_cooccurrence_bias"]
metrics = ["acc", "pct_stereotype", "pct_male_preferred"]

In [20]:
pdata = (
    df
    .filter((pl.col("dataset").is_in(plot_dataset)) & (pl.col("metric").is_in(metrics)))
    .group_by(["dataset", "model_size", "step"])
    .agg(
        median=pl.col("value").median(),
        q25=pl.col("value").quantile(.25),
        q75=pl.col("value").quantile(.75),
    )
    .with_columns(pl.col("dataset").cast(pl.String).replace(dataset_map))
)

In [None]:
p = (
    pn.ggplot(pdata, pn.aes("step", "median", fill="model_size"))
    + pn.geom_line(pn.aes(colour="model_size"))
    + pn.geom_point(pn.aes(colour="model_size"), size=.8)
    + pn.geom_ribbon(pn.aes(ymin="q25", ymax="q75"), alpha=0.3)
    + pn.facet_wrap("dataset", ncol=3, scales="free_y")
    + pn.scale_x_log10(
        breaks=[1, 10, 100, 1000, 10_000, 100_000],
        labels=lambda x: [f"$10^{np.log10(v):.0f}$" if v > 0 else "0" for v in x]
    )
    + pn.scale_y_continuous(labels=lambda x: [f"{v:.2f}" if v > 0 else "0" for v in x])
    + pn.scale_colour_manual(cmap)
    + pn.scale_fill_manual(cmap)
    + pn.labs(x="Checkpoints Across Training", y="", colour="", fill="")
    + pn.theme_bw(base_size=12)
    + pn.theme(
        plot_margin=0.005, 
        plot_background=None, 
        legend_box_spacing=0.005, 
        legend_box_margin=0, 
        figure_size=(8, 2.5), 
        legend_position="top"
    )
    + pn.guides(colour=pn.guide_legend(nrow=1))
)
p.show()
p.save(plot_path / "bias.pdf")

### Training maps

In [30]:
cmap = {str(k): v for (k, v) in enumerate(['#377eb8', '#ff7f00', '#4daf4a', '#f781bf', '#a65628'])}

In [None]:
tmap = (
    pl.read_csv(path.parent / "training_maps.tsv", separator="\t")
    .with_columns(
        outlier=pl.when(pl.col("seed").is_in([3, 4])).then(pl.lit("Outlier")).otherwise(pl.lit("Stable")),
        seed=pl.col("seed").cast(pl.String),
        state=pl.col("state").cast(pl.String),
        shaded_area_start=pl.col("step").min().over(["state", "seed"]),
        shaded_area_end=pl.col("step").max().over(["state", "seed"]),
    )
    # Sort by state, seed, and step to ensure the order is correct
    .sort(["state", "seed", "step"])
    # Adjust shaded_area_end to avoid gaps
    .with_columns(
        shaded_area_other=pl.col("shaded_area_start").sort().shift(-1).over(["seed"]),
    )
    .with_columns(
        shaded_area_end=pl.when(pl.col("shaded_area_other") > pl.col("shaded_area_end")).then("shaded_area_other").otherwise("shaded_area_end")
    )
    .drop("shaded_area_other")
    .with_columns(
        shaded_area_start=pl.when(pl.col("seed") == "0").then("shaded_area_start").otherwise(None),
        shaded_area_end=pl.when(pl.col("seed") == "0").then("shaded_area_end").otherwise(None).max().over(["seed", "state"]),
    )
)

rect = (
    tmap
    .select(["state", "seed", "shaded_area_start", "shaded_area_end", "outlier"])
    .unique()
)

tmap

In [None]:
p = (
    pn.ggplot(tmap, pn.aes(shape="seed"))
    + pn.geom_line(pn.aes("step", "score"), alpha=0.2)
    + pn.geom_point(pn.aes("step", "score", colour="state"), size=2)
    + pn.facet_grid(rows="outlier")
    + pn.geom_rect(rect, pn.aes(ymin=-np.inf, ymax=np.inf, xmin="shaded_area_start", xmax="shaded_area_end", fill="state"), alpha=0.15)
    + pn.scale_x_sqrt(
        breaks=[0, 1000, 3000, 13000, 33000, 53000, 83000, 123000, 143000],
        labels=lambda x: [f"{v / 1000:.0f}k" if v > 0 else "0" for v in x]
    )
    + pn.scale_colour_manual(cmap, guide=None)
    + pn.scale_fill_manual(cmap, guide=None)
    + pn.scale_shape_manual(["o"] * 9, guide=None)
    + pn.labs(x="Checkpoints Across Training", y="", colour="", fill="")
    # Theme
    + pn.theme_bw(base_size=12)
    + pn.theme(
        plot_margin=0.005, 
        plot_background=None, 
        legend_box_spacing=0.005, 
        legend_box_margin=0, 
        figure_size=(4.5, 4), 
        legend_position="top"
    )
    + pn.guides(fill=pn.guide_legend(nrow=1), colour=pn.guide_legend(nrow=1))
)
p.show()
p.save(plot_path / "training_maps.pdf")