In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
model = AutoModelForCausalLM.from_pretrained("allenai/OLMo-1B", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B", trust_remote_code=True)

model.eval()

OLMoForCausalLM(
  (model): OLMo(
    (transformer): ModuleDict(
      (wte): Embedding(50304, 2048)
      (emb_drop): Dropout(p=0.0, inplace=False)
      (ln_f): LayerNorm()
      (blocks): ModuleList(
        (0-15): 16 x OLMoSequentialBlock(
          (dropout): Dropout(p=0.0, inplace=False)
          (act): SwiGLU()
          (attn_out): Linear(in_features=2048, out_features=2048, bias=False)
          (ff_out): Linear(in_features=8192, out_features=2048, bias=False)
          (rotary_emb): RotaryEmbedding()
          (att_proj): Linear(in_features=2048, out_features=6144, bias=False)
          (ff_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (attn_norm): LayerNorm()
          (ff_norm): LayerNorm()
        )
      )
    )
  )
)

In [3]:
def print_mlp_weights(model, layer_idx):
    """
    Print the weights of an MLP at a given layer.
    
    Args:
        model: The OLMo model
        layer_idx: Index of the layer containing the MLP
    """
    ff_proj = f'model.transformer.blocks.{layer_idx}.ff_proj.weight'
    ff_out = f'model.transformer.blocks.{layer_idx}.ff_out.weight'
    
    state_dict = model.state_dict()
    
    print(f"\nMLP weights for layer {layer_idx}:")
    print("\nFF Projection:")
    print(state_dict[ff_proj])
    print(state_dict[ff_proj].shape)
    
    print("\n")
    print("\nFF Output:") 
    print(state_dict[ff_out])
    print(state_dict[ff_out].shape)


In [4]:
def zero_mlp(model, layer_idx):
    """
    Zero out MLP weights in a specific layer of the OLMo model.
    
    Args:
        model: The OLMo model
        layer_idx: Index of the layer to zero out
    """
    # Access the transformer blocks
    decoder_layers = model.model.transformer.blocks
    print(f"Zeroing out MLP in layer {layer_idx}")
    
    # Verify index is valid
    num_layers = len(decoder_layers)
    if layer_idx >= num_layers:
        raise ValueError(f"Layer index out of range. Model has {num_layers} layers.")
    
    # Get the layer
    layer = decoder_layers[layer_idx]
    
    # Zero out ff projection weights
    layer.ff_proj.weight.data.zero_()
    
    # Zero out ff output weights
    layer.ff_out.weight.data.zero_()
    
    return {
        'zeroed_mlp': {
            'layer': layer_idx
        }
    }


In [5]:
def zero_multiple_mlps(model, layer_indices):
    """
    Zero out MLPs in multiple layers of the OLMo model.
    
    Args:
        model: The OLMo model
        layer_indices: List of layer indices to zero out
    """
    zeroed = []
    for layer_idx in layer_indices:
        zero_info = zero_mlp(model, layer_idx)
        zeroed.append(zero_info)
        print(f"Zeroed MLP in layer {layer_idx}")
    return zeroed

In [6]:
def generate_text(model, tokenizer, prompt, max_new_tokens=100):
    """Generate text from a prompt."""
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
    outputs = model.generate(
        inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        max_new_tokens=max_new_tokens,
        temperature=0.6,
        do_sample=True,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [7]:
# Test generation before swapping
prompt = """
Continue this text in a natural and coherent way:

The quantum physicist stared at the readout in disbelief. The entanglement
patterns showed something unprecedented - a stable quantum state that lasted
"""

print()
print("Original generation:")
print("-" * 80)
print()
original_output = generate_text(model, tokenizer, prompt)
print(original_output)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.



Original generation:
--------------------------------------------------------------------------------


Continue this text in a natural and coherent way:

The quantum physicist stared at the readout in disbelief. The entanglement
patterns showed something unprecedented - a stable quantum state that lasted
for a long time.

The physicist thought about the possibilities, but he was too tired to
think about them further.

He was a senior scientist at a university. He had been working on this
project for several months. The first step was to create a quantum
entanglement. The second step was to observe it. The third step was to
measure the entanglement. The fourth step was to calculate the entanglement
of the quantum state.

The fifth step was


In [8]:
# Print some weights before swapping (optional)
print_mlp_weights(model, 0)
print_mlp_weights(model, 1)



MLP weights for layer 0:

FF Projection:
tensor([[-1.2955e-03, -2.2828e-03,  5.8700e-03,  ..., -4.1715e-03,
         -2.9691e-03,  3.5815e-03],
        [-7.4904e-03,  5.2179e-05, -3.0653e-03,  ..., -6.2542e-03,
         -6.7004e-03,  3.1544e-03],
        [-1.2344e-02, -9.9137e-03, -1.5885e-03,  ..., -5.1158e-03,
         -5.1641e-03,  1.0260e-02],
        ...,
        [-1.8434e-03,  1.0538e-02,  1.5211e-02,  ..., -1.5467e-02,
         -4.5518e-03, -1.6748e-03],
        [ 9.3506e-04, -2.2135e-03, -8.7500e-03,  ..., -1.9706e-03,
         -8.7817e-03, -4.2292e-03],
        [ 1.3079e-02, -6.1314e-03, -1.2436e-02,  ...,  3.9528e-03,
          7.4631e-03, -1.9468e-03]])
torch.Size([16384, 2048])



FF Output:
tensor([[-5.0323e-04,  5.9398e-03, -2.1946e-03,  ...,  1.4586e-02,
         -1.7366e-03,  4.7926e-03],
        [-1.1657e-02, -3.1330e-03, -8.0331e-04,  ...,  7.7663e-04,
          1.0677e-02, -4.8340e-03],
        [ 1.5075e-03,  4.4474e-04, -1.4024e-03,  ..., -4.0307e-03,
          7.4

In [9]:
layer_indices = [0] 
zero_multiple_mlps(model, layer_indices)

Zeroing out MLP in layer 0
Zeroed MLP in layer 0


[{'zeroed_mlp': {'layer': 0}}]

In [10]:
# Print some weights after swapping (optional)
print_mlp_weights(model, 0)
print_mlp_weights(model, 1)


MLP weights for layer 0:

FF Projection:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
torch.Size([16384, 2048])



FF Output:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
torch.Size([2048, 8192])

MLP weights for layer 1:

FF Projection:
tensor([[ 7.7826e-03,  5.9513e-04,  3.0032e-03,  ..., -2.7562e-05,
         -3.3554e-04, -1.5863e-02],
        [ 1.5511e-02,  1.6985e-02, -3.0576e-03,  ...,  9.6881e-03,
          5.3660e-03,  9.6499e-04],
        [-2.2951e-04,  4.6680e-03, -4.6625e-03,  ...,  2.2134e-03,
          5.5331e-03, -8.3459e-03],
        ...,
        [ 7.8161e-03,

In [12]:
# Test generation after swapping
print(f"\nGeneration after zeroing out MLPs in layer {layer_indices}:")
print("-" * 80)
print()
modified_output = generate_text(model, tokenizer, prompt)
print(modified_output)


Generation after zeroing out MLPs in layer [0]:
--------------------------------------------------------------------------------


Continue this text in a natural and coherent way:

The quantum physicist stared at the readout in disbelief. The entanglement
patterns showed something unprecedented - a stable quantum state that lasted
weeks, and was not be be used with the same same results of the
Quantum state that lasted for so many years.
The entanglement was so stable and so that this was not.
The entanglement was so stable and so that this was not.
The entanglement was so stable.
The entanglement was so stable that it did not be used with the
The entanglement was so stable.
The entanglement was so stable.
The entanglement was so that it was so stable.
The entanglement
