In [1]:
# generic imports
import torch
torch.set_grad_enabled(False)

# Utilities
from general_utils import (
  ModelAndTokenizer,
  make_inputs,
  decode_tokens,
)
from patchscopes_utils import *

In [2]:
# loading the model
model_to_hook = {
    "EleutherAI/pythia-12b": set_hs_patch_hooks_neox,
    "meta-llama/Llama-2-13b-chat-hf": set_hs_patch_hooks_llama,
    "./stable-vicuna-13b": set_hs_patch_hooks_llama,
    "EleutherAI/gpt-j-6b": set_hs_patch_hooks_gptj
}

CURRENT_LLM = "meta-llama/Llama-2-13b-chat-hf"

model_name = CURRENT_LLM

if "13b" in model_name or "12b" in model_name:
    torch_dtype = torch.float16
else:
    torch_dtype = None

my_device = torch.device("cuda:0")

mt = ModelAndTokenizer(
    model_name,
    low_cpu_mem_usage=False,
    torch_dtype=torch_dtype,
    device=my_device,
)

mt.set_hs_patch_hooks = model_to_hook[model_name]
mt.model.eval()

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 5120)
    (layers): ModuleList(
      (0-39): 40 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (k_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (v_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (o_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=5120, out_features=13824, bias=False)
          (up_proj): Linear(in_features=5120, out_features=13824, bias=False)
          (down_proj): Linear(in_features=13824, out_features=5120, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_

In [3]:
def generate_and_extract_hidden_states(prompt):
    """
    Running the model and generating intermediate representations in each layer for the hidden states.
    """
    input_ids = make_inputs(mt.tokenizer, [prompt], device=mt.device)
    
    generated = mt.model(**input_ids, output_hidden_states=True)

    hs_cache_ = [
        generated["hidden_states"][layer + 1][0] for layer in range(mt.num_layers)
    ]

    return hs_cache_

def patchscope_interpret(vec, target_layer=0):
    """Interpretation of vectors using Patchscopes technique."""
    target_prompt = "Syria: Country in the Middle East, Leonardo DiCaprio: American actor, Samsung: South Korean multinational major appliance and consumer electronics corporation, x"
    
    # last token within target prompt
    target_idx = -1

    patch_config = {
        target_layer: [(target_idx, vec)]
    }

    patch_hooks = mt.set_hs_patch_hooks(
        mt.model, patch_config, module="hs", patch_input=False, generation_mode=True,
    )

    inp = make_inputs(mt.tokenizer, [target_prompt], device=mt.device)

    seq_len = len(inp["input_ids"][0])
    max_token_to_produce = 10
    output_toks = mt.model.generate(
        inp["input_ids"],
        max_length=seq_len + max_token_to_produce,
        pad_token_id=mt.model.generation_config.eos_token_id,
    )
    
    remove_hooks(patch_hooks)

    generations_patched =  mt.tokenizer.decode(output_toks[0][len(inp["input_ids"][0]):])
    
    return generations_patched

def find_token_id(prompt, token):
    """finding the offset of a specific token within a prompt."""
    inp = make_inputs(mt.tokenizer, [prompt], device=mt.device)
    decoded = decode_tokens(mt.tokenizer, inp['input_ids'])[0]

    return decoded.index(token)


# Hidden States Interpretation Experiment

This experiment comes to show how some hidden states become interpretable when they are amplified.
For each example in this experiment we are going to show the following:
1. Locate the first layer to become interpretable as the meaning of the sentence.
2. Use Superscopes amplification to interpret a layer prior.

We show that this methodology works in a lot of scenarios.

In [10]:
# test number 1 - "Alexander the Great" - "Great" token hidden state interpretation
hs_cache = generate_and_extract_hidden_states("Alexander the Great")

source_position = find_token_id("Alexander the Great", "Great")

for layer in range(4, 8):
    print()
    print(f"Layer {layer}")
    print()
    print(f"Hidden State result: {patchscope_interpret(hs_cache[layer][source_position])}")
    print()


Layer 4

Hidden State result: Barack Obama: 44th President


Layer 5

Hidden State result: Barack Obama: American politician, 4


Layer 6

Hidden State result: : Ancient Greek king of Macedon,


Layer 7

Hidden State result: Wall Street: Street in Lower Manhattan, New



### Hidden States results explanation
As one can simply see, hidden state of "Great" is contextualized as Alexander the Great in Layer 6.

Next we will try to amplify layers beforehand and obtain their meaning.

In [16]:
for layer in range(4, 6):
    print()
    print(f"Layer {layer}")
    for amp in range(2, 7, 1):
        print()
        print(f"Hidden State result (amp={amp}): {patchscope_interpret(hs_cache[layer][source_position] * amp)}")
        print()


Layer 4

Hidden State result (amp=2): Barack Obama: 44th President


Hidden State result (amp=3): : Ancient Greek king, The Great Wall of


Hidden State result (amp=4): Wall Street: Street in New York City's


Hidden State result (amp=5): Wall Street: Street in Lower Manhattan, New


Hidden State result (amp=6): Wall Street: Financial district in Lower Manh


Layer 5

Hidden State result (amp=2): Wall Street: Financial district in Lower Manh


Hidden State result (amp=3): : The Macedonian king who conquered a


Hidden State result (amp=4): : Which of these is NOT a country?



Hidden State result (amp=5): : These are just a few examples of what?


Hidden State result (amp=6): : What do these three things have in common?



### Hidden States Amplification - Results Analysis
We successfully interpreted the meaning "Alexander the Great" from layers 4 and 5, obtaining the contextualized meaning from 2 layers prior.

In [42]:
# test number 3 - "Red Hot Chili Peppers" - "ppers" token hidden state interpretation
hs_cache = generate_and_extract_hidden_states("Red Hot Chili Peppers")

source_position = find_token_id("Red Hot Chili Peppers", "ppers")

for layer in range(2, 5):
    print()
    print(f"Layer {layer}")
    print()
    print(f"Hidden State result: {patchscope_interpret(hs_cache[layer][source_position])}")
    print()


Layer 2

Hidden State result: : type of vegetable, Yellowstone National


Layer 3

Hidden State result: : A type of pepper, Gmail:


Layer 4

Hidden State result: : American rock band, Tesla: American



### Hidden States results explanation
The hidden state of "ppers" is contextualized as the American rock band starting from Layer 4.

Next we will try to amplify the layer beforehand and obtain its meaning.

In [48]:
for layer in range(3, 4):
    print()
    print(f"Layer {layer}")
    for amp in [1.3, 1.5, 1.7, 1.9]:
        print()
        print(f"Hidden State result (amp={amp}): {patchscope_interpret(hs_cache[layer][source_position] * amp)}")
        print()


Layer 3

Hidden State result (amp=1.3): : A type of vegetable, Dracula


Hidden State result (amp=1.5): : The band, Titanic: The ship


Hidden State result (amp=1.7): : Plant of the capsicum genus, Ebol


Hidden State result (amp=1.9): : What do these three things have in common?



In [49]:
for layer in range(3, 4):
    print()
    print(f"Layer {layer}")
    for amp in [1.3, 1.5, 1.7]:
        print()
        print(f"Hidden State result (amp={amp}): {patchscope_interpret(hs_cache[layer][source_position] * amp)}")
        print()


Layer 3

Hidden State result (amp=1.3): : Type of vegetable, The Beatles:


Hidden State result (amp=1.5): : Rock band, and more.

S


Hidden State result (amp=1.7): : They all have something in common. What is



### Hidden States Amplification - Results Analysis
We successfully interpreted the meaning "Rock band" from layer 3, obtaining the contextualized meaning.

In [9]:
# test number 2 - "Florence and the Machine" - "Machine" token hidden state interpretation
hs_cache = generate_and_extract_hidden_states("Florence and the Machine")

source_position = find_token_id("Florence and the Machine", "Machine")

for layer in range(9, 12):
    print()
    print(f"Layer {layer}")
    print()
    print(f"Hidden State result: {patchscope_interpret(hs_cache[layer][source_position])}")
    print()


Layer 9

Hidden State result: :

Syria: A country in


Layer 10

Hidden State result: : British rock band, Honey Boo Bo


Layer 11

Hidden State result: : These are just a few examples of the many



### Hidden States results explanation
The hidden state of "Machine" is contextualized as a British band starting from Layer 10.

Next we will try to amplify the layers beforehand and obtain their meaning.

In [12]:
for layer in range(9, 10):
    print()
    print(f"Layer {layer}")
    for amp in range(3, 24, 3):
        print()
        print(f"Hidden State result (amp={amp}): {patchscope_interpret(hs_cache[layer][source_position] * amp)}")
        print()


Layer 9

Hidden State result (amp=3): : These are just a few examples of words that


Hidden State result (amp=6): .

Syria is a country in


Hidden State result (amp=9): ’s new album, and more
Good morning


Hidden State result (amp=12): is a British rock band, and The Reven


Hidden State result (amp=15): , and The Great Gatsby: Novel


Hidden State result (amp=18): to name a few.

The word "


Hidden State result (amp=21): 's new album "Hair, Hunt



### Hidden States Amplification - Results Analysis
We successfully interpreted the meaning "British rock band" from layer 9, obtaining the contextualized meaning.

In [13]:
for layer in range(7, 9):
    print()
    print(f"Layer {layer}")
    for amp in range(3, 24, 3):
        print()
        print(f"Hidden State result (amp={amp}): {patchscope_interpret(hs_cache[layer][source_position] * amp)}")
        print()


Layer 7

Hidden State result (amp=3): :

Syria: A country in


Hidden State result (amp=6): : These are just a few examples of things that


Hidden State result (amp=9): toilet paper: Product used for cleaning


Hidden State result (amp=12): .

In the following list, the third


Hidden State result (amp=15): song "Someone Like You" by Ade


Hidden State result (amp=18): to name a few.

These are


Hidden State result (amp=21): ___: British singer-songwriter, and


Layer 8

Hidden State result (amp=3): : These are just a few examples of the many


Hidden State result (amp=6): : These three things may not seem related, but


Hidden State result (amp=9): , Pink Floyd: English rock band,


Hidden State result (amp=12): to name a few.

Answer:



Hidden State result (amp=15): , and Fear and Loathing in Las


Hidden State result (amp=18): , and the FIFA World Cup: International soccer


Hidden State result (amp=21): 's new album "Everything Now" is



### Hidden States Amplification - Results Analysis
When looking at layers prior to 9, we can see music and British culture contextualization which was not obtained before hand.

We even see the token at layer 7 interpreted as a "British singer-songwriter" when multiplied by a big multiplier, which is almost the correct meaning of the entire sentence.