In [1]:
import torch
import random
from transformer_lens import HookedTransformer, HookedTransformerConfig 
import pickle
from dataclasses import asdict

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import circuitsvis as cv

In [3]:
str2int = {
    '0': 0,
    '1': 1,
    '2': 2,
    '3': 3,
    '4' : 4,
    '5' : 5,
    '6' : 6,
    '7' : 7,
    '8' : 8,
    '9' : 9,
    ' ' : 10,
    '\n' : 11,
    '>' : 12,
    '=' : 13,
    '[' : 14,
    ']' : 15
}
int2str = {v: k for k, v in str2int.items()}
int2str[11] = '\\n'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def gen_list(n, without_rep=True):
    if without_rep:
        return random.sample(range(10), n)
    else:
        return [random.randint(0, 9) for _ in range(n)]

def gen_ex():
    a = gen_list(3)
    idx = random.randint(0, 2)

    y = random.randint(0, 9)
    while y in a:
        y = random.randint(0, 9)

    template = '''{}
[{}]={}
>{}'''

    b = a.copy()
    b[idx] = y

    return template.format("".join(map(str, a)), idx, y, "".join(map(str, b)))

def str_to_tokens(ex):
    return [str2int[c] for c in ex]

def tokens_to_str(tokens):
    return "".join([int2str[t] for t in tokens])


In [4]:
def load_model(pth):
    with open(pth, 'rb') as f:
        obj = pickle.load(f)
        cfg = HookedTransformerConfig.from_dict(obj['cfg'])
        model = HookedTransformer(cfg)
        model.load_state_dict(obj['model'])
    return model

model = load_model('../models/model_1l_14.pkl')

In [7]:
def gen_viz():
    ex = gen_ex()
    tokens = torch.tensor(str_to_tokens(ex), dtype=torch.long).unsqueeze(0)
    logits, cache_model = model.run_with_cache(tokens, remove_batch_dim=True) 
    tokens_input = [int2str[t] for t in tokens.tolist()[0]]

    p0 = cache_model["pattern", 0, "attn"]
    return cv.attention.attention_patterns(tokens=tokens_input, attention=p0)


NOTES
The two different heads seems to be focused on two different algorithms. Head 0 assign greater weight to the postiion token (i in a\[i\]) and the value token for every token in the output. Except when the value token is the next to be output, then weight is only assigned to the value token.

Head 1 simply assign weight to the token from the orginal list before the modification.

In [9]:
gen_viz()

In [13]:
gen_viz()

In [8]:
gen_viz()

# generation

In [14]:
def gen_test_ex():
    a = gen_list(3)
    idx = random.randint(0, 2)

    y = random.randint(0, 9)
    while y in a:
        y = random.randint(0, 9)

    template = '''{}
[{}]={}
>'''

    b = a.copy()
    b[idx] = y

    return template.format("".join(map(str, a)), idx, y)

In [16]:
test_ex = gen_test_ex()

In [19]:
test_tokens = torch.tensor(str_to_tokens(test_ex), dtype=torch.long).unsqueeze(0)
test_tokens

tensor([[ 4,  0,  7, 11, 14,  2, 15, 13,  1, 11, 12]])

In [23]:
pred = model.generate(test_tokens, max_new_tokens=3, stop_at_eos=False)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 537.20it/s]


In [27]:
logits, cache_model = model.run_with_cache(pred.clone(), remove_batch_dim=True) 
tokens_input = [int2str[t] for t in pred.tolist()[0]]

In [28]:
tokens_input

['4', '0', '7', '\\n', '[', '2', ']', '=', '1', '\\n', '>', '4', '0', '1']

In [29]:
p0 = cache_model["pattern", 0, "attn"]
cv.attention.attention_patterns(tokens=tokens_input, attention=p0)

In [31]:
p0.shape

torch.Size([2, 14, 14])