# Setup

In [1]:
%%capture
!pip install accelerate

In [2]:
# %%capture
# %pip install git+https://github.com/neelnanda-io/TransformerLens.git

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch as t
import torch.nn.functional as F

In [4]:
# model_name = 'gpt2'
# device = 'cuda:0' if t.cuda.is_available() else 'cpu'
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=t.float16).to(device)

## Load Model

In [5]:
from transformers import LlamaForCausalLM, LlamaTokenizer

In [6]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) n
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [7]:
LLAMA_2_7B_CHAT_PATH = "meta-llama/Llama-2-7b-chat-hf"

# tokenizer = LlamaTokenizer.from_pretrained(LLAMA_2_7B_CHAT_PATH)
tokenizer = LlamaTokenizer.from_pretrained(LLAMA_2_7B_CHAT_PATH, use_fast= False, add_prefix_space= False)
hf_model = LlamaForCausalLM.from_pretrained(LLAMA_2_7B_CHAT_PATH, low_cpu_mem_usage=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.62k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

In [9]:
# import transformer_lens.utils as utils
# from transformer_lens.hook_points import HookPoint
# from transformer_lens import HookedTransformer

In [None]:
# model = HookedTransformer.from_pretrained(
#     LLAMA_2_7B_CHAT_PATH,
#     hf_model = hf_model,
#     tokenizer = tokenizer,
#     device = "cpu",
#     fold_ln = False,
#     center_writing_weights = False,
#     center_unembed = False,
# )

# del hf_model

# model = model.to("cuda" if t.cuda.is_available() else "cpu")

In [10]:
device = 'cuda:0' if t.cuda.is_available() else 'cpu'
# model = model.to(device)
model = hf_model.to(device)

# test code

In [None]:
class LlamaForCausalLMWithLogitLens(LlamaForCausalLM):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, input_ids, attention_mask=None, labels=None, output_attentions=None, output_hidden_states=True, return_dict=True):
        outputs = super().forward(
            input_ids,
            attention_mask=attention_mask,
            labels=labels,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )

        # Extract hidden states from all layers
        hidden_states = outputs.hidden_states

        # Compute logits for each layer
        logits_per_layer = []
        for layer_hidden_state in hidden_states:
            logits = self.lm_head(layer_hidden_state)
            logits_per_layer.append(logits)

        return outputs, logits_per_layer

# Initialize the modified model
model_with_logit_lens = LlamaForCausalLMWithLogitLens.from_pretrained(LLAMA_2_7B_CHAT_PATH, low_cpu_mem_usage=True)

In [None]:
def get_logits_per_layer(input_text):
    # Tokenize input
    inputs = tokenizer(input_text, return_tensors="pt")
    input_ids = inputs['input_ids']

    # Forward pass through the model
    with t.no_grad():
        outputs, logits_per_layer = model_with_logit_lens(input_ids)

    return logits_per_layer

# Example usage
input_text = "This is an example sentence."
logits_per_layer = get_logits_per_layer(input_text)


# Functions

In [16]:
def actvs_to_logits(hidden_states):
    """
    outputs.hidden_states is a tuple for every layer
    each tuple member is an actvs tensor of size (batch_size, seq_len, d_model)
    loop thru tuple to get actv for each layer
    """
    layer_logits_list = []  # logits for each layer hidden state output actvs
    for i, h in enumerate(hidden_states):
        h_last_tok = h[:, -1, :]
        # if i == len(hidden_states) - 1:
        #     h_last_tok = model.transformer.ln_f(h_last_tok)  # apply layer norm as not in last
        logits = t.einsum('ab, cb -> ac', model.lm_head.weight, h_last_tok)
        layer_logits_list.append(logits)
    return layer_logits_list

def get_logits(input_text):
    token_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
    outputs = model(token_ids, output_hidden_states=True)
    logits = actvs_to_logits(outputs.hidden_states)
    logits = t.stack(logits).squeeze(-1)
    return logits

In [17]:
def get_decoded_indiv_toks(layer_logits, k=10):
    """
    i is the layer (from before to last).
    layer_logits[i] are the scores for each token in vocab dim for the ith unembedded layer
    j is the top 5
    """
    output_list = []
    for i, layer in enumerate(layer_logits):
        top_5_at_layer = []
        sorted_token_ids = F.softmax(layer_logits[i],dim=-1).argsort(descending=True)
        for j in range(5):  # loop to separate them in a list, rather than concat into one str
            top_5_at_layer.append(tokenizer.decode(sorted_token_ids[j]))
        output_list.append( top_5_at_layer )
    return output_list

# Test one samp, diff intervals

In [19]:
prompts = ["1 2 3 ", "2 4 6 "]
for test_text in prompts:
    layer_logits = get_logits(test_text)
    tok_logit_lens = get_decoded_indiv_toks(layer_logits)
    for i, tokouts in enumerate(tok_logit_lens):
        print(i-1, tokouts)
    print('\n')

-1 ['selves', ' Tam', 'beit', 'fx', 'ker']
0 [' in', ' (', '.', ',', ' for']
1 ['ә', 'ұ', ' for', ' in', ' on']
2 ['kele', 'penas', ' ge', ' estaven', ' personally']
3 ['kele', ' estaven', ' Query', ' reverse', 'ovis']
4 ['хе', 'enum', 'kele', ' Bedeut', 'fte']
5 ['хе', 'ector', 'irc', ' straight', 'fte']
6 [' Gemeins', 'irc', 'chor', 'lear', 'sono']
7 ['chor', 'ъ', ' sel', 'enum', 'fen']
8 ['iani', 'chor', 'fen', ' dol', ' estaven']
9 ['sono', 'chor', 'dex', ' estaven', ' pick']
10 [' estaven', ' pick', ' Bedeut', 'dex', ' multiplication']
11 ['dex', 'chor', ' estaven', ' lad', ' jack']
12 ['dex', ' pick', 'sono', 'uni', ' trigger']
13 ['dex', 'uni', 'рия', 'iten', ' estaven']
14 ['dex', ' trigger', 'рия', ' rip', ' jack']
15 ['dex', ' Franklin', ' Lad', ' rip', 'рия']
16 ['рия', ' jack', ' auto', 'Enable', ' gen']
17 ['isti', 'кт', 'udio', ' estaven', ' jack']
18 ['4', '3', ' forward', 'Enable', 'кт']
19 ['4', '3', ' inside', ' forward', '5']
20 ['4', ' fourth', ' four', '3', '5']
21

In [20]:
prompts = ["0 2 4 "]
for test_text in prompts:
    layer_logits = get_logits(test_text)
    tok_logit_lens = get_decoded_indiv_toks(layer_logits)
    for i, tokouts in enumerate(tok_logit_lens):
        print(i-1, tokouts)
    print('\n')

-1 ['selves', ' Tam', 'beit', 'fx', 'ker']
0 [' in', '.', ' (', ',', ' to']
1 [' in', ' for', 'ұ', 'ә', ' on']
2 ['kele', ' ge', '∘', ' personally', 'penas']
3 ['kele', 'sep', 'Ő', 'ovis', 'expl']
4 ['хе', 'chor', 'kele', 'utt', ' estaven']
5 ['хе', ' straight', 'chor', ' sea', 'fte']
6 [' straight', '̍', 'iele', ' Gemeins', 'lear']
7 ['chor', 'fen', 'chat', 'asha', ' touched']
8 ['fen', 'dex', ' dol', ' estaven', 'aucoup']
9 ['onto', 'chor', 'dex', '전', 'folge']
10 [' Bedeut', ' estaven', 'elm', ' graduated', 'fen']
11 ['chor', 'teen', ' Bedeut', 'Ő', 'daten']
12 ['chor', 'teen', ' Laur', 'uro', ' pac']
13 ['uro', 'chor', 'Ő', ' Laur', 'GR']
14 [' Lad', 'Ő', 'chor', 'dex', ' dil']
15 [' Lad', 'Ő', 'end', 'dex', 'ння']
16 ['elm', ' swe', ' Laur', '4', ' Um']
17 ['chor', '0', '4', '2', '5']
18 ['0', 'chor', '5', '2', '4']
19 ['4', '3', '0', '5', 'chor']
20 ['5', '6', '7', '8', '0']
21 ['5', '6', 'olis', '7', '8']
22 ['5', '6', '7', 'olis', '8']
23 ['5', 'olis', '7', '6', 'Ő']
24 ['6', '

In [21]:
prompts = ["2 4 6 8 "]
for test_text in prompts:
    layer_logits = get_logits(test_text)
    tok_logit_lens = get_decoded_indiv_toks(layer_logits)
    for i, tokouts in enumerate(tok_logit_lens):
        print(i-1, tokouts)
    print('\n')

-1 ['selves', ' Tam', 'beit', 'fx', 'ker']
0 [' in', ' (', '.', ' to', ' for']
1 ['ә', ' in', 'ұ', ' for', '∘']
2 ['kele', ' estaven', ' ge', 'frak', ' personally']
3 ['kele', 'onc', 'sep', 'gresql', ' Query']
4 ['kele', 'fte', 'utt', ' totalité', 'fen']
5 [' Bedeut', 'fen', 'lear', 'irc', 'ardi']
6 ['fen', ' Bedeut', ' later', 'lex', ' Pala']
7 ['fen', ' tam', ' later', 'ytu', ' Pala']
8 ['fen', 'onto', '자', 'TY', ' pac']
9 ['onto', 'ende', 'iel', 'fen', 'TY']
10 [' lif', ' estaven', 'fen', 'onto', ' Bedeut']
11 ['chor', ' Bedeut', 'ende', 'ye', 'entre']
12 ['<0xA7>', 'chor', ' Chor', ' Iz', ' Fir']
13 ['ende', 'üb', 'oir', ' Margaret', ' Chor']
14 [' Margaret', 'uber', 'entre', 'ende', 'üb']
15 [' Margaret', ' push', 'iga', ' repeat', 'lex']
16 ['xt', ' Um', 'rial', '<0x99>', '8']
17 ['xt', '8', 'chor', 'conde', 'end']
18 ['8', 'xt', ' "_', 'penas', 'amba']
19 ['8', ' charm', '9', ' Um', '7']
20 ['9', '8', 'penas', 'chor', ' Chor']
21 ['9', 'penas', '8', '<0xA7>', 'ъ']
22 ['9', 'pena

In [23]:
prompts = ["0 3 6 "]
for test_text in prompts:
    layer_logits = get_logits(test_text)
    tok_logit_lens = get_decoded_indiv_toks(layer_logits)
    for i, tokouts in enumerate(tok_logit_lens):
        print(i-1, tokouts)
    print('\n')

-1 ['selves', ' Tam', 'beit', 'fx', 'ker']
0 [' in', ' (', '.', ',', ' for']
1 ['ә', 'ұ', ' in', ' for', '∘']
2 ['kele', ' ge', ' util', 'ә', ' coinc']
3 ['kele', 'sep', ' pick', 'ovis', ' reverse']
4 ['kele', 'хе', 'chor', 'utt', 'fte']
5 ['ә', 'oce', 'хе', ' straight', ' Maj']
6 ['fen', 'рт', 'lear', ' straight', 'aze']
7 ['fen', 'chor', ' gram', ' touched', 'icas']
8 ['fen', ' estaven', 'aucoup', 'dex', ' dol']
9 ['onto', '전', 'chor', 'fen', 'Ő']
10 [' Bedeut', 'onto', ' estaven', ' Einz', ' graduated']
11 ['chor', ' Bedeut', ' Einz', 'Ő', 'daten']
12 ['chor', 'ople', ' Einz', ' trigger', 'Ő']
13 [' flu', 'Ő', ' Bedeut', ' graduated', 'chor']
14 ['Ő', 'dex', ' flu', ' Lad', ' составе']
15 [' Lad', ' fill', 'Ő', 'dex', 'ide']
16 ['Ő', ' fill', ' pla', '7', ' Lad']
17 ['0', '6', ' fill', '3', '7']
18 ['0', '7', '6', '4', '5']
19 ['7', '8', '6', '3', '0']
20 ['7', '8', '6', '0', '1']
21 ['7', '8', '1', '6', '0']
22 ['7', '8', '9', '1', '0']
23 ['7', '8', '9', '6', 'aucoup']
24 ['7', '9

In [22]:
prompts = ["3 6 9 "]
for test_text in prompts:
    layer_logits = get_logits(test_text)
    tok_logit_lens = get_decoded_indiv_toks(layer_logits)
    for i, tokouts in enumerate(tok_logit_lens):
        print(i-1, tokouts)
    print('\n')

-1 ['selves', ' Tam', 'beit', 'fx', 'ker']
0 [' in', ' (', '.', ',', ' for']
1 ['ә', 'ұ', ' for', ' in', '∘']
2 ['kele', 'ә', ' ge', ' spole', '∘']
3 ['kele', 'sep', 'Ő', ' Query', 'Query']
4 ['kele', 'utt', 'Ő', 'fen', ' Bedeut']
5 ['ector', 'irc', 'fen', 'ә', 'kele']
6 ['irc', 'penas', 'рт', 'fen', 'fte']
7 ['iel', ' noreferrer', ' Einz', 'touch', ' lif']
8 ['dex', 'chor', 'fen', ' Außer', 'пан']
9 ['onto', 'dex', ' Einz', ' Bedeut', '전']
10 [' Bedeut', ' Einz', ' graduated', 'entre', 'lon']
11 ['chor', ' Bedeut', 'Ő', 'dex', 'ye']
12 ['chor', 'Toggle', 'dex', 'trz', 'entre']
13 ['entre', ' Bedeut', 'chor', 'lique', 'iga']
14 ['dex', '9', ' trigger', 'entre', ' stim']
15 ['Ő', ' fill', 'dex', '9', ' Lad']
16 ['9', '7', 'Ő', '8', ' fill']
17 ['9', '7', 'Ő', '8', ' refer']
18 ['9', '7', '8', '4', '1']
19 ['9', '7', '8', '1', ' numer']
20 ['9', '1', '7', '8', ' numer']
21 ['1', '9', '7', ' numer', ' Ten']
22 ['1', '9', '7', ' numer', ' sus']
23 ['1', '9', ' straight', ' numer', '7']
24 

# pure seq prompts

In [None]:
def generate_prompts_list(x ,y):
    prompts_list = []
    for i in range(x, y):
        prompt_dict = {
            'corr': str(i+4),
            'incorr': str(i+3),
            'text': f"{i} {i+2} {i+4} {i+3}"
        }
        prompts_list.append(prompt_dict)
    return prompts_list

prompts_list = generate_prompts_list(1, 9)

In [None]:
# for pd in prompts_list:
#     test_text = pd['text']
#     layer_logits = get_logits(test_text)
#     tok_logit_lens = get_decoded_indiv_toks(layer_logits)
#     for i, tokouts in enumerate(tok_logit_lens):
#         print(i-1, tokouts)
#     print('\n')

-1 [' 4', ' livest', ' mathemat', ' myster', ' horizont']
0 ['th', 'teenth', 'x', ' 4', ' 3']
1 ['th', 'teenth', '54', 'x', '39']
2 ['th', 'x', '34', '54', 'GHz']
3 ['34', '54', 'th', 'GHz', '86']
4 ['ts', 'ms', '34', 'th', 'ths']
5 ['ts', ' 4', ' -', ' 3', ' 5']
6 [' 4', '★', ' 3', ' 5', ' 6']
7 [' 4', ' 5', ' 6', ' 3', ' 0']
8 [' 4', ' 5', ' 3', ' 6', ' 1']
9 [' 5', ' 4', ' 6', ' 3', ' 7']
10 [' 5', ' 6', ' 4', ' 3', ' 7']
11 [' 5', ' 4', ' 1', ' 6', ' 3']


-1 [' livest', ' 5', ' destro', 'theless', ' mathemat']
0 ['th', ' 5', ' 3', 'x', ' times']
1 ['th', 'x', '54', '45', '43']
2 ['th', '34', 'x', 'min', '54']
3 ['34', 'th', '90', '54', '85']
4 ['ths', 'ts', ' -', 'min', '34']
5 [' -', 'ts', ' 3', 'ths', ' 5']
6 [' 5', ' 4', ' 3', ' 6', '★']
7 [' 5', ' 6', ' 4', ' 3', ' 9']
8 [' 5', ' 6', ' 4', ' 3', ' 7']
9 [' 6', ' 5', ' 4', ' 7', '6']
10 [' 6', ' 5', ' 4', ' 7', ' 8']
11 [' 6', ' 5', ' 4', ' 7', ' 0']


-1 [' 6', ' livest', ' destro', ' mathemat', ' challeng']
0 ['th', '36', '31

# numerals 1536

In [None]:
task = "digits"
prompts_list = []

temps = ['done', 'lost', 'names']

for i in temps:
    file_name = f'/content/{task}_prompts_{i}.pkl'
    with open(file_name, 'rb') as file:
        filelist = pickle.load(file)

    print(filelist[0]['text'])
    prompts_list += filelist [:512] #768 512

len(prompts_list)

Van done in 1. Hat done in 2. Ring done in 3. Desk done in 4. Sun done in
Oil lost in 1. Apple lost in 2. Tree lost in 3. Snow lost in 4. Apple lost in
Marcus born in 1. Victoria born in 2. George born in 3. Brandon born in 4. Jamie born in


1536

In [None]:
num_corr = 0
anomolies = []
for pd in prompts_list:
    test_text = pd['text']
    layer_logits = get_logits(test_text)
    tok_logit_lens = get_decoded_indiv_toks(layer_logits)

    """
    Check if the 8th layer's predicted token is the sequence member just "one before"
    the correct next sequence member output found in the ninth layer

    Use `try` because when indexing, the output may not be a seq member of the right type!
    """
    try:
        a = tok_logit_lens[9][0].replace(' ', '')
        b= tok_logit_lens[10][0].replace(' ', '')
        if int(a) < int(b):
            if tok_logit_lens[10][0] == pd['corr']:
                num_corr += 1
            else:
                anomolies.append(pd)
        else:
                anomolies.append(pd)
    except:
        anomolies.append(pd)
    # for i, tokouts in enumerate(tok_logit_lens):
    #     print(i-1, tokouts)
    # print('\n')

1531

In [None]:
num_corr

1531

Do a quick scan of what prompts are anomolies in which 8th layer output is not just "the seq member one before" the 9th layer output.

In [None]:
for pd in anomolies[:2]:
    test_text = pd['text']
    layer_logits = get_logits(test_text)
    tok_logit_lens = get_decoded_indiv_toks(layer_logits)
    for i, tokouts in enumerate(tok_logit_lens):
        print(i-1, tokouts)
    print('\n')

-1 [' challeng', ' mathemat', ' arrang', ' corrid', ' destro']
0 [' order', ' the', ' particular', ' conjunction', ' a']
1 [' order', ' the', ' conjunction', ' particular', ' a']
2 [' order', ' the', ' vain', ' front', ' conjunction']
3 [' order', ' the', ' vain', ' conjunction', ' front']
4 [' order', ' the', ' conjunction', ' vain', ' spite']
5 [' order', ' conjunction', ' 3', ' vain', ' 15']
6 [' 3', ' 5', ' 2', ' 4', ' 7']
7 [' 3', ' 5', ' 7', ' 4', ' 2']
8 [' 7', ' 6', ' 5', ' 8', ' 1']
9 [' 7', ' 8', ' 6', ' 9', '7']
10 [' 7', ' 8', ' 9', ' 6', ' 1']
11 [' 7', ' 8', ' 1', ' 9', ' 6']


-1 [' challeng', ' mathemat', ' arrang', ' corrid', ' destro']
0 [' order', ' the', ' particular', ' a', ' conjunction']
1 [' order', ' the', ' conjunction', ' a', ' particular']
2 [' order', ' the', ' conjunction', ' vain', ' front']
3 [' order', ' the', ' conjunction', ' vain', ' accordance']
4 [' order', ' conjunction', ' the', ' relation', ' accordance']
5 [' order', ' 3', ' 5', ' 18', ' 6']
6 

In [None]:
for pd in anomolies[-2:]:
    test_text = pd['text']
    layer_logits = get_logits(test_text)
    tok_logit_lens = get_decoded_indiv_toks(layer_logits)
    for i, tokouts in enumerate(tok_logit_lens):
        print(i-1, tokouts)
    print('\n')

-1 [' challeng', ' mathemat', ' arrang', ' corrid', ' destro']
0 [' order', ' the', ' particular', ' a', ' front']
1 [' order', ' the', ' a', ' conjunction', ' front']
2 [' order', ' the', ' a', ' vain', ' accordance']
3 [' order', ' the', ' vain', ' his', ' a']
4 [' order', ' the', ' spite', ' relation', ' vain']
5 [' order', ' 3', ' 2', ' 5', ' 1']
6 [' 3', ' 2', ' 5', ' 1', ' 4']
7 [' 3', ' 2', ' 1', ' 5', ' 4']
8 [' 5', ' 4', ' 3', ' 1', ' 2']
9 [' 5', ' 6', ' 1', ' 7', '5']
10 [' 5', ' 6', ' 1', ' 7', ' 3']
11 [' 5', ' 6', ' 1', ' 7', ' 3']


-1 [' challeng', ' mathemat', ' arrang', ' corrid', ' destro']
0 [' the', ' order', ' a', ' particular', ' front']
1 [' order', ' the', ' a', ' conjunction', ' particular']
2 [' order', ' the', ' a', ' accordance', ' vain']
3 [' order', ' the', ' accordance', ' relation', ' vain']
4 [' order', ' the', ' accordance', ' spite', ' relation']
5 [' order', ' 3', ' 5', ' 18', ' 15']
6 [' 3', ' 5', ' 2', ' 7', ' 4']
7 [' 3', ' 5', ' 7', ' 4', ' 6']
