In [1]:
import einops
import torch
import plotly.express as px
from datasets import load_from_disk
from data import create_prompt
from copy import deepcopy

from nnsight import LanguageModel

In [2]:
model = LanguageModel("mistralai/Mistral-7B-v0.1", device_map="auto", dispatch=True)
# model = LanguageModel("openai-community/gpt2", device_map="cuda:0")



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
ds = load_from_disk("datasets/value_at_index/v0.0")
ds = ds.filter(lambda e: (e['query_array_length'] == 10) and (e['num_few_shots'] == 1) and (e['query'] == 9))
ds = ds.map(create_prompt)

In [4]:
def corrupt_fn(e):
    e = deepcopy(e)
    e['query'] -= 1
    e['target'] = e['array'][e['query']]
    return e

# def corrupt_fn(e):
#     '''
#     Change the value
#     '''
#     e = deepcopy(e)
#     idx = e['query']
#     e['array'][idx] += 7 
#     e['array'][idx] %= 10
#     e['target'] = e['array'][idx]
#     return e

prompts = []
answers = []
for e in [ds[1], ds[7], ds[2], ds[4]]:
    corrupted_e = corrupt_fn(e)
    prompts.append(create_prompt(e)['prompt'])
    prompts.append(create_prompt(corrupted_e)['prompt'])
    answers.append([str(e['target']), str(corrupted_e['target'])])
    answers.append([str(corrupted_e['target']), str(e['target'])])

print(prompts)
print(answers)

['Output the value of an array at a given index.\na=[6, 7, 9, 7, 8]. a[1]=7\na=[6, 9, 3, 0, 5, 0, 8, 2, 1, 0]. a[9]=', 'Output the value of an array at a given index.\na=[6, 7, 9, 7, 8]. a[1]=7\na=[6, 9, 3, 0, 5, 0, 8, 2, 1, 0]. a[8]=', 'Output the value of an array at a given index.\na=[6, 7, 9, 7, 8]. a[1]=7\na=[8, 0, 1, 6, 7, 3, 5, 1, 8, 9]. a[9]=', 'Output the value of an array at a given index.\na=[6, 7, 9, 7, 8]. a[1]=7\na=[8, 0, 1, 6, 7, 3, 5, 1, 8, 9]. a[8]=', 'Output the value of an array at a given index.\na=[6, 7, 9, 7, 8]. a[1]=7\na=[2, 9, 7, 6, 2, 4, 1, 8, 3, 6]. a[9]=', 'Output the value of an array at a given index.\na=[6, 7, 9, 7, 8]. a[1]=7\na=[2, 9, 7, 6, 2, 4, 1, 8, 3, 6]. a[8]=', 'Output the value of an array at a given index.\na=[6, 7, 9, 7, 8]. a[1]=7\na=[5, 2, 8, 6, 2, 5, 5, 1, 6, 2]. a[9]=', 'Output the value of an array at a given index.\na=[6, 7, 9, 7, 8]. a[1]=7\na=[5, 2, 8, 6, 2, 5, 5, 1, 6, 2]. a[8]=']
[['0', '1'], ['1', '0'], ['9', '8'], ['8', '9'], ['6', 

In [5]:
# prompts = [
#     "When John and Mary went to the shops, John gave the bag to",
#     "When John and Mary went to the shops, Mary gave the bag to",
#     "When Tom and James went to the park, James gave the ball to",
#     "When Tom and James went to the park, Tom gave the ball to",
#     "When Dan and Sid went to the shops, Sid gave an apple to",
#     "When Dan and Sid went to the shops, Dan gave an apple to",
#     "After Martin and Amy went to the park, Amy gave a drink to",
#     "After Martin and Amy went to the park, Martin gave a drink to",
# ]

# answers = [
#     (" Mary", " John"),
#     (" John", " Mary"),
#     (" Tom", " James"),
#     (" James", " Tom"),
#     (" Dan", " Sid"),
#     (" Sid", " Dan"),
#     (" Martin", " Amy"),
#     (" Amy", " Martin"),
# ]

clean_tokens = model.tokenizer(prompts, return_tensors="pt")["input_ids"]

corrupted_tokens = clean_tokens[
    [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]
]

answer_token_indices = torch.tensor(
    [
        [model.tokenizer(answers[i][j])["input_ids"][0] for j in range(2)]
        for i in range(len(answers))
    ]
)

In [6]:
def get_logit_diff(logits, answer_token_indices=answer_token_indices):
    if len(logits.shape) == 3:
        logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()

clean_logits = model.trace(clean_tokens, trace=False).logits.cpu()
corrupted_logits = model.trace(corrupted_tokens, trace=False).logits.cpu()

CLEAN_BASELINE = get_logit_diff(clean_logits, answer_token_indices).item()
print(f"Clean logit diff: {CLEAN_BASELINE:.4f}")

CORRUPTED_BASELINE = get_logit_diff(corrupted_logits, answer_token_indices).item()
print(f"Corrupted logit diff: {CORRUPTED_BASELINE:.4f}")

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Clean logit diff: 0.0000
Corrupted logit diff: 0.0000


In [7]:
def ioi_metric(
    logits,
    answer_token_indices=answer_token_indices,
):
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (
        CLEAN_BASELINE - CORRUPTED_BASELINE
    )

print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")

Clean Baseline is 1: nan
Corrupted Baseline is 0: nan


In [8]:
clean_out = []
corrupted_out = []
corrupted_grads = []

with model.trace() as tracer:

    # with tracer.invoke(clean_tokens) as invoker_clean:

    #     for layer in model.transformer.h:
    #         attn_out = layer.attn.c_proj.input[0][0]
    #         clean_out.append(attn_out.save())

    # with tracer.invoke(corrupted_tokens) as invoker_corrupted:

    #     for layer in model.transformer.h:
    #         attn_out = layer.attn.c_proj.input[0][0]
    #         corrupted_out.append(attn_out.save())
    #         corrupted_grads.append(attn_out.grad.save())

    #     logits = model.lm_head.output.save()
    #     # Our metric uses tensors saved on cpu, so we
    #     # need to move the logits to cpu.
    #     value = ioi_metric(logits.cpu())
    #     value.backward()

    with tracer.invoke(clean_tokens) as invoker_clean:

        for layer in model.model.layers:
            attn_out = layer.self_attn.dense.input[0][0]
            clean_out.append(attn_out.save())

    with tracer.invoke(corrupted_tokens) as invoker_corrupted:

        for layer in model.model.layers:
            attn_out = layer.self_attn.dense.input[0][0]
            corrupted_out.append(attn_out.save())
            corrupted_grads.append(attn_out.grad.save())

        logits = model.lm_head.output.save()
        # Our metric uses tensors saved on cpu, so we
        # need to move the logits to cpu.
        value = ioi_metric(logits.cpu())
        value.backward()

AttributeError: 'MistralSdpaAttention' object has no attribute 'dense'

In [156]:
patching_results = []

for corrupted_grad, corrupted, clean, layer in zip(
    corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
):
    residual_attr = einops.reduce(
        corrupted_grad.value[:,-1,:] * (clean.value[:,-1,:] - corrupted.value[:,-1,:]),
        "batch (head dim) -> head",
        "sum",
        # head = 12,
        # dim = 64,
        head = 32,
        dim = 80,
    )
    patching_results.append(
        residual_attr.detach().cpu().numpy()
    )

In [157]:
fig = px.imshow(
    patching_results,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Patching Over Attention Heads"
)

fig.update_layout(
    xaxis_title="Head",
    yaxis_title="Layer"
)

fig.show()

In [11]:
patching_results = []

for corrupted_grad, corrupted, clean, layer in zip(
    corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
):

    residual_attr = einops.reduce(
        corrupted_grad.value * (clean.value - corrupted.value),
        "batch pos dim -> pos",
        "sum",
    )

    patching_results.append(
        residual_attr.detach().cpu().numpy()
    )

In [12]:
sent = "Output the value of an array at a given index.\na=[4, 0, 9, 2, 4, 0, 1, 0, 4, 0]. a[1]="
new_clean_tokens = model.tokenizer(sent)['input_ids']
new_clean_tokens = model.tokenizer.batch_decode(new_clean_tokens)
token_labels = [f"{token}_{index}" for index, token in enumerate(new_clean_tokens)]

fig = px.imshow(
    patching_results,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Patching Over Position",
    x=token_labels,
)

fig.update_layout(
    xaxis_title="Pos",
    yaxis_title="Layer"
)

fig.show()

ValueError: The length of the x vector must match the length of the second dimension of the img matrix.