# Unpacking Function Vectors

**Jonathan Keane | November 16, 2025**

To understand more of the nuances behind what function vectors are and how they are computed, I wanted to do some extra logging and documenting of how this works in practice before proceeding on to more experiments.

In [5]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [6]:
import os, re, json
import torch, numpy as np

import sys
sys.path.append('..')
torch.set_grad_enabled(False)

from src.utils.extract_utils import get_mean_head_activations, compute_universal_function_vector
from src.utils.intervention_utils import fv_intervention_natural_text, function_vector_intervention
from src.utils.model_utils import load_gpt_model_and_tokenizer
from src.utils.prompt_utils import load_dataset, word_pairs_to_prompt_data, create_prompt
from src.utils.eval_utils import decode_to_vocab, sentence_eval

## Load model & tokenizer

In [7]:
model_name = 'EleutherAI/gpt-j-6b'
model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name)
EDIT_LAYER = 9

Loading:  EleutherAI/gpt-j-6b


Some weights of the model checkpoint at EleutherAI/gpt-j-6b were not used when initializing GPTJForCausalLM: ['transformer.h.0.attn.bias', 'transformer.h.0.attn.masked_bias', 'transformer.h.1.attn.bias', 'transformer.h.1.attn.masked_bias', 'transformer.h.10.attn.bias', 'transformer.h.10.attn.masked_bias', 'transformer.h.11.attn.bias', 'transformer.h.11.attn.masked_bias', 'transformer.h.12.attn.bias', 'transformer.h.12.attn.masked_bias', 'transformer.h.13.attn.bias', 'transformer.h.13.attn.masked_bias', 'transformer.h.14.attn.bias', 'transformer.h.14.attn.masked_bias', 'transformer.h.15.attn.bias', 'transformer.h.15.attn.masked_bias', 'transformer.h.16.attn.bias', 'transformer.h.16.attn.masked_bias', 'transformer.h.17.attn.bias', 'transformer.h.17.attn.masked_bias', 'transformer.h.18.attn.bias', 'transformer.h.18.attn.masked_bias', 'transformer.h.19.attn.bias', 'transformer.h.19.attn.masked_bias', 'transformer.h.2.attn.bias', 'transformer.h.2.attn.masked_bias', 'transformer.h.20.attn.bi

## Dataset

For a series of different tasks, there are just pairs of input text and their corresponding output text. In the example below, we are working with antonym pairs.

In [9]:
dataset = load_dataset('antonym', seed=0)
print(json.dumps(dataset['train'][:4], indent=2))

{
  "input": [
    "limitless",
    "wake",
    "elevate",
    "push"
  ],
  "output": [
    "limited",
    "sleep",
    "depress",
    "pull"
  ]
}


# Compute Task-Conditioned Mean Activations

To create a function vector, you need to have activations of tokens from places where in-context learning (ICL) was occurring across ***all*** attention heads in the transformer (although only some will be used in practice). To do this, for a series of trials, 

In [None]:
mean_activations = get_mean_head_activations(dataset, model, model_config, tokenizer, n_icl_examples=10, N_TRIALS=100)

In [None]:
from src.utils.prompt_utils import get_dummy_token_labels
from src.utils.extract_utils import gather_attn_activations
from IPython.display import display, Markdown

def get_mean_head_activations_unpacked(dataset, model, model_config, tokenizer, n_icl_examples = 10, N_TRIALS = 10, shuffle_labels=False, prefixes=None, separators=None, filter_set=None):
    def split_activations_by_head(activations, model_config):
        new_shape = activations.size()[:-1] + (model_config['n_heads'], model_config['resid_dim']//model_config['n_heads']) # split by head: + (n_attn_heads, hidden_size/n_attn_heads)
        activations = activations.view(*new_shape)  # (batch_size, n_tokens, n_heads, head_hidden_dim)
        return activations

    n_test_examples = 1
    if prefixes is not None and separators is not None:
        dummy_labels = get_dummy_token_labels(n_icl_examples, tokenizer=tokenizer, prefixes=prefixes, separators=separators, model_config=model_config)
    else:
        dummy_labels = get_dummy_token_labels(n_icl_examples, tokenizer=tokenizer, model_config=model_config)
    print(len(dummy_labels))
    activation_storage = torch.zeros(N_TRIALS, model_config['n_layers'], model_config['n_heads'], len(dummy_labels), model_config['resid_dim']//model_config['n_heads'])

    if filter_set is None:
        filter_set = np.arange(len(dataset['valid']))

    # If the model already prepends a bos token by default, we don't want to add one
    prepend_bos =  False if model_config['prepend_bos'] else True

    word_pair_info = []
    for n in range(N_TRIALS):
        word_pairs = dataset['train'][np.random.choice(len(dataset['train']),n_icl_examples, replace=False)]
        word_pairs_test = dataset['valid'][np.random.choice(filter_set,n_test_examples, replace=False)]
        if prefixes is not None and separators is not None:
            prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test, prepend_bos_token=prepend_bos, 
                                                    shuffle_labels=shuffle_labels, prefixes=prefixes, separators=separators)
        else:
            prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test, prepend_bos_token=prepend_bos, shuffle_labels=shuffle_labels)
        word_pair_info.append({'word_pairs': word_pairs, 'word_pairs_test': word_pairs_test})
        # print(prompt_data)
        activations_td,idx_map,idx_avg = gather_attn_activations(prompt_data=prompt_data, 
                                                            layers = model_config['attn_hook_names'], 
                                                            dummy_labels=dummy_labels, 
                                                            model=model, 
                                                            tokenizer=tokenizer, 
                                                            model_config=model_config)
        # print(activations_td)
        # print(idx_map)
        # print(idx_avg)
        # print(activations_td[layer].input.shape for layer in model_config['attn_hook_names'])
        stack_initial = torch.vstack([split_activations_by_head(activations_td[layer].input, model_config) for layer in model_config['attn_hook_names']]).permute(0,2,1,3)
        stack_filtered = stack_initial[:,:,list(idx_map.keys())]
        for (i,j) in idx_avg.values():
            stack_filtered[:,:,idx_map[i]] = stack_initial[:,:,i:j+1].mean(axis=2) # Average activations of multi-token words across all its tokens
        
        activation_storage[n] = stack_filtered

    print(activation_storage.shape)
    mean_activations = activation_storage.mean(dim=0) # averaging over trials (not tokens => 97 represents n_tokens)
    print(mean_activations.shape)
    return mean_activations, word_pair_info

mean_activations, word_pair_info = get_mean_head_activations_unpacked(dataset, model, model_config, tokenizer)

def tabulate(data, headers):
    table = "| " + " | ".join(headers) + " |\n"
    table += "| " + " | ".join(["---"] * len(headers)) + " |\n"
    for row in data[1:]:
        table += "| " + " | ".join(row) + " |\n"
    return table

n_samples_display = 2
table_data_headers = []
table_data = []
for sample_idx in range(n_samples_display):
    sample_data = word_pair_info[sample_idx]
    sample_table_data = [
        [
            sample_data["word_pairs"]["input"][i], 
            sample_data["word_pairs"]["output"][i]
        ] 
        for i in range(len(sample_data["word_pairs"]["input"]))
    ]
    sample_table_data.append([
        f'**Query**: {sample_data["word_pairs_test"]["input"][0]}', 
        f'**Target**: {sample_data["word_pairs_test"]["output"][0]}',
    ])
    if len(table_data) == 0:
        table_data = sample_table_data
    else:
        for i in range(len(sample_table_data)):
            table_data[i].extend(sample_table_data[i])
    table_data_headers.extend([f"Trial {sample_idx+1}: Input", f"Trial {sample_idx+1}: Output"])
    
display(Markdown("### In-Context Examples"))
display(Markdown(tabulate(table_data, table_data_headers)))

97
torch.Size([10, 28, 16, 97, 256])
torch.Size([28, 16, 97, 256])


### In-Context Examples

| Trial 1: Input | Trial 1: Output | Trial 2: Input | Trial 2: Output |
| --- | --- | --- | --- |
| unusual | usual | relational | isolated |
| soil | sky | worried | relaxed |
| sturdy | fragile | artificial | natural |
| mainland | island | unauthorized | authorized |
| able | unable | wee | large |
| daylight | nighttime | civil | uncivilized |
| glad | sad | capable | incapable |
| resent | cherish | federal | state |
| pleased | displeased | integration | differentiation |
| **Query**: transmitter | **Target**: receiver | **Query**: regain | **Target**: lose |


## Compute function vector (FV)

In [6]:
FV, top_heads = compute_universal_function_vector(mean_activations, model, model_config, n_top_heads=10)

## Prompt Creation - ICL, Shuffled-Label, Zero-Shot, and Natural Text

In [7]:
# Sample ICL example pairs, and a test word
dataset = load_dataset('antonym')
word_pairs = dataset['train'][:5]
test_pair = dataset['test'][21]

prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=test_pair, prepend_bos_token=True)
sentence = create_prompt(prompt_data)
print("ICL prompt:\n", repr(sentence), '\n\n')

shuffled_prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=test_pair, prepend_bos_token=True, shuffle_labels=True)
shuffled_sentence = create_prompt(shuffled_prompt_data)
print("Shuffled ICL Prompt:\n", repr(shuffled_sentence), '\n\n')

zeroshot_prompt_data = word_pairs_to_prompt_data({'input':[], 'output':[]}, query_target_pair=test_pair, prepend_bos_token=True, shuffle_labels=True)
zeroshot_sentence = create_prompt(zeroshot_prompt_data)
print("Zero-Shot Prompt:\n", repr(zeroshot_sentence))

ICL prompt:
 '<|endoftext|>Q: hardware\nA: software\n\nQ: fascism\nA: democracy\n\nQ: incompatible\nA: compatible\n\nQ: illness\nA: health\n\nQ: notice\nA: ignore\n\nQ: increase\nA:' 


Shuffled ICL Prompt:
 '<|endoftext|>Q: hardware\nA: health\n\nQ: fascism\nA: ignore\n\nQ: incompatible\nA: compatible\n\nQ: illness\nA: software\n\nQ: notice\nA: democracy\n\nQ: increase\nA:' 


Zero-Shot Prompt:
 '<|endoftext|>Q: increase\nA:'


## Evaluation

### Clean ICL Prompt

In [8]:
print(sentence)

<|endoftext|>Q: hardware
A: software

Q: fascism
A: democracy

Q: incompatible
A: compatible

Q: illness
A: health

Q: notice
A: ignore

Q: increase
A:


In [9]:
# Check model's ICL answer
clean_logits = sentence_eval(sentence, [test_pair['output']], model, tokenizer, compute_nll=False)

print("Input Sentence:", repr(sentence), '\n')
print(f"Input Query: {repr(test_pair['input'])}, Target: {repr(test_pair['output'])}\n")
print("ICL Prompt Top K Vocab Probs:\n", decode_to_vocab(clean_logits, tokenizer, k=5), '\n')

Input Sentence: '<|endoftext|>Q: hardware\nA: software\n\nQ: fascism\nA: democracy\n\nQ: incompatible\nA: compatible\n\nQ: illness\nA: health\n\nQ: notice\nA: ignore\n\nQ: increase\nA:' 

Input Query: 'increase', Target: 'decrease'

ICL Prompt Top K Vocab Probs:
 [(' decrease', 0.73675), (' reduce', 0.07769), (' increase', 0.03435), (' decline', 0.01574), (' decreased', 0.01037)] 



### Corrupted ICL Prompt

In [10]:
print(shuffled_sentence)

<|endoftext|>Q: hardware
A: health

Q: fascism
A: ignore

Q: incompatible
A: compatible

Q: illness
A: software

Q: notice
A: democracy

Q: increase
A:


In [11]:
# Perform an intervention on the shuffled setting
clean_logits, interv_logits = function_vector_intervention(shuffled_sentence, [test_pair['output']], EDIT_LAYER, FV, model, model_config, tokenizer)

print("Input Sentence:", repr(shuffled_sentence), '\n')
print(f"Input Query: {repr(test_pair['input'])}, Target: {repr(test_pair['output'])}\n")
print("Few-Shot-Shuffled Prompt Top K Vocab Probs:\n", decode_to_vocab(clean_logits, tokenizer, k=5), '\n')
print("Shuffled Prompt+FV Top K Vocab Probs:\n", decode_to_vocab(interv_logits, tokenizer, k=5))

Input Sentence: '<|endoftext|>Q: hardware\nA: health\n\nQ: fascism\nA: ignore\n\nQ: incompatible\nA: compatible\n\nQ: illness\nA: software\n\nQ: notice\nA: democracy\n\nQ: increase\nA:' 

Input Query: 'increase', Target: 'decrease'

Few-Shot-Shuffled Prompt Top K Vocab Probs:
 [(' decrease', 0.26971), (' reduce', 0.02771), (' increase', 0.02549), (' decline', 0.01609), (' growth', 0.00921)] 

Shuffled Prompt+FV Top K Vocab Probs:
 [(' decrease', 0.75362), (' reduce', 0.03991), (' decline', 0.02322), (' increase', 0.00935), (' reduction', 0.00837)]


### Zero-Shot Prompt

In [12]:
# Intervention on the zero-shot prompt
clean_logits, interv_logits = function_vector_intervention(zeroshot_sentence, [test_pair['output']], EDIT_LAYER, FV, model, model_config, tokenizer)

print("Input Sentence:", repr(zeroshot_sentence), '\n')
print(f"Input Query: {repr(test_pair['input'])}, Target: {repr(test_pair['output'])}\n")
print("Zero-Shot Top K Vocab Probs:\n", decode_to_vocab(clean_logits, tokenizer, k=5), '\n')
print("Zero-Shot+FV Vocab Top K Vocab Probs:\n", decode_to_vocab(interv_logits, tokenizer, k=5))

Input Sentence: '<|endoftext|>Q: increase\nA:' 

Input Query: 'increase', Target: 'decrease'

Zero-Shot Top K Vocab Probs:
 [(' increase', 0.14925), (' yes', 0.02272), (' I', 0.02189), (' the', 0.0212), (' 1', 0.01418)] 

Zero-Shot+FV Vocab Top K Vocab Probs:
 [(' decrease', 0.25627), (' increase', 0.1799), (' reduce', 0.03497), (' improve', 0.00987), ('\n', 0.00582)]


### Natural Text Prompt

In [13]:
sentence = f"The word \"{test_pair['input']}\" means"
co, io = fv_intervention_natural_text(sentence, EDIT_LAYER, FV, model, model_config, tokenizer, max_new_tokens=10)


print("Input Sentence: ", repr(sentence))
print("GPT-J:" , repr(tokenizer.decode(co.squeeze())))
print("GPT-J+FV:", repr(tokenizer.decode(io.squeeze())), '\n')

Input Sentence:  'The word "increase" means'
GPT-J: 'The word "increase" means "to make larger, to enlarge, to expand'
GPT-J+FV: 'The word "increase" means "decrease" in the Bible.\n' 



### Jonny Additions

### Zero-Shot Application to New Q/A Pair

This is cool and all that it remembers to say decrease for this example, but is this an actual "antonym" function vector (this is how I had originally interpreted the paper, but the examples don't show this off). I test with a couple pairs below:

In [14]:
new_test_pair = {'input': 'happy', 'output': 'sad'}
sentence = f"The word \"{new_test_pair['input']}\" means"
co, io = fv_intervention_natural_text(sentence, EDIT_LAYER, FV, model, model_config, tokenizer, max_new_tokens=10)

print("Input Sentence: ", repr(sentence))
print("GPT-J:" , repr(tokenizer.decode(co.squeeze())))
print("GPT-J+FV:", repr(tokenizer.decode(io.squeeze())), '\n')

Input Sentence:  'The word "happy" means'
GPT-J: 'The word "happy" means different things to different people. For some, it'
GPT-J+FV: 'The word "happy" means "sad" to me.\n\nI' 



In [15]:
new_test_pair = {'input': 'yin', 'output': 'yang'}
sentence = f"The word \"{new_test_pair['input']}\" means"
co, io = fv_intervention_natural_text(sentence, EDIT_LAYER, FV, model, model_config, tokenizer, max_new_tokens=10)

print("Input Sentence: ", repr(sentence))
print("GPT-J:" , repr(tokenizer.decode(co.squeeze())))
print("GPT-J+FV:", repr(tokenizer.decode(io.squeeze())), '\n')

Input Sentence:  'The word "yin" means'
GPT-J: 'The word "yin" means "darkness" or "black" in Chinese'
GPT-J+FV: 'The word "yin" means "yang." The word "yin" means' 



### Questions

1. What does the embedding of the original token look like if it does not have as much in-context learning effecting it?

    - If you increase/decrease the number of samples, is there a convergence to a single vector that represents opposites? (I feel like the embedding should start pretty spread out if you ablate certain samples that are part of the context and then converge as you get to a more complete set)

2. Is it the embedding of the colon that is changing or is it the tokens around it that are changing in terms of their position?

In [16]:
# Check model's ICL answer
clean_logits = sentence_eval(sentence, [test_pair['output']], model, tokenizer, compute_nll=False)
clean_logits.shape

torch.Size([1, 50400])

In [17]:
print(model)

GPTJForCausalLM(
  (transformer): GPTJModel(
    (wte): Embedding(50400, 4096)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-27): 28 x GPTJBlock(
        (ln_1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (attn): GPTJAttention(
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (out_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): GPTJMLP(
          (fc_in): Linear(in_features=4096, out_features=16384, bias=True)
          (fc_out): Linear(in_features=16384, out_features=4096, bias=True)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f)

In [18]:
# Perform an intervention on the shuffled setting
clean_logits, interv_logits = function_vector_intervention(shuffled_sentence, [test_pair['output']], EDIT_LAYER, FV, model, model_config, tokenizer)

print(f"function_vector_intervention Inputs:\n"
      f"\tsentence='str:len={len(shuffled_sentence)}',\n"
      f"\target_outputs={[test_pair['output']]},\n"
      f"\tEDIT_LAYER={EDIT_LAYER},\n"
      f"\tFV={FV.shape},\n"
      f"\tmodel=model,\n"
      f"\tmodel_config=model_config,\n"
      f"\ttokenizer=tokenizer")
print(f"Outputs: clean_logits={clean_logits.shape}, interv_logits={interv_logits.shape}")

function_vector_intervention Inputs:
	sentence='str:len=151',
	arget_outputs=['decrease'],
	EDIT_LAYER=9,
	FV=torch.Size([1, 4096]),
	model=model,
	model_config=model_config,
	tokenizer=tokenizer
Outputs: clean_logits=torch.Size([1, 50400]), interv_logits=torch.Size([1, 50400])


In [72]:
FV, top_heads = compute_universal_function_vector(mean_activations, model, model_config, n_top_heads=10)
print("Computed FV shape:", FV.shape)
print("Top Heads:", top_heads)

Computed FV shape: torch.Size([1, 4096])
Top Heads: [(15, 5, 0.0587), (9, 14, 0.0584), (12, 10, 0.0526), (8, 1, 0.0445), (11, 0, 0.0445), (13, 13, 0.019), (8, 0, 0.0184), (14, 9, 0.016), (9, 2, 0.0127), (24, 6, 0.0113)]


In [None]:
print(mean_activations.shape) # (n_layers, n_heads, n_tokens, head_hidden_dim)

torch.Size([28, 16, 97, 256])


In [37]:
print(model_config["n_heads"], "*", 256, "=", model_config["n_heads"] * 256)

16 * 256 = 4096


In [None]:
dataset = load_dataset('antonym', seed=0)
mean_activations = get_mean_head_activations(dataset, model, model_config, tokenizer)

In [None]:
from src.utils.prompt_utils import get_dummy_token_labels
from src.utils.extract_utils import gather_attn_activations

def get_mean_head_activations_unpacked(dataset, model, model_config, tokenizer, n_icl_examples = 10, N_TRIALS = 1, shuffle_labels=False, prefixes=None, separators=None, filter_set=None):
    def split_activations_by_head(activations, model_config):
        new_shape = activations.size()[:-1] + (model_config['n_heads'], model_config['resid_dim']//model_config['n_heads']) # split by head: + (n_attn_heads, hidden_size/n_attn_heads)
        activations = activations.view(*new_shape)  # (batch_size, n_tokens, n_heads, head_hidden_dim)
        return activations

    n_test_examples = 1
    if prefixes is not None and separators is not None:
        dummy_labels = get_dummy_token_labels(n_icl_examples, tokenizer=tokenizer, prefixes=prefixes, separators=separators, model_config=model_config)
    else:
        dummy_labels = get_dummy_token_labels(n_icl_examples, tokenizer=tokenizer, model_config=model_config)
    print(len(dummy_labels))
    activation_storage = torch.zeros(N_TRIALS, model_config['n_layers'], model_config['n_heads'], len(dummy_labels), model_config['resid_dim']//model_config['n_heads'])

    if filter_set is None:
        filter_set = np.arange(len(dataset['valid']))

    # If the model already prepends a bos token by default, we don't want to add one
    prepend_bos =  False if model_config['prepend_bos'] else True

    all_prompt_data = []
    for n in range(N_TRIALS):
        word_pairs = dataset['train'][np.random.choice(len(dataset['train']),n_icl_examples, replace=False)]
        word_pairs_test = dataset['valid'][np.random.choice(filter_set,n_test_examples, replace=False)]
        if prefixes is not None and separators is not None:
            prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test, prepend_bos_token=prepend_bos, 
                                                    shuffle_labels=shuffle_labels, prefixes=prefixes, separators=separators)
        else:
            prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=word_pairs_test, prepend_bos_token=prepend_bos, shuffle_labels=shuffle_labels)
        all_prompt_data.append(prompt_data)
        # print(prompt_data)
        activations_td,idx_map,idx_avg = gather_attn_activations(prompt_data=prompt_data, 
                                                            layers = model_config['attn_hook_names'], 
                                                            dummy_labels=dummy_labels, 
                                                            model=model, 
                                                            tokenizer=tokenizer, 
                                                            model_config=model_config)
        print(activations_td)
        print(idx_map)
        print(idx_avg)
        # print(activations_td[layer].input.shape for layer in model_config['attn_hook_names'])
        stack_initial = torch.vstack([split_activations_by_head(activations_td[layer].input, model_config) for layer in model_config['attn_hook_names']]).permute(0,2,1,3)
        stack_filtered = stack_initial[:,:,list(idx_map.keys())]
        for (i,j) in idx_avg.values():
            stack_filtered[:,:,idx_map[i]] = stack_initial[:,:,i:j+1].mean(axis=2) # Average activations of multi-token words across all its tokens
        
        activation_storage[n] = stack_filtered

    print(activation_storage.shape)
    mean_activations = activation_storage.mean(dim=0) # averaging over trials (not tokens => 97 represents n_tokens)
    print(mean_activations.shape)
    return mean_activations, all_prompt_data

mean_activations, all_prompt_data = get_mean_head_activations_unpacked(dataset, model, model_config, tokenizer)

97
TraceDict([('transformer.h.0.attn.out_proj', <baukit.nethook.Trace object at 0x73baf63c6410>), ('transformer.h.1.attn.out_proj', <baukit.nethook.Trace object at 0x73baf63c6800>), ('transformer.h.2.attn.out_proj', <baukit.nethook.Trace object at 0x73baf63c6290>), ('transformer.h.3.attn.out_proj', <baukit.nethook.Trace object at 0x73baf63c6260>), ('transformer.h.4.attn.out_proj', <baukit.nethook.Trace object at 0x73baf63c4370>), ('transformer.h.5.attn.out_proj', <baukit.nethook.Trace object at 0x73baf63c5c30>), ('transformer.h.6.attn.out_proj', <baukit.nethook.Trace object at 0x73baf7772e90>), ('transformer.h.7.attn.out_proj', <baukit.nethook.Trace object at 0x73baf6425270>), ('transformer.h.8.attn.out_proj', <baukit.nethook.Trace object at 0x73baf64255a0>), ('transformer.h.9.attn.out_proj', <baukit.nethook.Trace object at 0x73baf6425bd0>), ('transformer.h.10.attn.out_proj', <baukit.nethook.Trace object at 0x73baf6424e80>), ('transformer.h.11.attn.out_proj', <baukit.nethook.Trace obje

In [62]:
from src.utils.prompt_utils import get_token_meta_labels

query = all_prompt_data[6]['query_target']['input']
token_labels, prompt_string = get_token_meta_labels(prompt_data, tokenizer, query, prepend_bos=model_config['prepend_bos'])
print(prompt_string)

<|endoftext|>Q: hardware
A: software

Q: fascism
A: democracy

Q: incompatible
A: compatible

Q: illness
A: health

Q: notice
A: ignore

Q: lavish
A:


In [69]:
print(model_config["name_or_path"])
print(dataset["train"])
print(json.dumps(dataset["train"][:5], indent=2))

EleutherAI/gpt-j-6b
ICLDataset({
	features: ['input', 'output'],
	num_rows: 1678
})
{
  "input": [
    "hardware",
    "fascism",
    "incompatible",
    "illness",
    "notice"
  ],
  "output": [
    "software",
    "democracy",
    "compatible",
    "health",
    "ignore"
  ]
}
