In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os, re, json
import torch, numpy as np

import sys
sys.path.append('..')
torch.set_grad_enabled(False)

from src.utils.extract_utils import get_mean_head_activations, compute_universal_function_vector
from src.utils.intervention_utils import fv_intervention_natural_text, function_vector_intervention
from src.utils.model_utils import load_gpt_model_and_tokenizer
from src.utils.prompt_utils import load_dataset, word_pairs_to_prompt_data, create_prompt
from src.utils.eval_utils import decode_to_vocab, sentence_eval

## Load model & tokenizer

In [3]:
model_name = 'EleutherAI/gpt-j-6b'
model, tokenizer, model_config = load_gpt_model_and_tokenizer(model_name)
EDIT_LAYER = 9

Loading:  EleutherAI/gpt-j-6b


pytorch_model.bin:   0%|          | 0.00/24.2G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/24.2G [00:00<?, ?B/s]

Some weights of the model checkpoint at EleutherAI/gpt-j-6b were not used when initializing GPTJForCausalLM: ['transformer.h.0.attn.bias', 'transformer.h.0.attn.masked_bias', 'transformer.h.1.attn.bias', 'transformer.h.1.attn.masked_bias', 'transformer.h.10.attn.bias', 'transformer.h.10.attn.masked_bias', 'transformer.h.11.attn.bias', 'transformer.h.11.attn.masked_bias', 'transformer.h.12.attn.bias', 'transformer.h.12.attn.masked_bias', 'transformer.h.13.attn.bias', 'transformer.h.13.attn.masked_bias', 'transformer.h.14.attn.bias', 'transformer.h.14.attn.masked_bias', 'transformer.h.15.attn.bias', 'transformer.h.15.attn.masked_bias', 'transformer.h.16.attn.bias', 'transformer.h.16.attn.masked_bias', 'transformer.h.17.attn.bias', 'transformer.h.17.attn.masked_bias', 'transformer.h.18.attn.bias', 'transformer.h.18.attn.masked_bias', 'transformer.h.19.attn.bias', 'transformer.h.19.attn.masked_bias', 'transformer.h.2.attn.bias', 'transformer.h.2.attn.masked_bias', 'transformer.h.20.attn.bi

## Load dataset and Compute task-conditioned mean activations

In [4]:
dataset = load_dataset('antonym', seed=0)
mean_activations = get_mean_head_activations(dataset, model, model_config, tokenizer)

## Compute function vector (FV)

In [5]:
FV, top_heads = compute_universal_function_vector(mean_activations, model, model_config, n_top_heads=10)

## Prompt Creation - ICL, Shuffled-Label, Zero-Shot, and Natural Text

In [12]:
# Sample ICL example pairs, and a test word
dataset = load_dataset('antonym')
word_pairs = dataset['train'][:5]
test_pair = dataset['test'][21]

prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=test_pair, prepend_bos_token=True)
sentence = create_prompt(prompt_data)
print("ICL prompt:\n", repr(sentence), '\n\n')

shuffled_prompt_data = word_pairs_to_prompt_data(word_pairs, query_target_pair=test_pair, prepend_bos_token=True, shuffle_labels=True)
shuffled_sentence = create_prompt(shuffled_prompt_data)
print("Shuffled ICL Prompt:\n", repr(shuffled_sentence), '\n\n')

zeroshot_prompt_data = word_pairs_to_prompt_data({'input':[], 'output':[]}, query_target_pair=test_pair, prepend_bos_token=True, shuffle_labels=True)
zeroshot_sentence = create_prompt(zeroshot_prompt_data)
print("Zero-Shot Prompt:\n", repr(zeroshot_sentence))

ICL prompt:
 '<|endoftext|>Q: hardware\nA: software\n\nQ: fascism\nA: democracy\n\nQ: incompatible\nA: compatible\n\nQ: illness\nA: health\n\nQ: notice\nA: ignore\n\nQ: increase\nA:' 


Shuffled ICL Prompt:
 '<|endoftext|>Q: hardware\nA: democracy\n\nQ: fascism\nA: health\n\nQ: incompatible\nA: compatible\n\nQ: illness\nA: ignore\n\nQ: notice\nA: software\n\nQ: increase\nA:' 


Zero-Shot Prompt:
 '<|endoftext|>Q: increase\nA:'


## Evaluation

### Clean ICL Prompt

In [14]:
print(sentence)

<|endoftext|>Q: hardware
A: software

Q: fascism
A: democracy

Q: incompatible
A: compatible

Q: illness
A: health

Q: notice
A: ignore

Q: increase
A:


In [13]:
# Check model's ICL answer
clean_logits = sentence_eval(sentence, [test_pair['output']], model, tokenizer, compute_nll=False)

print("Input Sentence:", repr(sentence), '\n')
print(f"Input Query: {repr(test_pair['input'])}, Target: {repr(test_pair['output'])}\n")
print("ICL Prompt Top K Vocab Probs:\n", decode_to_vocab(clean_logits, tokenizer, k=5), '\n')

Input Sentence: '<|endoftext|>Q: hardware\nA: software\n\nQ: fascism\nA: democracy\n\nQ: incompatible\nA: compatible\n\nQ: illness\nA: health\n\nQ: notice\nA: ignore\n\nQ: increase\nA:' 

Input Query: 'increase', Target: 'decrease'

ICL Prompt Top K Vocab Probs:
 [(' decrease', 0.73675), (' reduce', 0.07769), (' increase', 0.03435), (' decline', 0.01574), (' decreased', 0.01037)] 



### Corrupted ICL Prompt

In [15]:
print(shuffled_sentence)

<|endoftext|>Q: hardware
A: democracy

Q: fascism
A: health

Q: incompatible
A: compatible

Q: illness
A: ignore

Q: notice
A: software

Q: increase
A:


In [8]:
# Perform an intervention on the shuffled setting
clean_logits, interv_logits = function_vector_intervention(shuffled_sentence, [test_pair['output']], EDIT_LAYER, FV, model, model_config, tokenizer)

print("Input Sentence:", repr(shuffled_sentence), '\n')
print(f"Input Query: {repr(test_pair['input'])}, Target: {repr(test_pair['output'])}\n")
print("Few-Shot-Shuffled Prompt Top K Vocab Probs:\n", decode_to_vocab(clean_logits, tokenizer, k=5), '\n')
print("Shuffled Prompt+FV Top K Vocab Probs:\n", decode_to_vocab(interv_logits, tokenizer, k=5))

Input Sentence: '<|endoftext|>Q: hardware\nA: ignore\n\nQ: fascism\nA: health\n\nQ: incompatible\nA: software\n\nQ: illness\nA: compatible\n\nQ: notice\nA: democracy\n\nQ: increase\nA:' 

Input Query: 'increase', Target: 'decrease'

Few-Shot-Shuffled Prompt Top K Vocab Probs:
 [(' decrease', 0.10689), (' increase', 0.01887), (' notice', 0.01787), (' reduce', 0.01295), (' democracy', 0.01078)] 

Shuffled Prompt+FV Top K Vocab Probs:
 [(' decrease', 0.70746), (' reduce', 0.04102), (' decline', 0.02279), (' increase', 0.00984), (' reduction', 0.00724)]


### Zero-Shot Prompt

In [9]:
# Intervention on the zero-shot prompt
clean_logits, interv_logits = function_vector_intervention(zeroshot_sentence, [test_pair['output']], EDIT_LAYER, FV, model, model_config, tokenizer)

print("Input Sentence:", repr(zeroshot_sentence), '\n')
print(f"Input Query: {repr(test_pair['input'])}, Target: {repr(test_pair['output'])}\n")
print("Zero-Shot Top K Vocab Probs:\n", decode_to_vocab(clean_logits, tokenizer, k=5), '\n')
print("Zero-Shot+FV Vocab Top K Vocab Probs:\n", decode_to_vocab(interv_logits, tokenizer, k=5))

Input Sentence: '<|endoftext|>Q: increase\nA:' 

Input Query: 'increase', Target: 'decrease'

Zero-Shot Top K Vocab Probs:
 [(' increase', 0.14925), (' yes', 0.02272), (' I', 0.02189), (' the', 0.0212), (' 1', 0.01418)] 

Zero-Shot+FV Vocab Top K Vocab Probs:
 [(' decrease', 0.27255), (' increase', 0.17896), (' reduce', 0.03427), (' improve', 0.00906), ('\n', 0.00579)]


### Natural Text Prompt

In [10]:
sentence = f"The word \"{test_pair['input']}\" means"
co, io = fv_intervention_natural_text(sentence, EDIT_LAYER, FV, model, model_config, tokenizer, max_new_tokens=10)


print("Input Sentence: ", repr(sentence))
print("GPT-J:" , repr(tokenizer.decode(co.squeeze())))
print("GPT-J+FV:", repr(tokenizer.decode(io.squeeze())), '\n')

Input Sentence:  'The word "increase" means'
GPT-J: 'The word "increase" means "to make larger, to enlarge, to expand'
GPT-J+FV: 'The word "increase" means "decrease" in the Bible.\n' 



### Jonny Additions

### Zero-Shot Application to New Q/A Pair

This is cool and all that it remembers to say decrease for this example, but is this an actual "antonym" function vector (this is how I had originally interpreted the paper, but the examples don't show this off). I test with a couple pairs below:

In [19]:
new_test_pair = {'input': 'happy', 'output': 'sad'}
sentence = f"The word \"{new_test_pair['input']}\" means"
co, io = fv_intervention_natural_text(sentence, EDIT_LAYER, FV, model, model_config, tokenizer, max_new_tokens=10)

print("Input Sentence: ", repr(sentence))
print("GPT-J:" , repr(tokenizer.decode(co.squeeze())))
print("GPT-J+FV:", repr(tokenizer.decode(io.squeeze())), '\n')

Input Sentence:  'The word "happy" means'
GPT-J: 'The word "happy" means different things to different people. For some, it'
GPT-J+FV: 'The word "happy" means "sad" in the dictionary.\n\n' 



In [21]:
new_test_pair = {'input': 'yin', 'output': 'yang'}
sentence = f"The word \"{new_test_pair['input']}\" means"
co, io = fv_intervention_natural_text(sentence, EDIT_LAYER, FV, model, model_config, tokenizer, max_new_tokens=10)

print("Input Sentence: ", repr(sentence))
print("GPT-J:" , repr(tokenizer.decode(co.squeeze())))
print("GPT-J+FV:", repr(tokenizer.decode(io.squeeze())), '\n')

Input Sentence:  'The word "yin" means'
GPT-J: 'The word "yin" means "darkness" or "black" in Chinese'
GPT-J+FV: 'The word "yin" means "yang." The word "yin" means' 



### Questions

1. What does the embedding of the original token look like if it does not have as much in-context learning effecting it?

    - If you increase/decrease the number of samples, is there a convergence to a single vector that represents opposites? (I feel like the embedding should start pretty spread out if you ablate certain samples that are part of the context and then converge as you get to a more complete set)

2. Is it the embedding of the colon that is changing or is it the tokens around it that are changing in terms of their position?