In [7]:
import torch

# 1. Load the Model

In [52]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel

tokenizer = GPT2Tokenizer.from_pretrained("gpt2-xl")
model = GPT2LMHeadModel.from_pretrained("gpt2-xl")

# 1 . Factual Prompt

In [53]:
# The factual prompt
prompt = "The Space Needle is located in the city of"

# Tokenize the input prompt
inputs = tokenizer(prompt, return_tensors="pt")
# inputs

In [54]:
# # Token IDs
# input_ids = inputs['input_ids'][0]  # Get the token IDs

# # Convert token IDs to actual tokens
# tokens = tokenizer.convert_ids_to_tokens(input_ids)

# # Print the tokens
# print(tokens)

In [55]:
# Forward pass through the model to get the outputs
with torch.no_grad():
    outputs = model(**inputs)

# # Print the logits (output predictions)
# print(outputs.logits)

In [56]:
# Function to hook and capture only the hidden states (first element of the output tuple)
hidden_states_clean = []

# Hook function to capture clean hidden states
def hook_fn_clean(module, input, output):
    hidden_states_clean.append(output[0])

# Register hooks to capture hidden states for each layer
hooks_clean = []
for i in range(model.config.n_layer):
    hooks_clean.append(model.transformer.h[i].register_forward_hook(hook_fn_clean))

# Run the clean model pass
with torch.no_grad():
    outputs_clean = model(**inputs)

# Remove hooks after the clean run
for hook in hooks_clean:
    hook.remove()

# Now hidden_states contains activations for all layers
print(f"Number of layers: {len(hidden_states_clean)}")
print(f"Shape of hidden states from layer 1: {hidden_states_clean[0].shape}")


Number of layers: 48
Shape of hidden states from layer 1: torch.Size([1, 10, 1600])


In [80]:
# Set pad_token as eos_token
tokenizer.pad_token = tokenizer.eos_token

# Get the input IDs and attention mask for the clean prompt
inputs_with_attention = tokenizer(prompt, return_tensors="pt", padding=True)

# Generate output for the clean run with attention mask
generated_outputs_clean = model.generate(
    inputs_with_attention.input_ids,
    attention_mask=inputs_with_attention.attention_mask,
    max_length=11,
    num_beams=1,
    no_repeat_ngram_size=2,
    early_stopping=True,
    pad_token_id=tokenizer.eos_token_id  # Explicitly set the pad token to eos token
)

# Decode the generated output
clean_text = tokenizer.decode(generated_outputs_clean[0], skip_special_tokens=True)
print(f"Clean prediction: {clean_text.split()[-1]}")


Clean prediction: Seattle


# 2 . Corrupted Prompt

In [73]:
# **Controlled corruption**: Replace "Space Needle" with "Eiffel Tower"
corrupted_prompt = "The Eiffel Tower is located in the city of"

# Tokenize the corrupted prompt
corrupted_inputs = tokenizer(corrupted_prompt, return_tensors="pt")

# Initialize list to store hidden states from the corrupted run
hidden_states_corrupted = []

# Hook function to capture corrupted hidden states
def hook_fn_corrupted(module, input, output):
    hidden_states_corrupted.append(output[0])

# Register hooks to capture hidden states for each layer during the corrupted run
hooks_corrupted = []
for i in range(model.config.n_layer):
    hooks_corrupted.append(model.transformer.h[i].register_forward_hook(hook_fn_corrupted))

# Run the corrupted model pass and collect activations
with torch.no_grad():
    corrupted_outputs = model(**corrupted_inputs)

# Remove hooks after the corrupted run
for hook in hooks_corrupted:
    hook.remove()


In [79]:
# Set pad_token as eos_token
tokenizer.pad_token = tokenizer.eos_token

# Get the input IDs and attention mask for the corrupt prompt
inputs_with_attention = tokenizer(corrupted_prompt, return_tensors="pt", padding=True)

# Generate output for the corrupt run with attention mask
generated_outputs_corrupted = model.generate(
    inputs_with_attention.input_ids,
    attention_mask=inputs_with_attention.attention_mask,
    max_length=12,
    num_beams=1,
    no_repeat_ngram_size=2,
    early_stopping=True,
    pad_token_id=tokenizer.eos_token_id
)

# Decode the generated output
corrupt_text = tokenizer.decode(generated_outputs_corrupted[0], skip_special_tokens=True)
print(f"Clean prediction: {corrupt_text.split()[-1]}")

Clean prediction: Paris


# 3 . Restoration

In [86]:
# Decide which layer and token to restore
layer_to_restore = 12  # Example: restoring layer 15
token_to_restore = 1  # Example: token index of subject (e.g., "Eiffel Tower" in corrupted prompt)

# Hook function for corrupted-with-restoration
hidden_states_restored = []

def hook_fn_restoration(module, input, output):
    # Replace the corrupted hidden state with the clean hidden state at the specified layer and token
    clean_state = hidden_states_clean[layer_to_restore][0, token_to_restore, :]
    restored_output = output[0].clone()
    restored_output[0, token_to_restore, :] = clean_state  # Restore the clean state at this layer and token
    hidden_states_restored.append(restored_output)
    return (restored_output, *output[1:])

# Register the hook to restore clean activations at the specific layer and token
hooks_restoration = []
hooks_restoration.append(model.transformer.h[layer_to_restore].register_forward_hook(hook_fn_restoration))

# Run the corrupted model pass again with the restoration
with torch.no_grad():
    restored_outputs = model(**corrupted_inputs)

# Remove the hooks after restoration
for hook in hooks_restoration:
    hook.remove()


# Set pad_token as eos_token
tokenizer.pad_token = tokenizer.eos_token

# Get the input IDs and attention mask for the corrupt prompt
inputs_with_attention = tokenizer(corrupted_prompt, return_tensors="pt", padding=True)

# Generate output for the corrupt run with attention mask
generated_outputs_restored = model.generate(
    inputs_with_attention.input_ids,
    attention_mask=inputs_with_attention.attention_mask,
    max_length=12,
    num_beams=1,
    no_repeat_ngram_size=2,
    early_stopping=True,
    pad_token_id=tokenizer.eos_token_id
)

# Decode the generated output
restored_text = tokenizer.decode(generated_outputs_restored[0], skip_special_tokens=True)
print(f"Clean prediction: {restored_text.split()[-1]}")

Clean prediction: Paris


In [84]:
# Set pad_token as eos_token
tokenizer.pad_token = tokenizer.eos_token

# Get the input IDs and attention mask for the corrupt prompt
inputs_with_attention = tokenizer(corrupted_prompt, return_tensors="pt", padding=True)

# Generate output for the corrupt run with attention mask
generated_outputs_restored = model.generate(
    inputs_with_attention.input_ids,
    attention_mask=inputs_with_attention.attention_mask,
    max_length=12,
    num_beams=1,
    no_repeat_ngram_size=2,
    early_stopping=True,
    pad_token_id=tokenizer.eos_token_id
)

# Decode the generated output
restored_text = tokenizer.decode(generated_outputs_restored[0], skip_special_tokens=True)
print(f"Clean prediction: {restored_text.split()[-1]}")

Clean prediction: Paris


# Loop

In [87]:
# Try different layers and token indices for restoration
for layer in range(10, 20):  # Experiment with layers between 10 and 20
    for token_idx in [1, 2]:  # Token indices for 'Space' and 'Needle'
        print(f"Testing layer: {layer}, token index: {token_idx}")

        # Hook function for restoration
        def hook_fn_restoration(module, input, output):
            clean_state = hidden_states_clean[layer][0, token_idx, :]
            restored_output = output[0].clone()
            restored_output[0, token_idx, :] = clean_state
            return (restored_output, *output[1:])

        # Register the hook
        hooks_restoration = []
        hooks_restoration.append(model.transformer.h[layer].register_forward_hook(hook_fn_restoration))

        # Run the corrupted model pass again with the restoration
        with torch.no_grad():
            restored_outputs = model(**corrupted_inputs)

        # Remove the hooks after restoration
        for hook in hooks_restoration:
            hook.remove()

        # Generate the output for the restored model
        generated_outputs_restored = model.generate(
            inputs_with_attention.input_ids,
            attention_mask=inputs_with_attention.attention_mask,
            max_length=12,
            num_beams=1,
            no_repeat_ngram_size=2,
            early_stopping=True,
            pad_token_id=tokenizer.eos_token_id
        )

        # Decode the generated output
        restored_text = tokenizer.decode(generated_outputs_restored[0], skip_special_tokens=True)
        print(f"Restored prediction: {restored_text.split()[-1]}")


Testing layer: 10, token index: 1
Restored prediction: Paris
Testing layer: 10, token index: 2
Restored prediction: Paris
Testing layer: 11, token index: 1
Restored prediction: Paris
Testing layer: 11, token index: 2
Restored prediction: Paris
Testing layer: 12, token index: 1
Restored prediction: Paris
Testing layer: 12, token index: 2
Restored prediction: Paris
Testing layer: 13, token index: 1
Restored prediction: Paris
Testing layer: 13, token index: 2
Restored prediction: Paris
Testing layer: 14, token index: 1
Restored prediction: Paris
Testing layer: 14, token index: 2
Restored prediction: Paris
Testing layer: 15, token index: 1
Restored prediction: Paris
Testing layer: 15, token index: 2
Restored prediction: Paris
Testing layer: 16, token index: 1
Restored prediction: Paris
Testing layer: 16, token index: 2
Restored prediction: Paris
Testing layer: 17, token index: 1
Restored prediction: Paris
Testing layer: 17, token index: 2
Restored prediction: Paris
Testing layer: 18, token