In [1]:
import functools
import json
import os
from typing import Any, List, Tuple, Union
import matplotlib.pyplot as plt
import torch
import torch as t
import torch.nn.functional as F
from fancy_einsum import einsum
from sklearn.linear_model import LinearRegression
from torch import nn
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from einops import rearrange, repeat
import pandas as pd
import numpy as np

import w5d5_tests
from w5d5_transformer import ParenTransformer, SimpleTokenizer

MAIN = __name__ == "__main__"
DEVICE = t.device("cpu")


In [2]:
if MAIN:
    model = ParenTransformer(ntoken=5, nclasses=2, d_model=56, nhead=2, d_hid=56, nlayers=3).to(DEVICE)
    state_dict = t.load("w5d5_balanced_brackets_state_dict.pt")
    model.to(DEVICE)
    model.load_simple_transformer_state_dict(state_dict)
    model.eval()
    tokenizer = SimpleTokenizer("()")
    with open("w5d5_brackets_data.json") as f:
        data_tuples: List[Tuple[str, bool]] = json.load(f)
        print(f"loaded {len(data_tuples)} examples")
    assert isinstance(data_tuples, list)

class DataSet:
    '''A dataset containing sequences, is_balanced labels, and tokenized sequences'''

    def __init__(self, data_tuples: list):
        '''
        data_tuples is List[Tuple[str, bool]] signifying sequence and label
        '''
        self.strs = [x[0] for x in data_tuples]
        self.isbal = t.tensor([x[1] for x in data_tuples]).to(device=DEVICE, dtype=t.bool)
        self.toks = tokenizer.tokenize(self.strs).to(DEVICE)
        self.open_proportion = t.tensor([s.count("(") / len(s) for s in self.strs])
        self.starts_open = t.tensor([s[0] == "(" for s in self.strs]).bool()

    def __len__(self) -> int:
        return len(self.strs)

    def __getitem__(self, idx) -> Union["DataSet", tuple[str, t.Tensor, t.Tensor]]:
        if type(idx) == slice:
            return self.__class__(list(zip(self.strs[idx], self.isbal[idx])))
        return (self.strs[idx], self.isbal[idx], self.toks[idx])

    @property
    def seq_length(self) -> int:
        return self.toks.size(-1)

    @classmethod
    def with_length(cls, data_tuples: list[tuple[str, bool]], selected_len: int) -> "DataSet":
        return cls([(s, b) for (s, b) in data_tuples if len(s) == selected_len])

    @classmethod
    def with_start_char(cls, data_tuples: list[tuple[str, bool]], start_char: str) -> "DataSet":
        return cls([(s, b) for (s, b) in data_tuples if s[0] == start_char])

if MAIN:
    N_SAMPLES = 5000
    data_tuples = data_tuples[:N_SAMPLES]
    data = DataSet(data_tuples)
    "TODO: YOUR CODE HERE"

loaded 100000 examples


In [3]:
def is_balanced_forloop(parens: str) -> bool:
    '''Return True if the parens are balanced.

    Parens is just the ( and ) characters, no begin or end tokens.

    use a interger counter, increment for (, decrement for ), return False if counter < 0 at any point
    '''
    counter = 0
    for c in parens:
        if c == "(":
            counter += 1
        else:
            counter -= 1
        if counter < 0:
            return False
    return counter == 0

if MAIN:
    examples = ["()", "))()()()()())()(())(()))(()(()(()(", "((()()()()))", "(()()()(()(())()", "()(()(((())())()))"]
    labels = [True, False, True, False, True]
    for (parens, expected) in zip(examples, labels):
        actual = is_balanced_forloop(parens)
        assert expected == actual, f"{parens}: expected {expected} got {actual}"
    print("is_balanced_forloop ok!")

is_balanced_forloop ok!


In [4]:
def is_balanced_vectorized(tokens: t.Tensor) -> bool:
    '''
    tokens: sequence of tokens including begin, end and pad tokens - recall that 3 is '(' and 4 is ')'
    check if the parenthesis are balanced using a vectorized approach
    '''
    start = 1
    end = t.argmax((tokens == 2).to(t.int))
    tokens = tokens[start:end]
    tokens = t.where(tokens == 3, 1, -1)
    return t.any(t.cumsum(tokens, dim=0)<0) == False and t.sum(tokens) == 0



if MAIN:
    for (tokens, expected) in zip(tokenizer.tokenize(examples), labels):
        actual = is_balanced_vectorized(tokens)
        assert expected == actual, f"{tokens}: expected {expected} got {actual}"
    print("is_balanced_vectorized ok!")

is_balanced_vectorized ok!


In [5]:
if MAIN:
    toks = tokenizer.tokenize(examples).to(DEVICE)
    out = model(toks)
    prob_balanced = out.exp()[:, 1]
    print("Model confidence:\n" + "\n".join([f"{ex:34} : {prob:.4%}" for ex, prob in zip(examples, prob_balanced)]))

def run_model_on_data(model: ParenTransformer, data: DataSet, batch_size: int = 200) -> t.Tensor:
    '''Return probability that each example is balanced'''
    ln_probs = []
    for i in range(0, len(data.strs), batch_size):
        toks = data.toks[i : i + batch_size]
        with t.no_grad():
            out = model(toks)
        ln_probs.append(out)
    out = t.cat(ln_probs).exp()
    assert out.shape == (len(data), 2)
    return out

if MAIN:
    test_set = data
    n_correct = t.sum((run_model_on_data(model, test_set).argmax(-1) == test_set.isbal).int())
    print(f"\nModel got {n_correct} out of {len(data)} training examples correct!")

Model confidence:
()                                 : 99.9987%
))()()()()())()(())(()))(()(()(()( : 0.0003%
((()()()()))                       : 99.9987%
(()()()(()(())()                   : 0.0006%
()(()(((())())()))                 : 99.9982%

Model got 5000 out of 5000 training examples correct!


In [6]:
def get_post_final_ln_dir(model: ParenTransformer) -> t.Tensor:
    """ 
    Use the weights of the final linear layer (model.decoder) to identify the direction in the space that goes into the linear layer (and out of the LN) 
    corresponding to an 'unbalanced' classification. 
    Hint: this is a one line function.
    """
    return model.decoder.weight[0,:] - model.decoder.weight[1,:]

In [7]:
# Define a hook function fn, which takes three inputs: module, input, and output. module is the module that the hook is attached to, input is a tuple of the inputs to the module, 
# and output is the output of the module. The hook function should return None.
# Call module.register_forward_hook(fn) to attach the hook to the module. This will return a handle that you can use to remove the hook later.
# Call the run_model_on_data function to run the model on some data. The hook will be called on every forward pass of the model.
# Call handle.remove() to remove the hook.

def get_inputs(model: ParenTransformer, data: DataSet, module: nn.Module) -> t.Tensor:
    '''
    Get the inputs to a particular submodule of the model when run on the data.
    Returns a tensor of size (data_pts, seq_pos, emb_size).
    '''
    # create a list for the input to the module
    inputs = []
    def fn(module, input, output):
        # get the input to the module
        # store it in a global variable
        inputs.append(input[0].detach().clone())	
    # register the hook and get the handle
    handle = module.register_forward_hook(fn)
    # run the model on the data
    _ = run_model_on_data(model, data)


    # remove the hook
    handle.remove()

    return t.cat(inputs, dim=0)
    

def get_outputs(model: ParenTransformer, data: DataSet, module: nn.Module) -> t.Tensor:
    '''
    Get the outputs from a particular submodule of the model when run on the data.
    Returns a tensor of size (data_pts, seq_pos, emb_size).
    '''
    outputs = []
    def fn(module, input, output):
        # get the output from the module
        # store it in a global variable
        outputs.append(output.detach().clone())
    # register the hook and get the handle 
    handle = module.register_forward_hook(fn)
    # run the model on the data
    _ = run_model_on_data(model, data)
    # remove the hook
    handle.remove()
    return t.cat(outputs, dim=0)

if MAIN:
    w5d5_tests.test_get_inputs(get_inputs, model, data)
    w5d5_tests.test_get_outputs(get_outputs, model, data)

All tests in `test_get_inputs` passed.
All tests in `test_get_outputs` passed.


In [8]:
# Now, use these functions and the sklearn LinearRegression class to find a linear fit to the inputs and outputs of model.norm.
# The argument below takes seq_pos as an input. If this is an integer, then we are fitting only for that sequence position. 
# If seq_pos = None, then we are fitting for all sequence positions (we aggregate the sequence and batch dimensions before performing our regression). 
# You should include a fit coefficient in your linear regression (this is the default for LinearRegression).

def get_ln_fit(
    model: ParenTransformer, data: DataSet, ln_module: nn.LayerNorm, seq_pos: Union[None, int]
) -> Tuple[LinearRegression, t.Tensor]:
    '''
    if seq_pos is None, find best fit for all sequence positions. Otherwise, fit only for given seq_pos.

    Returns: A tuple of a (fitted) sklearn LinearRegression object and a dimensionless tensor containing the r^2 of the fit (hint: wrap a value in torch.tensor() to make a dimensionless tensor)
    '''
    # get the inputs and outputs of the layer norm
    inputs = get_inputs(model, data, ln_module)
    outputs = get_outputs(model, data, ln_module)
    # define linear regression object with sklearn
    lr = LinearRegression()
    # if seq_pos is None, aggregate the batch dimension(dim=0) and sequence dimension(dim=1) before fitting
    if seq_pos is None:
        inputs = inputs.view(-1, inputs.shape[-1])
        outputs = outputs.view(-1, outputs.shape[-1])
    # if seq_pos is an integer, only fit for that sequence position
    else:
        inputs = inputs[:,seq_pos,:]
        outputs = outputs[:,seq_pos,:]
    # fit the linear regression
    lr.fit(inputs, outputs)
    # get the r^2 of the fit
    r2 = lr.score(inputs, outputs)
    # return the linear regression object and the r^2
    return lr, t.tensor(r2)




if MAIN:
    (final_ln_fit, r2) = get_ln_fit(model, data, model.norm, seq_pos=0)
    print("r^2: ", r2)
    w5d5_tests.test_final_ln_fit(model, data, get_ln_fit)

r^2:  tensor(0.9820, dtype=torch.float64)
All tests in `test_final_ln_fit` passed.


In [9]:
# Question, why is seq_pos=0?
# Answer, because we are only fitting for the first sequence position.

def get_pre_final_ln_dir(model: ParenTransformer, data: DataSet) -> t.Tensor:
    '''
    Use the function get_post_final_ln_dir to get the direction in the space that goes into the final linear layer (and out of the LN) corresponding to an 'unbalanced' classification. 
    Then, use the function get_ln_fit to find a linear fit to the inputs and outputs of model.norm.
    '''
    # get the direction in the space that goes into the final linear layer (and out of the LN) corresponding to an 'unbalanced' classification
    post_final_ln_dir = get_post_final_ln_dir(model)
    # get the linear fit to the inputs and outputs of model.norm
    (final_ln_fit, r2) = get_ln_fit(model, data, model.norm, seq_pos=0)
    # get the coefficients of the linear fit
    final_ln_coefs = t.tensor(final_ln_fit.coef_)
    pre_final_ln_dir = final_ln_coefs.T @ post_final_ln_dir
    # return the direction
    return pre_final_ln_dir

if MAIN:
    w5d5_tests.test_pre_final_ln_dir(model, data, get_pre_final_ln_dir)

All tests in `test_pre_final_ln_dir` passed.


In [10]:
# These functions can be used to capture the output of an MLP or an attention layer. 
# However, we also want to be able to get the output of an individual attention head.

# Write a function that returns the output by head instead, for a given layer. You'll need the hook functions you wrote earlier.

# Each of the linear layers in the attention layers have bias terms. 
# For getting the output by head, we will ignore the bias that comes from model.W_O, since this is not cleanly attributable to any individual head.

# Reminder: PyTorch stores weights for linear layers in the shape (out_features, in_features).

def get_out_by_head(model: ParenTransformer, data: DataSet, layer: int) -> t.Tensor:
    '''
    Get the output of the heads in a particular layer when the model is run on the data.
    Returns a tensor of shape (batch, num_heads, seq, embed_width)
    '''
    # find the arrention layer in the model based on the layer number
    attn_layer = model.layers[layer].self_attn
    # get the output of the attention layer using the get_inputs function
    input_pre_W_O = get_inputs(model, data, attn_layer.W_O)
    # reshape the output to (batch, num_heads, seq, embed_width)    

    input_pre_W_O = rearrange(input_pre_W_O, 'b s (h e) -> b h s e', h=attn_layer.num_heads)
    # get the weights of W_O
    W_O_weights = rearrange(attn_layer.W_O.weight, 'o (h e) -> h o e', h=attn_layer.num_heads)
    # return the output
    return einsum('b h s e, h o e -> b h s o', input_pre_W_O, W_O_weights)

if MAIN:
    w5d5_tests.test_get_out_by_head(get_out_by_head, model, data)

All tests in `test_get_out_by_head` passed.


Breaking down the residual stream by component
Use your hook tools to create a tensor of shape [num_components, dataset_size, seq_pos], where the number of components = 10.

This is a termwise representation of the input to the final layer norm from each component (recall that we can see each head as writing something to the residual stream, which is eventually fed into the final layer norm). The order of the components in your function's output should be:

embeddings, i.e. the sum of token and positional embeddings (corresponding to the direct path through the model, which avoids all attention layers and MLPs)
For each of the layers layer = 0, 1, 2:
Head layer.0, i.e. the output of the first attention head in layer layer
Head layer.1, i.e. the output of the second attention head in layer layer
MLP layer, i.e. the output of the MLP in layer layer
(The only term missing the W_O-bias from each of the attention layers).

In [11]:
def get_out_by_components(model: ParenTransformer, data: DataSet) -> t.Tensor:
    '''
    Computes a tensor of shape [10, dataset_size, seq_pos, emb] representing the output of the model's components when run on the data.
    The first dimension is  [embeddings, head 0.0, head 0.1, mlp 0, head 1.0, head 1.1, mlp 1, head 2.0, head 2.1, mlp 2]
    '''

    # create a tensor for the ouput of shape [10, dataset_size, seq_pos, emb]
    seq_len = data.seq_length
    out_by_components = t.zeros((10, len(data), seq_len, model.encoder.embedding_dim))
    layers = model.layers
    
    # get the word and position embeddings
    embeddings = get_outputs(model, data, model.pos_encoder)
    # fill in the first component, the embeddings
    out_by_components[0] = embeddings

    for layer_idx, layer in enumerate(layers):
        # get the output of the attention layer by head
        attn_out_by_head = get_out_by_head(model, data, layer_idx)
        # fill in the layer_idx*3+1 and layer_idx*3 + 2 components, the attention heads
        out_by_components[layer_idx*3+1] = attn_out_by_head[:,0]
        out_by_components[layer_idx*3+2] = attn_out_by_head[:,1]
        # get the output of the MLP
        mlp_out = get_outputs(model, data, layer.linear2)
        # fill in the layer_idx*3+3 component, the MLP
        out_by_components[layer_idx*3+3] = mlp_out
    # return the output
    return out_by_components


if MAIN:
    w5d5_tests.test_get_out_by_component(get_out_by_components, model, data)

All tests in `test_get_out_by_component` passed.


In [12]:
if MAIN:
    biases = sum([model.layers[l].self_attn.W_O.bias for l in (0, 1, 2)]).clone()
    out_by_components = get_out_by_components(model, data)
    summed_terms = t.sum(out_by_components, dim=0) + biases
    pre_final_ln = get_inputs(model, data, model.norm)
    t.testing.assert_close(summed_terms, pre_final_ln)

In [13]:
def hists_per_comp(magnitudes, data):
    num_comps = magnitudes.shape[0]
    titles = {
        (1, 1): "embeddings",
        (2, 1): "head 0.0",
        (2, 2): "head 0.1",
        (2, 3): "mlp 0",
        (3, 1): "head 1.0",
        (3, 2): "head 1.1",
        (3, 3): "mlp 1",
        (4, 1): "head 2.0",
        (4, 2): "head 2.1",
        (4, 3): "mlp 2"
    }
    assert num_comps == len(titles)

    fig = make_subplots(rows=4, cols=3)
    for ((row, col), title), mag in zip(titles.items(), magnitudes):
        fig.add_trace(go.Histogram(x=mag[data.isbal].numpy(), name="Balanced", marker_color="blue", opacity=0.5, legendgroup = '1', showlegend=title=="embeddings"), row=row, col=col)
        fig.add_trace(go.Histogram(x=mag[~data.isbal].numpy(), name="Unbalanced", marker_color="red", opacity=0.5, legendgroup = '2', showlegend=title=="embeddings"), row=row, col=col)
        fig.update_xaxes(title_text=title, range=[-10, 20], row=row, col=col)
    fig.update_layout(width=1200, height=1200, barmode="overlay", legend=dict(yanchor="top", y=0.92, xanchor="left", x=0.4), title="Histograms of component significance")
    fig.show()
    return fig

if MAIN:
        # Get output by components at the 0th sequence position
    out_by_components = get_out_by_components(model, data)[:, :, 0, :].detach()
    # Get unbalanced directions for balanced and unbalanced respectively
    unbalanced_dir = get_pre_final_ln_dir(model, data).detach()
    # Get magnitudes, and plot them
    magnitudes = einsum("component sample emb, emb -> component sample", out_by_components, unbalanced_dir)
    # Subtract the mean of the balanced magnitudes from each component
    magnitudes = magnitudes - magnitudes[:, data.isbal].mean(-1, keepdim=True)
    hists_per_comp(magnitudes, data)

In [24]:
def get_negative_failure(data):
    negative_failure = []
    for batch in data:
        parens = batch[0]
        counter = 0
        if not parens:
            negative_failure.append(False)
            continue
        for c in reversed(parens):
            if c == ")":
                counter += 1
            elif c == "(":
                counter -= 1
            if counter < 0:
                negative_failure.append(True)
                break
        else:
            negative_failure.append(False)
    return np.array(negative_failure)

def get_total_elevation_failure(data):
    total_elevation_failure = []
    for batch in data:
        parens = batch[0]
        counter = 0
        if not parens:
            total_elevation_failure.append(False)
            continue
        for c in reversed(parens):
            if c == ")":
                counter += 1
            elif c == "(":
                counter -= 1
        if counter != 0:
            total_elevation_failure.append(True)
        else:
            total_elevation_failure.append(False)
    return np.array(total_elevation_failure)


if MAIN:
    negative_failure = get_negative_failure(data)
    total_elevation_failure = get_total_elevation_failure(data)
    assert len(negative_failure) == len(total_elevation_failure)
    h20_in_d = magnitudes[7] - magnitudes[7, data.isbal].mean(0)
    h21_in_d = magnitudes[8] - magnitudes[8, data.isbal].mean(0)

    failure_types = np.full(len(h20_in_d), "", dtype=np.dtype("U32"))
    failure_types_dict = {
        "both failures": negative_failure & total_elevation_failure,
        "just neg failure": negative_failure & ~total_elevation_failure,
        "just total elevation failure": ~negative_failure & total_elevation_failure,
        "balanced": ~negative_failure & ~total_elevation_failure
    }
    for name, mask in failure_types_dict.items():
        failure_types = np.where(mask, name, failure_types)
    failures_df = pd.DataFrame({
        "Head 2.0 contribution": h20_in_d,
        "Head 2.1 contribution": h21_in_d,
        "Failure type": failure_types
    })[data.starts_open.tolist()]
    fig = px.scatter(
        failures_df, 
        x="Head 2.0 contribution", y="Head 2.1 contribution", color="Failure type", 
        title="h20 vs h21 for different failure types", template="simple_white", height=600, width=800,
        category_orders={"color": failure_types_dict.keys()}
    ).update_traces(marker_size=4)
    fig.show()

In [25]:
if MAIN:
    fig = px.scatter(
        x=data.open_proportion, y=h20_in_d, color=failure_types, 
        title="Head 2.0 contribution vs proportion of open brackets '('", template="simple_white", height=500, width=800,
        labels={"x": "Open-proportion", "y": "Head 2.0 contribution"}, category_orders={"color": failure_types_dict.keys()}
    ).update_traces(marker_size=4, opacity=0.5).update_layout(legend_title_text='Failure type')
    fig.show()

### Attention pattern of the responible head

In [36]:
def get_attn_probs(model: ParenTransformer, tokenizer: SimpleTokenizer, data: DataSet, layer: int, head: int) -> t.Tensor:
    '''
    Returns: (N_SAMPLES, max_seq_len, max_seq_len) tensor that sums to 1 over the last dimension.
    '''
    output = t.zeros((len(data), data.seq_length, data.seq_length))
    # first get the inputs
    inputs = get_inputs(model, data, model.layers[layer].self_attn)
    masks = data.toks == 1
    # then apply the function
    attention_pattern = model.layers[layer].self_attn.attention_pattern_pre_softmax(inputs).detach()

    attention_pattern = attention_pattern[:,head,...]
    # apply padding mask
    masked_attention_pattern = einsum("sample first_seq_len seq_len, sample seq_len -> sample first_seq_len seq_len", attention_pattern, masks)
    # apply softmax
    return masked_attention_pattern.softmax(-1)

    

if MAIN:
    attn_probs = get_attn_probs(model, tokenizer, data, 2, 0)
    attn_probs_open = attn_probs[data.starts_open].mean(0)[[0]]
    fig = px.bar(
        y=attn_probs_open.squeeze().numpy(), labels={"y": "Probability", "x": "Key Position"},
        template="simple_white", height=500, width=600, title="Avg Attention Probabilities for '(' query from query 0"
    ).update_layout(showlegend=False, hovermode='x unified')
    # show px plot inline
    fig.show()