# What's in a developmental phase? Training dynamics & Behavioural characterizations of grammar learning
*Authors: Oskar van der Wal & Marianne de Heer Kloots*

In this notebook, we are going to explore the grammar learning dynamics of a *decoder-only* language model (LM): Pythia-160m.
The *Pythia* model suite is interesting, because it provides us with intermediate checkpoints during training (Biderman et al., 2023). Moreover, for the smaller sized models, we'll also have access to 10 different seeds—which change the random initialization of the model parameters at the start as well as the order of the training data—over the same training run.

We'll use the performance on the BLiMP dataset as a way to quantify the model's grammar capabilities. To save you time, we have prepared the results on BLiMP for 24 checkpoints in advance. The first part of this notebook shows examples of how to load and visualize the BLiMP results.

The second part will show how Latent State Models trained on the internal parts of the LMs (i.e., parameters of the weight and bias matrices) can be used to find distinct phases during training (Hu et al., 2023). It would be really interesting if low-level mathematical features of the model tells something about higher-level capabilities, but it's not clear whether that should be the case! Do these phases overlap with phases in Pythia's grammar learning?

**Relevant papers:**
- 📄 [BLiMP (Warstadt et al., 2020)](https://aclanthology.org/2020.tacl-1.25.pdf)
- 📄 [Pythia (Biderman et al., 2023)](https://proceedings.mlr.press/v202/biderman23a/biderman23a.pdf)
- 📄 [Latent State Models (Hu et al., 2023)](https://arxiv.org/html/2308.09543v3)

© 2024 The authors.

## 0-Setup
You can use the following code to set things up + install the required dependencies when running on Colab.

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")

    # Note: because Altair V4 instead of V5 is used in Colab, some settings
    # will be different in the notebook.

    # Installing some dependencies
    !pip3 install altair hmmlearn

    import altair as alt
    alt.data_transformers.disable_max_rows()

    !wget https://raw.githubusercontent.com/ANN-HumLang/ANN-HumLang-tutorials/main/pythia/training_map.py
    MATRIX_METRICS_PATH="https://raw.githubusercontent.com/ANN-HumLang/ANN-HumLang-tutorials/main/pythia/matrix_metrics_160m.tsv"
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook")

    import altair as alt
    alt.data_transformers.enable("vegafusion")
    MATRIX_METRICS_PATH="matrix_metrics_160m.tsv"

## 1-Loading BLiMP results
Let's have a look at the dataframe containing all BLiMP results for the different Pythia 160m training checkpoints.

In [None]:
import pandas as pd

blimp = pd.read_csv("https://raw.githubusercontent.com/ANN-HumLang/ANN-HumLang-tutorials/main/pythia/blimp_160m.tsv", sep="\t")
blimp = blimp[~(blimp["field"] == "Aggregate")]
blimp

BLiMP consists of sentence-pairs that consist of a correct vs. an incorrect variant. The accuracy on the grammar tasks are measured by testing whether the LM assigns a higher probability (logit) to the correct sentence. For example (from `determiner_noun_agreement`):
> Craig explored that grocery **store**.

vs.

> Craig explored that grocery **stores**.

These sentence-pairs are grouped into 67 *paradigms*. These paradigms are grouped into different *phenomena*, which in turn are part of one of the 4 *fields*: morphology, syntax, semantics, or syntaxsemantics.

In [None]:
print("-"*10)
print("Information about the BLiMP dataset:")
print("Number of paradigms: {}".format(len(blimp.paradigm.unique())))
print("Number of phenomena: {}".format(len(blimp.phenomenon.unique())))
print("Number of fields: {}".format(len(blimp.field.unique())))

print("-"*10)
print("Information about the Pythia checkpoints:")
print("Number of seeds: {}".format(len(blimp.seed.unique())))
print("Number of steps: {}".format(len(blimp.step.unique())))
print("First step: {}".format(blimp.step.min()))
print("Last step: {}".format(blimp.step.max()))

In case you need the full list of all paradigms, phenomena, or fields, you can use the following:

In [None]:
print(list(blimp.paradigm.unique()))
print(list(blimp.phenomenon.unique()))
print(list(blimp.field.unique()))

Since the results dataframe contains both accuracy and std for each BLiMP paradigm, we'll separate these for convenience. We'll also remove the std from `blimp`, so you can continue using this short-hand.

In [None]:
blimp_acc   = blimp[blimp["metric"] == "acc"]
blimp_std   = blimp[blimp["metric"] == "std"]

blimp       = blimp[blimp["metric"] == "acc"]
blimp

## 3-Visualizing the BLiMP results
In the following examples, Altair is used for visualizing the BLiMP results. It is possible to pass a dataframe directly, but using a URL instead reduces the size of the notebook considerably.

Of course, feel free to use the visualization library you're most comfortable with!




In [None]:
BLIMP_URL = "https://raw.githubusercontent.com/ANN-HumLang/ANN-HumLang-tutorials/main/pythia/blimp_160m_acc.csv"

In [None]:
import altair as alt

def get_rule_selecting_step(blimp, y_min=0.0):
    """Helper function used for visualizing the step at your mouse's cursor in the figure."""

    # Create a selection that chooses the nearest point & selects based on x-value
    try:
        nearest = alt.selection_point(nearest=True, on="pointerover",
                              fields=["step"], empty=False)
    except AttributeError:
        nearest = alt.selection(type='single', nearest=True, on='mouseover',
                        fields=['step'], empty='none')

    # Transparent selectors across the chart. This is what tells us
    # the x-value of the cursor
    try:
        selectors = alt.Chart(blimp).mark_point().encode(
            x="step:Q",
            opacity=alt.value(0),
        ).add_params(
            nearest
        )
    except AttributeError:
        selectors = alt.Chart(blimp).mark_point().encode(
            x="step:Q",
            opacity=alt.value(0),
        ).add_selection(
            nearest
        )

    # Draw a rule at the location of the selection
    rules = alt.Chart(blimp).mark_rule(color="gray").encode(
        x="step:Q",
    ).transform_filter(
        nearest
    )

    # Draw text labels near the points, and highlight based on selection
    text = rules.mark_text(align="left", dx=5, dy=-5).encode(
        text=alt.condition(nearest, "step:Q", alt.value(" ")), y=alt.datum(y_min)
    )

    return selectors, rules, text

In [None]:
line = alt.Chart(BLIMP_URL).mark_line(opacity=0.5).encode(
    x=alt.X('step:Q',scale=alt.Scale(type="sqrt")),
    y=alt.Y('mean(score):Q', scale=alt.Scale(domain=[0.5, 0.85]), title="BLiMP (mean)"),
    color=alt.Color("seed:N"),
    #tooltip=["seed","step","mean(score)"],
)

selectors, rules, text = get_rule_selecting_step(blimp, y_min=0.5)

alt.layer(
    line, selectors, rules, text
).properties(
    width=600, height=300
)

In [None]:
alt.Chart(BLIMP_URL).mark_line(opacity=0.5).encode(
    x=alt.X('step:Q',scale=alt.Scale(type="sqrt")),
    y=alt.Y('score:Q'),
    color=alt.Color("seed:N"),
).facet(facet="paradigm:N", columns=7)

In [None]:
# If you only want to look at one paradigm more closely
paradigm="blimp_superlative_quantifiers_2"
y_min=0.3
y_max=0.85
title=f"{paradigm} (mean)"

base = alt.Chart(BLIMP_URL)

bands = base.mark_errorband(extent="ci").encode(
    x="step:Q",
    y=alt.Y('score:Q', scale=alt.Scale(domain=[y_min, y_max])),
)

line = base.mark_line(opacity=0.5).encode(
    x=alt.X('step:Q',scale=alt.Scale(type="sqrt")),
    y=alt.Y('mean(score):Q', scale=alt.Scale(domain=[y_min, y_max])),
)

selectors, rules, text = get_rule_selecting_step(blimp, y_min=y_min)

# Put the five layers into a chart and bind the data
alt.layer(
    bands, line, selectors, rules, text
).transform_filter(
    f'datum.paradigm == {paradigm}'
).properties(
    width=600, height=300, title=title
)

In [None]:
line = alt.Chart(blimp).mark_line(opacity=0.8).encode(
    x=alt.X('step:Q',scale=alt.Scale(type="sqrt")),
    y=alt.Y('mean(score):Q'),
    color=alt.Color("phenomenon:N"),
)

selectors, rules, text = get_rule_selecting_step(blimp)

alt.layer(
    line, selectors, rules, text
).properties(
    width=600, height=300
)

In [None]:
lines = alt.Chart(blimp).mark_line(opacity=0.8).encode(
    x=alt.X('step:Q',scale=alt.Scale(type="sqrt")),
    y=alt.Y('mean(score):Q', scale=alt.Scale(domain=[0.5, 1])),
    color=alt.Color("field:N"),
)

selectors, rules, text = get_rule_selecting_step(blimp, y_min=0.5)

# Put the five layers into a chart and bind the data
alt.layer(
    lines, selectors, rules, text
).properties(
    width=600, height=300
)

## 4-Training Latent State Models

Now that we have explored the progression of the various BLiMP phenomena during training, we'll now check out another tool in our training dynamics tool-box: Latent State Models. Following Hu et al. (2023), we are going to test whether these models, induced on how general model matrix properties change during training (e.g., the mean of the weight matrices, $\mu_w$), can help us understand the different phases during training.

The first step in inducing an HMM, is selecting the right number of states. We'll select $N$ such that it minimizes AIC and BIC, and maximizes the log-likelihood (LL).

In [None]:
import sys

sys.path.insert(0, '.')

In [None]:
from training_map import HMMTrainingMapSelection

matrix_metrics = pd.read_csv(MATRIX_METRICS_PATH, sep="\t")
TS = HMMTrainingMapSelection(matrix_metrics)
TS.show_model_selection()

Based on the graph above, we'll select $N=3$ hidden states, which gives us the following diagram.

In [None]:
# Select N=3 components from model selection plot above
training_map = TS.get_training_map(3)
training_map.show()

Hu et al. (2023) use the distribution over the states the models visit during training ("Bag of States") as one way to compare across seeds. For our Pythia-160m model, we find that these are very similar! So there is not much variation between the different seeds. But perhaps these differences reflect in important differences between the models?

In [None]:
training_map.bag_of_states_distributions # shape = (seed x state)

Check out the slides/paper for a description of the metrics used for training the HMM Latent State Model, but here you can find an overview as well.

In [None]:
training_map.data.metrics

How do the states changes over time? Let's plot these on top of one of the model matrix properties.

Note that we have pre-computed the model metrics for all training checkpoints, instead of only a selection as we did for BLiMP!

In [None]:
training_map.show_training_states("trace")

## 5-Labelling Checkpoints
As we've seen above, we can use the latent state models to get a "training map" that labels each checkpoint according to one of the three states.

In [None]:
# Get df from (seed, step) -> state
training_map.labeled_checkpoints.set_index(["seed","step"])

In [None]:
def reformat_blimp(df, training_map, column="paradigm"):
    """
    Reformat BLiMP dataframe, and add state labels as new column.

    column options ['paradigm', 'phenomenon', 'field']"""
    assert column in ['paradigm', 'phenomenon', 'field']

    if not column=="paradigm":
        df = df[["seed", "step", "score", column]].groupby(["seed", "step", column]).mean().reset_index()
    df_ = df.pivot(index=["seed", "step"], columns=[column], values=["score"]).score.reset_index()
    df_.columns.name = None
    df_ = df_.set_index(["seed","step"])
    df_["average"] = df_.mean(axis=1)

    cols = list(df_.columns)

    # Add state for each checkpoints
    labeled_checkpoints = training_map.labeled_checkpoints.set_index(["seed","step"])
    df_ = pd.merge(df_, labeled_checkpoints, left_index=True, right_index=True)
    return df_[["state",]+cols].reset_index()

paradigms_labeled = reformat_blimp(blimp, training_map, column="paradigm")
paradigms_labeled

In [None]:
def plot_labeled_ckpts_blimp(data, metric):
    line = alt.Chart(data).mark_line().encode(
        x=alt.X('step:Q',scale=alt.Scale(type="sqrt"), axis=alt.Axis(title="step")),
        y=alt.Y(metric+':Q', scale=alt.Scale(zero=False), axis=alt.Axis(title=metric)),
    )

    dots = alt.Chart(data).mark_circle(size=100).encode(
        x=alt.X('step:Q',scale=alt.Scale(type="sqrt"), axis=alt.Axis(title="step")),
        y=alt.Y(metric+':Q', scale=alt.Scale(zero=False)),
        color=alt.Color("state:N", scale=alt.Scale(range=training_map.plot_config["state_colors"])),
    )

    selectors, rules, text = get_rule_selecting_step(blimp)

    # rules = alt.Chart(pd.DataFrame({
    #     'step': [10000],
    #     'color': ['red']
    #     })).mark_rule().encode(
    #     x='step',
    #     color=alt.Color('color:N', scale=None)
    #     )

    return (line+dots+selectors+rules+text).facet(column="seed")

plot_labeled_ckpts_blimp(paradigms_labeled, "blimp_adjunct_island")

In [None]:
phenomena_labeled = reformat_blimp(blimp, training_map, column="phenomenon")

# print(paradigms_labeled.phenomenon.unique())
# Choose from:
# ['Anaphor agreement' 'Argument structure' 'Binding'
# 'Control/raising' 'Determiner-Noun agreement' 'Ellipsis' 'Filler gap'
# 'Irregular forms' 'Island effects' 'NPI licensing' 'Quantifiers'
# 'Subject-Verb agreement']

plot_labeled_ckpts_blimp(phenomena_labeled, "Island effects")

In [None]:
fields_labeled = reformat_blimp(blimp, training_map, column="field")

plots = []
for field in ['morphology','semantics','syntax','syntaxsemantics']:
    plots.append(plot_labeled_ckpts_blimp(fields_labeled, field))
alt.vconcat(*plots)