<a href="https://colab.research.google.com/github/wlg100/numseqcont_circuit_expms/blob/main/notebook_templates/vec_arith_template.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" align="left"/></a>&nbsp;or in a local notebook.

# Setup

In [None]:
"""
This is a demo of the findings in the paper:
Language Models Implement Simple Word2Vec-Style Vector Arithmetic
and will run through the main result, which is that while solving some task,
e.g, recalling the capital city of some location, a language model
(GPT2-medium in this case) will predict that city by transforming the country name
into the capital with an additive update:
Poland+o_city=Warsaw
and that the exact vector the model infers to do this (o_city) implements the same
function on other examples when we add it to the residual stream of the forward
pass, e.g., China+o_cirt=Beijing
"""

'\nThis is a demo of the findings in the paper:\nLanguage Models Implement Simple Word2Vec-Style Vector Arithmetic\nand will run through the main result, which is that while solving some task,\ne.g, recalling the capital city of some location, a language model\n(GPT2-medium in this case) will predict that city by transforming the country name\ninto the capital with an additive update:\nPoland+o_city=Warsaw\nand that the exact vector the model infers to do this (o_city) implements the same\nfunction on other examples when we add it to the residual stream of the forward\npass, e.g., China+o_cirt=Beijing\n'

In [None]:
!git clone https://github.com/jmerullo/lm_vector_arithmetic.git

Cloning into 'lm_vector_arithmetic'...
remote: Enumerating objects: 35, done.[K
remote: Counting objects: 100% (35/35), done.[K
remote: Compressing objects: 100% (30/30), done.[K
remote: Total 35 (delta 10), reused 21 (delta 4), pack-reused 0[K
Unpacking objects: 100% (35/35), 236.05 KiB | 1.02 MiB/s, done.


In [None]:
cd lm_vector_arithmetic

/content/lm_vector_arithmetic


In [None]:
%%capture
!pip install -r requirements.txt

In [None]:
from modeling import *
import matplotlib.pyplot as plt

In [None]:
model, tokenizer = load_gpt2('gpt2-medium')
model = model.float()
wrapper = GPT2Wrapper(model, tokenizer)

Downloading (…)lve/main/config.json:   0%|          | 0.00/718 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.52G [00:00<?, ?B/s]

In [None]:
def tokenize(text):
    inp_ids = wrapper.tokenize(text)
    str_toks = wrapper.list_decode(inp_ids[0])
    return inp_ids, str_toks

# Functions

## Get o_update from initial input

Find layer to change by finding the boundary it goes from the token it copied to the correct answer. Choose the 'layer_to_subtract' as the first layer to show change (starting at 0):

In [None]:
def get_tokIDS(test_text):
    test_ids, test_toks = tokenize(test_text)
    logits = wrapper.get_layers(test_ids)
    wrapper.print_top(logits[1:]) #skip the embedding layer
    return test_ids

In [None]:
def get_o_update(token_ids, layer_to_subtract):
    wrapper.add_hooks()
    out = wrapper.model(input_ids = token_ids, output_hidden_states=True)
    logits = out.logits
    hidden_states = out.hidden_states
    hidden_states = list(hidden_states)[1:]

    o_city = wrapper.model.activations_['mlp_' + str(layer_to_subtract)]

    layer_logits = wrapper.layer_decode(hidden_states)
    layer_logits = torch.stack(layer_logits).squeeze(-1)
    print("Original top tokens at layer "+str(layer_to_subtract))
    wrapper.print_top(layer_logits[layer_to_subtract].unsqueeze(0))

    hidden_states[layer_to_subtract]-=o_city

    layer_logits = wrapper.layer_decode(hidden_states)
    layer_logits = torch.stack(layer_logits).squeeze(-1)
    print("After subtracting mlp_" + str(layer_to_subtract))
    wrapper.print_top(layer_logits[layer_to_subtract].unsqueeze(0))

    return o_city

Original top tokens at layer 19
0  Warsaw Poland Polish Budapest Prague Moscow Berlin Kiev � Frankfurt
After subtracting mlp_19
0  Poland Warsaw Polish Budapest Prague Poles � Moscow Berlin Kiev


## Apply linear update to layers

In [None]:
def create_layer_list(X, Y):
    layer_list = []
    for layer in range(X, Y+1):
        layer_list.append("Layer {}".format(layer))
    return layer_list

from tabulate import tabulate
def create_table(list1, list2, list3):
    table_data = zip(list1, list2, list3)
    headers = ["Layer", "Orig Top Token", "New Top Token"]
    table = tabulate(table_data, headers, tablefmt="grid")
    return table

import torch.nn.functional as F
def get_decoded(logits, k=10):
    output_list = []
    for i,layer in enumerate(logits):
        output_list.append( wrapper.tokenizer.decode(F.softmax(layer,dim=-1).argsort(descending=True)[:k]) )
    return output_list

In [None]:
def linear_update_layers(tok_ids, o_update, layer_to_subtract, prevL_scale, nextL_scale, lastL_scale):
    out = wrapper.model(input_ids = tok_ids, output_hidden_states=True)
    logits = out.logits
    hidden_states = out.hidden_states
    hidden_states = list(hidden_states)[1:]

    orig_layer_logits = wrapper.layer_decode(hidden_states)
    orig_layer_logits = torch.stack(orig_layer_logits).squeeze(-1)

    orig_decoded = get_decoded(orig_layer_logits[layer_to_subtract-4:], k=1)

    # FFN output is added to residual stream, so we can add it to other parts

    # try to get these to say correct
    for i in range(layer_to_subtract-4, layer_to_subtract):
        hidden_states[i]+= (i-layer_to_subtract+ prevL_scale) * o_update

    # Try to get these to not say correct
    for i in range(layer_to_subtract, len(hidden_states) - 1):
        hidden_states[i]-= (i-layer_to_subtract+ nextL_scale) * o_update
    #The last layer can be weird. Play around with this scaling factor to see
    hidden_states[len(hidden_states) - 1]-= lastL_scale * o_update

    layer_logits = wrapper.layer_decode(hidden_states)
    layer_logits = torch.stack(layer_logits).squeeze(-1)

    new_decoded = get_decoded(layer_logits[layer_to_subtract-4:], k=1)

    table_output = create_table(create_layer_list(layer_to_subtract-4, 23), orig_decoded, new_decoded)
    print(table_output)

## Metric to measure what how many layers change

How many times does expected change appear in changed list? The "before boundary" layers should be changed to "correct output" while "after boundary" layers should be changed to "copied input" token. The higher the change, the better the update did.

In [None]:
def compare_lists(list1, list2, expected_1, expected_2):
    # count = sum(1 for x, y in zip(list1, list2) if x == y)  #element-wise similarity. more similar is worse change

    return (list1.count(expected_2) + list2.count(expected_1)) / ( len(list1) + len(list2) )

In [None]:
def measure_change_layers(tok_ids, o_update, layer_to_subtract, prevL_scale, nextL_scale, lastL_scale):
    out = wrapper.model(input_ids = tok_ids, output_hidden_states=True)
    logits = out.logits
    hidden_states = out.hidden_states
    hidden_states = list(hidden_states)[1:]

    orig_layer_logits = wrapper.layer_decode(hidden_states)
    orig_layer_logits = torch.stack(orig_layer_logits).squeeze(-1)

    orig_decoded = get_decoded(orig_layer_logits[layer_to_subtract-4:], k=1)

    # FFN output is added to residual stream, so we can add it to other parts

    # try to get these to say correct
    for i in range(layer_to_subtract-4, layer_to_subtract):
        hidden_states[i]+= (i-layer_to_subtract+ prevL_scale) * o_update

    # Try to get these to not say correct
    for i in range(layer_to_subtract, len(hidden_states) - 1):
        hidden_states[i]-= (i-layer_to_subtract+ nextL_scale) * o_update
    #The last layer can be weird. Play around with this scaling factor to see
    hidden_states[len(hidden_states) - 1]-= lastL_scale * o_update

    layer_logits = wrapper.layer_decode(hidden_states)
    layer_logits = torch.stack(layer_logits).squeeze(-1)

    new_decoded = get_decoded(layer_logits[layer_to_subtract-4:], k=1)

    expected_1 = orig_decoded[0]
    expected_2 = orig_decoded[-1]
    return compare_lists(new_decoded[0:4], new_decoded[4:], expected_1, expected_2)

## Measure effects of multiple random vectors

In [None]:
from torch.nn.functional import cosine_similarity

def gen_rand_tensor(o_upd):
    random_tensor = torch.rand(o_upd.shape)
    if torch.cuda.is_available():
        random_tensor = random_tensor.cuda()
    return random_tensor

def get_cosine_sim(random_tensor, o_upd):
    # Normalizing the vectors to unit length
    tensor1 = random_tensor / random_tensor.norm(dim=0)
    tensor2 = o_upd / o_upd.norm(dim=0)
    similarity = cosine_similarity(tensor1.unsqueeze(0), tensor2.unsqueeze(0), dim=1)
    print(similarity.item())

In [None]:
import matplotlib.pyplot as plt

def upd_hist(poland_ids, o_city, layer_pt, prevL_scale=5, nextL_scale=5, lastL_scale=0.9):
    outputs = []
    for i in range(100):
        random_tensor = torch.rand(o_city.shape)
        if torch.cuda.is_available():
            random_tensor = random_tensor.cuda()
        out_val = measure_change_layers(poland_ids, random_tensor, layer_pt, prevL_scale, nextL_scale, lastL_scale)
        outputs.append(out_val)

    # Plot a histogram of the outputs
    plt.hist(outputs, bins=10, edgecolor='black')
    plt.xlabel('Output')
    plt.ylabel('Frequency')
    plt.title('Histogram of Function Outputs')
    plt.show()

This shows random tensors will just throw it into not the right territory most of the time. If somehow it does nudge into the correct answer, those are rare statistically anomalies, where it likely just got close by chance.

# Test Variations Template

- Add o_update from different inputs of varying similiarity

- Vary scaling factors

In [None]:
### Enter custom variable names here (so no need to indiv change each var in upcoming fns)
cust_name_text = """Bob is short. Mary is"""
cust_name_ids = get_tokIDS(cust_name_text)
###

test_text = cust_name_text
test_ids = cust_name_ids

In [None]:
impt_layer = 22

In [None]:
### Enter custom variable names here
o_cust_name = get_o_update(test_ids, impt_layer)
###
o_upd =o_cust_name

Original top tokens at layer 22
0  short tall shorter skinny thin small a slim slender not
After subtracting mlp_22
0  tall short skinny shorter thin slender slim small taller lean


In [None]:
linear_update_layers(test_ids, o_upd, impt_layer, prevL_scale=6, nextL_scale=1, lastL_scale=0.0283)

+----------+------------------+-----------------+
| Layer    | Orig Top Token   | New Top Token   |
| Layer 18 | tall             | short           |
+----------+------------------+-----------------+
| Layer 19 | tall             | short           |
+----------+------------------+-----------------+
| Layer 20 | tall             | short           |
+----------+------------------+-----------------+
| Layer 21 | tall             | short           |
+----------+------------------+-----------------+
| Layer 22 | short            | tall            |
+----------+------------------+-----------------+
| Layer 23 | short            | skinny          |
+----------+------------------+-----------------+


In [None]:
measure_change_layers(test_ids, o_upd, impt_layer, prevL_scale=5, nextL_scale=5, lastL_scale=0.9)

0.8888888888888888

Create another set of ids, then add previous update to it. See if it makes analogies.

Optionally, run rand tensor comparisons:

In [None]:
# random_tensor = gen_rand_tensor(o_upd)
# measure_change_layers(test_ids, random_tensor, impt_layer, prevL_scale=1, nextL_scale=1, lastL_scale=0.1)

0.0

In [None]:
# upd_hist(test_ids, o_upd, impt_layer, prevL_scale=5, nextL_scale=5, lastL_scale=0.9)