# MATS Application - Christophe Thomassin: <br> Factual Knowledge in Transformer Language Model

Note: I am gonna try to back-up my code with some comments for better understandanding as time allows. Although I want to refer to the Google Docs document for an in-depth analysis: [Link](https://docs.google.com/document/d/13CZcQ818lBNqBGEqn3wxq3iVc7uHldprSr2tgTrBUr8/edit?usp=sharing)

**The Task**: <br>
I am interested in understanding how Transformer Language Model store and retrieve knowledge, to be specific factual knowledge. Factual knowledge is essential to all possible information processing tasks as it can be considered the underlying assumptions of pretty much every step of reasoning. So, if LLMs were not able to acquire factual knowledge, they would be not much more than random word generators. I have three reasons why I think applying MI to understand how LLMs acquire and store factual knowledge is a good idea:
1. I see facutal knowledge as the easiest form of intelligence, the entry point to cognitive capabilities. Hence, I find it intuitive to start with examining factual knowledge when trying "mechanistically" unwined the complexity of LLMs. Once understood, factual knowledge will open a lot of doors to dig deeper into the "mind" of Transformers.
2. Factual knowledge, or in this case the lack of it, can be considered a large driver of Hallucinations. Understanding how and where factual knowledge is stored could allow us to remediate many Hallucinations.
3. Factual knowledge, as the name indicates, is based on commonly-known facts which makes it quite easy to evaluate factual knowledge (one would think). A fact can only be True or False.

One of the main drawbacks is that, as one could imagine, we cannot expect factual knowledge to be "universal". En contraire, because of its factual nature it is most definitely highly dependent of the training data. A model cannot reason that Paris is the capital of France, without being told that it is during pre-training. Hence, we might see that some models have developed specific factual knowledge while others have not. Yet, my hope is that the mechanisims allowing to store and retrieve factual knowledge are somewhat universal.

## Setup

In [1]:
# allows to reload packages when reimported
%load_ext autoreload
%autoreload 2

In [63]:
import json
import os

import numpy as np
import pandas as pd
import torch
import transformer_lens.utils as utils
from dotenv import load_dotenv
from fancy_einsum import einsum
from huggingface_hub import login
from transformer_lens import HookedTransformer

from src.config import DATA

In [3]:
load_dotenv()

login(token=os.getenv("HF_TOKEN"))

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /home/chrisoutho_gmail_com/.cache/huggingface/token
Login successful


In [4]:
torch.set_grad_enabled(False)

device: torch.device = utils.get_device()
print(f"Device: {device}")

Device: cuda


### Load data

In [94]:
# load data
df = pd.read_csv(DATA / "fk_samples.csv")
print(df.head())

                                   question  answer
0         Is Paris the capital of France?\n    True
1   Is the Eiffel Tower located in Paris?\n    True
2    Is Paris known as the City of Light?\n    True
3               Is Paris a city in Italy?\n   False
4  Is the Louvre Museum located in Paris?\n    True


In [20]:
with open(DATA / "answer_map.json", "r") as f:
    answer_map = json.load(f)

print(answer_map)

{'True': ['Yes', 'Sure', 'Correct', 'Certainly', 'Absolutely', 'Indeed', 'True', 'Yep'], 'False': ['No', 'no', 'Nope', ' No', ' no', 'Wrong', 'NO', 'False']}


### Load model

In [7]:
model = HookedTransformer.from_pretrained("google/gemma-2b-it", device=device)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Loaded pretrained model google/gemma-2b-it into HookedTransformer


### Get model activations

In [8]:
from src.logit_diff import df_to_logits

logits, cache = df_to_logits(model, df)

In [21]:
from src.utils import answer_map_to_tokens

answer_map_tokens = answer_map_to_tokens(model, answer_map)

## Task definition

As said before, we want to understand how the model stores and retrieves factual knowledge. To do so, as in every MI investigation, we have to define a task that allows us to evaluate LLM's on our task, factual knowledge. There are a few things to consider when engineering this task:
1. Most importantly: We want to isolate the concept we are testing for, here, factual knowledge, as much as possible. Good performance on our task should imply that the tasked model is good at the exact concept we are interested in. Conversly, if a model does not perform good on our task, this should imply this model is not good at this concept.
2. Our generation should only be one token long for easy evaluation. It would be nice if one could determine another token that would be expected if the model were not able to correctly answer the query. These two tokens can be used for introducing logit difference as evaluation metric to the task.
3. It should be easy to scale the task to many different examples to rule out randomness.
4. It should be possible, only by changing a few tokens, to remove the concept from the task and make it random. This alternative distributions is needed for path patching and ablation.

For this time-constrained assessment, I will have to focus on a specific example. I chose to inspect factual knowledge around the French city "Paris". Paris is well-known enough to be in pretty much every pre-training dataset, yet yields enough diversity to ask mutiple questions in order to prove generalisation.

Under consideration of all these requirements, I selected a question-answer format for factual knowledge as task. This allows to isolate specific factual knowledge, yields one-token generations, scales pefectly, and makes it possible to "disarm" the query by removing the "key".

An example, which I will use a lot throughout this assessment, is: <br>
    - Query: Is Paris the capital of France? <br>
    - Label: True <br>
    - Opposite Generation: False <br>
    - Clue: capital, France <br>

I admit that this task has the drawback that I am quite sure that question-answering is a unique circuit worth studying by itself. I expect the question-answering circuits and the factual knowledge circuit to have an overlap which means I did not 100% isolate our concept. Nonetheless, after a lot of back-and-forth, I believe this is the best way of testing factual knowledge for now.

### Single example

Let's see if the model is able to confidently solve this task in the first place:

In [30]:
query = "Is Paris the capital of France?\n"
ground_truth = "Yes"

<details><summary>Note</summary>
It seems like I have to add the '\n' tokens at the end of the query for this model to answer properly.

In [31]:
utils.test_prompt(query, ground_truth, model, prepend_bos=True, prepend_space_to_answer=False)

Tokenized prompt: ['<bos>', 'Is', ' Paris', ' the', ' capital', ' of', ' France', '?', '\n']
Tokenized answer: ['Yes']


Top 0th token. Logit: 32.54 Prob: 55.63% Token: |Sure|
Top 1th token. Logit: 32.09 Prob: 35.34% Token: |Yes|
Top 2th token. Logit: 29.96 Prob:  4.22% Token: |No|
Top 3th token. Logit: 29.90 Prob:  3.96% Token: |Paris|
Top 4th token. Logit: 26.54 Prob:  0.14% Token: |Answer|
Top 5th token. Logit: 26.24 Prob:  0.10% Token: |Certainly|
Top 6th token. Logit: 26.02 Prob:  0.08% Token: |The|
Top 7th token. Logit: 25.92 Prob:  0.07% Token: |Of|
Top 8th token. Logit: 25.76 Prob:  0.06% Token: |Correct|
Top 9th token. Logit: 25.70 Prob:  0.06% Token: |*|


While this looks good, we see the first "inconvenience" in our task here. The model is more than capable of answering the query. Yet, it is not able to concentrate all its probability mass on one "True-token" such as "Yes" but spreads it among multiple such as "Sure" (56% probability), "Yes" (35% probability), and "Certainly" (0.1% probability). This is most likely because assertion of something being true is most definitely done by using all these words in the training data. While this does not make the generation wrong, it raises two questions:
1. Is the model abel to understand that all these tokens express "True" and build a circuit out of that?
2. How do we evaluate the task if we cannot measure it's outcome based on the logit/probability of a single token?

For the first question, I do not have a concrete answer but, for the sake of this experiment, I will just hope that the model chosen here is indeed able to abstract the answers into "Fact is True" and "Fact is False". Larger models such as ChatGPT 4o are definitely able to do so. This model achieved close to 100% accuracy on 10 simple factual knowledge questions around Paris under the condition of answering with True or False only. For the second question, we will have to find a solution to capture all possible True- and False-tokens

Going through a bunch of example questions, I have mapped all True- and Wrong-Tokens to their corresponding underlying meaning. Although, computationally, it is not very elegant (especially bc lists have diff lenghts), we will have to always evalaute the task on all these tokens.

In [29]:
answer_map

{'True': ['Yes',
  'Sure',
  'Correct',
  'Certainly',
  'Absolutely',
  'Indeed',
  'True',
  'Yep'],
 'False': ['No', 'no', 'Nope', ' No', ' no', 'Wrong', 'NO', 'False']}

Now the question is, can we still use logit difference to quantify the probability mass conentrated on each answer if we consider groups of tokens? When comparing logits of two tokens (say $diff = logit_1 - logit_2$), we know that the token with the higher logit will have a $e^{diff}$ higher token probability. For groups of tokens this is not so straightforward because the softmax is not a linear operator. Yet, under certain assumptions, the logit difference also has value for comparing probability mass assigned to groups of tokens. Let's do a little back of envelope calculation:

Say $T$ is the set of tokens describing True and $W$ is the set of tokens describing wrong. Furthermore, let's assume $|T| = |W| = n$. $U$ is the set of all tokens in the vocabulary.

The difference $diff$ we are now computing is $\text{diff} = \sum_{i \in T} \text{logit}_i - \sum_{j \in W} \text{logit}_j = n \left( \overline{\text{logit}_T} - \overline{\text{logit}_W} \right)$

We can express the token probability of all True-tokens as: $p_T = \frac{\sum_{i \in T} e^{\text{logit}_i}}{\sum_{u \in U} e^{\text{logit}_u}} = \frac{e^{\overline{\text{logit}_T}} \cdot \sum_{i \in T} e^{\Delta_i}}{\sum_{u \in U} e^{\text{logit}_u}}$ where $\Delta_i$ is the difference of logit of token i and the average logit of all True-tokens

Equivalently, we can express the token probability of all Wrong-tokens as: $p_W = \frac{\sum_{j \in W} e^{\text{logit}_j}}{\sum_{u \in U} e^{\text{logit}_u}} = \frac{e^{\overline{\text{logit}_W}} \cdot \sum_{j \in W} e^{\Delta_j}}{\sum_{u \in U} e^{\text{logit}_u}}$ where $\Delta_j$ is the difference of logit of token j and the average logit of all Wrong-tokens

Hence, the probability of True-tokens is a multiple of the probability of Wrong-tokens of magnitude: $\frac{p_T}{p_W} = \frac{\sum_{i \in T} e^{\text{logit}_i}}{\sum_{j \in W} e^{\text{logit}_j}} = e^{\frac{\overline{\text{logit}_T}}{\overline{\text{logit}_W}}} \cdot \frac{\sum_{i \in T} e^{\Delta_i}}{\sum_{j \in W} e^{\Delta_j}} = e^{\frac{\text{diff}}{n}} \cdot \frac{\sum_{i \in T} e^{\Delta_i}}{\sum_{j \in W} e^{\Delta_j}} \approx e^{\frac{\text{diff}}{n}}$ 

Assuming: $\frac{\sum_{i \in T} e^{\Delta_i}}{\sum_{j \in W} e^{\Delta_j}} \approx 1 \quad$ i.e., logits of tokens in W and T follow similar distributions

Hence, assuming we have a similar token distribution, we can use the average logit difference $\frac{\text{diff}}{n}$ to quantify the difference in token probability of the two groups. Let's try out this new metric.

In [32]:
df = pd.DataFrame({'question': [query], 'answer': [True]})

In [33]:
from src.logit_diff import logits_to_logit_diff

logits, cache = df_to_logits(model, df)
logit_diff = logits_to_logit_diff(df, logits, answer_map_tokens, device)
print(f"Logit difference of True and False tokens summed up is {logit_diff:.2f} nats")

Logit difference of True and False tokens summed up is 4.99 nats


When accounting for all True and False token, the model is clearly able to classifiy the statement as True. The True-token have an $e^{5}\approx 148\times$ higher probability than the False-token.

### Build dataset

Let's get some more example testing factual knowledge around Paris to make sure our results are representative. I will use OpenAI's ChatGPT-4o for dataset generation and subsequently filter to only include valid examples. This might seem unnecessary for our experiment since there are not too many qeustion-answer pairs around Paris such a small model can answer but I this makes it easy to adapt this script to other use-cases in the future and most of the code is simply recycled...

In [69]:
model_name = "gpt-4o-2024-08-06"
query = [
    {
        "role": "system",
        "content": (
            "Please generate 10 questions about Paris that can be answered with "
            "either Yes or No. Annotate each question with True if the answer is "
            "Yes and False if the answer is No. The questions should be easy to "
            "answer for any human. Return the output as a JSON object with each "
            "entry having the keys 'question' and 'answer' filled with a list of "
            "10 values. There should be 5 questions with answer Yes and 5 "
            "questions with answer no. Each question should end with a question "
            "mark followed by the newline character.\n"
            "Example question: Is Paris the capital of France?\n"
            "Example answer: True"
        )
    }
]

In [74]:
import asyncio

import nest_asyncio

from src.data import invoke_openai

nest_asyncio.apply()

dataset = asyncio.run(invoke_openai(model_name, query))

<ClientResponse(https://api.openai.com/v1/chat/completions) [200 OK]>
<CIMultiDictProxy('Date': 'Thu, 29 Aug 2024 07:09:58 GMT', 'Content-Type': 'application/json', 'Transfer-Encoding': 'chunked', 'Connection': 'keep-alive', 'Access-Control-Expose-Headers': 'X-Request-ID', 'openai-organization': 'christho', 'openai-processing-ms': '7052', 'openai-version': '2020-10-01', 'strict-transport-security': 'max-age=15552000; includeSubDomains; preload', 'x-ratelimit-limit-requests': '500', 'x-ratelimit-limit-tokens': '30000', 'x-ratelimit-remaining-requests': '499', 'x-ratelimit-remaining-tokens': '29839', 'x-ratelimit-reset-requests': '120ms', 'x-ratelimit-reset-tokens': '322ms', 'x-request-id': 'req_fffa1c3be4c49c140382a4b5b1e3aa50', 'CF-Cache-Status': 'DYNAMIC', 'Set-Cookie': '__cf_bm=a9h9e12B1G53GbGInmzNYNdWo0LLdwRqhz2K_Ttm6cY-1724915398-1.0.1.1-xKOd.eHUEp8Po9qjv_ethFtrkAIOTGssNBKerF1Yl4s2UGu6KbXyFUuz_v6RLoQYqskOZ7VwGqFyYnlzf6YqNQ; path=/; expires=Thu, 29-Aug-24 07:39:58 GMT; domain=.api.o

In [76]:
df = pd.DataFrame(dataset[0])
df["answer"] = df["answer"].apply(lambda x: True if x == "True" else False)

In [77]:
df.head(10).style.set_properties(**{'text-align': 'left'}).set_table_styles([dict(selector='th', props=[('text-align', 'left')])])

Unnamed: 0,question,answer
0,Is the Eiffel Tower located in Paris?,True
1,Is Paris the largest city in Germany?,False
2,Does the Louvre Museum reside in Paris?,True
3,Is Paris known as the City of Love?,True
4,Is Paris situated on the banks of the Thames River?,False
5,Can you find the Notre-Dame Cathedral in Paris?,True
6,Is Paris famous for its sushi?,False
7,Does Paris have more than 10 million residents?,False
8,Is the Arc de Triomphe a famous monument in Paris?,True
9,Is the official language of Paris Spanish?,False


Finally, let's make sure our model is able to answer all questions correctly...

In [35]:
logits, cache = df_to_logits(model, df)

In [22]:
print(
    "Per prompt logit difference:",
    logits_to_logit_diff(df, logits, answer_map_tokens, device, per_prompt=True)
    .cpu()
    .round(decimals=3),
)
print(
    "Average logit difference:",
    logits_to_logit_diff(df, logits, answer_map_tokens, device)
    .clone()
    .detach()
    .cpu()
    .round(decimals=2),
)

Per prompt logit difference: tensor([ 3.5940,  2.3680,  8.9890,  5.6400,  4.4280, -1.1900,  5.2440,  3.6440,
         6.0260,  0.8290])
Average logit difference: tensor(3.9600)


Alright, we see that overall the model is more than able to assign higher probability to True-tokens (avg. logit difference is around 4 nats). But for two samples (index 5 and 9) we get a negative/close-to-negative logit difference. Let's have a look what's happening there:

In [23]:
for i in [5, 9]:
    utils.test_prompt(df["question"].iloc[i], str(df["answer"].iloc[i]), model, prepend_bos=True, prepend_space_to_answer=False)

Tokenized prompt: ['<bos>', 'Is', ' Paris', ' the', ' largest', ' city', ' in', ' Europe', '?', '\n']
Tokenized answer: ['False']


Top 0th token. Logit: 33.59 Prob: 51.47% Token: |No|
Top 1th token. Logit: 33.06 Prob: 30.36% Token: |Sure|
Top 2th token. Logit: 32.00 Prob: 10.51% Token: |Yes|
Top 3th token. Logit: 31.54 Prob:  6.59% Token: |Paris|
Top 4th token. Logit: 29.14 Prob:  0.60% Token: |The|
Top 5th token. Logit: 27.78 Prob:  0.15% Token: |no|
Top 6th token. Logit: 27.13 Prob:  0.08% Token: |True|
Top 7th token. Logit: 26.55 Prob:  0.04% Token: | Paris|
Top 8th token. Logit: 26.11 Prob:  0.03% Token: |Answer|
Top 9th token. Logit: 25.99 Prob:  0.03% Token: |Certainly|


Tokenized prompt: ['<bos>', 'Is', ' Paris', ' home', ' to', ' the', ' Co', 'losseum', '?', '\n']
Tokenized answer: ['False']


Top 0th token. Logit: 35.45 Prob: 95.16% Token: |No|
Top 1th token. Logit: 31.36 Prob:  1.59% Token: |The|
Top 2th token. Logit: 31.13 Prob:  1.26% Token: |Sure|
Top 3th token. Logit: 30.85 Prob:  0.95% Token: |Paris|
Top 4th token. Logit: 30.60 Prob:  0.74% Token: |Yes|
Top 5th token. Logit: 29.13 Prob:  0.17% Token: |no|
Top 6th token. Logit: 27.59 Prob:  0.04% Token: |There|
Top 7th token. Logit: 26.73 Prob:  0.02% Token: | No|
Top 8th token. Logit: 26.23 Prob:  0.01% Token: | no|
Top 9th token. Logit: 26.04 Prob:  0.01% Token: |Answer|


For the first sample, the model actually does not seem to be able to answer our question with high confidence (token probability of 'No' about the same as of True token). We should probably get rid of this sample. Yet, we noitice that the difference in token probability in this case is still postive. Why do we get a negative average logit difference? For the second sample, we actually get a 95.2% token probability for the False-token 'No', yet the logit difference is barely positive. Why is that? Well in these two cases the previously described condition of similar logit distribution for both token groups is not fulfilled. In the latter case we have a much larger variance in Wrong-token than in True-token which makes us underestimate the actual multiple in token probability. This is a clear problem with our method. An alternative solution could be to use token probability for comparison. Let's add this as a feature to our logits_to_logits_diff method...

In [26]:
from src.logit_diff import logits_to_logit_diff

In [28]:
print(
    "Per prompt logit difference:",
    logits_to_logit_diff(df, logits, answer_map_tokens, device, return_probs=True, per_prompt=True)
    .cpu()
    .round(decimals=3),
)
print(
    "Average logit difference:",
    logits_to_logit_diff(df, logits, answer_map_tokens, device, return_probs=True)
    .clone()
    .detach()
    .cpu()
    .round(decimals=2),
)

Per prompt logit difference: tensor([0.5940, 0.2050, 0.8950, 0.9940, 0.8270, 0.4090, 0.7470, 0.9880, 0.8480,
        0.8900])
Average logit difference: tensor(0.7400)


This looks way better. Now we have found two insightful metric to evaluate our task. Let's still remove the sample with index 5 from the dataframe as the model seems to be unsure about this case and move on to the fun part...

In [31]:
drop_idx = [5]
df.drop(drop_idx, inplace=True)
df.reset_index(drop=True, inplace=True)

In [32]:
print(f"Final dataset:\n{df}")

Final dataset:
                                          question  answer
0                Is Paris the capital of France?\n    True
1          Is the Eiffel Tower located in Paris?\n    True
2           Is Paris known as the City of Light?\n    True
3                      Is Paris a city in Italy?\n   False
4         Is the Louvre Museum located in Paris?\n    True
5        Does the Seine River run through Paris?\n    True
6  Is Paris situated on the Mediterranean coast?\n   False
7         Is the Arc de Triomphe found in Paris?\n    True
8                Is Paris home to the Colosseum?\n   False


In [33]:
df.to_csv(DATA / 'fk_samples.csv', index=False)

## Direct Logit Attribution

The idea of direct logit attribution is to quantify how the different components in our transformer affect the outcome of our task. When evaluationg the token probability of only two tokens, $t_1$ and $t_2$, we can use the residual stream direction of our logit difference and observe how the vector product between this vector and the embedding of the final token in the residual stream behaves, layer by layer. A positive number indicates that the logit difference is high and vice versa. This works since unembedding $W_U$ and layernorm (normalization plus scaling), the two transformations applied to the final state of the residual stream are approxamately linear, and thus, we can expedite the evaluation step to the final residual stream embeddings.
$$
\text{We are interested in the logit difference } diff = \text{logits}_{t_1} - \text{logits}_{t_2}
$$


$$
\text{We have unembedding matrix }W_u \in \mathbb{R}^{n_{\text{vocab}}, d_{\text{model}}} \text{ and residual stream embedding } x_n \in \mathbb{R}^{d_{\text{model}}, 1}
$$

$$
\text{logits} = \text{Layernorm}\left(W_u \cdot x_n\right) = \gamma \cdot \text{Norm}(W_u \cdot x_n) + \beta
$$

$$
\text{Hence, }\text{logits}_{t_1} - \text{logits}_{t_2} = \gamma \cdot \text{Norm}\left(W_U[t_1] \cdot x_n\right) + \beta - \gamma \cdot \text{Norm}\left(W_U[t_2] \cdot x_n\right) + \beta \approx \gamma \cdot \text{Norm}\left((W_U[t_1] - W_U[t_2]) \cdot x_n\right)
$$

$$
\propto (W_U[t_1] - W_U[t_2]) \cdot x_n
$$
Now the question is, does this also work in our case of groups of tokens? The answer is two-fold. We have determined the average logit difference per group $diff_1 = \frac{diff}{n} = \frac{\sum_{i \in T} \text{logits}_i - \sum_{j \in W} \text{logits}_j}{n}$ and $diff_2 = \sum_{i \in T} p_i - \sum_{j \in W} p_j$ to be suitable metrics where metric one is slightly flawed by the assumption of similar logit distribution within the groups and equal size of the token groups. Unfortunately, 
quickly see that only this prior metric is of use for direct logit attribution. If the substitute $diff$ with $diff_1$ in the calculation above, we quickly see that can approximate $diff_1$ as follows $diff_1 \propto \frac{(\sum_{i \in T} W_U[t_i] - \sum_{j \in W} W_U[t_j]) \cdot x_n}{n}$. Unfortunately, softmax, the final transformation applied to get to token probabilities, the evaluation metric we deemed most useful for our use-case, is not even close to being linear. Therefore, we cannot directly translate the a positive/negative change in direct logit attribution into an equal change in our second metric. Hence, we move forward with the first solution, bearing in mind the assumptions we have to take.

In [95]:
# in case you deleted samples in the last section, rerun this cell
from src.logit_diff import df_to_logits

logits, cache = df_to_logits(model, df)

### Logit Lens

First, we get the residual stream directiosn of our True-/False-Tokens. This means nothing else than that we retrieve the column-vectors in the unembedding matrix that will dictate the logit of the tokens we are interested in. We then sum those unembedding vector for the True- and False-Tokens respectively and substract the True-Token embeddings from the False-Token sum of unembeddings (or vice versa) according to our label. We end up with a logit_diff_direction vector for every sample in our dataframe.

In [96]:
from src.direct_logit_attribution import get_logit_diff_directions

logit_diff_directions = get_logit_diff_directions(
    model,
    df,
    answer_map_tokens,
    device
)

Finally, we use the embedding of the residual stream for the last token in our sequence at every point in our network where the output of a layer is fed back into the residual stream to plot the value of the answer residual direction after each layer.

In [97]:
from src.direct_logit_attribution import residual_stack_to_logit_diff
from src.utils import line

tokens_per_group = len(answer_map["True"])

accumulated_residual, labels = cache.accumulated_resid(
    layer=-1, incl_mid=True, pos_slice=-1, return_labels=True
)
logit_lens_logit_diffs = residual_stack_to_logit_diff(
    accumulated_residual, 
    logit_diff_directions, 
    cache,
    tokens_per_group
)
line(
    logit_lens_logit_diffs,
    x=np.arange(model.cfg.n_layers * 2 + 1) / 2,
    hover_name=labels,
    title="Logit Difference From Accumulate Residual Stream",
)

To be completely honest, this is a rather unexpected result. We can see that we start with a very high average logit difference which dramatically decreases after the first self-attention layer. This is most likely random. The inital residual stream embedding of the last token is simply the embedding of that token itself, this token, the newline token ('\n'), does not have any semantic meaning for the question. After this inital rapid decrease in layers 1-3, the average logit difference is somewhat constant around 0 until the MLP layer of layer 13 where it makes a jump to half of the final average logit difference. After the logit difference steadily continuous to increase until the final layer where a slight drop is recorded.
<details><summary>Notes</summary>
I have plotted this many times with many different samples. Unfortunately, there are little patterns to recognize hinting the task might be to vague. I mostly got a monotic increasting graph starting from a highly negative average logit difference to a positive average logit difference which intuitively made sense to me. Interestingly, there was always rapid increase in average logit difference after the MLP of layer 13. I will definitely try to make a deep dive into this layer.

In [81]:
from src.direct_logit_attribution import increase_per_layer_type

increase_per_layer_type(logit_lens_logit_diffs)

Summed increase in self-attention layer after layer 1: -0.55
Summed increase in MLP layer after layer 1: 4.12



If we neglect the first layer, we observe that the average logit difference increases much more in MLP layer than self-attention layer. Over the 17 remaining layer, the average logit difference in the residual stream decreases by -0.55 nat over self-attention layer and increases about 4.12 nat (92% of total increase) over MLP layer. This hints MLP layer a benefitial for this task although it is worth noting that layer often collaborate in circuits which limits the value of this isolated view. For instance, self-attention often encrypt information in residual stream subspaces which might be used by MLP layer to extract valuable information and write them to the residual stream themselves.

### Layer Attribution

Let's plot the change in logit difference for each layer to confirm this observation.

In [83]:
from src.direct_logit_attribution import get_logit_diff_directions

logit_diff_directions = get_logit_diff_directions(
    model,
    df,
    answer_map_tokens,
    device
)

In [99]:
from src.direct_logit_attribution import residual_stack_to_logit_diff
from src.utils import line

per_layer_residual, labels = cache.decompose_resid(
    layer=-1, pos_slice=-1, return_labels=True
)
per_layer_logit_diffs = residual_stack_to_logit_diff(
    per_layer_residual, 
    logit_diff_directions, 
    cache,
    tokens_per_group
)
line(per_layer_logit_diffs, hover_name=labels, title="Logit Difference From Each Layer")

Indeed we see the largest change after the self-attention head in layer 1 after the MLP layer in layer 13.

### Head attribution

Let's know decompose the output of attention layer in their single heads. This can be done by decomposing the output matrix into submatrices applied to the attention-scaled value output matrix of each head.

In [100]:
from src.direct_logit_attribution import residual_stack_to_logit_diff
from src.utils import imshow

per_head_residual, labels = cache.stack_head_results(
    layer=-1, pos_slice=-1, return_labels=True
)
per_head_logit_diffs = residual_stack_to_logit_diff(
    per_head_residual, 
    logit_diff_directions, 
    cache,
    tokens_per_group
)
per_head_logit_diffs = einops.rearrange(
    per_head_logit_diffs,
    "(layer head_index) -> layer head_index",
    layer=model.cfg.n_layers,
    head_index=model.cfg.n_heads,
)
imshow(
    per_head_logit_diffs,
    labels={"x": "Head", "y": "Layer"},
    title="Logit Difference From Each Head",
)

Tried to stack head results when they weren't cached. Computing head results now


### Attention analysis

In [90]:
batch_index = 1
top_k = 3

In [91]:
top_positive_logit_attr_heads = torch.topk(
    per_head_logit_diffs.flatten(), k=top_k
).indices

In [92]:
top_negative_logit_attr_heads = torch.topk(
    -per_head_logit_diffs.flatten(), k=top_k
).indices

In [93]:
from src.attention_analysis import visualize_attention_patterns
from IPython.display import HTML

positive_html = visualize_attention_patterns(
    model,
    top_positive_logit_attr_heads,
    cache,
    df,
    batch_index,
    f"Top {top_k} Positive Logit Attribution Heads",
)

negative_html = visualize_attention_patterns(
    model,
    top_negative_logit_attr_heads,
    cache,
    df,
    batch_index,
    title=f"Top {top_k} Negative Logit Attribution Heads",
)

HTML(positive_html + negative_html)

## Activation patching

Now let's get into activation patching. First, we need some corrupted prompts as a baseline for our path patching algorithm. The choice of the corrupted prompts is quite delicate and fundamental for path patching to help narrow-down the components part of the circuit in question. We want our corrupted examples to only differ in the aspect we are testing such that, when the model performs significantly better with some paths exchanged with the clean activations, we can be sure the components attached to this paths are most likely part of the circuit. It is also important to change as little tokens as possible to not change the nature of the task. In our case, there are two intuitive solutons. <br>

Let's, again, consider the example: <br>
> Is Paris the capital of France?

Grammatically, we are looking at a subject-verb inversion which means the subject and verb of a sentence (Paris is the captibal of France.) have been interchanged to form a question. 'Is' is the verb that hints a question, 'Paris' is the subject, and 'the capital' is a noun phrase complementing the subject and 'of France' is modifyin the noun phrase. 
While the grammatical point of view might be helpful for some more technical tasks, such as indirect object identification, I feel like in our case it is simpler to try to think what we want to test for. We want to test if the model is able to query it's factual knowledge to recognize that Paris is indeed the capital of France and subsequently return a token confirming this. Coming from that angle, what parts of the sentence are essential to the task? As human, how do we query our brain for the answer to this question?
Before we get into an algorithm that could be used for this task, there is a fundamental questions to answer. How is the factual knowledge stored in transformers? Coming from the world of computers, there are two answers: relational and non-relational databases. Relational databases save data in a table-format where objects (rows) with different attributes (columns) are described. Non-relational databases do no not follow this structured approach. Prominent examples are key-value storages and graph databases. Neuroscientists provide evidence that the human brain falls rather in this latter category of non-relational databases where neurons that ["fire together, wire together"](https://www.brainfacts.org/thinking-sensing-and-behaving/learning-and-memory/2021/what-memories-are-made-of-100121) capturing information about related bits of information. While I am not going to claim to know what happens in the human brain, I find the idea of the hyppocampus being a big graph of knowledge points whose connection captures relationships between the bits of information very intriging. Since neural networks are heavily inspired by the human brain, let's consider an algorithm how a neural network saving information in Graph like structure could solve the factual knowledge retrieval task at hand:
1. Recognize question: A form of to be in present tense in the beginning of the sentence pretty much always suggest a question following. The question mark at the end confirms this. This should hint that the answer should encompass that the following statement is either True or False, hence narrowing down the set of possible tokens to True- and False-Tokens.
2. What is the statement? The statemnt is "Paris is the captial of France". This could be modelled as a graph with the nodes "Paris" and "France" connected by the edge/relationship "capital". Note that this is a big assumption. The network could model this relationship completely differenc (key-valu,...) or not at all. Here, I would just like to go through an example algorithm end-to-end to find suitable tokens to replace for the corrupted prompt.
3. Is the statement correct? Now comes the tricky part. How does the model check if the statement is correct. In our case it would simply have to check if the nodes Paris and France are connected by the edge "capital". This graph like structure could for instance be encoded in the embedding space where the embeddings of Paris and France are "connected" via similar values in a privileged basis which represents the feature "capital".
4. Write output to individual stream: Finally, the circuite should alter the residual stream such that the logits and thus, the probability of the predicted outcome (True/False) are increased.

This is all well and good but what does this tell us for the corrupted prompt? Assuming Paris and France are detected as keys and their relationship is to be examined, we could alter either one of the keys or the relationship. Since we want to explore factual knowledge around Paris, let's rull this option out. It remains the option to remove either the second key or the realtionship. I do not see a clear benefit for either option, so let's just start with option 1, altering the second key, and come back to the alternative solutiosn if this does not yield valuable insights. 

In [18]:
df

Unnamed: 0,question,answer
0,Is Paris the capital of France?\n,True
1,Is Paris the capital of France?\n,True
2,Is the Eiffel Tower located in Paris?\n,True
3,Is Paris known as the City of Light?\n,True
4,Is the Louvre Museum in Rome?\n,False
5,Is the official language of Paris English?\n,False
6,Is Montmartre a district in Paris?\n,True
7,Does Paris have a desert climate?\n,False
8,Is the Seine River in Paris?\n,True
9,Is the Mona Lisa displayed in Paris?\n,True


In [10]:
df["question"].values

array(['Is Paris the capital of France?\n',
       'Is Paris the capital of France?\n',
       'Is the Eiffel Tower located in Paris?\n',
       'Is Paris known as the City of Light?\n',
       'Is the Louvre Museum in Rome?\n',
       'Is the official language of Paris English?\n',
       'Is Montmartre a district in Paris?\n',
       'Does Paris have a desert climate?\n',
       'Is the Seine River in Paris?\n',
       'Is the Mona Lisa displayed in Paris?\n'], dtype=object)

In [None]:
corrupted_prompt = {
    "corrupted_question": [
        'Is Paris the capital of UK?\n',
        'Is Paris the capital of UK?\n',
        'Is the Eiffel Tower located in Paris?\n',
        'Is Paris known as the City of Dark?\n',
        'Is the Louvre Museum in Rome?\n',
        'Is the official language of Paris English?\n',
        'Is Montmartre a district in Paris?\n',
        'Does Paris have a desert climate?\n',
        'Is the Seine River in Paris?\n',
        'Is the Mona Lisa displayed in Paris?\n'
    ]}