In [1]:
import torch

# 1. Load the Model

In [57]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel

tokenizer = GPT2Tokenizer.from_pretrained("gpt2-xl")
model = GPT2LMHeadModel.from_pretrained("gpt2-xl")



# 1 . Factual Prompt

In [58]:
# The factual prompt
clean_prompt = "The Space Needle is located in the city of"

# Tokenize the input prompt
inputs = tokenizer(clean_prompt, return_tensors="pt")
# inputs

In [59]:
# # Token IDs
# input_ids = inputs['input_ids'][0]  # Get the token IDs

# # Convert token IDs to actual tokens
# tokens = tokenizer.convert_ids_to_tokens(input_ids)

# # Print the tokens
# print(tokens)

In [60]:
# Forward pass through the model to get the outputs
with torch.no_grad():
    outputs = model(**inputs)

# # Print the logits (output predictions)
# print(outputs.logits)

In [61]:
# Function to hook and capture only the hidden states (first element of the output tuple)
hidden_states_clean = []

# Hook function to capture clean hidden states
def hook_fn_clean(module, input, output):
    hidden_states_clean.append(output[0])

# Register hooks to capture hidden states for each layer
hooks_clean = []
for i in range(model.config.n_layer):
    hooks_clean.append(model.transformer.h[i].register_forward_hook(hook_fn_clean))

# Run the clean model pass
with torch.no_grad():
    outputs_clean = model(**inputs)

# Remove hooks after the clean run
for hook in hooks_clean:
    hook.remove()

# Now hidden_states contains activations for all layers
print(f"Number of layers: {len(hidden_states_clean)}")
print(f"Shape of hidden states from layer 1: {hidden_states_clean[0].shape}")


Number of layers: 48
Shape of hidden states from layer 1: torch.Size([1, 10, 1600])


In [62]:
# Set pad_token as eos_token
tokenizer.pad_token = tokenizer.eos_token

# Get the input IDs and attention mask for the clean prompt
inputs_with_attention = tokenizer(clean_prompt, return_tensors="pt", padding=True)

# Generate output for the clean run with attention mask
generated_outputs_clean = model.generate(
    inputs_with_attention.input_ids,
    attention_mask=inputs_with_attention.attention_mask,
    max_length=11,
    num_beams=1,
    no_repeat_ngram_size=2,
    early_stopping=True,
    pad_token_id=tokenizer.eos_token_id  # Explicitly set the pad token to eos token
)

# Decode the generated output
clean_text = tokenizer.decode(generated_outputs_clean[0], skip_special_tokens=True)
print(f"Clean prediction: {clean_text.split()[-1]}")




Clean prediction: Seattle


# 2 . Corrupted Prompt

In [63]:
# **Controlled corruption**: Replace "Space Needle" with "Eiffel Tower"
corrupted_prompt = "The Eiffel Tower is located in the city of"

# Tokenize the corrupted prompt
corrupted_inputs = tokenizer(corrupted_prompt, return_tensors="pt")

# Initialize list to store hidden states from the corrupted run
hidden_states_corrupted = []

# Hook function to capture corrupted hidden states
def hook_fn_corrupted(module, input, output):
    hidden_states_corrupted.append(output[0])

# Register hooks to capture hidden states for each layer during the corrupted run
hooks_corrupted = []
for i in range(model.config.n_layer):
    hooks_corrupted.append(model.transformer.h[i].register_forward_hook(hook_fn_corrupted))

# Run the corrupted model pass and collect activations
with torch.no_grad():
    corrupted_outputs = model(**corrupted_inputs)

# Remove hooks after the corrupted run
for hook in hooks_corrupted:
    hook.remove()


In [64]:
# Set pad_token as eos_token
tokenizer.pad_token = tokenizer.eos_token

# Get the input IDs and attention mask for the corrupt prompt
inputs_with_attention = tokenizer(corrupted_prompt, return_tensors="pt", padding=True)

# Generate output for the corrupt run with attention mask
generated_outputs_corrupted = model.generate(
    inputs_with_attention.input_ids,
    attention_mask=inputs_with_attention.attention_mask,
    max_length=12,
    num_beams=1,
    no_repeat_ngram_size=2,
    early_stopping=True,
    pad_token_id=tokenizer.eos_token_id
)

# Decode the generated output
corrupt_text = tokenizer.decode(generated_outputs_corrupted[0], skip_special_tokens=True)
print(f"Corrupted prediction: {corrupt_text.split()[-1]}")

Corrupted prediction: Paris


# 3 . Restoration

In [65]:
tokenized_input = tokenizer.decode(inputs_with_attention.input_ids[0], skip_special_tokens=False)
decoded_tokens = tokenizer.convert_ids_to_tokens(inputs_with_attention.input_ids[0])

# Print the tokenized input for reference
print(f"Decoded tokenized input: {decoded_tokens}")
print(f"The subject: {decoded_tokens[1:4]}" )

Decoded tokenized input: ['The', 'ĠE', 'iff', 'el', 'ĠTower', 'Ġis', 'Ġlocated', 'Ġin', 'Ġthe', 'Ġcity', 'Ġof']
The subject: ['ĠE', 'iff', 'el']


In [66]:
# Choose layers to restore hidden states from
layers_to_restore = range(0,48)

# Tokenize the corrupted prompt to get the number of tokens
num_tokens = inputs_with_attention.input_ids.shape[1]  # Get the number of tokens in the input


# Loop over each layer
for layer in layers_to_restore:  # Iterate over the selected layers
    print(f"Restoring hidden states for layer {layer} :")

    # Hook function to restore hidden states for all tokens except the last
    def hook_fn_restoration(module, input, output):
        restored_output = output[0].clone()

        # Restore the clean hidden states for all tokens except the last one
        for token_idx in range(1,4):
            clean_state = hidden_states_clean[layer][0, token_idx, :]  # Get the clean hidden state for each token
            restored_output[0, token_idx, :] = clean_state  # Restore clean state for each token

        return (restored_output, *output[1:])

    # Register the hook to restore clean activations at the specific layer for selected tokens
    hooks_restoration = []
    hooks_restoration.append(model.transformer.h[layer].register_forward_hook(hook_fn_restoration))

    # Run the corrupted model pass with the restoration active
    with torch.no_grad():
        # Generate the output for the restored model while the hook is active
        generated_outputs_restored = model.generate(
            inputs_with_attention.input_ids,
            attention_mask=inputs_with_attention.attention_mask,
            max_length=12,
            num_beams=1,
            no_repeat_ngram_size=2,
            early_stopping=True,
            pad_token_id=tokenizer.eos_token_id
        )

    # Remove the hooks after generating the output
    for hook in hooks_restoration:
        hook.remove()

    # Decode the generated output
    restored_text = tokenizer.decode(generated_outputs_restored[0], skip_special_tokens=True)
    print(f"Restored prediction for layer {layer}: {restored_text.split()[-1]}")

Restoring hidden states for layer 0 :
Restored prediction for layer 0: Seattle
Restoring hidden states for layer 1 :
Restored prediction for layer 1: Seattle
Restoring hidden states for layer 2 :
Restored prediction for layer 2: Seattle
Restoring hidden states for layer 3 :
Restored prediction for layer 3: Seattle
Restoring hidden states for layer 4 :
Restored prediction for layer 4: Seattle
Restoring hidden states for layer 5 :
Restored prediction for layer 5: Seattle
Restoring hidden states for layer 6 :
Restored prediction for layer 6: Seattle
Restoring hidden states for layer 7 :
Restored prediction for layer 7: Seattle
Restoring hidden states for layer 8 :
Restored prediction for layer 8: Seattle
Restoring hidden states for layer 9 :
Restored prediction for layer 9: Seattle
Restoring hidden states for layer 10 :
Restored prediction for layer 10: Seattle
Restoring hidden states for layer 11 :
Restored prediction for layer 11: Seattle
Restoring hidden states for layer 12 :
Restored 