# Understanding bracket closing in GPT-Neo

The goal of this notebook is to explore the phenomenon of bracket closing in the [GPT-Neo 125M model](https://www.eleuther.ai/artifacts/gpt-neo), whereby it can correctly match open parentheses `([{<` with their corresponding closing versions `)]}>`.

This is [Problem 2.13](https://www.alignmentforum.org/s/yivyHaCAmMJ3CqSyj/p/XNjRwEX9kxbpzWFWd#block71) in Neel Nanda's [200 Concrete Open Problems in Mechanistic Interpretability](https://www.alignmentforum.org/posts/LbrPTJ4fmABEdEnLf/200-concrete-open-problems-in-mechanistic-interpretability). The first goal is to figure out how the model determines whether an opening or closing bracket is more appropriate, and the second is to figure out how it knows the correct kind: `(`, `[`, `{` or `<`.

I'm using the [TransformerLens library](https://github.com/neelnanda-io/TransformerLens), and a lot of this notebook is copied from Neel's [Exploratory Analysis notebook](https://neelnanda.io/exploratory-analysis-demo).

This notebook lives in my [mechanistic interpretability GitHub repository](https://github.com/SamAdamDay/mechanistic-interpretability-projects), which also has some common utilities which I import below.

# Setup

In [1]:
#@title Config options

DEVELOPMENT_MODE = False #@param {type:"boolean"}

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install https://github.com/SamAdamDay/mechanistic-interpretability-projects.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Import things

In [None]:
import random
from pathlib import Path
from typing import List, Union, Optional
from functools import partial
import copy
import itertools
import dataclasses

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

import numpy as np

import einops

from fancy_einsum import einsum

import tqdm.auto as tqdm

import plotly.express as px

from jaxtyping import Float, Int

from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import datasets

from IPython.display import HTML

import circuitsvis as cv

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

Turn automatic differentiation off

In [None]:
torch.set_grad_enabled(False)

Torch device

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

Plotting helpers

In [None]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)