In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import json
import os
import pandas as pd
from collections import defaultdict
import plotly.graph_objects as go

In [3]:
def load_model(model_name="allenai/OLMoE-1B-7B-0924"):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
def swap_experts(model, expert_idx, target_layer_idx, source_layer_idx=0, source_expert_idx=0):
    """
    Swap experts between two layers in the OLMoE model.
    
    Args:
        model: The OLMoE model
        expert_idx: Index of the expert in target layer to swap with
        target_layer_idx: Index of the layer containing the expert to swap with
        source_layer_idx: Index of the source layer (default 0)
        source_expert_idx: Index of the source expert (default 0)

    """
    # Access the decoder layers
    decoder_layers = model.model.layers
    print(decoder_layers[0].mlp.experts[0].gate_proj.weight.shape)
    
    # Verify indices are valid
    num_layers = len(decoder_layers)
    if target_layer_idx >= num_layers or source_layer_idx >= num_layers:
        raise ValueError(f"Layer index out of range. Model has {num_layers} layers.")
    
    # Get the MoE blocks from both layers
    source_moe = decoder_layers[source_layer_idx].mlp
    target_moe = decoder_layers[target_layer_idx].mlp
    
    # Verify expert indices are valid
    num_experts = len(source_moe.experts)
    if expert_idx >= num_experts or source_expert_idx >= num_experts:
        raise ValueError(f"Expert index out of range. Each layer has {num_experts} experts.")
        
    # Swap the expert weights
    source_expert = source_moe.experts[source_expert_idx]
    target_expert = target_moe.experts[expert_idx]
    
    # Swap gate projection weights
    source_expert.gate_proj.weight, target_expert.gate_proj.weight = \
        target_expert.gate_proj.weight, source_expert.gate_proj.weight
        
    # Swap up projection weights
    source_expert.up_proj.weight, target_expert.up_proj.weight = \
        target_expert.up_proj.weight, source_expert.up_proj.weight
        
    # Swap down projection weights  
    source_expert.down_proj.weight, target_expert.down_proj.weight = \
        target_expert.down_proj.weight, source_expert.down_proj.weight
    
    return {
        'swapped_experts': {
            'source': {
                'layer': source_layer_idx,
                'expert': source_expert_idx
            },
            'target': {
                'layer': target_layer_idx,
                'expert': expert_idx
            }
        }
    }

In [9]:
swap_experts(model, expert_idx=30, target_layer_idx=0, source_layer_idx=0, source_expert_idx=0)

{'swapped_experts': {'source': {'layer': 0, 'expert': 0},
  'target': {'layer': 0, 'expert': 57}}}

In [None]:
# Test the model with a prompt
prompt = """\title{On the Convergence Properties of Gradient Descent in Deep Neural Networks}
\abstract{We prove that for neural networks with ReLU activation functions and width at least 4n, where n is the input dimension, gradient descent converges to the global minimum in O(log(1/ϵ)) iterations with probability 0.997. Our proof relies on three key lemmas showing: (1) linear separability of hidden layer features, (2) strict monotonicity of the loss function outside a compact set, and (3) existence of descent directions in the separating hyperplane. The optimal learning rate is shown to be η = 0.01.}
\keywords{deep learning, optimization theory, gradient descent, convergence analysis}"""
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

outputs = model.generate(
    **inputs,
    max_new_tokens=50,
    temperature=0.7,
    do_sample=True,
    pad_token_id=tokenizer.eos_token_id
)

print("Prompt:", prompt)
print("\nGenerated response:", tokenizer.decode(outputs[0], skip_special_tokens=True))


In [7]:
# for key in model.state_dict().keys():
#     print(key)