In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
model = AutoModelForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924")
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
def print_expert_weights(model, layer_idx, expert_idx):
    """
    Print the weights of a specific expert MLP at a given layer.
    
    Args:
        model: The OLMoE model
        layer_idx: Index of the layer containing the expert
        expert_idx: Index of the expert within the layer
    """
    gate_proj = f'model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight'
    up_proj = f'model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight'
    down_proj = f'model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight'
    
    print("\nGate Projection:")
    print(model.state_dict()[gate_proj])
    print(model.state_dict()[gate_proj].shape)
    print("\nUp Projection:") 
    print(model.state_dict()[up_proj])
    print(model.state_dict()[up_proj].shape)
    print("\nDown Projection:")
    print(model.state_dict()[down_proj])
    print(model.state_dict()[down_proj].shape)

In [4]:
def zero_expert(model, layer_idx, expert_idx):
    """
    Zero out a specific expert in a specific layer of the OLMoE model.
    
    Args:
        model: The OLMoE model
        layer_idx: Index of the layer containing the expert
        expert_idx: Index of the expert to zero out
    """
    # Access the decoder layers
    decoder_layers = model.model.layers
    
    # Verify indices are valid
    num_layers = len(decoder_layers)
    if layer_idx >= num_layers:
        raise ValueError(f"Layer index out of range. Model has {num_layers} layers.")
    
    # Get the MoE block
    moe = decoder_layers[layer_idx].mlp
    
    # Verify expert index is valid
    num_experts = len(moe.experts)
    if expert_idx >= num_experts:
        raise ValueError(f"Expert index out of range. Layer has {num_experts} experts.")
        
    # Get the expert
    expert = moe.experts[expert_idx]
    
    # Zero out all weights in the expert
    expert.gate_proj.weight.data.zero_()
    expert.up_proj.weight.data.zero_()
    expert.down_proj.weight.data.zero_()
    
    return {
        'zeroed_expert': {
            'layer': layer_idx,
            'expert': expert_idx
        }
    }


In [5]:
def zero_multiple_experts(model, expert_indices_per_layer):
    """
    Zero out multiple experts across different layers.
    
    Args:
        model: The OLMoE model
        expert_indices_per_layer: Dict mapping layer indices to lists of expert indices to zero
                                e.g. {0: [0,1], 1: [2,3]} zeros experts 0,1 in layer 0 and 2,3 in layer 1
    """
    results = []
    
    for layer_idx, expert_indices in expert_indices_per_layer.items():
        for expert_idx in expert_indices:
            result = zero_expert(model, layer_idx, expert_idx)
            results.append(result)
            print(f"Zeroed out expert {expert_idx} in layer {layer_idx}")
            
    return results


### zeroing out experts

In [6]:
prompt = ("""    
Continue this text in a natural and coherent way, maintaining consistency with the style, 
terminology, and logical flow of the preceding text.
\\title{Quantum Error Mitigation in NISQ Devices}
\\begin{abstract}
We present a novel approach to error mitigation in noisy intermediate-scale quantum (NISQ) devices. 
Our method introduces a scaling framework for quantum channels that preserves gate fidelity while reducing environmental noise.
\end{abstract}
\section{Introduction}
Recent advances in NISQ devices have demonstrated both promise and limitations in quantum computation. 
The primary challenge remains decoherence, which introduces errors in quantum operations. We propose a channel scaling approach 
$\mathcal{N}(\\rho) = e^{-\lambda t}\\rho$ 
that provides a systematic way to
""")

In [7]:
# Convert the prompt to inputs and run a forward pass
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
# Generate output (since it's a causal LM, we need to generate text from input)
outputs = model.generate(
    inputs['input_ids'],  # Only provide input_ids to generate
    attention_mask=inputs['attention_mask'],  # Add attention mask to not attend to padding tokens
    max_new_tokens=156,    # Generate 100 new tokens
    temperature=0.6,       # Control randomness
    do_sample=True,        # Use sampling instead of greedy decoding
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.eos_token_id  # Set padding token
)

# Decode the generated output
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

# Print the original prompt and generated response
print("Prompt:", prompt)
print("\nGenerated response (pre zero out):", generated_text)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Prompt:     
Continue this text in a natural and coherent way, maintaining consistency with the style, 
terminology, and logical flow of the preceding text.
\title{Quantum Error Mitigation in NISQ Devices}
\begin{abstract}
We present a novel approach to error mitigation in noisy intermediate-scale quantum (NISQ) devices. 
Our method introduces a scaling framework for quantum channels that preserves gate fidelity while reducing environmental noise.
\end{abstract}
\section{Introduction}
Recent advances in NISQ devices have demonstrated both promise and limitations in quantum computation. 
The primary challenge remains decoherence, which introduces errors in quantum operations. We propose a channel scaling approach 
$\mathcal{N}(\rho) = e^{-\lambda t}\rho$ 
that provides a systematic way to


Generated response (pre zero out):     
Continue this text in a natural and coherent way, maintaining consistency with the style, 
terminology, and logical flow of the preceding text.
\title{Quantum 

In [8]:
# pre zero out

print_expert_weights(model, layer_idx=0, expert_idx=0)
print_expert_weights(model, layer_idx=1, expert_idx=47)

print("\n" + "="*80 + "\n") 



Gate Projection:
tensor([[-0.0118,  0.0150,  0.0190,  ..., -0.0022, -0.0043,  0.0098],
        [-0.0156,  0.0110,  0.0079,  ..., -0.0136, -0.0059,  0.0064],
        [-0.0004, -0.0051, -0.0077,  ...,  0.0092,  0.0153, -0.0466],
        ...,
        [-0.0208, -0.0040,  0.0033,  ...,  0.0048, -0.0037,  0.0115],
        [-0.0066,  0.0152,  0.0057,  ..., -0.0050, -0.0275, -0.0078],
        [ 0.0004, -0.0062,  0.0060,  ...,  0.0011, -0.0302,  0.0057]])
torch.Size([1024, 2048])

Up Projection:
tensor([[ 0.0038, -0.0003,  0.0136,  ..., -0.0016,  0.0109, -0.0172],
        [ 0.0078, -0.0044,  0.0027,  ...,  0.0194,  0.0127,  0.0271],
        [ 0.0020,  0.0182,  0.0090,  ...,  0.0281, -0.0231, -0.0129],
        ...,
        [ 0.0015,  0.0172,  0.0036,  ...,  0.0123, -0.0040, -0.0101],
        [-0.0208,  0.0119,  0.0157,  ..., -0.0388,  0.0206, -0.0027],
        [-0.0113,  0.0016, -0.0238,  ..., -0.0243, -0.0132,  0.0069]])
torch.Size([1024, 2048])

Down Projection:
tensor([[-0.0053,  0.0125, -0.

#### top 8 and bottom 8 experts for each layer for `inputs.txt`
Layer `0`
First 8 keys: `[0, 6, 36, 41, 52, 10, 49, 21]`
Last 8 keys: `[43, 20, 63, 7, 58, 34, 39, 9]`

Layer `1`
First 8 keys: `[47, 18, 61, 25, 5, 16, 27, 7]`
Last 8 keys: `[3, 6, 33, 52, 56, 10, 43, 57]`

Layer `2`
First 8 keys: `[60, 10, 61, 45, 40, 30, 15, 26]`
Last 8 keys: `[53, 4, 55, 41, 2, 8, 51, 59]`

Layer `3`
First 8 keys: `[35, 9, 6, 62, 15, 51, 19, 43]`
Last 8 keys: `[1, 44, 24, 25, 60, 26, 16, 17]`

Layer `4`
First 8 keys: `[17, 21, 6, 27, 2, 14, 25, 55]`
Last 8 keys: `[36, 13, 54, 12, 57, 18, 15, 32]`

Layer `5`
First 8 keys: `[0, 60, 31, 57, 37, 17, 21, 2]`
Last 8 keys: `[32, 47, 51, 24, 49, 63, 15, 26]`

Layer `6`
First 8 keys: `[57, 62, 18, 36, 52, 40, 4, 31]`
Last 8 keys: `[25, 21, 45, 34, 14, 39, 11, 56]`

Layer `7`
First 8 keys: `[17, 58, 35, 2, 4, 59, 21, 61]`
Last 8 keys: `[43, 47, 41, 27, 53, 6, 51, 8]`

Layer `8`
First 8 keys: `[16, 54, 14, 18, 32, 3, 37, 6]`
Last 8 keys: `[9, 25, 33, 4, 47, 62, 19, 11]`

Layer `9`
First 8 keys: `[5, 8, 6, 4, 28, 14, 7, 20]`
Last 8 keys: `[61, 10, 0, 60, 35, 58, 56, 37]`

Layer `10`
First 8 keys: `[56, 43, 11, 59, 22, 60, 13, 28]`
Last 8 keys: `[53, 63, 57, 33, 1, 31, 10, 54]`

Layer `11`
First 8 keys: `[47, 23, 27, 51, 54, 33, 52, 25]`
Last 8 keys: `[55, 29, 49, 14, 53, 19, 17, 59]`

Layer `12`
First 8 keys: `[43, 55, 59, 38, 31, 58, 47, 44]`
Last 8 keys: `[18, 19, 27, 61, 45, 3, 37, 0]`

Layer `13`
First 8 keys: `[2, 32, 5, 20, 25, 22, 55, 61]`
Last 8 keys: `[56, 17, 40, 1, 48, 52, 21, 36]`

Layer `14`
First 8 keys: `[9, 58, 6, 4, 24, 52, 11, 17]`
Last 8 keys: `[41, 3, 26, 13, 43, 25, 27, 55]`

Layer `15`
First 8 keys: `[17, 1, 34, 44, 50, 45, 30, 54]`
Last 8 keys: `[18, 15, 37, 26, 7, 38, 21, 51]`



In [13]:
# Zero out a single expert
# zero_expert(model, layer_idx=0, expert_idx=0)

# # Zero out multiple experts across different layers
expert_config = {
    1: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63],
    2: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63],
    3: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63],
    4: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63],
    5: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63],
    6: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63],
    7: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63],
    8: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63],
    9: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63],
    10: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63],
    11: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63],
    12: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63],
    13: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63],
    14: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63],
    15: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]
    # 1: [47, 18, 61, 25],
    # 2: [60, 10, 61, 45],
    # 3: [35, 9, 6, 62],
    # 4: [17, 21, 6, 27],
    # 5: [0, 61, 31, 57],
    # 6: [57, 62, 18, 36],
    # 7: [17, 58, 35, 2],
    # 8: [16, 54, 14, 18],
    # 9: [5, 8, 6, 4],
    # 10: [56, 43, 11, 59],
    # 11: [47, 23, 27, 51],
    # 12: [43, 55, 59, 38],
    # 13: [2, 32, 5, 20],
    # 14: [9, 58, 6, 4],
    # 15: [17, 1, 34, 44]
}
results = zero_multiple_experts(model, expert_config)

Zeroed out expert 0 in layer 1
Zeroed out expert 1 in layer 1
Zeroed out expert 2 in layer 1
Zeroed out expert 3 in layer 1
Zeroed out expert 4 in layer 1
Zeroed out expert 5 in layer 1
Zeroed out expert 6 in layer 1
Zeroed out expert 7 in layer 1
Zeroed out expert 8 in layer 1
Zeroed out expert 9 in layer 1
Zeroed out expert 10 in layer 1
Zeroed out expert 11 in layer 1
Zeroed out expert 12 in layer 1
Zeroed out expert 13 in layer 1
Zeroed out expert 14 in layer 1
Zeroed out expert 15 in layer 1
Zeroed out expert 16 in layer 1
Zeroed out expert 17 in layer 1
Zeroed out expert 18 in layer 1
Zeroed out expert 19 in layer 1
Zeroed out expert 20 in layer 1
Zeroed out expert 21 in layer 1
Zeroed out expert 22 in layer 1
Zeroed out expert 23 in layer 1
Zeroed out expert 24 in layer 1
Zeroed out expert 25 in layer 1
Zeroed out expert 26 in layer 1
Zeroed out expert 27 in layer 1
Zeroed out expert 28 in layer 1
Zeroed out expert 29 in layer 1
Zeroed out expert 30 in layer 1
Zeroed out expert 

In [14]:
# post zero out

print_expert_weights(model, layer_idx=15, expert_idx=0)
print_expert_weights(model, layer_idx=1, expert_idx=47)

print("\n" + "="*80 + "\n")  


Gate Projection:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
torch.Size([1024, 2048])

Up Projection:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
torch.Size([1024, 2048])

Down Projection:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])
torch.Size([2048, 1024])

Gate Projection:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 

In [15]:
# Convert the prompt to inputs and run a forward pass
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
# Generate output (since it's a causal LM, we need to generate text from input)
outputs = model.generate(
    inputs['input_ids'],  # Only provide input_ids to generate
    attention_mask=inputs['attention_mask'],  # Add attention mask to not attend to padding tokens
    max_new_tokens=156,    # Generate 100 new tokens
    temperature=0.6,       # Control randomness
    do_sample=True,        # Use sampling instead of greedy decoding
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.eos_token_id  # Set padding token
)

# Decode the generated output
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

# Print the original prompt and generated response
print("Prompt:", prompt)
print("\nGenerated response (post zero out):", generated_text)

Prompt:     
Continue this text in a natural and coherent way, maintaining consistency with the style, 
terminology, and logical flow of the preceding text.
\title{Quantum Error Mitigation in NISQ Devices}
\begin{abstract}
We present a novel approach to error mitigation in noisy intermediate-scale quantum (NISQ) devices. 
Our method introduces a scaling framework for quantum channels that preserves gate fidelity while reducing environmental noise.
\end{abstract}
\section{Introduction}
Recent advances in NISQ devices have demonstrated both promise and limitations in quantum computation. 
The primary challenge remains decoherence, which introduces errors in quantum operations. We propose a channel scaling approach 
$\mathcal{N}(\rho) = e^{-\lambda t}\rho$ 
that provides a systematic way to


Generated response (post zero out):     
Continue this text in a natural and coherent way, maintaining consistency with the style, 
terminology, and logical flow of the preceding text.
\title{Quantum