In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
model = AutoModelForCausalLM.from_pretrained("allenai/OLMo-1B", trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-1B", trust_remote_code=True)

model.eval()

OLMoForCausalLM(
  (model): OLMo(
    (transformer): ModuleDict(
      (wte): Embedding(50304, 2048)
      (emb_drop): Dropout(p=0.0, inplace=False)
      (ln_f): LayerNorm()
      (blocks): ModuleList(
        (0-15): 16 x OLMoSequentialBlock(
          (dropout): Dropout(p=0.0, inplace=False)
          (act): SwiGLU()
          (attn_out): Linear(in_features=2048, out_features=2048, bias=False)
          (ff_out): Linear(in_features=8192, out_features=2048, bias=False)
          (rotary_emb): RotaryEmbedding()
          (att_proj): Linear(in_features=2048, out_features=6144, bias=False)
          (ff_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (attn_norm): LayerNorm()
          (ff_norm): LayerNorm()
        )
      )
    )
  )
)

In [26]:
def print_mlp_weights(model, layer_idx):
    """
    Print the weights of an MLP at a given layer.
    
    Args:
        model: The OLMo model
        layer_idx: Index of the layer containing the MLP
    """
    ff_proj = f'model.transformer.blocks.{layer_idx}.ff_proj.weight'
    ff_out = f'model.transformer.blocks.{layer_idx}.ff_out.weight'
    
    state_dict = model.state_dict()
    
    print(f"\nMLP weights for layer {layer_idx}:")
    print("\nFF Projection:")
    print(state_dict[ff_proj])
    print(state_dict[ff_proj].shape)
    
    print("\n")
    print("\nFF Output:") 
    print(state_dict[ff_out])
    print(state_dict[ff_out].shape)


In [15]:

def swap_mlps(model, layer_idx_1, layer_idx_2):
    """
    Swap MLPs between two layers in the OLMo model.
    
    Args:
        model: The OLMo model
        layer_idx_1: Index of the first layer
        layer_idx_2: Index of the second layer
    """
    # Access the transformer blocks
    decoder_layers = model.model.transformer.blocks
    print(f"Swapping MLPs between layers {layer_idx_1} and {layer_idx_2}")
    
    # Verify indices are valid
    num_layers = len(decoder_layers)
    if layer_idx_1 >= num_layers or layer_idx_2 >= num_layers:
        raise ValueError(f"Layer index out of range. Model has {num_layers} layers.")
    
    # Get the MLP blocks from both layers
    layer_1 = decoder_layers[layer_idx_1]
    layer_2 = decoder_layers[layer_idx_2]
    
    # Swap ff projection weights
    layer_1.ff_proj.weight.data, layer_2.ff_proj.weight.data = \
        layer_2.ff_proj.weight.data.clone(), layer_1.ff_proj.weight.data.clone()
        
    # Swap ff output weights
    layer_1.ff_out.weight.data, layer_2.ff_out.weight.data = \
        layer_2.ff_out.weight.data.clone(), layer_1.ff_out.weight.data.clone()
    
    return {
        'swapped_mlps': {
            'layer_1': layer_idx_1,
            'layer_2': layer_idx_2
        }
    }


In [16]:

def swap_multiple_mlps(model, layer_pairs):
    """
    Swap multiple pairs of MLPs in the OLMo model.
    
    Args:
        model: The OLMo model
        layer_pairs: List of tuples, each containing two layer indices to swap
    """
    swaps = []
    for layer_1, layer_2 in layer_pairs:
        swap_info = swap_mlps(model, layer_1, layer_2)
        swaps.append(swap_info)
        print(f"Swapped MLPs between layers {layer_1} and {layer_2}")
    return swaps

In [17]:
def generate_text(model, tokenizer, prompt, max_new_tokens=100):
    """Generate text from a prompt."""
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
    outputs = model.generate(
        inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        max_new_tokens=max_new_tokens,
        temperature=0.6,
        do_sample=True,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [27]:
# Test generation before swapping
prompt = """
Continue this text in a natural and coherent way:

The quantum physicist stared at the readout in disbelief. The entanglement
patterns showed something unprecedented - a stable quantum state that lasted
"""

print()
print("Original generation:")
print("-" * 80)
print()
original_output = generate_text(model, tokenizer, prompt)
print(original_output)


Original generation:
--------------------------------------------------------------------------------


Continue this text in a natural and coherent way:

The quantum physicist stared at the readout in disbelief. The entanglement
patterns showed something unprecedented - a stable quantum state that lasted
for millions of years.

(Source: https://www.wired.com/2016/11/quantum-entangled-molecules-time-travel/)

I think that this is a good example of how a narrative or story can be
used to expand the mind. I think that this is a good example of a way of
explaining quantum mechanics in a way that is easily understood.

I think that this is a good example of how a narrative or story


In [29]:
# Print some weights before swapping (optional)
print_mlp_weights(model, 0)
print_mlp_weights(model, 1)



MLP weights for layer 0:

FF Projection:
tensor([[-1.2955e-03, -2.2828e-03,  5.8700e-03,  ..., -4.1715e-03,
         -2.9691e-03,  3.5815e-03],
        [-7.4904e-03,  5.2179e-05, -3.0653e-03,  ..., -6.2542e-03,
         -6.7004e-03,  3.1544e-03],
        [-1.2344e-02, -9.9137e-03, -1.5885e-03,  ..., -5.1158e-03,
         -5.1641e-03,  1.0260e-02],
        ...,
        [-1.8434e-03,  1.0538e-02,  1.5211e-02,  ..., -1.5467e-02,
         -4.5518e-03, -1.6748e-03],
        [ 9.3506e-04, -2.2135e-03, -8.7500e-03,  ..., -1.9706e-03,
         -8.7817e-03, -4.2292e-03],
        [ 1.3079e-02, -6.1314e-03, -1.2436e-02,  ...,  3.9528e-03,
          7.4631e-03, -1.9468e-03]])
torch.Size([16384, 2048])



FF Output:
tensor([[-5.0323e-04,  5.9398e-03, -2.1946e-03,  ...,  1.4586e-02,
         -1.7366e-03,  4.7926e-03],
        [-1.1657e-02, -3.1330e-03, -8.0331e-04,  ...,  7.7663e-04,
          1.0677e-02, -4.8340e-03],
        [ 1.5075e-03,  4.4474e-04, -1.4024e-03,  ..., -4.0307e-03,
          7.4

In [36]:
# Perform MLP swaps
layer_pairs = [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9)]  # Example layer pairs
swap_results = swap_multiple_mlps(model, layer_pairs)

Swapping MLPs between layers 0 and 1
Swapped MLPs between layers 0 and 1
Swapping MLPs between layers 2 and 3
Swapped MLPs between layers 2 and 3
Swapping MLPs between layers 4 and 5
Swapped MLPs between layers 4 and 5
Swapping MLPs between layers 6 and 7
Swapped MLPs between layers 6 and 7
Swapping MLPs between layers 8 and 9
Swapped MLPs between layers 8 and 9


In [34]:
# Print some weights after swapping (optional)
print_mlp_weights(model, 0)
print_mlp_weights(model, 1)


MLP weights for layer 0:

FF Projection:
tensor([[-1.2955e-03, -2.2828e-03,  5.8700e-03,  ..., -4.1715e-03,
         -2.9691e-03,  3.5815e-03],
        [-7.4904e-03,  5.2179e-05, -3.0653e-03,  ..., -6.2542e-03,
         -6.7004e-03,  3.1544e-03],
        [-1.2344e-02, -9.9137e-03, -1.5885e-03,  ..., -5.1158e-03,
         -5.1641e-03,  1.0260e-02],
        ...,
        [-1.8434e-03,  1.0538e-02,  1.5211e-02,  ..., -1.5467e-02,
         -4.5518e-03, -1.6748e-03],
        [ 9.3506e-04, -2.2135e-03, -8.7500e-03,  ..., -1.9706e-03,
         -8.7817e-03, -4.2292e-03],
        [ 1.3079e-02, -6.1314e-03, -1.2436e-02,  ...,  3.9528e-03,
          7.4631e-03, -1.9468e-03]])
torch.Size([16384, 2048])



FF Output:
tensor([[-5.0323e-04,  5.9398e-03, -2.1946e-03,  ...,  1.4586e-02,
         -1.7366e-03,  4.7926e-03],
        [-1.1657e-02, -3.1330e-03, -8.0331e-04,  ...,  7.7663e-04,
          1.0677e-02, -4.8340e-03],
        [ 1.5075e-03,  4.4474e-04, -1.4024e-03,  ..., -4.0307e-03,
          7.4

In [37]:
# Test generation after swapping
print("\nGeneration after swapping MLPs:")
print("-" * 80)
print()
modified_output = generate_text(model, tokenizer, prompt)
print(modified_output)


Generation after swapping MLPs:
--------------------------------------------------------------------------------


Continue this text in a natural and coherent way:

The quantum physicist stared at the readout in disbelief. The entanglement
patterns showed something unprecedented - a stable quantum state that lasted
ended).
and continued.
alalp).3 hours.8 hours.8.1.1.5.8).3 hours.8.5.8.8.4.8.8.8.8.8.8dand.T.im.T.T.T.T.T.T.T.T.T.T.T.T.T.T.T.T.T.T.m.T.T.


### zero out weights

In [None]:
def zero_mlp(model, layer_idx):
    """
    Zero out MLP weights in a specific layer of the OLMo model.
    
    Args:
        model: The OLMo model
        layer_idx: Index of the layer to zero out
    """
    # Access the transformer blocks
    decoder_layers = model.model.transformer.blocks
    print(f"Zeroing out MLP in layer {layer_idx}")
    
    # Verify index is valid
    num_layers = len(decoder_layers)
    if layer_idx >= num_layers:
        raise ValueError(f"Layer index out of range. Model has {num_layers} layers.")
    
    # Get the layer
    layer = decoder_layers[layer_idx]
    
    # Zero out ff projection weights
    layer.ff_proj.weight.data.zero_()
    
    # Zero out ff output weights
    layer.ff_out.weight.data.zero_()
    
    return {
        'zeroed_mlp': {
            'layer': layer_idx
        }
    }


In [None]:
def zero_multiple_mlps(model, layer_indices):
    """
    Zero out MLPs in multiple layers of the OLMo model.
    
    Args:
        model: The OLMo model
        layer_indices: List of layer indices to zero out
    """
    zeroed = []
    for layer_idx in layer_indices:
        zero_info = zero_mlp(model, layer_idx)
        zeroed.append(zero_info)
        print(f"Zeroed MLP in layer {layer_idx}")
    return zeroed