In [13]:
import sys
import os

# Add the parent directory to sys.path
module_path = os.path.abspath(os.path.join('..', '..'))
if module_path not in sys.path:
    sys.path.append(module_path)

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [14]:
import einops
import torch
import plotly.express as px
from profiler import Profiler
from nnsight import LanguageModel

In [15]:
profiler = Profiler()

# Print available devices
devices = profiler.available_devices()
for device_name, device_str in devices.items():
    print(f"Device: {device_name}, String: {device_str}")

Device: cpu, String: cpu
Device: MPS, String: mps


In [16]:
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)


`resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.



In [17]:
# check device of the model
print(model.device)

mps:0


In [18]:
prompts = [
    "When John and Mary went to the shops, John gave the bag to",
    "When John and Mary went to the shops, Mary gave the bag to",
    "When Tom and James went to the park, James gave the ball to",
    "When Tom and James went to the park, Tom gave the ball to",
    "When Dan and Sid went to the shops, Sid gave an apple to",
    "When Dan and Sid went to the shops, Dan gave an apple to",
    "After Martin and Amy went to the park, Amy gave a drink to",
    "After Martin and Amy went to the park, Martin gave a drink to",
]

answers = [
    (" Mary", " John"),
    (" John", " Mary"),
    (" Tom", " James"),
    (" James", " Tom"),
    (" Dan", " Sid"),
    (" Sid", " Dan"),
    (" Martin", " Amy"),
    (" Amy", " Martin"),
]

clean_tokens = model.tokenizer(prompts, return_tensors="pt")["input_ids"]

corrupted_tokens = clean_tokens[
    [(i + 1 if i % 2 == 0 else i - 1) for i in range(len(clean_tokens))]
]

answer_token_indices = torch.tensor(
    [
        [model.tokenizer(answers[i][j])["input_ids"][0] for j in range(2)]
        for i in range(len(answers))
    ]
)

In [19]:
@profiler.log_profile
def get_logit_diff(logits, answer_token_indices=answer_token_indices):
    if len(logits.shape) == 3:
        logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()

@profiler.log_profile
def trace_model(tokens):
    return model.trace(tokens, trace=False).logits.cpu()

@profiler.log_profile
def compute_baselines():
    clean_logits = trace_model(clean_tokens)
    corrupted_logits = trace_model(corrupted_tokens)

    clean_baseline = get_logit_diff(clean_logits, answer_token_indices).item()
    print(f"Clean logit diff: {clean_baseline:.4f}")

    corrupted_baseline = get_logit_diff(corrupted_logits, answer_token_indices).item()
    print(f"Corrupted logit diff: {corrupted_baseline:.4f}")
    
    return clean_baseline, corrupted_baseline, clean_logits, corrupted_logits

clean_baseline, corrupted_baseline, clean_logits, corrupted_logits = compute_baselines()

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Memory used by trace_model: 17.78 MB
Peak memory usage during trace_model: 37.47 MB
Time taken by trace_model: 0.5664 seconds
Memory used by trace_model: -11.05 MB
Peak memory usage during trace_model: 0.94 MB
Time taken by trace_model: 0.7596 seconds
Memory used by get_logit_diff: -3.14 MB
Peak memory usage during get_logit_diff: 0.72 MB
Time taken by get_logit_diff: 1.2195 seconds
Clean logit diff: 2.8138
Memory used by get_logit_diff: 0.23 MB
Peak memory usage during get_logit_diff: 0.23 MB
Time taken by get_logit_diff: 1.1589 seconds
Corrupted logit diff: -2.8138
Memory used by compute_baselines: 8.89 MB
Peak memory usage during compute_baselines: 37.67 MB
Memory used by trace_model: 7.02 MB
Peak memory usage during trace_model: 7.02 MB
Time taken by trace_model: 0.7521 seconds
Memory used by trace_model: 1.06 MB
Peak memory usage during trace_model: 1.06 MB
Time taken by trace_model: 0.7303 seconds
Memory used by get_logit_diff: 0.00 MB
Peak memory usage during get_logit_diff: 0.0

In [20]:
@profiler.log_profile
def ioi_metric(logits, answer_token_indices=answer_token_indices):
    return (get_logit_diff(logits, answer_token_indices) - corrupted_baseline) / (
        clean_baseline - corrupted_baseline
    )

print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")


Memory used by get_logit_diff: 0.00 MB
Peak memory usage during get_logit_diff: 0.00 MB
Time taken by get_logit_diff: 1.1552 seconds
Memory used by ioi_metric: 0.06 MB
Peak memory usage during ioi_metric: 0.06 MB
Memory used by get_logit_diff: 0.00 MB
Peak memory usage during get_logit_diff: 0.00 MB
Time taken by get_logit_diff: 1.1683 seconds
Time taken by ioi_metric: 2.5281 seconds
Clean Baseline is 1: 1.0000
Memory used by get_logit_diff: 0.00 MB
Peak memory usage during get_logit_diff: 0.00 MB
Time taken by get_logit_diff: 1.1896 seconds
Memory used by ioi_metric: 0.00 MB
Peak memory usage during ioi_metric: 0.00 MB
Memory used by get_logit_diff: 0.38 MB
Peak memory usage during get_logit_diff: 0.38 MB
Time taken by get_logit_diff: 1.1886 seconds
Time taken by ioi_metric: 2.5721 seconds
Corrupted Baseline is 0: 0.0000


# Over Components

In [21]:
@profiler.log_profile
def trace_attention():
    clean_out = []
    corrupted_out = []
    corrupted_grads = []

    with model.trace() as tracer:

        with tracer.invoke(clean_tokens) as invoker_clean:
            for layer in model.transformer.h:
                attn_out = layer.attn.c_proj.input[0][0]
                clean_out.append(attn_out.save())

        with tracer.invoke(corrupted_tokens) as invoker_corrupted:
            for layer in model.transformer.h:
                attn_out = layer.attn.c_proj.input[0][0]
                corrupted_out.append(attn_out.save())
                corrupted_grads.append(attn_out.grad.save())

            logits = model.lm_head.output.save()
            value = ioi_metric(logits.cpu())
            value.backward()

    return clean_out, corrupted_out, corrupted_grads

clean_out, corrupted_out, corrupted_grads = trace_attention()


Memory used by get_logit_diff: 0.45 MB
Peak memory usage during get_logit_diff: 0.45 MB
Time taken by get_logit_diff: 0.5507 seconds
Memory used by ioi_metric: 0.59 MB
Peak memory usage during ioi_metric: 0.59 MB
Memory used by get_logit_diff: 0.12 MB
Peak memory usage during get_logit_diff: 0.12 MB
Time taken by get_logit_diff: 0.5698 seconds
Time taken by ioi_metric: 1.3142 seconds
Memory used by trace_attention: 76.56 MB
Peak memory usage during trace_attention: 76.56 MB
Memory used by get_logit_diff: 4.08 MB
Peak memory usage during get_logit_diff: 4.08 MB
Time taken by get_logit_diff: 0.5740 seconds
Memory used by ioi_metric: 8.31 MB
Peak memory usage during ioi_metric: 8.31 MB
Memory used by get_logit_diff: 2.11 MB
Peak memory usage during get_logit_diff: 2.11 MB
Time taken by get_logit_diff: 0.5951 seconds
Time taken by ioi_metric: 1.3889 seconds
Time taken by trace_attention: 3.8859 seconds


In [22]:
@profiler.log_profile
def patch_attention_heads():
    patching_results = []

    for corrupted_grad, corrupted, clean, layer in zip(
        corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
    ):
        residual_attr = einops.reduce(
            corrupted_grad.value[:,-1,:] * (clean.value[:,-1,:] - corrupted.value[:,-1,:]),
            "batch (head dim) -> head",
            "sum",
            head=12,
            dim=64,
        )
        patching_results.append(residual_attr.detach().cpu().numpy())

    fig = px.imshow(
        patching_results,
        color_continuous_scale="RdBu",
        color_continuous_midpoint=0.0,
        title="Patching Over Attention Heads"
    )
    fig.update_layout(
        xaxis_title="Head",
        yaxis_title="Layer"
    )
    fig.show()

patch_attention_heads()

Memory used by patch_attention_heads: -44.05 MB
Peak memory usage during patch_attention_heads: -44.05 MB


Time taken by patch_attention_heads: 0.5163 seconds


# Over Position

In [23]:
@profiler.log_profile
def patch_positions():
    patching_results = []

    for corrupted_grad, corrupted, clean, layer in zip(
        corrupted_grads, corrupted_out, clean_out, range(len(clean_out))
    ):
        residual_attr = einops.reduce(
            corrupted_grad.value * (clean.value - corrupted.value),
            "batch pos dim -> pos",
            "sum",
        )
        patching_results.append(residual_attr.detach().cpu().numpy())

    fig = px.imshow(
        patching_results,
        color_continuous_scale="RdBu",
        color_continuous_midpoint=0.0,
        title="Patching Over Position"
    )
    fig.update_layout(
        xaxis_title="Position",
        yaxis_title="Layer"
    )
    fig.show()

patch_positions()

Memory used by patch_positions: -28.77 MB
Peak memory usage during patch_positions: 1.31 MB


Time taken by patch_positions: 0.4708 seconds
