In [None]:
import os
from dotenv import load_dotenv
import torch
from transformers import AutoTokenizer
import transformer_lens
torch.set_default_device("cuda")
import pandas as pd
from src.utils import get_w_vo
from src.maps import MAPS
from tabulate import tabulate

pd.set_option('display.width', 1000)
pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', None)



In [2]:
model_name = r"gpt2-xl"
model = transformer_lens.HookedTransformer.from_pretrained_no_processing(model_name, device_map="auto")
for param in model.parameters():
    param.requires_grad = False
model.eval()
state_dict = model.state_dict()
cfg = model.cfg
is_gqa = cfg.n_key_value_heads != None
tokenizer = AutoTokenizer.from_pretrained(model_name)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_gpus = torch.cuda.device_count()
last_device = torch.device(f"cuda:{num_gpus-1}" if torch.cuda.is_available() else "cpu")
def first_mlp(x):
    return model.blocks[0].mlp.to(device)(model.blocks[0].ln2.to(device)(x))

Loaded pretrained model gpt2-xl into HookedTransformer


In [3]:
layer = 26
head = 2
maps = MAPS(model, tokenizer)
apply_first_mlp = True
k_salient_tokens = 15
k_mappings = 5

In [4]:
salient_tokens_decoded, salient_mappings_decoded = maps.get_salient_operations(layer, head, k_salient_tokens, k_mappings, apply_first_mlp)
table = list(zip(salient_tokens_decoded, [", ".join(m) for m in salient_mappings_decoded]))
print(tabulate(table, headers=["Token", "Mappings"], tablefmt="plain"))

Token          Mappings
' Jedi'        ' lightsaber', ' Jedi', ' Kenobi', ' droid', ' Skywalker'
' lightsaber'  ' lightsaber', ' Jedi', ' Kenobi', ' Skywalker', ' Sith'
' galactic'    ' Galactic', ' galactic', ' starship', ' galaxy', ' droid'
' Starfleet'   ' galactic', ' Starfleet', ' starship', ' Galactic', ' interstellar'
' Klingon'     ' starship', ' Starfleet', ' Klingon', ' Trek', ' Starship'
' starship'    ' starship', ' Galactic', ' galactic', ' interstellar', ' Planetary'
' Skyrim'      ' Skyrim', ' Magicka', ' Bethesda', ' Elven', ' Hearth'
' Darth'       ' Jedi', ' lightsaber', ' Kenobi', ' Darth', ' Sith'
' galaxy'      ' Galactic', ' galactic', ' starship', ' galaxy', ' droid'
' Fairy'       ' Fairy', ' Magical', ' fairy', ' Pokémon', ' Cinderella'
' droid'       ' droid', ' Kenobi', ' Galactic', ' lightsaber', ' Jedi'
'Pokémon'      ' Pokémon', 'Pokémon', ' Pikachu', ' Poké', ' Pokemon'
' Sith'        ' Sith', ' Jedi', ' lightsaber', ' Kenobi', ' Mandal'
' Elven'       ' 

In [5]:
def get_mapping_str(salient_tokens_decoded, salient_mappings_decoded):
    mappings_str = ""
    for ix,token in enumerate(salient_tokens_decoded):
        if ix > 0:
            mappings_str += "\n"
        mappings_str += f"{token}: "
        mappings_str += ",".join(salient_mappings_decoded[ix])
    return mappings_str

def plug_in_prompt(salient_tokens_decoded, salient_mappings_decoded):
    
    prompt = f"""Below you are given a list of input strings, and a list of mappings: each mapping is between an input string and a list of 5 strings. 
Mappings are provided in the format "s: t1, t2, t3, t4, t5" where each of s, t1, t2, t3, t4, t5 is a short string, typically corresponding to a single word or a sub-word.
Your goal is to describe shortly and simply the inputs and the function that produces these mappings. To perform the task, look for semantic and textual patterns. 
For example, input tokens 'water','ice','freeze' are water-related, and a mapping ('fire':'f') is from a word to its first letter.
As a final response, suggest the most clear patterns observed or indicate that no clear pattern is visible (write only the word "Unclear").
Your response should be a vaild json, with the following keys: 
"Reasoning": your reasoning.
"Input strings": One sentence describing the input strings (or "Unclear").
"Observed pattern": One sentence describing the most clear patterns observed (or "Unclear").

The input strings are:
{salient_tokens_decoded}

The mappings are: 
{get_mapping_str(salient_tokens_decoded, salient_mappings_decoded)}
"""
    return prompt

In [6]:
salient_tokens_decoded_ = [tok[1:-1] for tok in salient_tokens_decoded]
salient_mappings_decoded_ = [[tok[1:-1] for tok in tok_lst] for tok_lst in salient_mappings_decoded]
prompt = plug_in_prompt(salient_tokens_decoded_, salient_mappings_decoded_)
print(prompt)

Below you are given a list of input strings, and a list of mappings: each mapping is between an input string and a list of 5 strings. 
Mappings are provided in the format "s: t1, t2, t3, t4, t5" where each of s, t1, t2, t3, t4, t5 is a short string, typically corresponding to a single word or a sub-word.
Your goal is to describe shortly and simply the inputs and the function that produces these mappings. To perform the task, look for semantic and textual patterns. 
For example, input tokens 'water','ice','freeze' are water-related, and a mapping ('fire':'f') is from a word to its first letter.
As a final response, suggest the most clear patterns observed or indicate that no clear pattern is visible (write only the word "Unclear").
Your response should be a vaild json, with the following keys: 
"Reasoning": your reasoning.
"Input strings": One sentence describing the input strings (or "Unclear").
"Observed pattern": One sentence describing the most clear patterns observed (or "Unclear")