In [1]:
import os

os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ["PYTORCH_TRANSFORMERS_SDP_BACKEND"] = "flash"

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import json
import pandas as pd
from collections import defaultdict
import plotly.graph_objects as go

In [3]:
def load_model(model_name="allenai/OLMoE-1B-7B-0924"):
    # Load the model with device mapping
    # device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        device_map="cpu",  # Automatically use CUDA if available
        # torch_dtype=torch.float16,  # Use half precision for better CUDA memory usage
        offload_folder="./offload",  # Temporary storage for offloaded layers
    )

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
# Test the model with a prompt
prompt = (""" 
Continue the poem naturally and coherently, maintaining consistency with the rhyme scheme, diction and imagery. Match the poem's tone and style precisely.

we measure rainfall in memories now
count droplets like endangered species
my grandmother's garden is underwater
but the roses still bloom, phosphorescent
in depths where submarines chart
the coordinates of lost cities, while above                  
""")

# Convert the prompt to inputs and run a forward pass
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
# Generate output (since it's a causal LM, we need to generate text from input)
outputs = model.generate(
    inputs['input_ids'],  # Only provide input_ids to generate
    attention_mask=inputs['attention_mask'],  # Add attention mask to not attend to padding tokens
    max_new_tokens=156,    # Generate 1024 new tokens
    temperature=0.6,       # Control randomness
    # top_k=100,  # Use top-k sampling
    do_sample=True,        # Use sampling instead of greedy decoding
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.eos_token_id  # Set padding token
)

# Decode the generated output
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

# Print the original prompt and generated response
print("Prompt:", prompt)
print("\nGenerated response :", generated_text)

In [5]:
def swap_experts(model, expert_idx, target_layer_idx, source_layer_idx=0, source_expert_idx=0):
    """
    Swap experts between two layers in the OLMoE model.
    
    Args:
        model: The OLMoE model
        expert_idx: Index of the expert in target layer to swap with
        target_layer_idx: Index of the layer containing the expert to swap with
        source_layer_idx: Index of the source layer (default 0)
        source_expert_idx: Index of the source expert (default 0)

    """
    # Access the decoder layers
    decoder_layers = model.model.layers
    print(decoder_layers[0].mlp.experts[0].gate_proj.weight.shape)
    
    # Verify indices are valid
    num_layers = len(decoder_layers)
    if target_layer_idx >= num_layers or source_layer_idx >= num_layers:
        raise ValueError(f"Layer index out of range. Model has {num_layers} layers.")
    
    # Get the MoE blocks from both layers
    source_moe = decoder_layers[source_layer_idx].mlp
    target_moe = decoder_layers[target_layer_idx].mlp
    
    # Verify expert indices are valid
    num_experts = len(source_moe.experts)
    if expert_idx >= num_experts or source_expert_idx >= num_experts:
        raise ValueError(f"Expert index out of range. Each layer has {num_experts} experts.")
        
    # Swap the expert weights
    source_expert = source_moe.experts[source_expert_idx]
    target_expert = target_moe.experts[expert_idx]
    
    # Swap gate projection weights
    source_expert.gate_proj.weight, target_expert.gate_proj.weight = \
        target_expert.gate_proj.weight, source_expert.gate_proj.weight
        
    # Swap up projection weights
    source_expert.up_proj.weight, target_expert.up_proj.weight = \
        target_expert.up_proj.weight, source_expert.up_proj.weight
        
    # Swap down projection weights  
    source_expert.down_proj.weight, target_expert.down_proj.weight = \
        target_expert.down_proj.weight, source_expert.down_proj.weight
    
    return {
        'swapped_experts': {
            'source': {
                'layer': source_layer_idx,
                'expert': source_expert_idx
            },
            'target': {
                'layer': target_layer_idx,
                'expert': expert_idx
            }
        }
    }

In [None]:
swap_experts(model, expert_idx=34, target_layer_idx=0, source_layer_idx=0, source_expert_idx=49)

In [None]:
# Test the model with a prompt
prompt = ("""
\title{Quantum Error Mitigation in NISQ Devices}
\begin{abstract}
We present a novel approach to error mitigation in noisy intermediate-scale quantum (NISQ) devices. 
Our method introduces a scaling framework for quantum channels that preserves gate fidelity while reducing environmental noise.
\end{abstract}
\section{Introduction}
Recent advances in NISQ devices have demonstrated both promise and limitations in quantum computation. 
The primary challenge remains decoherence, which introduces errors in quantum operations. We propose a channel scaling approach 
$\mathcal{N}(\rho) = e^{-\lambda t}\rho$ 
that provides a systematic way to ...
"""
)

# Convert the prompt to inputs and run a forward pass
inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(model.device)
# Generate output (since it's a causal LM, we need to generate text from input)
outputs = model.generate(
    inputs['input_ids'],  # Only provide input_ids to generate
    attention_mask=inputs['attention_mask'],  # Add attention mask to not attend to padding tokens
    max_new_tokens=100,    # Generate 100 new tokens
    temperature=0.7,       # Control randomness
    do_sample=True,        # Use sampling instead of greedy decoding
    pad_token_id=tokenizer.eos_token_id  # Set padding token
)

# Decode the generated output
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

# Print the original prompt and generated response
print("Prompt:", prompt)
print("\nGenerated response:", generated_text)

In [7]:
# for key in model.state_dict().keys():
#     print(key)