# 1.5.1: Balanced Bracket Classifier

In [2]:
import json
import sys
import os
from functools import partial
from pathlib import Path

import circuitsvis as cv
import einops
import torch as t
from IPython.display import display
from jaxtyping import Bool, Float, Int
from sklearn.linear_model import LinearRegression
from torch import Tensor, nn
from tqdm import tqdm
from transformer_lens import ActivationCache, HookedTransformer, HookedTransformerConfig, utils
from transformer_lens.hook_points import HookPoint

# Make sure exercises are in the path
chapter = r"chapter1_transformer_interp"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "part51_balanced_bracket_classifier"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

import plotly_utils
from plotly_utils import hist, bar, imshow
import part51_balanced_bracket_classifier.tests as tests
from part51_balanced_bracket_classifier.brackets_datasets import SimpleTokenizer, BracketsDataset

MAIN = __name__ == "__main__"

device = t.device('mps' if t.backends.mps.is_available() else 'cuda' if t.cuda.is_available() else 'cpu')
print("using device: ", device)

using device:  mps


# 1. Bracket classifier

## Loading and running the model

In [7]:
VOCAB = "()"

cfg = HookedTransformerConfig(
    n_ctx=42,
    d_model=56,
    d_head=28,
    n_heads=2,
    d_mlp=56,
    n_layers=3,
    attention_dir="bidirectional", # defaults to "causal"
    act_fn="relu",
    d_vocab=len(VOCAB)+3, # plus 3 because of end and pad and start token
    d_vocab_out=2, # 2 because we're doing binary classification
    use_attn_result=True,
    device=device,
    use_hook_tokens=True
)

model = HookedTransformer(cfg).eval()

state_dict = t.load(section_dir / "brackets_model_state_dict.pt", weights_only=True, map_location=device)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [19]:
tokenizer = SimpleTokenizer("()")

def pretty_print_dict(d):
    print('{')
    for k in sorted(d):
        print(f"  {k} : {d[k]},")
    print('}')

# Examples of tokenization
# (the second one applies padding, since the sequences are of different lengths)
print(tokenizer.tokenize("()"))
print(tokenizer.tokenize(["()", "()()"]))

# Dictionaries mapping indices to tokens and vice versa
pretty_print_dict(tokenizer.i_to_t)
pretty_print_dict(tokenizer.t_to_i)

# Examples of decoding (all padding tokens are removed)
print(tokenizer.decode(t.tensor([[0, 3, 4, 2, 1, 1]])))

tensor([[0, 3, 4, 2]])
tensor([[0, 3, 4, 2, 1, 1],
        [0, 3, 4, 3, 4, 2]])
{
  0 : [start],
  1 : [pad],
  2 : [end],
  3 : (,
  4 : ),
}
{
  ( : 3,
  ) : 4,
  [end] : 2,
  [pad] : 1,
  [start] : 0,
}
['()']


In [20]:
def add_perma_hooks_to_mask_pad_tokens(model: HookedTransformer, pad_token: int) -> HookedTransformer:

    # Hook which operates on the tokens, and stores a mask where tokens equal [pad]
    def cache_padding_tokens_mask(tokens: Float[Tensor, "batch seq"], hook: HookPoint) -> None:
        hook.ctx["padding_tokens_mask"] = einops.rearrange(tokens == pad_token, "b sK -> b 1 1 sK")

    # Apply masking, by referencing the mask stored in the `hook_tokens` hook context
    def apply_padding_tokens_mask(
        attn_scores: Float[Tensor, "batch head seq_Q seq_K"],
        hook: HookPoint,
    ) -> None:
        attn_scores.masked_fill_(model.hook_dict["hook_tokens"].ctx["padding_tokens_mask"], -1e5)

        # delete cached padding mask if this is the last attention layer
        if hook.layer() == model.cfg.n_layers - 1:
            del model.hook_dict["hook_tokens"].ctx["padding_tokens_mask"]

    # Add these hooks as permanent hooks (i.e. they aren't removed after functions like run_with_hooks)
    for name, hook in model.hook_dict.items():
        if name == "hook_tokens":
            hook.add_perma_hook(cache_padding_tokens_mask)
        elif name.endswith("attn_scores"):
            hook.add_perma_hook(apply_padding_tokens_mask)

    return model


model.reset_hooks(including_permanent=True)
model = add_perma_hooks_to_mask_pad_tokens(model, tokenizer.PAD_TOKEN)

In [22]:
N_SAMPLES = 5000
with open(section_dir / "brackets_data.json") as f:
    data_tuples: tuple[str, bool] = json.load(f)
    print(f"loaded {len(data_tuples)} examples")
assert isinstance(data_tuples, list)
data_tuples = data_tuples[:N_SAMPLES]
data = BracketsDataset(data_tuples).to(device)
data_mini = BracketsDataset(data_tuples[:100]).to(device)

loaded 100000 examples


In [23]:
hist(
    [len(x) for x, _ in data_tuples],
    nbins=data.seq_length,
    title="Sequence lengths of brackets in dataset",
    labels={"x": "Seq len"}
)

Note how many length-2 examples there are -- I guess it's useful to really drill in that `()` is good and `)(`, `((`, and `))` are potentially troublesome?

Also, all the examples are even-length.

Reason: a single attention head can classify odd vs even-length sequences, which would be a really easy way to get those right

In [26]:
# Define and tokenize examples
examples = ["()()", "(())", "))((", "()", "((()()()()))", "(()()()(()(())()", "()(()(((())())()))"]
labels = [True, True, False, True, True, False, True]
toks = tokenizer.tokenize(examples)

# Get output logits for the 0th sequence position (i.e. the [start] token)
logits = model(toks)[:, 0]

# Get the probabilities via softmax, then get the balanced probability (which is the second element)
prob_balanced = logits.softmax(-1)[:, 1]

# Display output
print("Model confidence:\n" + "\n".join([f"{ex:18} : {prob:<8.4%} : label={int(label)}" for ex, prob, label in zip(examples, prob_balanced, labels)]))

Model confidence:
()()               : 99.9986% : label=1
(())               : 99.9989% : label=1
))((               : 0.0005%  : label=0
()                 : 99.9987% : label=1
((()()()()))       : 99.9987% : label=1
(()()()(()(())()   : 0.0006%  : label=0
()(()(((())())())) : 99.9982% : label=1


In [27]:
def run_model_on_data(model: HookedTransformer, data: BracketsDataset, batch_size: int = 200) -> Float[Tensor, "batch 2"]:
    """Return probability that each example is balanced"""
    all_logits = []
    for i in tqdm(range(0, len(data.strs), batch_size)):
        toks = data.toks[i : i + batch_size]
        logits = model(toks)[:, 0]
        all_logits.append(logits)
    all_logits = t.cat(all_logits)
    assert all_logits.shape == (len(data), 2)
    return all_logits


test_set = data
n_correct = (run_model_on_data(model, test_set).argmax(-1).bool() == test_set.isbal).sum()
print(f"\nModel got {n_correct} out of {len(data)} training examples correct!")

100%|██████████| 25/25 [00:01<00:00, 14.74it/s]



Model got 5000 out of 5000 training examples correct!


## Algorithmic Solutions

In [28]:
def is_balanced_forloop(parens: str) -> bool:
    n_open = 0
    for p in parens:
        if p == '(':
            n_open += 1
        elif n_open > 0:
            n_open -= 1
        else:
            return False
    return n_open == 0

def is_balanced_forloop_cleaner(parens: str) -> bool:
    """From solutions"""
    cumsum = 0
    for p in parens:
        cumsum += 1 if paren == '(' else -1
        if cumsum < 0:
            return False
    return cumsum == 0

for (parens, expected) in zip(examples, labels):
    actual = is_balanced_forloop(parens)
    assert expected == actual, f"{parens}: expected {expected} got {actual}"
print("is_balanced_forloop ok!")

is_balanced_forloop ok!


In [51]:
tokens = t.tensor([0, 3, 4, 3, 4, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
opens = (tokens == 3).int()
closes = (tokens == 4).int()
(opens - closes).cumsum(0)[tokens == 2]

tensor([0])

In [54]:
def is_balanced_vectorized(tokens: Float[Tensor, "seq_len"]) -> bool:
    """
    Returns True if the parens are balanced.

    tokens is a vector which has start/pad/end indices (0/1/2/)
    as well as left/right token brackets (3/4)
    """
    opens = (tokens == 3).int()
    closes = (tokens == 4).int()
    cumsum = (opens - closes).cumsum(0)
    if t.any(cumsum < 0): return False
    return not cumsum[-1].bool().item()

for (tokens, expected) in zip(tokenizer.tokenize(examples), labels):
    actual = is_balanced_vectorized(tokens)
    assert expected == actual, f"{tokens}: expected {expected} got {actual}"
print("is_balanced_vectorized ok!")

is_balanced_vectorized ok!
