# Understanding bracket closing in GPT-Neo

The goal of this notebook is to explore the phenomenon of bracket closing in the [GPT-Neo 125M model](https://www.eleuther.ai/artifacts/gpt-neo), whereby it can correctly match open parentheses `([{<` with their corresponding closing versions `)]}>`.

This is [Problem 2.13](https://www.alignmentforum.org/s/yivyHaCAmMJ3CqSyj/p/XNjRwEX9kxbpzWFWd#block71) in Neel Nanda's [200 Concrete Open Problems in Mechanistic Interpretability](https://www.alignmentforum.org/posts/LbrPTJ4fmABEdEnLf/200-concrete-open-problems-in-mechanistic-interpretability). The first goal is to figure out how the model determines whether an opening or closing bracket is more appropriate, and the second is to figure out how it knows the correct kind: `(`, `[`, `{` or `<`.

I'm using the [TransformerLens library](https://github.com/neelnanda-io/TransformerLens), and a lot of this notebook is copied from Neel's [Exploratory Analysis notebook](https://neelnanda.io/exploratory-analysis-demo).

This notebook lives in my [mechanistic interpretability GitHub repository](https://github.com/SamAdamDay/mechanistic-interpretability-projects), which also has some common utilities which I import below.

# Setup

In [1]:
# @title Config options

DEVELOPMENT_MODE = False  # @param {type:"boolean"}


In [2]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = True
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install https://github.com/SamAdamDay/mechanistic-interpretability-projects.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Jupyter notebook - intended for development only!
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [3]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")


Using renderer: notebook_connected


Import things

In [4]:
import random
from pathlib import Path
from typing import List, Union, Optional
from functools import partial
import copy
import itertools
import dataclasses

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import numpy as np

import einops

from fancy_einsum import einsum

import tqdm.auto as tqdm

import plotly.express as px

from jaxtyping import Float, Int

from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import datasets

from IPython.display import HTML

import circuitsvis as cv

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import (
    HookedTransformer,
    HookedTransformerConfig,
    FactoredMatrix,
    ActivationCache,
)


Turn automatic differentiation off

In [5]:
torch.set_grad_enabled(False)


<torch.autograd.grad_mode.set_grad_enabled at 0x7f0e7c15c7d0>

Torch device

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"


Plotting helpers

In [7]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        labels={"x": xaxis, "y": yaxis},
        **kwargs
    ).show(renderer)


def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x": xaxis, "y": yaxis}, **kwargs).show(
        renderer
    )


def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y, x=x, labels={"x": xaxis, "y": yaxis, "color": caxis}, **kwargs
    ).show(renderer)


# Task specification

The basic task is as follows.

**Task.** Given a string $s$ containing some brackets, determine: (1) if an opening or closing bracket is more appropriate and (2) which type of bracket is most appropriate.

We'll be using the GPT-Neo 125M model.

In [8]:
model = HookedTransformer.from_pretrained(
    "gpt-neo-125M",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)


Using pad_token, but it is not set yet.


Loaded pretrained model gpt-neo-125M into HookedTransformer


We'll look at the brackets `([{<)]}>`.

In [9]:
brackets = (tuple("([{<"), tuple(")]}>"))
bracket_pairs = tuple(zip(*brackets))
brackets_flat = brackets[0] + brackets[1]

# Get the tokens both for the brackets on their own and with a space before
brackets_tokens = model.to_tokens(brackets_flat, prepend_bos=False).squeeze()
brackets_space_tokens = model.to_tokens(
    [" " + b for b in brackets_flat], prepend_bos=False
).squeeze()
brackets_all_tokens = torch.cat((brackets_tokens, brackets_space_tokens))

print("brackets:", brackets)
print("bracket_pairs:", bracket_pairs)
print("brackets_flat:", brackets_flat)
print("brackets_tokens:", brackets_tokens)
print("brackets_space_tokens:", brackets_space_tokens)
print("brackets_all_tokens:", brackets_all_tokens)


brackets: (('(', '[', '{', '<'), (')', ']', '}', '>'))
bracket_pairs: (('(', ')'), ('[', ']'), ('{', '}'), ('<', '>'))
brackets_flat: ('(', '[', '{', '<', ')', ']', '}', '>')
brackets_tokens: tensor([ 7, 58, 90, 27,  8, 60, 92, 29], device='cuda:0')
brackets_space_tokens: tensor([ 357,  685, 1391, 1279, 1267, 2361, 1782, 1875], device='cuda:0')
brackets_all_tokens: tensor([   7,   58,   90,   27,    8,   60,   92,   29,  357,  685, 1391, 1279,
        1267, 2361, 1782, 1875], device='cuda:0')


# Exploring model capability

How good is GPT-Neo at closing brackets? In this section I explore its capabilities and try to break it. 

I will explore the following variations on the string $s$.
- Whether the brackets are balanced or not.
- The type of brackets used.
- Whether we mix different types.
- The complexity of the bracket structure. This can be thought of as a tree, and we can consider varying both its depth and breadth.
- The complexity of the rest of the string.
- Whether $s$ looks like real code. I'll look at the following ways this could fail.
    * It's actually natural language.
    * It's like a programming language but has syntax errors.
    * It's valid syntax but the symbol names are gibberish/unnatural.
    * It consists only of brackets.

## The prompts

I will test the following prompts, to see what the model does.

In [10]:
exploratory_prompts = [
    "def line(tensor, renderer=None",
    "def line(tensor, renderer=None)",
    "exploratory_prompts = ['test'",
    "exploratory_prompts = ['test']",
    "exploratory_dict = {'test': 'four'",
    "exploratory_dict = {'test': 'four'}",
    "<template",
    "<template>",
    "def sieve(num, prime_list = [2, 3]",
    "def sieve(num, prime_list = [2, 3])",
    "exploratory_dict = {'test': [3, 5]",
    "exploratory_dict = {'test': [3, 5]}",
    "exploratory_dict = {'test': get_test()",
    "exploratory_dict = {'test': get_test()}",
    "html_to_markdown('<s>'",
    "html_to_markdown('<s>')",
    "<table id='name()'",
    "<table id='name()'>",
    "load_model(build_structure()",
    "load_model(build_structure())",
    "load_model(build_structure(), get_hyperparameters()",
    "load_model(build_structure(), get_hyperparameters())",
    "load_model(build_structure(), get_hyperparameters(), (True, False)",
    "load_model(build_structure(), get_hyperparameters(), (True, False))",
    "load_model(build_structure(), get_hyperparameters(True), (True, False), get_extra_config(x)",
    "load_model(build_structure(), get_hyperparameters(True), (True, False), get_extra_config(x))",
    "x.detach().cpu().to_numpy(",
    "x.detach().cpu().to_numpy()",
    "enumerate(list(zip([1,3,65], [1, 2, 3])",
    "enumerate(list(zip([1,3,65], [1, 2, 3]",
    "enumerate(list(zip([1,3,65], [1, 2, 3",
    "enumerate(list(zip([1,3,65], [1, 2, 3]))",
    "enumerate(list(zip([1,3,65], [1, 2, 3], [4, 3], [1], list(np.zeros(4))",
    "enumerate(list(zip([1,3,65], [1, 2, 3], [4, 3], [1], list(np.zeros(4)))",
    "zip(enumerate(list(zip([1,3,65], [1, 2, 3], [4, 3], [1], list(np.zeros(4)))), [3, 4]",
    "zip(enumerate(list(zip([1,3,65], [1, 2, 3], [4, 3], [1], list(np.zeros(4)))), [3, 4])",
    "zip(enumerate(list(zip([1,3,65], list(np.zeros(4)))), [3, 4], {1: 2}.items()",
    "zip(enumerate(list(zip([1,3,65], list(np.zeros(4)))), [3, 4], {1: 2}.items())",
    "px.imshow(to_numpy(tensor), ccm=0.0, ccs='RdBu', labels={'x':xaxis, 'y':yaxis}, **kwargs",
    "px.imshow(to_numpy(tensor), ccm=0.0, ccs='RdBu', labels={'x':xaxis, 'y':yaxis}, **kwargs)",
    "In the course our our analysis (which was long",
    "In the course our our analysis (which was long)",
    "He was eating a apple [sic",
    "He was eating a apple [sic]",
    "In the course our our analysis (which was long (though not too long",
    "In the course our our analysis (which was long (though not too long))",
    "def sieve(,num prime_list = 2[, 3]",
    "def sieve(,num prime_list = 2[, 3])",
    "defn line(tensor, renderer===None",
    "defn line(tensor, renderer===None)",
    "exploratory_dict = {'test': [3,} 5]",
    "exploratory_dict = {'test': [3,} 5",
    "exploratory_prompts = ['test'(]",
    "exploratory_prompts = ['test'(])",
    "def safasfd(oubefwef, vcewfec=afuasvfs",
    "def safasfd(oubefwef, vcewfec=afuasvfs)",
    "asdjhvauyrfsac = ['asdasdasd'",
    "asdjhvauyrfsac = ['asdasdasd']",
    "dfc = {'sdasd': 'casdasd'",
    "dfc = {'sdasd': 'casdasd'}",
    "<bwevzcxc",
    "<bwevzcxc>",
    "([]",
    "([])",
    "([({},[{()}])])",
    "([({},[{()}])]",
    "([({},[{()}])",
    "([({},[{()}]",
]
print("len(exploratory_prompts):", len(exploratory_prompts))

len(exploratory_prompts): 68


Let's convert the prompts to padded tokens, keeping track of each unpadded length, so we can find the next predicted token for each one.

In [11]:
# Compute the token length of each prompt, so we know where the next-token
# prediction will be
exp_prompt_token_lengths = []
for prompt in exploratory_prompts:
    prompt_tokens = model.to_tokens(prompt)
    exp_prompt_token_lengths.append(prompt_tokens.shape[1])

# Convert all the prompts to tokens, padding to make them the same length
exp_prompt_tokens = model.to_tokens(exploratory_prompts)
exp_prompt_tokens.to(device)

print("exp_prompt_tokens:", exp_prompt_tokens.shape)


exp_prompt_tokens: torch.Size([68, 47])


## Investigating performance

The following function computes runs the model, looks at the predictions for the next tokens for each prompt, and computes the probability that it each possible bracket (including spaces), conditioned on that it actually is a bracket.

In [12]:
def compute_bracket_scores(
    prompt_tokens: Float[torch.Tensor, "batch pos"],
    prompt_token_lengths: Optional[list[int]] = None,
) -> Float[torch.Tensor, "batch n_brackets"]:
    """Computes the conditional prob that the next token is each bracket

    Conditioned on the next token actually being a bracket
    """

    num_prompts = prompt_tokens.shape[0]
    num_brackets = len(brackets_flat)

    all_logits = model(prompt_tokens, return_type="logits")  # batch pos d_vocab

    d_vocab = all_logits.shape[2]

    # Select the last token from each
    if prompt_token_lengths is None:
        logits = all_logits[:, prompt_tokens.shape[1] - 1, :]  # batch d_vocab
    else:
        indices = torch.tensor(prompt_token_lengths, device=device) - 1  # batch
        indices = indices.reshape((num_prompts, 1, 1))  # batch 1 1
        indices = indices.repeat((1, 1, d_vocab))  # batch 1 d_vocab
        logits = torch.gather(all_logits, 1, indices).squeeze()  # batch d_vocab

    probs = F.softmax(logits, dim=1)  # batch d_vocab

    # Compute the probability for each bracket and spaced bracket, conditioned
    # on the fact that it is a bracket
    cond_probs = probs[:, brackets_all_tokens]  # batch (2 n_brackets)
    cond_probs = F.normalize(cond_probs, p=1.0, dim=1)

    # Combine the conditional probabilities for each bracket with its spaced
    # version
    cond_probs = cond_probs[:, :num_brackets] + cond_probs[:, num_brackets:]

    return cond_probs


bracket_scores = compute_bracket_scores(exp_prompt_tokens, exp_prompt_token_lengths)


Let's display these in a nice chart. I break it up into two since there are a lot of prompts.

In [15]:
def display_bracket_scores(
    prompts: list[str], bracket_scores: Float[torch.Tensor, "batch n_brackets"]
):
    """Display the bracket scores nicely"""
    num_prompts = len(prompts)
    fig = px.imshow(
        utils.to_numpy(bracket_scores),
        color_continuous_scale="blues",
        labels=dict(x="Bracket", color="Probability"),
        x=brackets_flat,
        y=prompts,
        height=30 * num_prompts,
    )
    for ix, bracket in enumerate(brackets_flat):
        for iy in range(num_prompts):
            fig.add_annotation(
                x=ix,
                y=iy,
                text=bracket,
                showarrow=False,
                font_color="orange",
            )
    fig.show()


In [16]:
display_bracket_scores(exploratory_prompts[:34], bracket_scores[:34, :])

In [17]:
display_bracket_scores(exploratory_prompts[34:], bracket_scores[34:, :])

## Discussion

- The model seems to do very well at predicting the correct closing bracket, and is robust to almost everything I've thrown at it.
- One important case which occurs a few times though is when I'm looking to have the model predict `)` but it actually predicts `[`, with `)` being the second-most likely next bracket. I investigate this a little more below.
- The model struggles a bit on the last prompts made purely of brackets and commas.
- When the brackets are balanced, the model outputs vary a lot more and tend to be less strong. This seems reasonable: in general there aren't clear-cut best next-bracket answers in these cases.

Let's think about the case where the model predicts `[` instead of `)`. This happens more frequently with more complex prompts. I wouldn't say this is always wrong: it would be plausible to see a `[` in most of these cases, as a way of indexing some object, though usually this would be a bit weird. There are some cases where `[` is definitely not what we'd want, for example:
```python
zip(enumerate(list(zip([1,3,65], list(np.zeros(4)))), [3, 4], {1: 2}.items()
```
(The object returned by `items()` is not subscriptable.)

I would say that this observation raises a larger question about what exactly the task *is* and what the metric should be. The intuition is that a good model should be able to keep track of the open and closed brackets, and should prefer generating text which is *eventually* is bracket-balanced. However, in the shorter term this may involve opening new brackets (after all, we wouldn't want to the model to be biased towards immediately closing all brackets in the prompt). I can think of the following ways of approaching this.
1. The most direct way is to simply let the model continue to generate tokens, with the aim of seeing if the whole generated text is bracket-balanced and from there trying to understand how the model has done this. This would be a substantial undertaking, and beyond the scope of this small exploration.
2. Another option is to focus only on the bracket type we care about. In this case, we'd only compare the prediction for `(` with `)`, and ignore the comparison with `[` and `]`. Of course there may still be instances where opening with `(` is a reasonable choice for the model to make so this doesn't completely eliminate the problem.
3. The simplest way is to focus on clear-cut examples, where the only reasonable bracket is a closing one. I will go with this direction here, since it isolates more cleanly exactly what we want to investigate, which hopefully also makes the model behaviour more evident.