<a href="https://colab.research.google.com/github/Gusanidas/TicTacToe/blob/main/NeelNandaSeriMats.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Neel Nanda SERI MATS Application**, by Alejandro Alvarez.

The problem I have decided to study is that of famous pairs of names and context. I thought of it while thinking of Alice and Bob, in computer science and physics problems it is very common to use the names *Alice* and *Bob* when presenting thought experiments. If a LLM reads a sentence involving cryptography and someone named Bob, it may assume the recipient is named Alice.
Some questions come to mind:

*   Does this actually happen? Will Alice be the most probable name if cryptography (or some other area) and Bob are involved?
*   Is the context important? Or are the names of Alice and Bob so commonly written together that they are forever associated.
*   Do other examples work similarly? Examples: Adam and Eve, Tom and Jerry, Bonny and Clyde...
*   Is there a recognisble circuit involved, and is it the same across sentences and name pairs?

I am a bit skeptical of the last question, I am not sure this is going to be computed in a simple circuit. It is a small example of a bigger more general type of mechanism: given some context, some words may be more related to each other.
If there is a wizard and the context is fantasy, words like orc or elf can appear more often. If we are talking about physics, and the word "proton" is in the sentence, the word "neutron" is more probable to happen.
However in the case of Alice and Bob I think there are some differences, in the correct context they either appear together or they dont. And the context, if its cryptography, is easy to identify.


I will write in a semi stream of conciussnes way, as I dont think I will arrive at any concrete solution, I will just write different explorations to the problem.





In [None]:
%pip install git+https://github.com/neelnanda-io/TransformerLens.git
%pip install circuitsvis

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-2l9l5v3z
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-2l9l5v3z
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 036327c59349e07af5ba935bbe92384776680f2b
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
# Some imports (probably too many)

import circuitsvis as cv
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import time


import functools

import transformer_lens.patching as patching

In [None]:
# I will use this to import the only attn transformer.

from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"


In [None]:
# The two models I will use gpt2_small and attn_only:

gpt2_small = HookedTransformer.from_pretrained("gpt2-small", device=device)

weights_path = "/content/drive/MyDrive/Data/attn_only_2L_half.pth"

cfg = HookedTransformerConfig(
        d_model=768,
        d_head=64,
        n_heads=12,
        n_layers=2,
        n_ctx=2048,
        d_vocab=50278,
        attention_dir="causal", # defaults to "bidirectional"
        attn_only=True, # defaults to False
        tokenizer_name="EleutherAI/gpt-neox-20b",
        seed=398,
        use_attn_result=True,
        normalization_type=None, # defaults to "LN", i.e. layernorm with weights & biases
        positional_embedding_type="shortformer"
)

attn_only = HookedTransformer(cfg)
pretrained_weights = torch.load(weights_path, map_location=device)
attn_only.load_state_dict(pretrained_weights)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


Using pad_token, but it is not set yet.


<All keys matched successfully>

In [None]:
# I ll answer the first question by looking at which tokens score highly as the next token.

def print_top_k_last_token(model, text, k=10):
  str_tokens = model.to_str_tokens(text)
  logits = model(text, return_type = "logits")

  k_logits, k_largest = torch.topk(logits[0,-1,:], k)
  for i, logit in enumerate(k_logits):
    print(f"{i+1} | {model.to_string(k_largest[i])} | {logit}")



In [None]:
# An easy example:

text = "In cryptography, secure communication is often illustrated using Alice and"
print_top_k_last_token(gpt2_small, text)

1 |  Bob | 22.862083435058594
2 |  Bill | 17.67469596862793
3 |  Wonderland | 17.009260177612305
4 |  Ron | 16.987939834594727
5 |  the | 16.929946899414062
6 |  Nick | 16.80923080444336
7 |  Jacob | 16.568485260009766
8 |  Ben | 16.163223266601562
9 |  Rob | 16.112186431884766
10 |  Jack | 16.059738159179688


In [None]:
text = "In cryptography, secure communication is often illustrated using Bob and"
print_top_k_last_token(gpt2_small, text)

1 |  Bob | 16.70122528076172
2 |  Alice | 16.315021514892578
3 |  the | 14.989341735839844
4 |  Bill | 14.882482528686523
5 |  Ell | 14.282912254333496
6 |  Charlie | 14.117244720458984
7 |  his | 14.106197357177734
8 |  Jerry | 13.459790229797363
9 |  Eve | 13.332417488098145
10 |  Jack | 13.332345962524414


At least there seems to be some correlation between the two names, however in this example they are very close, so it may be that the trigram "Alice and Bob" is very common.

In [None]:
text = "Lets go over public-key cryptography, lets assume Bob wants to send a secure message to"
print_top_k_last_token(gpt2_small, text)

1 |  the | 14.01496410369873
2 |  a | 13.859108924865723
3 |  his | 13.481128692626953
4 |  someone | 12.860709190368652
5 |  Alice | 12.83881950378418
6 |  Bob | 12.267598152160645
7 |  an | 12.203741073608398
8 |  everyone | 12.10692310333252
9 |  Satoshi | 11.7417573928833
10 |  some | 11.360934257507324


In [None]:
text = "In professional soccer, plays are normaly ilustrated using Alice and"
print_top_k_last_token(gpt2_small, text)

1 |  Bob | 13.890080451965332
2 |  the | 13.800544738769531
3 |  Bill | 12.446436882019043
4 |  Alice | 12.242537498474121
5 |  Wonderland | 12.05294418334961
6 |  Jack | 12.052478790283203
7 |  John | 11.858514785766602
8 |  her | 11.801490783691406
9 |  Dave | 11.691384315490723
10 |  Adam | 11.58998966217041


In [None]:
text = "In professional soccer, plays are normaly ilustrated using Bob and"
print_top_k_last_token(gpt2_small, text)

1 |  the | 13.082548141479492
2 |  his | 12.803836822509766
3 |  Bob | 12.440136909484863
4 |  Dave | 12.226999282836914
5 |  Bill | 11.755409240722656
6 |  Jerry | 11.678823471069336
7 |  Mike | 11.398613929748535
8 |  Jack | 11.354839324951172
9 |  I | 11.330429077148438
10 |  Joe | 11.306892395019531


In [None]:
text = "This afternoon, Bob is going to the cinema with"
print_top_k_last_token(gpt2_small, text, k = 20)

1 |  his | 14.619878768920898
2 |  a | 14.229316711425781
3 |  the | 13.395949363708496
4 |  my | 13.11326789855957
5 |  some | 13.030168533325195
6 |  me | 12.26167106628418
7 |  an | 12.221197128295898
8 |  friends | 11.897590637207031
9 |  us | 11.554744720458984
10 |  our | 11.477663040161133
11 |  another | 11.465581893920898
12 |  two | 11.315040588378906
13 |  one | 11.293984413146973
14 |  her | 11.093805313110352
15 |  Bob | 10.999310493469238
16 |  all | 10.662812232971191
17 |  Bill | 10.657220840454102
18 |  you | 10.634076118469238
19 |  John | 10.62600326538086
20 |  The | 10.525693893432617


There seems to be always some connection between Alice an Bob, especially when Alice comes first.
However the context does seem to have a visible effect, and the names are much more related if the context is cryptography.

Lets try with other pairs.

In [None]:
#text = 'From the dust of the ground, the Lord God formed Adam, and for his companion, He created'
#text = 'From the dust of the ground, the Lord God formed Bob, and for his companion, He created'
#text = "In cryptography, secure communication is often illustrated using Adam and"
#text = "This afternoon, Adam is going to the cinema with"
#text = "When the serpent beguiled Eve, she partook of the forbidden fruit and shared it with"
#text = "When the serpent beguiled Alice, she partook of the forbidden fruit and shared it with"
#text = "Despite their criminal reputation, Bonnie was fiercely loyal to"
#text = "Tom constantly devises cunning plans to catch"
text = "In Hogwarts, Harry wants to send a secure message to"
text = "In the Garden of Eden, Alice wants to send a message to"
text = "In the Garden of Eden, if Alice wants to send a secure message to"

print_top_k_last_token(gpt2_small, text, k = 20)

1 |  the | 15.496297836303711
2 |  her | 15.360969543457031
3 |  God | 14.738523483276367
4 |  a | 14.463143348693848
5 |  someone | 12.884888648986816
6 |  an | 12.837688446044922
7 |  Jesus | 12.71864128112793
8 |  your | 12.607561111450195
9 |  all | 12.55396842956543
10 |  Adam | 12.546812057495117
11 |  everyone | 12.442309379577637
12 |  his | 12.24174690246582
13 |  Bob | 12.176631927490234
14 |  Abraham | 12.13146686553955
15 |  another | 12.063867568969727
16 |  one | 11.873207092285156
17 |  anyone | 11.858263969421387
18 |  Zeus | 11.688714981079102
19 |  Alice | 11.67710018157959
20 |  David | 11.650980949401855


After some very minor exploration it seems that *Adam* and *Eve* behave similarly, often being associated but mainly in a given context. Other name pairs I have tried, like "Bonnie and Clyde" and "Tom and Jerry", dont seem to be associated enough to appear if there are far away in the sentence.

Lets look in the smaller model.

In [None]:
text = 'From the dust of the ground, the Lord God formed Adam, and for his companion, He created'
#text = "Lets go over public-key cryptography, lets assume Bob wants to send a secure message to"
#text = "In cryptography, secure communication is often illustrated using Alice and"
text = "Two character of the Lord of the Rings are Adam and"

print_top_k_last_token(attn_only, text, k = 20)

1 |  Eve | 12.149096488952637
2 |  the | 9.916333198547363
3 |  his | 9.643230438232422
4 |  Adam | 9.612616539001465
5 |  I | 8.75866985321045
6 |  Jon | 8.377252578735352
7 |  Will | 8.364314079284668
8 |  David | 8.277421951293945
9 |  Mary | 7.792721271514893
10 |  Noah | 7.790463447570801
11 |  are | 7.593355655670166
12 |  Peter | 7.571305751800537
13 |  Paul | 7.5581159591674805
14 |  Luke | 7.537231922149658
15 |  we | 7.3723297119140625
16 |  Michael | 7.350755214691162
17 |  two | 7.2414116859436035
18 |  Lord | 7.241343975067139
19 |  Sam | 7.231867790222168
20 |  John | 7.219054222106934


In the smaller attention only model, the names seem to be associated only if written together, regardless of context. So the behaviour I am looking for is not present. I will be looking at gpt2_small from now on.

One first step to probe into the model can be plotting the attention patterns, if only becuase it is an easy step.

In [None]:
def display_attn_pattern(model, layer, text):
  logits, cache = model.run_with_cache(text, remove_batch_dim=True)

  pattern = cache["pattern", layer]
  text_str_tokens = model.to_str_tokens(text)
  display(cv.attention.attention_patterns(tokens=text_str_tokens, attention=pattern))

In [None]:
text = "Lets go over public-key cryptography, lets assume Bob wants to send a secure message to"
display_attn_pattern(gpt2_small, 0, text)

It is hard to conclude anything from the attention patterns of layer 0.
Some things are interesting, but may not be relevant:

*   Head 1 attends to repeating words
*   Head 0 and 6 attend to cryptography from secure, so they may be involved in computing the context "talking about cryptography".

Other attention patterns.




In [None]:
display_attn_pattern(gpt2_small,11, text)

In the last layer, most heads attend to Bob when selectiong the next token.

In [None]:
text = "When the serpent beguiled Eve, she partook of the forbidden fruit and shared it with"
display_attn_pattern(gpt2_small, 11, text)

In [None]:
text = "In Hogwarts, Harry wants to send a secure message to"
display_attn_pattern(gpt2_small, 11, text)

In [None]:
text = "In the Diffie-Hellman Key Exchange, Alice exchanges public keys with"
display_attn_pattern(gpt2_small, 11, text)

Some interesting observation is that in the prompts with Adam and Eve or Alice and Bob, in the last layer the name ("Bob" or "Eve") is a focus of attention. However in the Harry Potter prompt, it is "Hogwarts" and not "Harry" what is most attended to, possibly because in the Harry Potter context there are many possible names, so attending at what name has already occured is not as important.

I will now look into activation patching, to see if we can discover at what stage the name is computed.

A high level idea that I think is coherent with the attention patterns and makes sense a priori is that context is computed first (is the sentence about the bible, harry potter or cryptography) and then the names present in the prompt are taken into account.

In [None]:
answer_token_indices = torch.tensor([gpt2_small.to_single_token(x) for x in ["Bob", "Alice"]], device=device)
def get_logit_diff(logits, answer_token_indices=answer_token_indices):
    if len(logits.shape)==3:
        # Get final logits only
        logits = logits[:, -1, :]
    correct_logits = logits[0,answer_token_indices[0]]
    incorrect_logits = logits[0,answer_token_indices[1]]
    return (correct_logits - incorrect_logits).mean()

In [None]:
text_a = "In public-key cryptography, if Alice wants to send a secure message to"
text_b = "In public-key cryptography, if Bob wants to send a secure message to"


In [None]:
def pre_patcher(text_a, text_b, answer_token_indices=answer_token_indices):
  # name in progress
  tokens_b = gpt2_small.to_tokens(text_b)

  logits_a, cache_a = gpt2_small.run_with_cache(text_a)
  logits_b, cache_b = gpt2_small.run_with_cache(text_b)

  CLEAN_BASELINE = get_logit_diff(logits_a, answer_token_indices=answer_token_indices)
  CORRUPTED_BASELINE = get_logit_diff(logits_b, answer_token_indices=answer_token_indices)
  def ioi_metric(logits, answer_token_indices=answer_token_indices):
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE  - CORRUPTED_BASELINE)

  return tokens_b, cache_a, ioi_metric





In [None]:
# Some plotting utils.


def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [None]:
tokens_a, cache_b, ioi_metric = pre_patcher(text_a, text_b)
resid_pre_act_patch_results = patching.get_act_patch_resid_pre(gpt2_small, tokens_a, cache_b, ioi_metric)

  0%|          | 0/192 [00:00<?, ?it/s]

In [None]:
imshow(resid_pre_act_patch_results ,
       yaxis="Layer",
       xaxis="Position",
       x=[f"{tok} {i}" for i, tok in enumerate(gpt2_small.to_str_tokens(tokens_a))],
       title="resid_pre Activation Patching")

In [None]:
attn_head_out_all_pos_act_patch_results = patching.get_act_patch_attn_head_out_all_pos(gpt2_small, *pre_patcher(text_a, text_b))

  0%|          | 0/144 [00:00<?, ?it/s]

In [None]:
imshow(attn_head_out_all_pos_act_patch_results,
       yaxis="Layer",
       xaxis="Head",
       title="attn_head_out Activation Patching (All Pos)")

From this graphs I have two observations:

*   The name "Bob" seems to be taken into account later, as expected.
*   The transition around layer 8 from token "Bob" to the last token doesnt seem as clean as other examples I have seen before, but I dont know why.



In [None]:
text_a = "In public-key cryptography, if Alice wants to send a secure message to"
text_b = "In the Garden of Eden, if Alice wants to send a secure message to"
answer_token_indices = torch.tensor([gpt2_small.to_single_token(x) for x in ["Bob", "Adam"]], device=device)

In [None]:
tokens_a, cache_b, ioi_metric = pre_patcher(text_a, text_b, answer_token_indices=answer_token_indices)
resid_pre_act_patch_results = patching.get_act_patch_resid_pre(gpt2_small, tokens_a, cache_b, ioi_metric)

  0%|          | 0/192 [00:00<?, ?it/s]

In [None]:
imshow(resid_pre_act_patch_results ,
       yaxis="Layer",
       xaxis="Position",
       x=[f"{tok} {i}" for i, tok in enumerate(gpt2_small.to_str_tokens(tokens_a))],
       title="resid_pre Activation Patching")

In [None]:
attn_head_out_all_pos_act_patch_results = patching.get_act_patch_attn_head_out_all_pos(gpt2_small, tokens_a, cache_b, ioi_metric)

  0%|          | 0/144 [00:00<?, ?it/s]

In [None]:
imshow(attn_head_out_all_pos_act_patch_results,
       yaxis="Layer",
       xaxis="Head",
       title="attn_head_out Activation Patching (All Pos)")

This two graphs look more confusing and I dont know what most of it means. But it does look a bit as if the context change has an effect earlier.

Here concludes my exploration of this problem, as its been already more than 10 hours. Its been lots of fun.

I am still not sure if there is an easily understandable circuit underneath, but there are many things to try. If I continue with it, what I would do is to gain a better understanding of the IOI paper, and try the same approach.

