# Understanding bracket closing in GPT-Neo

The goal of this notebook is to explore the phenomenon of bracket closing in the [GPT-Neo 125M model](https://www.eleuther.ai/artifacts/gpt-neo), whereby it can correctly match open parentheses `([{<` with their corresponding closing versions `)]}>`.

This is [Problem 2.13](https://www.alignmentforum.org/s/yivyHaCAmMJ3CqSyj/p/XNjRwEX9kxbpzWFWd#block71) in Neel Nanda's [200 Concrete Open Problems in Mechanistic Interpretability](https://www.alignmentforum.org/posts/LbrPTJ4fmABEdEnLf/200-concrete-open-problems-in-mechanistic-interpretability). The first goal is to figure out how the model determines whether an opening or closing bracket is more appropriate, and the second is to figure out how it knows the correct kind: `(`, `[`, `{` or `<`.

I'm using the [TransformerLens library](https://github.com/neelnanda-io/TransformerLens), and a lot of this notebook is copied from Neel's [Exploratory Analysis notebook](https://neelnanda.io/exploratory-analysis-demo).

This notebook lives in my [mechanistic interpretability GitHub repository](https://github.com/SamAdamDay/mechanistic-interpretability-projects).

# Setup

In [1]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install https://github.com/SamAdamDay/mechanistic-interpretability-projects.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Jupyter notebook - intended for development only!
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  ipython.magic("load_ext autoreload")
  ipython.magic("autoreload 2")


In [2]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")


Using renderer: notebook_connected


Import things

In [3]:
import random
from pathlib import Path
from typing import List, Union, Optional
from functools import partial
import copy
import itertools
import dataclasses
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import numpy as np

import einops

from fancy_einsum import einsum

import tqdm.auto as tqdm

import plotly.express as px

from jaxtyping import Float, Int

from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import datasets

from IPython.display import HTML

import circuitsvis as cv

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import (
    HookedTransformer,
    HookedTransformerConfig,
    FactoredMatrix,
    ActivationCache,
)


Turn automatic differentiation off

In [4]:
torch.set_grad_enabled(False)


<torch.autograd.grad_mode.set_grad_enabled at 0x7f765360bd90>

Torch device

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"


Plotting helpers

In [6]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        labels={"x": xaxis, "y": yaxis},
        **kwargs
    ).show(renderer)


def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)
# def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
#     px.line(utils.to_numpy(tensor), labels={"x": xaxis, "y": yaxis}, **kwargs).show(
#         renderer
#     )


def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y, x=x, labels={"x": xaxis, "y": yaxis, "color": caxis}, **kwargs
    ).show(renderer)


# Task specification

The basic task is as follows.

**Task.** Given a string $s$ containing some brackets, determine: (1) if an opening or closing bracket is more appropriate and (2) which type of bracket is most appropriate.

We'll be using the GPT-Neo 125M model.

In [7]:
model = HookedTransformer.from_pretrained(
    "gpt-neo-125M",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)


Using pad_token, but it is not set yet.


Loaded pretrained model gpt-neo-125M into HookedTransformer


We'll look at the brackets `([{<)]}>`. We want to match all tokens which either begin with a bracket, or a space followed by the bracket. We want to match for example the token ').' for the bracket `)`.

In [8]:
brackets = (tuple("([{<"), tuple(")]}>"))
bracket_pairs = tuple(zip(*brackets))
brackets_flat = brackets[0] + brackets[1]
num_brackets = len(brackets_flat)

# Get all (non-endoftext tokens) tokens which either start with a bracket or a
# space followed by a bracket
bracket_tokens = OrderedDict([(bracket, []) for bracket in brackets_flat])
bracket_token_strs = OrderedDict([(bracket, []) for bracket in brackets_flat])
all_tokens = model.to_str_tokens(np.arange(model.cfg.d_vocab - 1), prepend_bos=False)
for i, token_str in enumerate(all_tokens):
    for bracket in brackets_flat:
        if token_str.startswith(bracket) or token_str.startswith(" " + bracket):
            bracket_tokens[bracket].append(i)
            bracket_token_strs[bracket].append(token_str)
for bracket, tokens in bracket_tokens.items():
    bracket_tokens[bracket] = torch.tensor(tokens)

# Flatten the dict of tokens, and record the sizes of each list
bracket_tokens_flat = torch.cat(list(bracket_tokens.values()))
bracket_tokens_sizes = [tokens.shape[0] for tokens in bracket_tokens.values()]

# Select the open and closed bracket tokens
num_open_bracket_tokens = sum(bracket_tokens_sizes[:num_brackets // 2])
open_bracket_tokens = bracket_tokens_flat[:num_open_bracket_tokens]
closed_bracket_tokens = bracket_tokens_flat[num_open_bracket_tokens:]

print("brackets:", brackets)
print("bracket_pairs:", bracket_pairs)
print("brackets_flat:", brackets_flat)
for bracket, token_strs in bracket_token_strs.items():
    print(f"{bracket} tokens:", token_strs)
print("bracket_tokens_flat:", bracket_tokens_flat.shape)
print("open_bracket_tokens:", open_bracket_tokens.shape)
print("closed_bracket_tokens:", closed_bracket_tokens.shape)
print("bracket_tokens_sizes:", bracket_tokens_sizes)


brackets: (('(', '[', '{', '<'), (')', ']', '}', '>'))
bracket_pairs: (('(', ')'), ('[', ']'), ('{', '}'), ('<', '>'))
brackets_flat: ('(', '[', '{', '<', ')', ']', '}', '>')
( tokens: ['(', ' (', '()', ' (@', ' ("', ' ($', '("', ' ()', '();', "('", ' (+', ' (%)', ' (-', ' ();', ' ((', '({', '($', ' (#', " ('", '((', ' (.', ' (*', '().', ' (!', '(),', ' (£', '([', ' ().', '(_', '())', ' ([', ' (),', ' (~', '(-', ' (?,', ' ())', '():', '());', ' (&', ' (−', ' (%', ' ({', '(\\', ' (<', ' ());', '(&', '(){', ' (_', ' (>', ' ($)', ' (=', '(*', ' (/']
[ tokens: ['[', ' [', '[/', ' ["', '["', ' [[', ' []', "['", '[]', ' […]', ' [];', ' [-', ' [+', ' [...]', '[_', '[[', ' [*', ' [*]', " ['", ' [/', ' [+]', ' [(', ' [|', ' [&']
{ tokens: ['{', ' {', '{"', ' {"', ' {{', ' {}', '{{', '{\\', ' {\\', ' {:', ' {*']
< tokens: ['<', ' <', '</', ' </', ' <<', '<<', ' <=', ' <-', ' <[', ' <@', ' <!--', ' <+', '<?']
) tokens: [')', ').', '),', ' )', ');', '):', '))', ' );', ')(', ' ).', ' ),', ')-', ')|

# Exploring model capability

How good is GPT-Neo at closing brackets? In this section I explore its capabilities and try to break it. 

I will explore the following variations on the string $s$.
- Whether the brackets are balanced or not.
- The type of brackets used.
- Whether we mix different types.
- The complexity of the bracket structure. This can be thought of as a tree, and we can consider varying both its depth and breadth.
- The complexity of the rest of the string.
- Whether $s$ looks like real code. I'll look at the following ways this could fail.
    * It's actually natural language.
    * It's like a programming language but has syntax errors.
    * It's valid syntax but the symbol names are gibberish/unnatural.
    * It consists only of brackets.

## Exploratory prompts

I will test the following prompts, to see what the model does.

In [9]:
exploratory_prompts = [
    "def line(tensor, renderer=None",
    "def line(tensor, renderer=None)",
    "exploratory_prompts = ['test'",
    "exploratory_prompts = ['test']",
    "array[0",
    "array[0]",
    "exploratory_dict = {'test': 'four'",
    "exploratory_dict = {'test': 'four'}",
    "<template",
    "<template>",
    "def sieve(num, prime_list = [2, 3]",
    "def sieve(num, prime_list = [2, 3])",
    "exploratory_dict = {'test': [3, 5]",
    "exploratory_dict = {'test': [3, 5]}",
    "exploratory_dict = {'test': get_test()",
    "exploratory_dict = {'test': get_test()}",
    "html_to_markdown('<s>'",
    "html_to_markdown('<s>')",
    "<table id='name()'",
    "<table id='name()'>",
    "load_model(build_structure()",
    "load_model(build_structure())",
    "load_model(build_structure(), get_hyperparameters()",
    "load_model(build_structure(), get_hyperparameters())",
    "load_model(build_structure(), get_hyperparameters(), (True, False)",
    "load_model(build_structure(), get_hyperparameters(), (True, False))",
    "load_model(build_structure(), get_hyperparameters(True), (True, False), get_extra_config(x)",
    "load_model(build_structure(), get_hyperparameters(True), (True, False), get_extra_config(x))",
    "x.detach().cpu().to_numpy(",
    "x.detach().cpu().to_numpy()",
    "enumerate(list(zip([1,3,65], [1, 2, 3])",
    "enumerate(list(zip([1,3,65], [1, 2, 3]",
    "enumerate(list(zip([1,3,65], [1, 2, 3",
    "enumerate(list(zip([1,3,65], [1, 2, 3]))",
    "enumerate(list(zip([1,3,65], [1, 2, 3], [4, 3], [1], list(np.zeros(4))",
    "enumerate(list(zip([1,3,65], [1, 2, 3], [4, 3], [1], list(np.zeros(4)))",
    "zip(enumerate(list(zip([1,3,65], [1, 2, 3], [4, 3], [1], list(np.zeros(4)))), [3, 4]",
    "zip(enumerate(list(zip([1,3,65], [1, 2, 3], [4, 3], [1], list(np.zeros(4)))), [3, 4])",
    "zip(enumerate(list(zip([1,3,65], list(np.zeros(4)))), [3, 4], {1: 2}.items()",
    "zip(enumerate(list(zip([1,3,65], list(np.zeros(4)))), [3, 4], {1: 2}.items())",
    "list())))).append(x",
    "list())))).append(x)",
    "px.imshow(to_numpy(tensor), ccm=0.0, ccs='RdBu', labels={'x':xaxis, 'y':yaxis}, **kwargs",
    "px.imshow(to_numpy(tensor), ccm=0.0, ccs='RdBu', labels={'x':xaxis, 'y':yaxis}, **kwargs)",
    "In the course our our analysis (which was long",
    "In the course our our analysis (which was long)",
    "He was eating a apple [sic",
    "He was eating a apple [sic]",
    "In the course our our analysis (which was long (though not too long",
    "In the course our our analysis (which was long (though not too long))",
    "def sieve(,num prime_list = 2[, 3]",
    "def sieve(,num prime_list = 2[, 3])",
    "defn line(tensor, renderer===None",
    "defn line(tensor, renderer===None)",
    "exploratory_dict = {'test': [3,} 5]",
    "exploratory_dict = {'test': [3,} 5",
    "exploratory_prompts = ['test'(]",
    "exploratory_prompts = ['test'(])",
    "def safasfd(oubefwef, vcewfec=afuasvfs",
    "def safasfd(oubefwef, vcewfec=afuasvfs)",
    "asdjhvauyrfsac = ['asdasdasd'",
    "asdjhvauyrfsac = ['asdasdasd']",
    "dfc = {'sdasd': 'casdasd'",
    "dfc = {'sdasd': 'casdasd'}",
    "<bwevzcxc",
    "<bwevzcxc>",
    "([]",
    "([])",
    "([({},[{()}])])",
    "([({},[{()}])]",
    "([({},[{()}])",
    "([({},[{()}]",
]
num_exploratory_prompts = len(exploratory_prompts)
print("num_exploratory_prompts:", num_exploratory_prompts)

num_exploratory_prompts: 72


Let's convert the prompts to padded tokens, keeping track of each unpadded length, so we can find the next predicted token for each one.

In [10]:
# Compute the token length of each prompt, so we know where the next-token
# prediction will be
exp_prompt_token_lengths = []
for prompt in exploratory_prompts:
    prompt_tokens = model.to_tokens(prompt)
    exp_prompt_token_lengths.append(prompt_tokens.shape[1])

# Convert all the prompts to tokens, padding to make them the same length
exp_prompt_tokens = model.to_tokens(exploratory_prompts)
exp_prompt_tokens.to(device)

print("exp_prompt_tokens:", exp_prompt_tokens.shape)


exp_prompt_tokens: torch.Size([72, 47])


## Investigating performance

The following function computes runs the model, looks at the predictions for the next tokens for each prompt, and computes the probability that it each possible bracket (including spaces), conditioned on that it actually is a bracket.

In [11]:
def compute_bracket_scores(
    prompt_tokens: Float[torch.Tensor, "batch pos"],
    prompt_token_lengths: Optional[list[int]] = None,
) -> Float[torch.Tensor, "batch n_brackets"]:
    """Computes the conditional prob that the next token is each bracket

    Conditioned on the next token actually being a bracket
    """

    num_prompts = prompt_tokens.shape[0]

    all_logits = model(prompt_tokens, return_type="logits")  # batch pos d_vocab

    d_vocab = all_logits.shape[2]

    # Select the last token from each
    if prompt_token_lengths is None:
        logits = all_logits[:, prompt_tokens.shape[1] - 1, :]  # batch d_vocab
    else:
        indices = torch.tensor(prompt_token_lengths, device=device) - 1  # batch
        indices = indices.reshape((num_prompts, 1, 1))  # batch 1 1
        indices = indices.repeat((1, 1, d_vocab))  # batch 1 d_vocab
        logits = torch.gather(all_logits, 1, indices).squeeze()  # batch d_vocab

    probs = F.softmax(logits, dim=1)  # batch d_vocab

    # Compute the probability for each bracket and spaced bracket, conditioned
    # on the fact that it is a bracket
    cond_probs = probs[:, bracket_tokens_flat]  # batch (2 n_bracket_tokens)
    cond_probs = F.normalize(cond_probs, p=1.0, dim=1)

    # Combine the conditional probabilities for each bracket
    cond_probs_combined = torch.zeros((num_prompts, num_brackets))
    index = 0
    for i, size in enumerate(bracket_tokens_sizes):
        cond_probs_combined[:, i] = cond_probs[:, index : index + size].sum(dim=1)
        index += size

    return cond_probs_combined


bracket_scores = compute_bracket_scores(exp_prompt_tokens, exp_prompt_token_lengths)


Let's display these in a nice chart. I break it up into two since there are a lot of prompts.

In [12]:
def display_bracket_scores(
    prompts: list[str],
    bracket_scores: Float[torch.Tensor, "batch n_brackets"],
    height_scale: int = 30,
):
    """Display the bracket scores nicely"""
    num_prompts = len(prompts)
    fig = px.imshow(
        utils.to_numpy(bracket_scores),
        color_continuous_scale="blues",
        labels=dict(x="Bracket", color="Conditional Probability"),
        x=brackets_flat,
        y=prompts,
        height=height_scale * num_prompts,
    )
    for ix, bracket in enumerate(brackets_flat):
        for iy in range(num_prompts):
            fig.add_annotation(
                x=ix,
                y=iy,
                text=bracket,
                showarrow=False,
                font_color="orange",
            )
    fig.show()


In [13]:
display_bracket_scores(exploratory_prompts[:num_exploratory_prompts // 2], bracket_scores[:num_exploratory_prompts // 2, :])

In [14]:
display_bracket_scores(exploratory_prompts[num_exploratory_prompts // 2:], bracket_scores[num_exploratory_prompts // 2:, :])

## Discussion

- The model seems to do very well at predicting the correct closing bracket, and is robust to most things I've thrown at it.
- One important case which occurs a few times though is when I'm looking to have the model predict `)` but it actually predicts `[`, with `)` being the second-most likely next bracket. I investigate this a little more below.
- The fact that the model predicts `(` on `list())))).append(x` indicates that it is not confused by lots of closing brackets.
- The model struggles a bit on the last prompts made purely of brackets and commas.
- When the brackets are balanced, the model outputs vary a lot. Usually it predicts an opening bracket, though often spreading the probability over several types. Other times the probability is spread over both opening and closing brackets. And sometimes it predicts a closing bracket. 
    * It's not entirely clear what the model *should* predict in these cases. Oftentimes any kind of bracket would be inappropriate.

Let's think about the case where the model predicts `[` instead of `)`. I wouldn't say this is always wrong: it would be plausible to see a `[` in most of these cases, as a way of indexing some object, though usually this would be a bit weird.

This observation raises a larger question about what exactly the task *is* and what the metric should be. The intuition is that a good model should be able to keep track of the open and closed brackets, and should prefer generating text which is *eventually* bracket-balanced. However, in the shorter term this may involve opening new brackets (after all, we wouldn't want to the model to be biased towards immediately closing all brackets it creates). I can think of the following ways of approaching this.
1. The most direct way is to simply let the model continue to generate tokens, with the aim of seeing if the whole generated text is bracket-balanced and from there trying to understand how the model has done this. This would be a substantial undertaking, and beyond the scope of this small exploration.
2. Another option is to focus only on the bracket type we care about. In this case, we'd only compare the prediction for `(` with `)`, and ignore the comparison with `[` and `]`. Of course there may still be instances where opening with `(` is a reasonable choice for the model to make so this doesn't completely eliminate the problem.
3. The simplest way is to focus on clear-cut examples, where the only reasonable bracket is a closing one. I will go with this direction here, since it isolates more cleanly exactly what we want to investigate, which hopefully also makes the model behaviour more evident.

# Hypotheses

How is GPT-Neo able to determine whether to close a bracket? Before I get my hands dirty with the model weights, I'm going briefly elaborate my thoughts for what might be going on.

- A basic component for this capability might be a 'bracket-counting' head. In this head certain tokens (perhaps those where some kind of bracket is likely as the next token, or just all tokens) attend to all the previous brackets. The opening bracket value vectors lie in an opposite direction to the closing bracket value vectors. This way, when we take the weighted sum, its projection onto the line determined by these opposing directions counts the value:
```
    (Number of opening brackets) - (Number of closing brackets)
```
- A simple way the transformer could use a bracket-counting head is by predicting an closing bracket if this number is positive and it is likely that the next token is some kind of bracket.
- Such a simple head doesn't explain:
    1. Why the model doesn't get confused by `list())))).append(x` (note that at `x` this count will be negative).
    2. How the model can determine *which* bracket is appropriate.
- Intuitively, in order to the deal with the first problem, the model needs some way of 'resetting' the count when it encounters the second `(`.
- Here is one way this could be accomplished. There is a second bracket-counting head on a later layer, which works the same way except for the following modification. Any opening bracket which has a negative count from the first bracket-counting head gets the value vector which is the normal opening-bracket vector multiplied by the negative of the bracket count, plus one. This means that opening brackets appearing after a negatively balanced string reset the count, and counting can proceed as normal.
- I can't think of a way to simplify this to a single head. Intuitively, the head which determines the final count already needs to have access to the bracket count computation, in order to determine when to reset. Perhaps there's a way to do it which doesn't involve counting.
- To deal with the second token, the model needs some way of keeping track of the type of the most recent unclosed bracket. I haven't thought of a way this could work.

# Experimental setup

Here I define the prompts which I will be testing, and the metric used to quantify model performance on them. 

## Reference prompts

I choose the prompts according to following criteria.
1. They should have the same number of tokens.
2. The next token, if it is a bracket, should be clearly a closing one.
3. Each should have a corrupted version, which has the same number of tokens, differs only slightly, but after which the model predicts something different (ideally an opening bracket).
4. There should be a variety of kinds of prompt.

The motivations for these are as follows.
1. This makes working with the next predicted token across all prompts simultaneously easier.
2. Cases where there are more opening than closing brackets are more clear-cut.
3. Later I would like to use activation patching as an interpretability tool. This requires a corrupted version.
4. We want to find a mechanism by which the model robustly accomplishes the task, rather than one which might be specific to one kind of prompt.

In [15]:
# The regular prompts and their answers
prompts = [
    "def line(data_tensor, renderer='four'",
    "model.dataset.data = table['responder'",
    "{b(), c(), d(), e(), f(x)",
    "<template id='named_carriage' name='time'",
    "sieve(7777, [2], 'recurse'",
    "[factor(x_new), test(), inspect(p)",
    "specification = {'<xml>': more(True)",
    "<html style='b.blue {color: blue}'",
]
answers_openness = [0] * len(prompts) # 1 if opening bracket
answer_symbols = list(")]}>)]}>")

# The corrupted prompts and their answers
# Note: there aren't clear answers to what the exact symbol should be
corrupted_prompts = [
    "def line(data_tensor, renderer='four')",
    "model.dataset.data = table['responder']",
    "{b(), c(), d(), e(), f(x)}",
    "<template id='named_carriage' name='time'>",
    "sieve(7777, [2], 'recurse')",
    "[factor(x_new), test(), inspect(p)]",
    "specification = {'<xml>': more(True)}",
    "<html style='b.blue {color: blue}'>",
]
corrupted_answers_openness = [1] * len(corrupted_prompts) # 1 if opening bracket

# Combine the non-corrupted and corrupted
all_prompts = prompts + corrupted_prompts
all_answers_openness = answers_openness + corrupted_answers_openness

print ("Prompts")
for prompt in prompts:
    prompt_as_tokens = model.to_str_tokens(prompt)
    print(len(prompt_as_tokens), "#".join(prompt_as_tokens))

print()
print ("Corrupted prompts")
for prompt in corrupted_prompts:
    prompt_as_tokens = model.to_str_tokens(prompt)
    print(len(prompt_as_tokens), "#".join(prompt_as_tokens))

Prompts
14 <|endoftext|>#def# line#(#data#_#t#ensor#,# rend#erer#='#four#'
14 <|endoftext|>#model#.#dat#as#et#.#data# =# table#['#respond#er#'
14 <|endoftext|>#{#b#(),# c#(),# d#(),# e#(),# f#(#x#)
14 <|endoftext|>#<#template# id#='#named#_#car#riage#'# name#='#time#'
14 <|endoftext|>#s#ieve#(#77#77#,# [#2#],# '#rec#urse#'
14 <|endoftext|>#[#factor#(#x#_#new#),# test#(),# inspect#(#p#)
14 <|endoftext|>#spec#ification# =# {#'#<#xml#>#':# more#(#True#)
14 <|endoftext|>#<#html# style#='#b#.#blue# {#color#:# blue#}#'

Corrupted prompts
14 <|endoftext|>#def# line#(#data#_#t#ensor#,# rend#erer#='#four#')
14 <|endoftext|>#model#.#dat#as#et#.#data# =# table#['#respond#er#']
14 <|endoftext|>#{#b#(),# c#(),# d#(),# e#(),# f#(#x#)}
14 <|endoftext|>#<#template# id#='#named#_#car#riage#'# name#='#time#'>
14 <|endoftext|>#s#ieve#(#77#77#,# [#2#],# '#rec#urse#')
14 <|endoftext|>#[#factor#(#x#_#new#),# test#(),# inspect#(#p#)]
14 <|endoftext|>#spec#ification# =# {#'#<#xml#>#':# more#(#True#)}
14 <|end

In [16]:
# Convert all the prompts to tokens, padding to make them the same length
all_prompt_tokens = model.to_tokens(all_prompts).to(device)
prompt_tokens = all_prompt_tokens[:len(prompts), :]
corrupted_prompt_tokens = all_prompt_tokens[len(prompts):, :]

print("prompt_tokens:", prompt_tokens.shape)
print("corrupted_prompt_tokens:", corrupted_prompt_tokens.shape)
print("all_prompt_tokens:", all_prompt_tokens.shape)

prompt_tokens: torch.Size([8, 14])
corrupted_prompt_tokens: torch.Size([8, 14])
all_prompt_tokens: torch.Size([16, 14])


Let's visualise the model performance on the prompts

In [17]:
bracket_scores = compute_bracket_scores(all_prompt_tokens)
display_bracket_scores(all_prompts, bracket_scores, height_scale=40)

## Metrics

I now define the metrics used to evaluate the model performance. The first metric measures the success at predicting openness or closedness of the bracket. The second measures in addition how well the model predicts the actual token. Both use average logit difference.

For the first metric, we take the sum of the logits for open bracket tokens and find the difference with the sum of the logits for closed bracket tokens. When we expect the answer to be an open bracket, the metric is the first of these quantities take away the second. When we expect a closed bracket, it's the other way around.

In [18]:
def openness_metric(
    logits: torch.Tensor, answers_openness: list, per_prompt=False
) -> torch.Tensor:
    """Computes the average difference between the open and closed logits"""

    # Turn the answer openness into a sign tensor
    answers_openness = torch.tensor(answers_openness, device=device)
    answers_openness_sign = torch.sign(2 * answers_openness - 1)

    # Select the final open and closed bracket logits
    open_bracket_logits = logits[:, -1, open_bracket_tokens]
    closed_bracket_logits = logits[:, -1, closed_bracket_tokens]

    # Sum up the logits for open and closed brackets
    open_bracket_logits_sum = open_bracket_logits.sum(dim=-1)
    closed_bracket_logits_sum = closed_bracket_logits.sum(dim=-1)

    # Compute the difference signed by the answer openness
    bracket_logit_diff = open_bracket_logits_sum - closed_bracket_logits_sum
    bracket_logit_diff = bracket_logit_diff * answers_openness_sign

    if per_prompt:
        return bracket_logit_diff
    else:
        return bracket_logit_diff.mean()

Let's test on the reference prompts.

In [19]:
logits = model(prompt_tokens, return_type="logits")
openness_metrics = openness_metric(logits, answers_openness, per_prompt=True)

print(openness_metrics)

tensor([ 101.9095,  332.6285,    6.9961,  223.4639,  176.2359,  295.7979,
        -162.3482,  132.8027], device='cuda:0')


There's a fair bit of variance in the metric for the reference prompts, even though the conditional probabilities all clearly favour one option. This is because:
1. We're looking at logits not probabilities (i.e. they are not 'normalised' by the softmax).
2. Earlier we consider the conditional probability, which has to sum to one over all brackets. It could be that the model predicts a non-bracket higher than any bracket.

Next I define the metric for how well the model predicts the correct bracket. There are several ways of doing this. Here I take the sum of the logits corresponding to the correct bracket, and take away the mean of the sum of the logits for the rest of the brackets.

The motivation for this is as follows. We want the metric to be linear in the logits, because this makes later analysis easier. During training the optimiser tries to minimise the cross entropy loss of the softmax of the logits. If $\{x_i\}$ is the set of all logits, and $x_{\text{true}}$ is the logit for the true next token, this corresponds to maximising:
$$
    x_{\text{true}} - \log \left(\sum_i \exp(x_i) \right)
$$
If we want to focus on just getting the correct bracket, we can see this as maximising:
$$
    x_{\text{true}} - \log \left(\sum_{i \in B} \exp(x_i) \right)
$$
where $B$ is the set of logits corresponding to brackets.

How do approximate this with a linear function? In general, logsumexp is not very linear, but approximating it with the mean seems ok for the purposes of making a metric.

Now the above is a bit of a simplification, since there isn't a 'true' next bracket token, because there are many tokens whose string representation starts with the same token. So instead we combine all logits whose tokens begin with the same bracket, and thus arrive at our metric.

In [20]:
def bracket_symbol_metric(
    logits: torch.Tensor, answer_symbols: list, per_prompt=False
) -> torch.Tensor:
    """Compute the difference from the answer bracket logit to all others"""

    batch_size = logits.shape[0]

    # Sum the logits corresponding to each bracket
    sum_per_bracket = torch.zeros((batch_size, num_brackets))
    for i, tokens in enumerate(bracket_tokens.values()):
        sum_per_bracket[:, i] = logits[:, -1, tokens].sum(dim=-1)

    # Turn the answers_symbol list into a tensor for indexing `sum_per_bracket`
    answer_symbol_indices = [brackets_flat.index(bracket) for bracket in answer_symbols]
    answer_symbol_indices = torch.tensor(answer_symbol_indices)
    answer_symbol_indices = answer_symbol_indices.reshape((batch_size, 1))

    # Compute the logits difference from the answer to the sum of the other
    # brackets
    answer_logits = sum_per_bracket.gather(dim=-1, index=answer_symbol_indices)
    answer_logits = answer_logits.squeeze()
    logit_diff = 2 * answer_logits - sum_per_bracket.mean(dim=-1)

    if per_prompt:
        return logit_diff
    else:
        return logit_diff.mean()


Let's check this metric with reference prompts.

In [21]:
bracket_symbol_metric(logits, answer_symbols, per_prompt=True)

tensor([585.3566, 648.4564, 364.9954, 245.7734, 733.8599, 472.1603, 374.2223,
        320.5742])

Let's compare with some incorrect answers, to make sure the metric is doing what we want.

In [22]:
bracket_symbol_metric(logits, list("<[>){)[]"), per_prompt=True)

tensor([-36.8613,  89.3936, 149.0069, 199.3391, -48.8833, 345.0772, 142.3398,
         77.5326])

# Direct logit attribution

In this section I investigate the model using the 'direct logit attribution' method, which looks at how different parts of the model directly affect the output logits.

Much of this section is copied directly from the [Exploratory Analysis notebook](https://neelnanda.io/exploratory-analysis-demo).

As a first pass, I will make the following simplifications.
1. I will focus solely on the task determining if the next bracket should be opening or closing.
2. Rather than comparing all tokens beginning with a bracket across all bracket types, I will fix a bracket type per prompt and compare only the tokens corresponding to the opening and closing versions. This is to be able to talk about residual directions, looking at the logit difference between the two possible tokens.

Let's first add some wrong answers then tokenise everything.

In [23]:
answer_wrong_symbols = list("([{<([{<")

# Tokenise everything
answer_tokens = [model.to_single_token(b) for b in answer_symbols]
answer_wrong_tokens = [model.to_single_token(b) for b in answer_wrong_symbols]
answer_tokens = torch.tensor(answer_tokens)
answer_wrong_tokens = torch.tensor(answer_wrong_tokens)
answer_both_tokens = torch.stack((answer_tokens, answer_wrong_tokens)).T

print(answer_both_tokens)

tensor([[ 8,  7],
        [60, 58],
        [92, 90],
        [29, 27],
        [ 8,  7],
        [60, 58],
        [92, 90],
        [29, 27]])


Now let's run the model and cache the intermediate activations.

In [24]:
original_logits, cache = model.run_with_cache(prompt_tokens)

We now compute the directions in the residual stream corresponding to moving from the wrong answer to the right one.

In [25]:
answer_residual_directions = model.tokens_to_residual_directions(answer_both_tokens)
print("Answer residual directions shape:", answer_residual_directions.shape)
logit_diff_directions = (
    answer_residual_directions[:, 0] - answer_residual_directions[:, 1]
)
print("Logit difference directions shape:", logit_diff_directions.shape)

Answer residual directions shape: torch.Size([8, 2, 768])
Logit difference directions shape: torch.Size([8, 768])


Let's test to see if this works.

In [26]:
# cache syntax - resid_post is the residual stream at the end of the layer, -1 gets the final layer. The general syntax is [activation_name, layer_index, sub_layer_type].
final_residual_stream = cache["resid_post", -1]
print("Final residual stream shape:", final_residual_stream.shape)
final_token_residual_stream = final_residual_stream[:, -1, :]
# Apply LayerNorm scaling
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(
    final_token_residual_stream, layer=-1, pos_slice=-1
)

# Get the original logit difference
final_logits = original_logits[:, -1, :]
answer_logits = final_logits.gather(dim=-1, index=answer_both_tokens.to(device))
original_average_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
original_average_logit_diff = original_average_logit_diff.mean()

average_logit_diff = einsum(
    "batch d_model, batch d_model -> ",
    scaled_final_token_residual_stream,
    logit_diff_directions,
) / len(prompts)
print("Calculated average logit diff:", average_logit_diff.item())
print("Original logit difference:", original_average_logit_diff.item())


Final residual stream shape: torch.Size([8, 14, 768])
Calculated average logit diff: 3.6474530696868896
Original logit difference: 3.964486598968506


## Logit lens

Here we look at the residual stream after each layer calculate the logit difference from that. This gives an idea of when the model starts being able to do the task.

In [27]:
def residual_stack_to_logit_diff(
    residual_stack: Float[torch.Tensor, "components batch d_model"],
    cache: ActivationCache,
) -> float:
    scaled_residual_stack = cache.apply_ln_to_stack(
        residual_stack, layer=-1, pos_slice=-1
    )
    return einsum(
        "... batch d_model, batch d_model -> ...",
        scaled_residual_stack,
        logit_diff_directions,
    ) / len(prompts)


In [28]:
accumulated_residual, labels = cache.accumulated_resid(
    layer=-1, incl_mid=True, pos_slice=-1, return_labels=True
)
print(accumulated_residual.shape)
logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, cache)
print(logit_lens_logit_diffs.shape)
line(
    logit_lens_logit_diffs,
    x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
    hover_name=labels,
    title="Logit Difference From Accumulate Residual Stream",
)


torch.Size([25, 8, 768])
torch.Size([25])


Interestingly, for the first six layers the model actually gets worse at the task. In layer 5 the model performance jumps back up to around baseline performance, and stays there until layers 8 and 9 where it achieves best performance. Afterwards the performance decreases a little bit.

Here are my provisional thoughts on what might be happening.
- It could be that solving the task requires intermediate computation steps, and during these steps the model predicts the wrong token (at least on the prompts on which we're testing).
- Alternatively, the initial dip in performance might be unrelated to the task. Perhaps the model doesn't try to figure out bracket balance until the later layers. Earlier on it might be doing other things with the logit directions; in other words there's some superposition going on, and the different superposed features are computed at different stages of the model.
- The final decrease in performance might be because the sample of prompts is not representative enough. Perhaps in order to get the best performance across all bracket matching tasks (weighted by the data distribution), the optimiser decided to reduce performance on the current set of prompts in favour of others. In other words, while we might see decreasing performance on these prompts in the last layers, on others might still be low at layer 9 and continue increasing. 