In [1]:
import einops
import torch
import plotly.express as px

from nnsight import LanguageModel

## Single-Logit Attribution Patching

In [99]:
model = LanguageModel("gpt2", device_map="cuda:0", dispatch=True)


`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.



In [100]:
prompts = [
    "Restaurant review: I loved the pasta but the service was",
    "Restaurant review: I loved the pasta and the service was",
    "Restaurant review: I hated the pasta but the service was",
    "Restaurant review: I hated the pasta and the service was",
    "Restaurant review: I loved the salad but my waiter was",
    "Restaurant review: I loved the salad and my waiter was",
    "Restaurant review: I hated the salad but my waiter was",
    "Restaurant review: I hated the salad and my waiter was",
]

# "When John and Mary went to the store, John bought candy for " ... "Mary"
# "Mary" - correct
# "John" - incorrect

answers = [
    (" terrible", " wonderful"),
    (" wonderful", " terrible"),
    (" wonderful", " terrible"),
    (" terrible", " wonderful"),
    (" rude", " kind"),
    (" kind", " rude"),
    (" kind", " rude"),
    (" rude", " kind"),
]

clean_tokens = model.tokenizer([
    prompt for prompt, answer in zip(prompts, answers)
], return_tensors="pt", padding=True)["input_ids"]

corrupted_tokens = clean_tokens[
    [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]
]

answer_token_indices = torch.tensor(
    [
        [model.tokenizer(answers[i][j])["input_ids"][0] for j in range(2)]
        for i in range(len(answers))
    ]
)

In [101]:
def get_logit_diff(logits, answer_token_indices=answer_token_indices):
    if len(logits.shape) == 3:
        logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()

clean_logits = model.trace(clean_tokens, trace=False).logits.cpu()
corrupted_logits = model.trace(corrupted_tokens, trace=False).logits.cpu()

CLEAN_BASELINE = get_logit_diff(clean_logits, answer_token_indices).item()
print(f"Clean logit diff: {CLEAN_BASELINE:.4f}")

CORRUPTED_BASELINE = get_logit_diff(corrupted_logits, answer_token_indices).item()
print(f"Corrupted logit diff: {CORRUPTED_BASELINE:.4f}")

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Clean logit diff: 2.3721
Corrupted logit diff: -2.3721


In [102]:
def ioi_metric(
    logits,
    answer_token_indices=answer_token_indices,
):
    # normalizes logit diff between 0 and 1
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (
        CLEAN_BASELINE - CORRUPTED_BASELINE
    )

print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")

Clean Baseline is 1: 1.0000
Corrupted Baseline is 0: 0.0000


In [105]:
clean_out = []
corrupted_out = []
corrupted_grads = []

with model.trace() as tracer:
    with tracer.invoke(clean_tokens) as invoker_clean_correct:
        for layer in model.transformer.h:
            # attn_out = layer.attn.c_proj.input[0][0]
            residual_out = layer.output[0]
            clean_out.append(residual_out.save())
    
    with tracer.invoke(corrupted_tokens) as invoker_corrupted_wrong:
        for layer in model.transformer.h:
            # attn_out = layer.attn.c_proj.input[0][0]
            residual_out = layer.output[0]
            corrupted_out.append(residual_out.save())
            corrupted_grads.append(residual_out.grad.save())
        corrupted_logits = model.lm_head.output.save()
        # Our metric uses tensors saved on cpu, so we
        # need to move the logits to cpu.
        value = ioi_metric(corrupted_logits.cpu())
        value.backward()

In [106]:
patching_results = []

for corrupted_grad, corrupted, clean in zip(
    corrupted_grads, corrupted_out, clean_out
):

    residual_attr = einops.reduce(
        (corrupted_grad.value) * (clean.value - corrupted.value),
        "batch pos dim -> pos",
        "sum",
    )

    patching_results.append(
        residual_attr.detach().cpu().numpy()
    )

token_labels = [f"{i}. {model.tokenizer.decode(clean_tokens[0][i])}" for i in range(len(clean_tokens[0]))]

fig = px.imshow(
    patching_results,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Patching Over Position",
    # replace x axis with tokens at position
    x=token_labels
)

fig.update_layout(
    # rotate x axis labels
    xaxis_tickangle=-45,
    xaxis_title="Position",
    yaxis_title="Layer"
)

fig.show()

## Single-Logit Activation Patching

In [8]:
from tqdm import trange
from nnsight import util
from nnsight.tracing.Proxy import Proxy

N_LAYERS = len(model.transformer.h)

with torch.no_grad():
    with model.trace() as tracer:
        with tracer.invoke(clean_tokens) as invoker:
            clean_hs = [
                model.transformer.h[layer_idx].output[0].save()
                for layer_idx in range(N_LAYERS)
            ]

# fetch actual values from the proxy objects
clean_hs = util.apply(clean_hs, lambda x: x.value, Proxy)

patching_results = []
for layer_idx in trange(N_LAYERS, desc="Layer loop"):
    _patching_results = []
    for token_idx in range(clean_tokens.shape[1]):
        # Patching corrupted run at given layer and token
        with torch.no_grad():
            with model.trace() as tracer:
                with tracer.invoke(corrupted_tokens) as invoker:
                    # Apply the patch from the clean hidden states to the corrupted hidden states.
                    model.transformer.h[layer_idx].output[0].t[token_idx] = clean_hs[layer_idx][..., token_idx, :]

                    patched_logits = model.lm_head.output.cpu().save()
                    patched_result = ioi_metric(patched_logits).item().save()
                _patching_results.append(patched_result)
    patching_results.append(_patching_results)

Layer loop: 100%|██████████| 12/12 [01:25<00:00,  7.17s/it]


In [9]:
from nnsight import util
from nnsight.tracing.Proxy import Proxy

patching_results = util.apply(patching_results, lambda x: x.value, Proxy)

token_labels = [f"{i}. {model.tokenizer.decode(clean_tokens[0][i])}" for i in range(len(clean_tokens[0]))]

fig = px.imshow(
    patching_results,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Patching Over Position",
    # replace x axis with tokens at position
    x=token_labels
)

fig.update_layout(
    # rotate x axis labels
    xaxis_tickangle=-45,
    xaxis_title="Position",
    yaxis_title="Layer"
)

fig.show()

## Generation Attribution Patching

In [155]:
prompts = [
    "Restaurant review: I loved the pasta but the service was",
    "Restaurant review: I loved the pasta and the service was",
    "Restaurant review: I hated the pasta but the service was",
    "Restaurant review: I hated the pasta and the service was",
    "Restaurant review: I loved the salad but my waiter was",
    "Restaurant review: I loved the salad and my waiter was",
    "Restaurant review: I hated the salad but my waiter was",
    "Restaurant review: I hated the salad and my waiter was",
]

answers = [
    (" terrible", " wonderful"),
    (" wonderful", " terrible"),
    (" wonderful", " terrible"),
    (" terrible", " wonderful"),
    (" rude", " kind"),
    (" kind", " rude"),
    (" kind", " rude"),
    (" rude", " kind"),
]

clean_correct_tokens = model.tokenizer([
    prompt + answer[0] for prompt, answer in zip(prompts, answers)
], return_tensors="pt", padding=True)["input_ids"]

clean_wrong_tokens = model.tokenizer([
    prompt + answer[1] for prompt, answer in zip(prompts, answers)
], return_tensors="pt", padding=True)["input_ids"]

corrupted_correct_tokens = clean_correct_tokens[
    [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_correct_tokens))]
]

corrupted_wrong_tokens = clean_wrong_tokens[
    [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_wrong_tokens))]
]

In [156]:
print('Factual (clean)')
print('  Correct:\t', model.tokenizer.decode(clean_correct_tokens[0]))
print('  Incorrect:\t', model.tokenizer.decode(clean_wrong_tokens[0]))

print('Counterfactual (corrupted)')
print('  Correct:\t', model.tokenizer.decode(corrupted_correct_tokens[0]))
print('  Incorrect:\t',model.tokenizer.decode(corrupted_wrong_tokens[0]))

Factual (clean)
  Correct:	 Restaurant review: I loved the pasta but the service was terrible
  Incorrect:	 Restaurant review: I loved the pasta but the service was wonderful
Counterfactual (corrupted)
  Correct:	 Restaurant review: I loved the pasta and the service was wonderful
  Incorrect:	 Restaurant review: I loved the pasta and the service was terrible


In [157]:
def get_ce_diff_from_logits(correct_logits, correct_labels, wrong_logits, wrong_labels):
    """
    Compute the difference between the cross-entropy loss of the correct and wrong logits.
    """
    correct_logits = correct_logits[..., :-1, :].contiguous()
    correct_labels = correct_labels[..., 1:].contiguous()
    wrong_logits = wrong_logits[..., :-1, :].contiguous()
    wrong_labels = wrong_labels[..., 1:].contiguous()
    correct_loss = torch.nn.CrossEntropyLoss()(correct_logits.view(-1, correct_logits.size(-1)), correct_labels.view(-1))
    wrong_loss = torch.nn.CrossEntropyLoss()(wrong_logits.view(-1, wrong_logits.size(-1)), wrong_labels.view(-1))
    # we want to minimize the correct loss and maximize the wrong loss
    return wrong_loss - correct_loss

def get_mean_logit_diff(correct_logits, correct_labels, wrong_logits, wrong_labels):
    """
    Compute the difference between the mean logits of the correct and wrong labels.
    """
    correct_logits = correct_logits[..., :-1, :].contiguous()
    correct_labels = correct_labels[..., 1:].contiguous()
    wrong_logits = wrong_logits[..., :-1, :].contiguous()
    wrong_labels = wrong_labels[..., 1:].contiguous()
    correct_logits = correct_logits.gather(2, correct_labels.unsqueeze(2)).squeeze(2)
    wrong_logits = wrong_logits.gather(2, wrong_labels.unsqueeze(2)).squeeze(2)
    return correct_logits.mean() - wrong_logits.mean() 

def get_ce_diff_from_model(model, correct_labels, wrong_labels):
    """
    Compute the difference between the cross-entropy loss of the correct and wrong labels.
    """
    correct_loss = model.trace(correct_labels, labels=correct_labels, trace=False).loss
    wrong_loss = model.trace(wrong_labels, labels=wrong_labels, trace=False).loss
    return wrong_loss - correct_loss

clean_correct_logits = model.trace(clean_correct_tokens, trace=False).logits.cpu()
clean_wrong_logits = model.trace(clean_wrong_tokens, trace=False).logits.cpu()
corrupted_correct_logits = model.trace(corrupted_correct_tokens, trace=False).logits.cpu()
corrupted_wrong_logits = model.trace(corrupted_wrong_tokens, trace=False).logits.cpu()

CLEAN_BASELINE = get_mean_logit_diff(
    clean_correct_logits, clean_correct_tokens, clean_wrong_logits, clean_wrong_tokens
).item()
# CLEAN_BASELINE = get_ce_diff_from_model(model, clean_correct_tokens, clean_wrong_tokens).item()

CORRUPTED_BASELINE = get_mean_logit_diff(
    corrupted_wrong_logits, corrupted_wrong_tokens, corrupted_correct_logits, corrupted_correct_tokens
).item()

print(f"Clean logit diff: {CLEAN_BASELINE:.4f}")
print(f"Corrupted logit diff: {CORRUPTED_BASELINE:.4f}")

Clean logit diff: 0.1825
Corrupted logit diff: -0.1825


In [158]:
def ioi_metric(
    correct_logits,
    correct_labels,
    wrong_logits,
    wrong_labels,
):
    # normalizes logit diff between 0 and 1
    return (get_mean_logit_diff(correct_logits, correct_labels, wrong_logits, wrong_labels) - CORRUPTED_BASELINE) / (
        CLEAN_BASELINE - CORRUPTED_BASELINE
    )

print(
    "Clean Baseline is 1:",
    f"{ioi_metric(clean_correct_logits, clean_correct_tokens, clean_wrong_logits, clean_wrong_tokens).item():.4f}"
)
print(
    "Corrupted Baseline is 0:",
    f"{ioi_metric(corrupted_wrong_logits, corrupted_wrong_tokens, corrupted_correct_logits, corrupted_correct_tokens).item():.4f}"
)

Clean Baseline is 1: 1.0000
Corrupted Baseline is 0: 0.0000


In [159]:
clean_correct_out = []
clean_wrong_out = []
corrupted_correct_out = []
corrupted_correct_grads = []
corrupted_wrong_out = []
corrupted_wrong_grads = []

with model.trace() as tracer:
    with tracer.invoke(clean_correct_tokens) as invoker_clean_correct:
        for layer in model.transformer.h:
            # attn_out = layer.attn.c_proj.input[0][0]
            residual_out = layer.output[0]
            clean_correct_out.append(residual_out.save())
    
    with tracer.invoke(clean_wrong_tokens) as invoker_clean_wrong:
        for layer in model.transformer.h:
            # attn_out = layer.attn.c_proj.input[0][0]
            residual_out = layer.output[0]
            clean_wrong_out.append(residual_out.save())

    with tracer.invoke(corrupted_correct_tokens) as invoker_corrupted_correct:
        for layer in model.transformer.h:
            # attn_out = layer.attn.c_proj.input[0][0]
            residual_out = layer.output[0]
            corrupted_correct_out.append(residual_out.save())
            corrupted_correct_grads.append(residual_out.grad.save())
        corrupted_correct_logits = model.lm_head.output.save()
    
    with tracer.invoke(corrupted_wrong_tokens) as invoker_corrupted_wrong:
        for layer in model.transformer.h:
            # attn_out = layer.attn.c_proj.input[0][0]
            residual_out = layer.output[0]
            corrupted_wrong_out.append(residual_out.save())
            corrupted_wrong_grads.append(residual_out.grad.save())
        corrupted_wrong_logits = model.lm_head.output.save()

    # Our metric uses tensors saved on cpu, so we
    # need to move the logits to cpu.
    value = get_mean_logit_diff(
        corrupted_wrong_logits.cpu()[..., -2:, :], # DEBUG: look at last two tokens (after prompt)
        corrupted_wrong_tokens[..., -2:], 
        corrupted_correct_logits.cpu()[..., -2:, :], 
        corrupted_correct_tokens[..., -2:]
    )
    value.backward()

In [160]:
patching_results_wrong = []

for corrupted_correct_grad, corrupted_wrong_grad, corrupted, clean in zip(
    corrupted_correct_grads, corrupted_wrong_grads, corrupted_wrong_out, clean_wrong_out
):

    residual_attr = einops.reduce(
        (corrupted_wrong_grad.value) * (clean.value - corrupted.value),
        "batch pos dim -> pos",
        "sum",
    )

    patching_results_wrong.append(
        residual_attr.detach().cpu().numpy()
    )

In [161]:
fig = px.imshow(
    patching_results_wrong,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Patching Over Position",
    # replace x axis with tokens at position
    x = [f"{i}. {model.tokenizer.decode(clean_correct_tokens[0][i])}" for i in range(len(clean_correct_tokens[0]))]
)


fig.update_layout(
    # rotate x axis labels
    xaxis_tickangle=-45,
    xaxis_title="Position",
    yaxis_title="Layer"
)

fig.show()

In [162]:
patching_results_correct = []

for corrupted_grad, corrupted, clean in zip(
    corrupted_correct_grads, corrupted_correct_out, clean_correct_out
):

    residual_attr = einops.reduce(
        corrupted_grad.value * (clean.value - corrupted.value),
        "batch pos dim -> pos",
        "sum",
    )

    patching_results_correct.append(
        residual_attr.detach().cpu().numpy()
    )

In [163]:
fig = px.imshow(
    patching_results_correct,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Patching Over Position",
    # replace x axis with tokens at position
    x = [f"{i}. {model.tokenizer.decode(clean_correct_tokens[0][i])}" for i in range(len(clean_correct_tokens[0]))]
)


fig.update_layout(
    # rotate x axis labels
    xaxis_tickangle=-45,
    xaxis_title="Position",
    yaxis_title="Layer"
)

fig.show()

In [164]:
# average over gradients / patching results??
patching_results = [(c + w) / 2. for c, w in zip(patching_results_correct, patching_results_wrong)]

fig = px.imshow(
    patching_results,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Patching Over Position",
    # replace x axis with tokens at position
    x = [f"{i}. {model.tokenizer.decode(clean_correct_tokens[0][i])}" for i in range(len(clean_correct_tokens[0]))]
)


fig.update_layout(
    # rotate x axis labels
    xaxis_tickangle=-45,
    xaxis_title="Position",
    yaxis_title="Layer"
)

fig.show()

## Generation Activation Patching

In [141]:
from tqdm import trange
from nnsight import util
from nnsight.tracing.Proxy import Proxy

N_LAYERS = len(model.transformer.h)

with torch.no_grad():
    with model.trace() as tracer:
        # correct: "Restaurant review: I loved the pasta but the service | ..."
        # incorrect: "Restaurant review: I loved the pasta but the service | ..."
        # use clean correct tokens for clean states (same as clean wrong tokens up to end of prompt) 
        with tracer.invoke(clean_correct_tokens) as invoker:
            clean_hs = [
                model.transformer.h[layer_idx].output[0].save()
                for layer_idx in range(N_LAYERS)
            ]

# fetch actual values from the proxy objects
clean_hs = util.apply(clean_hs, lambda x: x.value, Proxy)

patching_results = []
for layer_idx in trange(N_LAYERS, desc="Layer loop"):
    _patching_results = []
    for token_idx in range(clean_correct_tokens.shape[1]):
        with torch.no_grad():
            # Patching corrupted run at given layer and token
            with model.trace() as tracer:
                with tracer.invoke(corrupted_correct_tokens) as invoker:
                    # Apply the patch from the clean hidden states to the corrupted hidden states.
                    model.transformer.h[layer_idx].output[0].t[token_idx] = clean_hs[layer_idx][..., token_idx, :]

                    patched_correct_logits = model.lm_head.output.cpu().save()
                with tracer.invoke(corrupted_wrong_tokens) as invoker:
                    # Apply the patch from the clean hidden states to the corrupted hidden states.
                    model.transformer.h[layer_idx].output[0].t[token_idx] = clean_hs[layer_idx][..., token_idx, :]

                    patched_wrong_logits = model.lm_head.output.cpu().save()
                
                patched_result = ioi_metric(
                    patched_wrong_logits, 
                    corrupted_wrong_tokens, 
                    patched_correct_logits, 
                    corrupted_correct_tokens
                ).item().save()
                _patching_results.append(patched_result)
    patching_results.append(_patching_results)

Layer loop: 100%|██████████| 12/12 [07:01<00:00, 35.16s/it]


In [143]:
from nnsight import util
from nnsight.tracing.Proxy import Proxy

patching_results = util.apply(patching_results, lambda x: x.value, Proxy)

token_labels = [f"{i}. {model.tokenizer.decode(clean_correct_tokens[0][i])}" for i in range(len(clean_correct_tokens[0]))]
token_labels[-1] = "..."

fig = px.imshow(
    patching_results,
    color_continuous_scale="RdBu",
    color_continuous_midpoint=0.0,
    title="Patching Over Position",
    # replace x axis with tokens at position
    x=token_labels
)

fig.update_layout(
    # rotate x axis labels
    xaxis_tickangle=-45,
    xaxis_title="Position",
    yaxis_title="Layer"
)

fig.show()