In [1]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, OlmoeConfig

In [2]:
model = AutoModelForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924")
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")

# Set the model to eval mode
model.eval()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

OlmoeForCausalLM(
  (model): OlmoeModel(
    (embed_tokens): Embedding(50304, 2048, padding_idx=1)
    (layers): ModuleList(
      (0-15): 16 x OlmoeDecoderLayer(
        (self_attn): OlmoeSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (q_norm): OlmoeRMSNorm((2048,), eps=1e-05)
          (k_norm): OlmoeRMSNorm((2048,), eps=1e-05)
        )
        (mlp): OlmoeSparseMoeBlock(
          (gate): Linear(in_features=2048, out_features=64, bias=False)
          (experts): ModuleList(
            (0-63): 64 x OlmoeMLP(
              (gate_proj): Linear(in_features=2048, out_features=1024, bias=False)
              (up_proj): Linear(in_features=2048, out_features=1024, bias=False)
              (down_proj): Linear(in

In [10]:
def get_attention_tensor(input_text, model, tokenizer, layer_num):
    # Tokenize the input
    inputs = tokenizer(input_text, return_tensors="pt")

    # Print the tokenized input
    print("Tokenized input:")
    for token_id in inputs.input_ids[0]:
        token = tokenizer.decode([token_id])
        print(f"Token: '{token}', ID: {token_id.item()}")

    print(f"\nInput shape: {inputs.input_ids.shape}")

    # Forward pass with output_attentions=True
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)

    # Get the attention tensor from the specified layer
    attention_tensor = outputs.attentions[layer_num]
    
    print(f"\nAttention tensor shape: {attention_tensor.shape}")
    
    # format of attnetion tensor?
    
    # Print which layer the attention tensor is from
    print(f"Attention tensor is from layer {layer_num} of the model")

    # Initialize a list to store attention weights for each token
    all_attention_weights = []

    # Print attention weights for each token
    for token_idx, token_id in enumerate(inputs.input_ids[0]):
        token = tokenizer.decode([token_id])
        attention_weights = attention_tensor[:, :, token_idx, :].squeeze(0)
        
        # Append the attention weights to the list
        all_attention_weights.append(attention_weights.tolist())

    return attention_tensor

In [12]:
# for testing this function
input_text = "The capital city of France is"
layer_num = 0

x =  get_attention_tensor(input_text, model, tokenizer, layer_num)
print(x)

Tokenized input:
Token: 'The', ID: 510
Token: ' capital', ID: 5347
Token: ' city', ID: 2846
Token: ' of', ID: 273
Token: ' France', ID: 6181
Token: ' is', ID: 310

Input shape: torch.Size([1, 6])

Attention tensor shape: torch.Size([1, 16, 6, 6])
Attention tensor is from layer 0 of the model
tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00],
          [5.8542e-01, 4.1458e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00],
          [4.2129e-01, 2.9937e-01, 2.7934e-01, 0.0000e+00, 0.0000e+00,
           0.0000e+00],
          [2.7586e-01, 2.1433e-01, 2.0015e-01, 3.0966e-01, 0.0000e+00,
           0.0000e+00],
          [2.4056e-01, 1.6639e-01, 1.5372e-01, 2.7955e-01, 1.5978e-01,
           0.0000e+00],
          [1.8133e-01, 1.4626e-01, 1.3953e-01, 1.9880e-01, 1.4343e-01,
           1.9064e-01]],

         [[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
           0.0000e+00],
          [8.1084e-01, 1.8916e-01, 0.000

In [13]:
# for testing this function
input_text = "The capital city of India United States is"
layer_num = 0

x =  get_attention_tensor(input_text, model, tokenizer, layer_num)
print(x)

Tokenized input:
Token: 'The', ID: 510
Token: ' capital', ID: 5347
Token: ' city', ID: 2846
Token: ' of', ID: 273
Token: ' India', ID: 5427
Token: ' United', ID: 1986
Token: ' States', ID: 2077
Token: ' is', ID: 310

Input shape: torch.Size([1, 8])

Attention tensor shape: torch.Size([1, 16, 8, 8])
Attention tensor is from layer 0 of the model
tensor([[[[1.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [5.8542e-01, 4.1458e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [4.2129e-01, 2.9937e-01, 2.7934e-01,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          ...,
          [2.1696e-01, 1.3829e-01, 1.2574e-01,  ..., 1.2326e-01,
           0.0000e+00, 0.0000e+00],
          [1.7966e-01, 1.2470e-01, 1.1679e-01,  ..., 1.1387e-01,
           1.3353e-01, 0.0000e+00],
          [1.4032e-01, 1.1320e-01, 1.0813e-01,  ..., 1.0715e-01,
           1.1795e-01, 1.4754e-01]],

         [[1.0000e+00, 0.0000e+

In [8]:
def patch_attention(model, tokenizer, input_1: str, input_2: str, max_new_tokens: int = 20):
    """
    Run attention patching experiment with pre-initialized model.
    """
    device = next(model.parameters()).device
    
    print(f"\nAttention Implementation: {model.config._attn_implementation}")
    
    tokens_1 = tokenizer(input_1, return_tensors="pt").to(device)
    tokens_2 = tokenizer(input_2, return_tensors="pt").to(device)
    
    print(f"\nInput 1 sequence length: {tokens_1['input_ids'].shape[1]}")
    print(f"Input 2 sequence length: {tokens_2['input_ids'].shape[1]}")

    stored_outputs = []
    
    def store_attention(module, input, output):
        stored_outputs.append(output)  # Store complete output tuple
        if len(stored_outputs) == 1:
            if isinstance(output, tuple):
                print(f"\nOutput is a tuple with {len(output)} elements")
                for i, o in enumerate(output):
                    if isinstance(o, torch.Tensor):
                        print(f"Output[{i}] shape: {o.shape}")
            else:
                print(f"\nOutput shape: {output.shape}")
        return output
    
    print("\nRegistering storage hooks...")
    hooks = []
    for i, layer in enumerate(model.model.layers):
        hooks.append(layer.self_attn.register_forward_hook(store_attention))

    print(f"\nGenerating with input_1: '{input_1}'")
    with torch.no_grad():
        output_1 = model.generate(
            **tokens_1,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.pad_token_id
        )
    
    print(f"\nStored attention outputs from {len(stored_outputs)} layers")
    
    for hook in hooks:
        hook.remove()
    
    print(f"\nGenerating with input_2: '{input_2}'")
    with torch.no_grad():
        output_2 = model.generate(
            **tokens_2,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.pad_token_id
        )
    
    print("\nRegistering patching hooks...")
    hooks = []
    
    def patch_attention(layer_idx):
        def hook(module, input, output):
            # if layer_idx == 0:
            #     print(f"\nLayer {layer_idx} patching:")
            #     if isinstance(output, tuple):
            #         print(f"Output is a tuple with {len(output)} elements")
            #         for i, o in enumerate(output):
            #             if isinstance(o, torch.Tensor):
            #                 print(f"Output[{i}] shape: {o.shape}")
            return stored_outputs[layer_idx]  # Return complete stored output tuple
        return hook
    
    for i, layer in enumerate(model.model.layers):
        hooks.append(layer.self_attn.register_forward_hook(patch_attention(i)))

    print(f"\nGenerating with patched attention...")
    with torch.no_grad():
        output_patched = model.generate(
            **tokens_2,
            max_new_tokens=max_new_tokens,
            pad_token_id=tokenizer.pad_token_id
        )

    for hook in hooks:
        hook.remove()

    output_1_text = tokenizer.decode(output_1[0], skip_special_tokens=True)
    output_2_text = tokenizer.decode(output_2[0], skip_special_tokens=True)
    output_patched_text = tokenizer.decode(output_patched[0], skip_special_tokens=True)

    print("\nFinal Outputs:")
    print(f"Input 1 ({input_1}) Output: {output_1_text}")
    print(f"Input 2 ({input_2}) Output: {output_2_text}")
    print(f"Patched Output: {output_patched_text}")

    return output_1_text, output_2_text, output_patched_text

In [9]:
output_1, output_2, output_patched = patch_attention(
    model,
    tokenizer, 
    "The capital city of France is",
    "The capital city of India is"
)


Attention Implementation: sdpa

Input 1 sequence length: 6
Input 2 sequence length: 6

Registering storage hooks...

Generating with input_1: 'The capital city of France is'

Output is a tuple with 3 elements
Output[0] shape: torch.Size([1, 6, 2048])

Stored attention outputs from 320 layers

Generating with input_2: 'The capital city of India is'

Registering patching hooks...

Generating with patched attention...

Final Outputs:
Input 1 (The capital city of France is) Output: The capital city of France is Paris.

The capital city of Germany is Berlin.

The capital city of Italy is
Input 2 (The capital city of India is) Output: The capital city of India is New Delhi.

The capital city of India is New Delhi.

The capital city of
Patched Output: The capital city of India is Paris, Paris, Paris, Paris, Paris, Paris, Paris, Paris, Paris, Paris,
