In [3]:
!pip install torch transformers numpy



In [52]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

# Load the tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')

# Set model in evaluation mode
model.eval()

# Example input text
input_text = "What is the capital of France? A) London B) Paris C) Berlin D) Rome"

# Tokenize input
input_ids = tokenizer.encode(input_text, return_tensors='pt')




In [53]:
# Get model output (logits)
with torch.no_grad():
    outputs = model(input_ids)
logits = outputs.logits

# Get the logits for the last token (this is the "answer")
final_logits = logits[:, -1, :]

In [54]:
final_logits

tensor([[-78.3749, -80.8012, -83.3722,  ..., -94.4622, -91.0560, -77.0115]])

In [55]:
# list(model.named_modules())

In [57]:
activations = {}

# hook function to accept the layer/module name
def create_hook_fn(name):
    def hook_fn(module, input, output):
        # Store the activations using the passed 'name'
        activations[name] = output
    return hook_fn

# Register hooks for attention and MLP layers with the layer name
for name, module in model.named_modules():
    if 'attn' in name or 'mlp' in name:
        module.register_forward_hook(create_hook_fn(name))

# Run the forward pass again with hooks activated
with torch.no_grad():
    outputs = model(input_ids)

In [58]:
len(activations)

240

In [59]:
# # Print the full model structure to identify attention and MLP layers
# for name, module in model.named_modules():
#     print(name)

In [62]:
# activations

In [64]:
# print(activations.keys())

In [66]:
# Calculate direct contribution of a component
def compute_direct_contribution(model, input_ids, activations, component_name):
    # Backup the original activation
    original_activation = activations[component_name].clone()

    # Set activation to zero
    activations[component_name].zero_()

    # Forward pass with zeroed-out component
    with torch.no_grad():
        outputs = model(input_ids)
        logits_zeroed = outputs.logits[:, -1, :]

    # Compute the difference in logits (direct effect)
    direct_effect = logits - logits_zeroed

    # Restore the original activation
    activations[component_name] = original_activation

    return direct_effect


In [70]:
# New input for patching
patch_input_text = "What is the capital of Germany? A) London B) Berlin C) Paris D) Rome"
patch_input_ids = tokenizer.encode(patch_input_text, return_tensors='pt')

# Get patch activations
with torch.no_grad():
    patch_outputs = model(patch_input_ids)

def activation_patching(model, input_ids, activations, patch_activations, component_name):
    # Check if activation is a tuple and get the first element
    original_activation = activations[component_name]
    if isinstance(original_activation, tuple):
        original_activation = original_activation[0]  # Extract the first element (tensor)

    # Clone the activation
    original_activation = original_activation.clone()

    # Replace activation with the one from a different prompt (also handle tuple case)
    patched_activation = patch_activations[component_name]
    if isinstance(patched_activation, tuple):
        patched_activation = patched_activation[0]  # Extract the tensor

    # Patch the activation
    activations[component_name] = patched_activation.clone()

    # Run the forward pass with patched activations
    with torch.no_grad():
        outputs = model(input_ids)
        patched_logits = outputs.logits[:, -1, :]

    # Restore original activation
    activations[component_name] = original_activation

    return patched_logits


# Compute patched logits for each component
patched_logits = activation_patching(model, input_ids, activations, patch_outputs, 'transformer.h.0.attn')


KeyError: 'transformer.h.0.attn'

In [11]:
# Print all named modules to find the correct name
for name, module in model.named_modules():
    print(name)


transformer
transformer.wte
transformer.wpe
transformer.drop
transformer.h
transformer.h.0
transformer.h.0.ln_1
transformer.h.0.attn
transformer.h.0.attn.c_attn
transformer.h.0.attn.c_proj
transformer.h.0.attn.attn_dropout
transformer.h.0.attn.resid_dropout
transformer.h.0.ln_2
transformer.h.0.mlp
transformer.h.0.mlp.c_fc
transformer.h.0.mlp.c_proj
transformer.h.0.mlp.act
transformer.h.0.mlp.dropout
transformer.h.1
transformer.h.1.ln_1
transformer.h.1.attn
transformer.h.1.attn.c_attn
transformer.h.1.attn.c_proj
transformer.h.1.attn.attn_dropout
transformer.h.1.attn.resid_dropout
transformer.h.1.ln_2
transformer.h.1.mlp
transformer.h.1.mlp.c_fc
transformer.h.1.mlp.c_proj
transformer.h.1.mlp.act
transformer.h.1.mlp.dropout
transformer.h.2
transformer.h.2.ln_1
transformer.h.2.attn
transformer.h.2.attn.c_attn
transformer.h.2.attn.c_proj
transformer.h.2.attn.attn_dropout
transformer.h.2.attn.resid_dropout
transformer.h.2.ln_2
transformer.h.2.mlp
transformer.h.2.mlp.c_fc
transformer.h.2.mlp