## Challenge

For this challenge you're given a model trained on classifying sequences of numbers as palindromes (more details below). However, after finishing training I zeroed out the weights of one of its four heads (Head 0 in layer 1) to disrupt its performace. Your task is to modify the weights of the missing head (and only the missing head) such that the model can perform the task correctly again.  

Your solution will be evaluated on a variety of palindrome and non-palindrome sequences to verify that it produces the correct result. In all cases it should product an average logit difference greater than 4 in the direction of the correct class, and the 90% confidence interval of the logit difference evaluated over a batch should not include zero. You can see the specific details of how your model will be evaluated at `test_palindrome_zero_h10.py`.

The original model for this challenge comes from the collection of [monthtly algoritmic problems](https://arena-ch1-transformers.streamlit.app/Monthly_Algorithmic_Problems) by Callum McDougall. While the original challenge was more open ended (reverse engineering the model), this version is adapted to have.

## Model Task & Dataset   

(Taken from Callum's original challenge without quotation marks)

This directory contains all the relevant files for the challenge, including `model.py` (for defining the model), `training.py` (for training the model), and `dataset.py` (for the dataset of palindromes and non-palindromes).

Each sequence in the dataset looks like:

```
[start_token, a_1, a_2, ..., a_N, end_token]
```

where `start_token = 31`, `end_token = 32`, and each value `a_i` is a value in the range `[0, 30]` inclusive. 

Each sequence has a corresponding label, which is `1` if the sequence is a palindrome (i.e. `(a_1, a_2, ..., a_N) == (a_N, ..., a_2, a_1)`), and `0` otherwise. The model has been trained to classify each sequence according to this label.

We've given you the class `PalindromeDataset` to store your data. You can slice this object to get batches of tokens and labels. You can also use the function `display_seq` to display a sequence in a more readable format (with any tokens that stop it from being a palindrome highlighted). There's an example later on this page. 

Some other useful methods and attributes of this dataset (you can inspect `dataset.py` to see for yourself) are:

* `dataset.toks`, to get a batch of all the tokens in the dataset, of shape `(size, 2 * half_length + 2)`.
* `dataset.is_palindrome`, to get a tensor of all the labels in the dataset, of shape `(size,)`.
* `dataset.str_toks`, to get a list of lists, with string representations of each sequence, e.g. `["START", "1", "4", ..., "END"]`. This is useful for visualisation, e.g. circuitsvis.

Our model was trained by minimising cross-entropy loss between its predictions and the true labels. You can inspect the notebook `training_model.ipynb` to see how it was trained.

The model is is a 2-layer transformer with 2 attention heads, and causal attention. It includes layernorm, but no MLP layers. You can load it in as follows:


In [1]:
import os; os.environ["ACCELERATE_ENABLE_RICH"] = "0"
from pathlib import Path
import torch
from dataset import PalindromeDataset, display_seq

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
current_dir = Path(os.getcwd()).resolve()

In [2]:
from transformer_lens import HookedTransformer, HookedTransformerConfig
import torch

half_length = 10
max_value = 30

cfg = HookedTransformerConfig(    
    n_layers=2,
    n_ctx=2*half_length+2, # also have [START] and [END] tokens
    d_model=28,
    d_head=14,
    n_heads=2,
    d_mlp=None,
    attn_only=True,
    act_fn="relu",

    # vocab is [START] and [END] token plus (0, 1, ..., max_height)
    # result is a True/False classification
    d_vocab=max_value+3,
    d_vocab_out=2, 

    # it's a small transformer so may as well use these hooks
    use_attn_result=True,
    use_split_qkv_input=True,
    use_hook_tokens=True,

    # Layernorm makes things way more accurate, even though it makes
    # mech interp a little more annoying!
    normalization_type='LN',
    device=device,
)

model = HookedTransformer(cfg)

filename = current_dir / "models" / "palindrome_classifier.pt" # or palindrome_H0_0_zero.pt
state_dict = torch.load(filename)

state_dict = model.center_writing_weights(torch.load(filename))
state_dict = model.center_unembed(state_dict)
state_dict = model.fold_layer_norm(state_dict)
state_dict = model.fold_value_biases(state_dict)
model.load_state_dict(state_dict, strict=False)

_IncompatibleKeys(missing_keys=['blocks.0.ln1.w', 'blocks.0.ln1.b', 'blocks.1.ln1.w', 'blocks.1.ln1.b', 'ln_final.w', 'ln_final.b'], unexpected_keys=[])

The model was trained to output the correct classification at the `END` token, in other words the value of the residual stream at `END` (post-layernorm) is mapped through `model.W_U` which has shape `(d_model, 2)`, and this gives us our classification logits for `(not palindrome, palindrome)`.

A demonstration of the model working (and of the `display_seq` function):

In [None]:
dataset = PalindromeDataset(size=100, max_value=30, half_length=10)

toks, is_palindrome = dataset[:5]

logits = model(toks)[:, -1]
probs = logits.softmax(-1)
probs_palindrome = probs[:, 1]

for tok, prob in zip(toks, probs_palindrome):
    display_seq(tok, prob)

Once you have a working solution you save your model using the following command and evaluate it on the terminal by running `pytest test_palindrome_zero_h10.py`

In [None]:
model_solution = None # Your modified model
torch.save(model_solution.state_dict(), 'models/solution_pal_h10.pth')

In the file `dataset.py` you can see the functions that will be used to generate data for the patches. Your task is to fill in the function `predict_patching_result`, which has to predict for every path patch from head 0.0 to 0.1 whether the model will predict the sequences are palindrome or not. 

In the file `test_patching_result_fn.py` you can see an example of the tests used to check your function. Your function will only be tested on cases where the model clearly predicts sequences as palindromes or not [1], so you don't have to worry about ambiguous cases where the model would be confused. 

[1] I operatinalize this as cases where the absolute value of the logit difference is greater than 3 and the 90% confidence interval of logit does not include zero. 

In [None]:
def predict_patching_result(k_dataset: Float[Tensor, 'seq'],
                            q_dataset: Float[Tensor, 'seq'],
                            v_dataset: Float[Tensor, 'seq'],
    ) -> bool:
    pass