In [None]:
from nnsight import LanguageModel, util
from nnsight.tracing.Proxy import Proxy
from nnsight.models import NNsightModel
model = LanguageModel('gpt2', device_map='cuda:0')


torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.



In [None]:
with model.generate(max_new_tokens=1) as generator:
    with generator.invoke("The famous computer scientist Alan") as invoker:
        logits = model.lm_head.output.save()


next_token_logits = logits.value[0,-1]
next_token_prediction = next_token_logits.argmax()
next_word_prediction = model.tokenizer.decode(next_token_prediction)
print(next_word_prediction)

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


 Turing


In [None]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [None]:
prompt = 'The famous computer scientist Alan'
with model.forward() as runner:
    with runner.invoke(prompt) as invoker:
        clean_tokens = invoker.input['input_ids'][0]
        clean_hs = [
            (model.transformer.h[layer_idx].output[1], model.transformer.h[layer_idx].output[0])
            for layer_idx in range(len(model.transformer.h))
        ]

for key, layer in clean_hs:
    print(key, layer.shape)


<nnsight.intervention.InterventionProxy object at 0x7fb9033178d0> torch.Size([1, 5, 768])
<nnsight.intervention.InterventionProxy object at 0x7fb9047932d0> torch.Size([1, 5, 768])
<nnsight.intervention.InterventionProxy object at 0x7fb910b12b50> torch.Size([1, 5, 768])
<nnsight.intervention.InterventionProxy object at 0x7fb900b0ca50> torch.Size([1, 5, 768])
<nnsight.intervention.InterventionProxy object at 0x7fb900bd7750> torch.Size([1, 5, 768])
<nnsight.intervention.InterventionProxy object at 0x7fb903de1cd0> torch.Size([1, 5, 768])
<nnsight.intervention.InterventionProxy object at 0x7fb900bd5f90> torch.Size([1, 5, 768])
<nnsight.intervention.InterventionProxy object at 0x7fb900bd6190> torch.Size([1, 5, 768])
<nnsight.intervention.InterventionProxy object at 0x7fb900bd6e90> torch.Size([1, 5, 768])
<nnsight.intervention.InterventionProxy object at 0x7fb900bdbc50> torch.Size([1, 5, 768])
<nnsight.intervention.InterventionProxy object at 0x7fb904788050> torch.Size([1, 5, 768])
<nnsight.i

In [None]:
clean_prompt = "John and Mary went to the store, Mary gave a bottle to"
corrupted_prompt = "John and Mary went to the store, John gave a bottle to"


correct_index = model.tokenizer(" John")["input_ids"][0]
incorrect_index = model.tokenizer(" Mary")["input_ids"][0]


# Enter nnsight tracing context
with model.forward() as runner:

    # Clean run
    with runner.invoke(clean_prompt) as invoker:
        clean_tokens = invoker.input["input_ids"][0]

        # Get hidden states of all layers in the network.
        # We index the output at 0 because it's a tuple where the first index is the hidden state.
        # No need to call .save() as we don't need the values after the run, just within the experiment run.
        clean_hs = [
            model.transformer.h[layer_idx].output[0]
            for layer_idx in range(len(model.transformer.h))
        ]

        # Get logits from the lm_head.
        clean_logits = model.lm_head.output

        # Calculate the difference between the correct answer and incorrect answer for the clean run and save it.
        clean_logit_diff = (
            clean_logits[0, -1, correct_index] - clean_logits[0, -1, incorrect_index]
        ).save()

    # Corrupted run
    with runner.invoke(corrupted_prompt) as invoker:
        corrupted_logits = model.lm_head.output

        # Calculate the difference between the correct answer and incorrect answer for the corrupted run and save it.
        corrupted_logit_diff = (
            corrupted_logits[0, -1, correct_index]
            - corrupted_logits[0, -1, incorrect_index]
        ).save()

    ioi_patching_results = []

    # Iterate through all the layers
    for layer_idx in range(len(model.transformer.h)):
        _ioi_patching_results = []

        # Iterate through all tokens
        for token_idx in range(len(clean_tokens)):

            # Patching corrupted run at given layer and token
            with runner.invoke(corrupted_prompt) as invoker:

                # Apply the patch from the clean hidden states to the corrupted hidden states.
                model.transformer.h[layer_idx].output[0].t[token_idx] = clean_hs[
                    layer_idx
                ].t[token_idx]

                patched_logits = model.lm_head.output

                patched_logit_diff = (
                    patched_logits[0, -1, correct_index]
                    - patched_logits[0, -1, incorrect_index]
                )

                # Calculate the improvement in the correct token after patching.
                patched_result = (patched_logit_diff - corrupted_logit_diff) / (
                    clean_logit_diff - corrupted_logit_diff
                )

                _ioi_patching_results.append(patched_result.save())

        ioi_patching_results.append(_ioi_patching_results)


You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [None]:
print(f"Clean logit difference: {clean_logit_diff.value:.3f}")
print(f"Corrupted logit difference: {corrupted_logit_diff.value:.3f}")

ioi_patching_results = util.apply(ioi_patching_results, lambda x: x.value.item(), Proxy)

clean_tokens = [model.tokenizer.decode(token) for token in clean_tokens]
token_labels = [f"{token}_{index}" for index, token in enumerate(clean_tokens)]

fig = px.imshow(
    ioi_patching_results,
    color_continuous_midpoint=0.0,
    color_continuous_scale="RdBu",
    labels={"x": "Position", "y": "Layer"},
    x=token_labels,
    title="Normalized Logit Difference After Patching Residual Stream on the IOI Task",
)

fig.show()

Clean logit difference: -0.495
Corrupted logit difference: -4.687
