In [31]:
import transformer_lens
from datasets import load_dataset
import pandas as pd
from sklearn.cluster import KMeans

In [2]:
model = transformer_lens.HookedTransformer.from_pretrained("tiny-stories-1M")

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

Loaded pretrained model tiny-stories-1M into HookedTransformer


In [13]:
print(f"{model.cfg.n_layers=}")
print(f"{model.cfg.n_heads=}")
print(f"{model.cfg.d_model=}")
print(f"{model.cfg.d_head=}")
print(f"{model.cfg.d_mlp=}")
print(f"{model.cfg.d_vocab=}")
print(f"{model.cfg.d_vocab_out=}")
print(f"{model.cfg.n_ctx=}")

model.cfg.n_layers=8
model.cfg.n_heads=16
model.cfg.d_model=64
model.cfg.d_head=4
model.cfg.d_mlp=256
model.cfg.d_vocab=50257
model.cfg.d_vocab_out=50257
model.cfg.n_ctx=2048


In [3]:
dataset = load_dataset('roneneldan/TinyStories', split='train')



In [4]:
inputs = dataset["text"]
print(inputs[0])

One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.

Lily went to her mom and said, "Mom, I found this needle. Can you share it with me and sew my shirt?" Her mom smiled and said, "Yes, Lily, we can share the needle and fix your shirt."

Together, they shared the needle and sewed the button on Lily's shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together.


In [79]:
# Choose how many inputs to run
inputs_to_run = 10
for i in range(inputs_to_run):
    print(f"\n#### Input {i}:")
    print(inputs[i])


#### Input 0:
One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.

Lily went to her mom and said, "Mom, I found this needle. Can you share it with me and sew my shirt?" Her mom smiled and said, "Yes, Lily, we can share the needle and fix your shirt."

Together, they shared the needle and sewed the button on Lily's shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together.

#### Input 1:
Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong.

One day, Beep was driving in the park when he saw a big tree. The tree had ma

In [93]:
# Tokenize inputs
tokenized_inputs = [model.tokenizer.tokenize(inputs[i]) for i in range(inputs_to_run)]

In [80]:
# Run the model and get logits and activations
logits, cache = model.run_with_cache(inputs[0:inputs_to_run])

In [134]:
# gather top prediction for each token and input
top_predictions = logits.argmax(dim=-1).reshape(-1).tolist()
top_predictions

[628,
 1110,
 11,
 257,
 1310,
 2576,
 3706,
 20037,
 373,
 257,
 1263,
 287,
 607,
 2119,
 13,
 1375,
 2227,
 340,
 373,
 257,
 284,
 466,
 351,
 607,
 13,
 340,
 373,
 845,
 290,
 1375,
 2227,
 284,
 1037,
 607,
 17598,
 11,
 607,
 2460,
 1820,
 475,
 673,
 1718,
 711,
 340,
 29654,
 13,
 262,
 1182,
 13,
 198,
 198,
 43,
 813,
 338,
 284,
 262,
 1995,
 290,
 531,
 11,
 366,
 29252,
 11,
 460,
 765,
 257,
 17598,
 2474,
 632,
 314,
 1037,
 340,
 351,
 502,
 1701,
 787,
 340,
 1182,
 1701,
 2332,
 1995,
 13541,
 290,
 531,
 11,
 366,
 5297,
 11,
 314,
 13,
 340,
 460,
 2648,
 340,
 17598,
 290,
 787,
 340,
 898,
 526,
 198,
 198,
 43,
 11,
 484,
 9859,
 262,
 4704,
 290,
 262,
 19103,
 262,
 17598,
 13,
 262,
 338,
 10147,
 13,
 1119,
 373,
 523,
 257,
 7471,
 607,
 13,
 340,
 547,
 3772,
 13,
 852,
 1123,
 584,
 13,
 20037,
 484,
 5201,
 11,
 20037,
 338,
 607,
 1995,
 329,
 852,
 607,
 17598,
 290,
 262,
 262,
 10147,
 13,
 3574,
 1111,
 550,
 3772,
 290,
 484,
 550,
 257,
 262,
 55

In [135]:
# preapre dataset for the classifier model
nodes = ["hook_embed", "hook_pos_embed"] + [f"blocks.{i}.hook_attn_out" for i in range(model.cfg.n_layers)] + [f"blocks.{i}.hook_mlp_out" for i in range(model.cfg.n_layers)]

# flatten batch and token positions, and apply mean for each activation vector
activations = { node: cache[node].reshape(-1, cache[node].shape[-1]).mean(dim=-1) for node in nodes}

# build a pandas dataframe. Each column is the mean activation for a node, and each row is the next token prediction for a part of an input
activations_df = pd.DataFrame(activations)

# assign index to the dataframe
max_input_tokens = cache["hook_embed"].shape[1]
activations_df.index = [f"{i}_{j}" for i in range(inputs_to_run) for j in range(max_input_tokens)]

In [82]:
activations_df

Unnamed: 0,hook_embed,hook_pos_embed,blocks.0.hook_attn_out,blocks.1.hook_attn_out,blocks.2.hook_attn_out,blocks.3.hook_attn_out,blocks.4.hook_attn_out,blocks.5.hook_attn_out,blocks.6.hook_attn_out,blocks.7.hook_attn_out,blocks.0.hook_mlp_out,blocks.1.hook_mlp_out,blocks.2.hook_mlp_out,blocks.3.hook_mlp_out,blocks.4.hook_mlp_out,blocks.5.hook_mlp_out,blocks.6.hook_mlp_out,blocks.7.hook_mlp_out
0_0,2.211891e-09,-3.492460e-10,-2.328306e-10,-2.328306e-09,-4.656613e-10,2.328306e-10,-4.511094e-10,-8.149073e-10,0.000000e+00,-2.328306e-10,-3.143214e-09,-9.313226e-10,-2.910383e-10,5.238689e-10,-1.280569e-09,-6.984919e-10,-1.629815e-09,1.047738e-09
0_1,1.396984e-09,-4.656613e-10,1.164153e-10,2.328306e-09,2.211891e-09,4.656613e-10,5.384209e-10,2.328306e-10,-2.328306e-10,-1.455192e-10,-5.587935e-09,5.820766e-10,-1.164153e-10,1.164153e-10,-7.566996e-10,-6.111804e-10,-1.280569e-09,2.095476e-09
0_2,1.862645e-09,3.492460e-10,5.820766e-11,1.862645e-09,-8.149073e-10,1.164153e-10,-5.820766e-11,1.164153e-10,3.492460e-10,-1.164153e-10,5.122274e-09,-1.396984e-09,-1.571607e-09,9.313226e-10,-4.656613e-10,8.731149e-10,-6.984919e-10,8.731149e-11
0_3,1.862645e-09,2.328306e-10,-6.984919e-10,1.396984e-09,-1.396984e-09,2.328306e-10,-1.600711e-10,-8.731149e-11,1.746230e-10,-1.164153e-10,2.793968e-09,3.492460e-10,3.492460e-10,9.895302e-10,-5.820766e-11,-5.238689e-10,-1.135049e-09,1.047738e-09
0_4,-3.026798e-09,8.149073e-10,1.455192e-10,-1.396984e-09,-2.328306e-10,-1.222361e-09,-2.910383e-11,-1.455192e-10,5.820766e-10,-1.164153e-10,3.026798e-09,5.820766e-11,0.000000e+00,8.149073e-10,-1.193257e-09,-1.105946e-09,5.820766e-11,-3.201421e-10
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9_208,2.211891e-09,2.619345e-10,-4.656613e-10,4.656613e-10,4.656613e-10,4.074536e-10,5.820766e-11,3.637979e-11,6.402843e-10,-2.037268e-10,4.656613e-10,-4.656613e-10,-1.047738e-09,-5.820766e-10,-4.365575e-10,1.164153e-10,8.440111e-10,1.862645e-09
9_209,2.211891e-09,5.820766e-11,-2.328306e-10,-1.280569e-09,-9.313226e-10,-1.746230e-10,-2.328306e-10,-1.673470e-10,2.328306e-10,0.000000e+00,-5.122274e-09,-9.313226e-10,4.656613e-10,1.396984e-09,-2.328306e-10,-4.656613e-10,1.455192e-10,3.492460e-10
9_210,2.211891e-09,8.731149e-11,-2.619345e-10,-3.492460e-10,0.000000e+00,-6.402843e-10,5.820766e-11,8.003553e-11,-1.164153e-10,-3.783498e-10,1.396984e-09,9.313226e-10,-1.047738e-09,1.280569e-09,-8.731149e-10,-1.047738e-09,1.746230e-10,-1.629815e-09
9_211,2.211891e-09,1.455192e-10,-2.037268e-10,-1.164153e-10,2.328306e-10,2.328306e-10,-5.820766e-11,-6.548362e-11,1.164153e-10,-2.328306e-10,-1.629815e-09,1.862645e-09,0.000000e+00,-5.238689e-10,-1.018634e-09,-2.328306e-10,3.783498e-10,-1.164153e-10


In [139]:
# use sklearn clustering to get groups of activations
n_labels = 50
kmeans = KMeans(n_clusters=n_labels, random_state=0).fit(activations_df)
labels = kmeans.labels_

In [118]:
def get_inputs_for_label(label, decoded=True):
    str_tokenized_inputs = [tokenized_inputs[int(i)][0:int(j)] for i, j in [index_str.split("_") for index_str in activations_df[labels == label].index.to_list()]]
    
    if decoded:
        return [model.tokenizer.decode(model.tokenizer.convert_tokens_to_ids(str_tokenized_input)) for str_tokenized_input in str_tokenized_inputs]
    else:
        return str_tokenized_inputs
      

In [140]:
# print number of inputs per label
for label in range(n_labels):
    print(f"Label {label}: {sum(labels == label)} inputs")

Label 0: 45 inputs
Label 1: 65 inputs
Label 2: 37 inputs
Label 3: 57 inputs
Label 4: 48 inputs
Label 5: 33 inputs
Label 6: 60 inputs
Label 7: 44 inputs
Label 8: 56 inputs
Label 9: 30 inputs
Label 10: 53 inputs
Label 11: 39 inputs
Label 12: 32 inputs
Label 13: 56 inputs
Label 14: 50 inputs
Label 15: 40 inputs
Label 16: 32 inputs
Label 17: 47 inputs
Label 18: 52 inputs
Label 19: 39 inputs
Label 20: 62 inputs
Label 21: 43 inputs
Label 22: 54 inputs
Label 23: 41 inputs
Label 24: 53 inputs
Label 25: 48 inputs
Label 26: 45 inputs
Label 27: 47 inputs
Label 28: 48 inputs
Label 29: 6 inputs
Label 30: 46 inputs
Label 31: 43 inputs
Label 32: 38 inputs
Label 33: 36 inputs
Label 34: 51 inputs
Label 35: 7 inputs
Label 36: 53 inputs
Label 37: 54 inputs
Label 38: 14 inputs
Label 39: 54 inputs
Label 40: 48 inputs
Label 41: 40 inputs
Label 42: 52 inputs
Label 43: 29 inputs
Label 44: 40 inputs
Label 45: 48 inputs
Label 46: 15 inputs
Label 47: 30 inputs
Label 48: 22 inputs
Label 49: 48 inputs


In [145]:
inspect_label = 0

for index, input in enumerate(sorted(get_inputs_for_label(inspect_label))):
    str_input = input.replace("\n", "\\n")
    token_prediction = top_predictions[index]
    str_prediction = model.tokenizer.decode([token_prediction]).replace("\n", "\\n")
    print(">> " + str_input + " --> " + str_prediction)

>> Once upon a --> \n\n
>> Once upon a -->  day
>> Once upon a --> ,
>> Once upon a -->  a
>> Once upon a -->  little
>> Once upon a -->  girl
>> Once upon a -->  named
>> Once upon a time, in a big lake, there was a brown kayak. The brown kayak liked to roll in the water all day long. It was very happy when it -->  Lily
>> Once upon a time, in a big lake, there was a brown kayak. The brown kayak liked to roll in the water all day long. It was very happy when it could roll and splash in the lake.\n\nOne day, a little boy named Tim came to play with the brown kayak. Tim and the brown kayak rolled in the water together. They laughed and had a lot of fun. The sun was shining, and the water was warm.\n\nAfter a while, it -->  was
>> Once upon a time, in a land full of trees, there was a -->  a
>> Once upon a time, in a land full of trees, there was a little cherry tree. The cherry tree was very sad because it did not have any friends. All the other trees were big and strong, but the cherry

In [146]:
activations_df[labels == inspect_label]

Unnamed: 0,hook_embed,hook_pos_embed,blocks.0.hook_attn_out,blocks.1.hook_attn_out,blocks.2.hook_attn_out,blocks.3.hook_attn_out,blocks.4.hook_attn_out,blocks.5.hook_attn_out,blocks.6.hook_attn_out,blocks.7.hook_attn_out,blocks.0.hook_mlp_out,blocks.1.hook_mlp_out,blocks.2.hook_mlp_out,blocks.3.hook_mlp_out,blocks.4.hook_mlp_out,blocks.5.hook_mlp_out,blocks.6.hook_mlp_out,blocks.7.hook_mlp_out
0_12,-2.328306e-09,-1.164153e-10,-4.656613e-10,-1.396984e-09,-2.852175e-09,-4.656613e-10,2.910383e-10,-8.731149e-11,0.0,4.365575e-10,1.396984e-09,2.793968e-09,1.164153e-10,4.656613e-10,-8.149073e-10,1.164153e-10,-2.328306e-10,2.037268e-10
0_21,-2.328306e-09,2.328306e-10,-4.074536e-10,1.164153e-09,-2.240995e-09,0.0,-8.731149e-11,-1.455192e-10,1.164153e-10,3.49246e-10,4.656613e-10,2.328306e-10,0.0,5.820766e-11,-9.022187e-10,1.74623e-10,3.49246e-10,2.386514e-09
0_94,-2.561137e-09,-2.328306e-10,-9.313226e-10,-5.820766e-10,-2.328306e-10,1.164153e-09,-1.455192e-11,-5.820766e-11,1.164153e-10,-5.820766e-11,4.656613e-10,3.49246e-10,-1.455192e-09,1.164153e-10,-4.074536e-10,-1.018634e-09,9.604264e-10,6.984919e-10
1_3,-3.026798e-09,2.328306e-10,-6.402843e-10,-9.313226e-10,-1.280569e-09,-1.280569e-09,4.802132e-10,-3.49246e-10,-4.656613e-10,3.49246e-10,-4.656613e-10,2.910383e-10,1.164153e-10,4.074536e-10,-1.164153e-09,-8.149073e-10,-8.731149e-11,2.037268e-09
1_23,-2.328306e-09,-6.984919e-10,-5.820766e-10,9.313226e-10,-3.434252e-09,8.149073e-10,1.74623e-10,5.820766e-11,3.49246e-10,4.074536e-10,1.280569e-09,1.396984e-09,-1.164153e-09,3.49246e-10,-5.820766e-11,-1.164153e-10,1.164153e-09,1.338776e-09
1_31,-1.396984e-09,-2.328306e-10,-6.984919e-10,-4.656613e-10,-2.793968e-09,1.396984e-09,-1.164153e-10,0.0,1.74623e-10,4.074536e-10,9.313226e-10,-5.238689e-10,3.49246e-10,3.49246e-10,-1.280569e-09,1.396984e-09,-5.820766e-11,1.629815e-09
1_33,-1.862645e-09,6.984919e-10,-3.49246e-10,-9.313226e-10,2.328306e-10,1.164153e-10,-2.619345e-10,2.910383e-10,1.164153e-10,4.074536e-10,6.984919e-10,6.984919e-10,9.313226e-10,9.313226e-10,-1.105946e-09,-1.164153e-09,-3.49246e-10,-5.820766e-11
1_59,-2.561137e-09,-4.074536e-10,-2.328306e-10,1.280569e-09,-1.426088e-09,-3.49246e-10,2.037268e-10,2.619345e-10,5.820766e-10,-2.328306e-10,0.0,1.280569e-09,1.396984e-09,-1.571607e-09,1.629815e-09,6.984919e-10,2.619345e-10,1.222361e-09
1_81,-2.561137e-09,0.0,-6.984919e-10,1.862645e-09,-1.280569e-09,0.0,5.820766e-11,-5.820766e-11,3.49246e-10,-2.910383e-11,9.313226e-10,1.164153e-09,2.910383e-09,-1.74623e-10,-1.280569e-09,-1.920853e-09,8.731149e-10,1.135049e-09
1_123,-2.793968e-09,-1.164153e-10,0.0,1.513399e-09,-1.862645e-09,-2.328306e-10,1.74623e-10,-1.455192e-11,2.328306e-10,1.455192e-10,4.656613e-10,-6.984919e-10,9.895302e-10,5.820766e-11,-9.313226e-10,-2.328306e-10,1.74623e-10,-4.656613e-10
