<a href="https://colab.research.google.com/github/amanzoni1/MoE-Burst-Upcycling/blob/main/MoE_Burst_Upcycling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q -U transformers peft datasets accelerate bitsandbytes wandb evaluate

In [None]:
#@title Clear temp files { display-mode: "form" }
import shutil
import os
from pathlib import Path

# Clear old local Hugging Face cache
cache_dir = Path.home() / ".cache" / "huggingface"
if cache_dir.exists():
    shutil.rmtree(cache_dir, ignore_errors=True)
    print("✅ Local HF cache cleared.")

# Clear general temp files
shutil.rmtree("/tmp", ignore_errors=True)
print("✅ Local temp files cleared.")

✅ Local temp files cleared.


In [None]:
import os
import json
from typing import Dict, Optional, Union, List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from dataclasses import dataclass
from transformers.trainer_utils import EvalPrediction
from evaluate import load

ImportError: cannot import name 'PreTrainedModel' from 'transformers' (/usr/local/lib/python3.12/dist-packages/transformers/__init__.py)

In [None]:
@dataclass
class SmokeTestConfig:
    """Minimal config for smoke test."""
    model_name: str = "deepseek-ai/deepseek-moe-16b-base"
    load_in_4bit: bool = True
    lora_rank_general: int = 16
    lora_rank_math: int = 128
    lora_alpha_general: int = 32
    lora_alpha_math: int = 256
    lora_dropout: float = 0.05
    burst_weight: float = 1.5
    affinity_coef: float = 0.05
    math_ratio: float = 0.5
    num_experts: int = 64
    math_expert_ids: list = None

    def __post_init__(self):
        if self.math_expert_ids is None:
            self.math_expert_ids = [60, 61, 62, 63]

config = SmokeTestConfig()

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Model persistence path on Drive
DRIVE_SAVE_ROOT = "/content/drive/MyDrive/DeepSeek_Model"
MODEL_SAVE_PATH = os.path.join(DRIVE_SAVE_ROOT, config.model_name.split('/')[-1] + "_4bit")

os.makedirs(MODEL_SAVE_PATH, exist_ok=True)
print(f"Model persistence path set to: {MODEL_SAVE_PATH}")

Model persistence path set to: /content/drive/MyDrive/DeepSeek_Model/deepseek-moe-16b-base_4bit


In [None]:
# BITS AND BYTES CONFIGURATION
bnb_config = BitsAndBytesConfig(
    load_in_4bit=config.load_in_4bit,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

if os.path.exists(os.path.join(MODEL_SAVE_PATH, "config.json")):
    print("\nPERSISTENT CACHE FOUND. Loading from Google Drive...")

    # Load from Drive
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_SAVE_PATH,
        quantization_config=bnb_config,
        device_map="auto",
        dtype=torch.bfloat16,
        trust_remote_code=True,
        resume_download=True
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_SAVE_PATH)

else:
    print(f"NO CACHE. Starting initial download of {config.model_name}...")

    # Download from HF
    model = AutoModelForCausalLM.from_pretrained(
        config.model_name,
        quantization_config=bnb_config,
        device_map="auto",
        dtype=torch.bfloat16,
        trust_remote_code=True,
        resume_download=True
    )
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)

    # Save to Drive for next times
    print(f"⏳ Initial load complete. Saving quantized model to Drive at: {MODEL_SAVE_PATH}")
    model.save_pretrained(MODEL_SAVE_PATH)
    tokenizer.save_pretrained(MODEL_SAVE_PATH)
    print("✅ Model successfully saved to persistent Drive cache.")

tokenizer.pad_token = tokenizer.eos_token

print(f"✅ Model loaded: {model.config.model_type}, Total Parameters: {model.num_parameters():,}")
print(f"   Model dtype: {model.dtype}")


PERSISTENT CACHE FOUND. Loading from Google Drive...




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✅ Model loaded: deepseek, Total Parameters: 16,375,728,128
   Model dtype: torch.bfloat16


In [None]:
# Load Quant Model + Train B5 Only + Save to Drive
# Colab: Drive mounted. A100 40GB+ (~32GB model + 4GB LoRA).
# Run: ~1-2h for B5; saves LoRA checkpoint to Drive for later evals/reloads.

import os
import torch
from dataclasses import dataclass, field
from peft import LoraConfig, get_peft_model, PeftModel
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling, BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

@dataclass
class Config:
    model_name: str = "deepseek-ai/deepseek-moe-16b-base"
    load_in_4bit: bool = True
    lora_rank_general: int = 16
    lora_rank_math: int = 128  # For B5
    lora_alpha_general: int = 32
    lora_alpha_math: int = 256
    lora_dropout: float = 0.05
    num_experts: int = 64
    math_expert_ids: list = field(default_factory=lambda: [62, 43, 45, 59])  # FIXED: default_factory for list

config = Config()

# Drive Path
DRIVE_SAVE_ROOT = "/content/drive/MyDrive/DeepSeek_Model"
QUANT_MODEL_PATH = os.path.join(DRIVE_SAVE_ROOT, config.model_name.split('/')[-1] + "_4bit")
B5_SAVE_PATH = os.path.join(DRIVE_SAVE_ROOT, "B5_checkpoint")  # B5 LoRA save

os.makedirs(B5_SAVE_PATH, exist_ok=True)
print(f"Quant model path: {QUANT_MODEL_PATH}")
print(f"B5 save path: {B5_SAVE_PATH}")

# BNB Config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=config.load_in_4bit,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

# Load Quant Model from Drive/HF
if os.path.exists(os.path.join(QUANT_MODEL_PATH, "config.json")):
    print("\nPERSISTENT CACHE FOUND. Loading from Drive...")
    model = AutoModelForCausalLM.from_pretrained(
        QUANT_MODEL_PATH,
        quantization_config=bnb_config,
        device_map="auto",
        dtype=torch.bfloat16,
        trust_remote_code=True,
        resume_download=True
    )
    tokenizer = AutoTokenizer.from_pretrained(QUANT_MODEL_PATH)
else:
    print(f"NO CACHE. Downloading {config.model_name}...")
    model = AutoModelForCausalLM.from_pretrained(
        config.model_name,
        quantization_config=bnb_config,
        device_map="auto",
        dtype=torch.bfloat16,
        trust_remote_code=True,
        resume_download=True
    )
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    # Save to Drive
    print(f"Saving quantized model to Drive: {QUANT_MODEL_PATH}")
    model.save_pretrained(QUANT_MODEL_PATH)
    tokenizer.save_pretrained(QUANT_MODEL_PATH)
    print("✅ Quant model saved to Drive.")

tokenizer.pad_token = tokenizer.eos_token
print(f"✅ Model loaded: {model.config.model_type}, Params: {model.num_parameters():,}")

# Data (500 math for B5 burst)
gsm8k = load_dataset("gsm8k", "main", split="train[:500]")

def format_math(ex):
    return {"text": f"Question: {ex['question']}\nAnswer: {ex['answer']}<|endoftext|>"}

math_data = gsm8k.map(format_math)

def tokenize(ex):
    tokenized = tokenizer(ex["text"], truncation=True, max_length=256, padding="max_length")
    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized

math_tokenized = math_data.map(tokenize, batched=False, remove_columns=math_data.column_names)

collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

# B5: r=128 uniform, 100% math burst
if isinstance(model, PeftModel):
    model = model.unload()
torch.cuda.empty_cache()

lora_config_b5 = LoraConfig(
    r=config.lora_rank_math,
    lora_alpha=config.lora_alpha_math,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=config.lora_dropout,
    bias="none",
    task_type="CAUSAL_LM"
)

model_b5 = get_peft_model(model, lora_config_b5)
model_b5.enable_input_require_grads()
model_b5.gradient_checkpointing_enable()
model_b5.config.use_cache = False
model_b5.print_trainable_parameters()

smoke_args = TrainingArguments(
    output_dir=B5_SAVE_PATH,  # Drive for B5
    num_train_epochs=2,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    bf16=True,
    logging_steps=10,
    eval_strategy="no",
    dataloader_drop_last=True,
    gradient_checkpointing=False,  # VRAM safe
    dataloader_pin_memory=False,
    optim="adamw_torch",
    report_to="none",
    remove_unused_columns=False,
    save_steps=50,  # Save mid-run
    save_total_limit=2
)

trainer_b5 = Trainer(
    model=model_b5,
    args=smoke_args,
    train_dataset=math_tokenized,
    data_collator=collator,
)
trainer_b5.train()

print("✅ B5 Trained & Saved to Drive. Final Loss:", trainer_b5.state.log_history[-1].get('train_loss', 'N/A'))

# FIXED: Save tokenizer to B5 dir
tokenizer.save_pretrained(B5_SAVE_PATH)
print(f"✅ B5 checkpoint saved to Drive: {B5_SAVE_PATH}")

# Quick Affinity (In-Memory)
discovered_ids = [62, 43, 45, 59]
math_prompts_check = [f"Question: {ex['question']}\nLet's think step by step." for ex in gsm8k[:50]]

def check_affinity(model, tokenizer, texts, ids, num_prompts=50):
    router_probs_all = []
    handles = []
    # FIXED: PEFT path: model.base_model.model.model.layers
    base = model.base_model.model.model
    for layer in base.layers:
        if hasattr(layer.mlp, 'gate'):
            def hook_fn(module, input):
                hidden = input[0].view(-1, input[0].shape[-1])
                logits = F.linear(hidden, module.weight)
                probs = F.softmax(logits, dim=-1)
                router_probs_all.append(probs.mean(dim=0, keepdim=True).cpu())
                return input
            handle = layer.mlp.gate.register_forward_pre_hook(hook_fn)
            handles.append(handle)

    model.eval()
    with torch.no_grad():
        for text in texts[:num_prompts]:
            inputs = tokenizer(text, return_tensors="pt").to(model.device)
            _ = model(**inputs, use_cache=False)

    for h in handles:
        h.remove()

    if router_probs_all:
        stacked = torch.cat(router_probs_all, dim=0)
        avg_aff = stacked.mean(dim=0)
        aff_to_ids = (avg_aff[ids].sum() / len(ids)).item()
        return aff_to_ids
    return 0.0

aff_b5 = check_affinity(model_b5, tokenizer, math_prompts_check, discovered_ids)
print(f"B5 Affinity to IDs: {aff_b5:.1%} (Target >20%)")

del model_b5, trainer_b5
torch.cuda.empty_cache()
print("✅ Cleanup done. Load B5 later with PeftModel.from_pretrained('/content/drive/MyDrive/B5_checkpoint')")

Quant model path: /content/drive/MyDrive/DeepSeek_Model/deepseek-moe-16b-base_4bit
B5 save path: /content/drive/MyDrive/DeepSeek_Model/B5_checkpoint

PERSISTENT CACHE FOUND. Loading from Drive...




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✅ Model loaded: deepseek, Params: 16,375,728,128


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

main/train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

main/test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

trainable params: 2,407,374,848 || all params: 18,783,102,976 || trainable%: 12.8167


Step,Training Loss
10,1.0242
20,0.9471
30,0.92
40,0.6683
50,0.5712
60,0.5824


✅ B5 Trained & Saved to Drive. Final Loss: 0.7704821899533272
✅ B5 checkpoint saved to Drive: /content/drive/MyDrive/DeepSeek_Model/B5_checkpoint


TypeError: string indices must be integers, not 'str'

In [None]:
import math
import numpy as np
from datasets import load_dataset
import torch
import torch.nn.functional as F

# FIXED PPL: Numpy for Labels
def perplexity(model, tokenizer, test_data):
    model.eval()
    total_loss = 0
    n_tokens = 0
    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
    with torch.no_grad():
        for ex in test_data:
            inputs = {"input_ids": torch.tensor([ex["input_ids"]]).to(model.device)}
            outputs = model(**inputs, use_cache=False)
            shift_logits = outputs.logits[..., :-1, :].contiguous()
            shift_labels = torch.tensor([ex["labels"]])[..., 1:].contiguous().to(model.device)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            mask = (shift_labels != -100).sum().item()
            total_loss += loss.item() * mask
            n_tokens += mask
    return math.exp(total_loss / n_tokens) if n_tokens > 0 else float('inf')

# Math PPL
gsm8k_test = load_dataset("gsm8k", "main", split="test[:50]")
def format_math_ppl(ex):
    text = f"Question: {ex['question']}\nAnswer: {ex['answer']}<|endoftext|>"
    tokenized = tokenizer(text, truncation=True, max_length=256, padding="max_length")
    labels = tokenized["input_ids"].copy()
    # FIXED: Numpy for slice assign
    labels_np = np.array(labels)
    prompt_len = len(tokenizer(f"Question: {ex['question']}\n").input_ids)
    labels_np[:prompt_len] = -100
    labels = labels_np.tolist()
    return {"input_ids": tokenized["input_ids"], "labels": labels}

gsm8k_test = gsm8k_test.map(format_math_ppl, batched=False, remove_columns=gsm8k_test.column_names)
ppl_b5_math = perplexity(model_b5, tokenizer, gsm8k_test)
print(f"B5 GSM8K PPL: {ppl_b5_math:.2f} (Expected ~5-10; lower better)")

# General PPL (HellaSwag)
hellaswag_test = load_dataset("hellaswag", split="validation[:50]")
def format_hellaswag_ppl(ex):
    label_index = int(ex['label'])
    text = f"Premise: {ex['ctx_a']} {ex['ctx_b']} Sentence: {ex['endings'][label_index]}<|endoftext|>"
    tokenized = tokenizer(text, truncation=True, max_length=256, padding="max_length")
    labels = tokenized["input_ids"].copy()
    # FIXED: Numpy
    labels_np = np.array(labels)
    premise_len = len(tokenizer(f"Premise: {ex['ctx_a']} {ex['ctx_b']} ").input_ids)
    labels_np[:premise_len] = -100
    labels = labels_np.tolist()
    return {"input_ids": tokenized["input_ids"], "labels": labels}

hellaswag_test = hellaswag_test.map(format_hellaswag_ppl, batched=False, remove_columns=hellaswag_test.column_names)
ppl_b5_hella = perplexity(model_b5, tokenizer, hellaswag_test)
print(f"B5 HellaSwag PPL: {ppl_b5_hella:.2f} (Expected ~6-12; lower better)")

# Summary
print(f"B5 Results: Affinity 6.5%, Math PPL {ppl_b5_math:.2f}, General PPL {ppl_b5_hella:.2f}")

del model_b5
torch.cuda.empty_cache()

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

B5 GSM8K PPL: 9.97 (Expected ~5-10; lower better)


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/24.4M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/6.11M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/6.32M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/39905 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10003 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10042 [00:00<?, ? examples/s]

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

B5 HellaSwag PPL: 903.95 (Expected ~6-12; lower better)
B5 Results: Affinity 6.5%, Math PPL 9.97, General PPL 903.95


In [None]:
# Search for 'mlp.experts' or 'gate_proj' for MoE/router
moe_paths = [name for name, _ in model.named_modules() if 'mlp.experts' in name or 'gate_proj' in name]
print("\nMoE Paths (first 5):", moe_paths[:5])  # Expect mlp.experts.0.gate_proj etc.
gate_paths = [name for name, _ in model.named_modules() if 'gate_proj' in name]
print("Gate Paths (sample):", gate_paths[:3])  # Expect mlp.experts.X.gate_proj

✅ Model loaded: deepseek, Total Parameters: 16,375,728,128
   Model dtype: torch.bfloat16

MoE Paths (first 5): ['model.layers.0.mlp.gate_proj', 'model.layers.1.mlp.experts', 'model.layers.1.mlp.experts.0', 'model.layers.1.mlp.experts.0.gate_proj', 'model.layers.1.mlp.experts.0.up_proj']
Gate Paths (sample): ['model.layers.0.mlp.gate_proj', 'model.layers.1.mlp.experts.0.gate_proj', 'model.layers.1.mlp.experts.1.gate_proj']


In [None]:
import inspect  # FIXED: Import for source

layer0 = model.model.layers[0]
print("Layers:", len(model.model.layers))  # 28

print("\nDeepseekMLP Attributes (dir):")
print([attr for attr in dir(layer0.mlp) if not attr.startswith('_')])

print("\nMLP Source (inspect snippet):")
print(inspect.getsource(layer0.mlp.__class__.forward)[:500] + "...")  # Forward logic (routed creation)

print("\nLayer 0 Named Children:")
for name, child in layer0.named_children():
    print(f"  - {name}: {child.__class__.__name__} (params: {sum(p.numel() for p in child.parameters()):,})")

# Simulate forward to instantiate routed (dummy hidden state)
dummy_input = torch.zeros(1, 1, 2048).to(model.device)  # hidden_size=2048
with torch.no_grad():
    _ = layer0.mlp(dummy_input)

print("\nAfter Forward - Routed Experts:")
if hasattr(layer0.mlp, 'routed_experts'):
    print("Number:", len(layer0.mlp.routed_experts))  # 64
    first_routed = layer0.mlp.routed_experts[0]
    print("First Routed Components (dir):", [attr for attr in dir(first_routed) if not attr.startswith('_')])
    print("First Routed Params:", sum(p.numel() for p in first_routed.parameters()))
    last_routed = layer0.mlp.routed_experts[-1]
    print("Last Routed Components (dir):", [attr for attr in dir(last_routed) if not attr.startswith('_')])
else:
    print("Routed experts dynamic—no attr post-forward. Check source for topk route.")

print("\nShared Experts:")
if hasattr(layer0.mlp, 'shared_experts'):
    print("Number:", len(layer0.mlp.shared_experts))  # 2
    print("First Shared Components (dir):", [attr for attr in dir(layer0.mlp.shared_experts[0]) if not attr.startswith('_')])
else:
    print("Shared experts dynamic in forward.")

print("\nMLP Param Counts:")
print(f"Total MLP Params: {sum(p.numel() for p in layer0.mlp.parameters()):,}")

# FIXED MoE Sample Names (Broader Search)
moe_sample = [name for name, _ in model.named_modules() if 'routed_experts' in name or 'shared_experts' in name or 'gate_proj' in name][:10]
print("\nMoE Sample Names:", moe_sample)

# Full Layer 0 Breakdown (Recursive, Depth 3)
def print_structure(module, indent=0, max_depth=3):
    if indent > max_depth:
        return
    print("  " * indent + f"{module.__class__.__name__} (params: {sum(p.numel() for p in module.parameters()):,})")
    for name, child in module.named_children():
        print_structure(child, indent + 1, max_depth)

print("\nLayer 0 Full Structure (Depth 3):")
print_structure(layer0, 0, 3)

Layers: 28

DeepseekMLP Attributes (dir):
['T_destination', 'act_fn', 'add_module', 'apply', 'bfloat16', 'buffers', 'call_super_init', 'children', 'compile', 'config', 'cpu', 'cuda', 'double', 'down_proj', 'dump_patches', 'eval', 'extra_repr', 'float', 'forward', 'gate_proj', 'get_buffer', 'get_extra_state', 'get_parameter', 'get_submodule', 'half', 'hidden_size', 'intermediate_size', 'ipu', 'load_state_dict', 'modules', 'mtia', 'named_buffers', 'named_children', 'named_modules', 'named_parameters', 'parameters', 'register_backward_hook', 'register_buffer', 'register_forward_hook', 'register_forward_pre_hook', 'register_full_backward_hook', 'register_full_backward_pre_hook', 'register_load_state_dict_post_hook', 'register_load_state_dict_pre_hook', 'register_module', 'register_parameter', 'register_state_dict_post_hook', 'register_state_dict_pre_hook', 'requires_grad_', 'set_extra_state', 'set_submodule', 'share_memory', 'smart_apply', 'state_dict', 'to', 'to_empty', 'train', 'training

In [None]:
import torch
import torch.nn.functional as F
from datasets import load_dataset
from dataclasses import dataclass

# Configuration placeholder
class TempConfig:
    num_experts = 64
    math_expert_ids = [0, 1, 2, 3]

router_probs_all = []
config = TempConfig()

# --- 1. The Debug Hook (Modified to run on MLP Output) ---
def final_router_debug_hook_on_mlp(module, input, output):
    """Prints the full structure of the MLP output to identify the logits tensor."""
    global router_probs_all # Still need this global for the final working version

    print(f"\n--- DEBUG HOOK ON {module._get_name()} FIRED ---")

    # 🚨 CRITICAL: Print the output structure causing the error
    if isinstance(output, tuple):
        print(f"  Output: Tuple with {len(output)} elements.")
        for idx, item in enumerate(output):
            item_type = type(item)
            shape_info = f", Shape={item.shape}" if isinstance(item, torch.Tensor) else ""
            print(f"    Output[{idx}]: Type={item_type}{shape_info}")

            # Check for the expected Logits Shape [..., 64]
            # Logits can be at index 0, 1, or 2, but we need the tensor with 64 experts.
            if isinstance(item, torch.Tensor) and item.dim() >= 2 and item.shape[-1] == config.num_experts:
                 print(f"    ^^^^^ 🚨 LOGITS TENSOR FOUND AT INDEX {idx} 🚨 ^^^^^")
                 # We will NOT process it here to prevent the crash.
                 # Just printing is enough to give us the index.

    elif isinstance(output, torch.Tensor):
        print(f"  Output: Single Tensor. Shape={output.shape}")
        if output.dim() >= 2 and output.shape[-1] == config.num_experts:
             print(f"    ^^^^^ 🚨 LOGITS TENSOR FOUND (Single Output) 🚨 ^^^^^")
    else:
        print(f"  Output: Unexpected non-tensor/non-tuple type: {type(output)}")

    print("--- END DEBUG HOOK ---")

    # Return the original output immediately to minimize interference.
    return output


# --- 2. Print and Find Logits Function (Using MLP Hook) ---
def print_and_find_logits_mlp(model, tokenizer, config, math_samples=10, batch_size=2):

    global router_probs_all
    router_probs_all = []

    # --- Data Loading (Required to run model) ---
    try:
        gsm8k = load_dataset("gsm8k", "main", split=f"train[:{math_samples}]")
        math_texts = [f"Question: {ex['question']}" for ex in gsm8k]
    except Exception as e:
        print(f"🚨 Error loading dataset: {e}")
        return config.math_expert_ids

    model.eval()
    device = model.device

    # --- Hooking Setup ---
    handles = []

    # Target the entire MLP/MoE module: model.layers.X.mlp
    for layer in model.model.layers:
        if hasattr(layer, 'mlp'):
            # Hook the output of the entire MLP module
            handle = layer.mlp.register_forward_hook(final_router_debug_hook_on_mlp)
            handles.append(handle)

    if not handles:
        print("🚨 CRITICAL FAILURE: Cannot find 'layer.mlp' module.")
        return config.math_expert_ids

    # --- Forward Pass (Triggers Hooks) ---
    print(f"\n--- Running Forward Pass (Batch size {batch_size}, Total Layers: {len(handles)}) ---")
    with torch.no_grad():
        batch_texts = math_texts[:batch_size]
        batch_inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)

        # We only need one small pass to trigger the hooks and print the output
        try:
            _ = model(**batch_inputs, use_cache=False)
        except Exception as e:
            # Catch the error here to ensure all hook prints have time to execute before the crash
            print(f"\n--- Model Forward Pass Crashed ---")
            print(f"Error: {e}")
            print("--- Look Above for Logits Index Printout! ---")
            # Proceed to remove hooks and return, relying on the prints that ran before the crash.


    # Remove hooks
    for h in handles:
        h.remove()

    print("\n✅ Debug scan complete. If the model crashed, the critical information (Logits Index) should be printed above the crash message.")
    return config.math_expert_ids

# --- Execution ---
config = TempConfig()
config.math_expert_ids = print_and_find_logits_mlp(model, tokenizer, config)
print(f"Updated config.math_expert_ids: {config.math_expert_ids}")


--- Running Forward Pass (Batch size 2, Total Layers: 28) ---
🚨 Router Logits not found in MoE output tuple.

--- DEBUG HOOK ON DeepseekMLP FIRED ---
  Output: Single Tensor. Shape=torch.Size([2, 41, 2048])
--- END DEBUG HOOK ---

--- Model Forward Pass Crashed ---
Error: 'tuple' object has no attribute 'softmax'
--- Look Above for Logits Index Printout! ---

✅ Debug scan complete. If the model crashed, the critical information (Logits Index) should be printed above the crash message.
Updated config.math_expert_ids: [0, 1, 2, 3]


In [None]:
# Single-Cell: Non-Quant Expert Affinity Discovery + Contrast Scan (Math vs. General)
# Colab: A100 40GB+ (~32GB VRAM; 80GB safer). !pip install transformers datasets torch accelerate

import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from dataclasses import dataclass
import math
import re  # For cleaning wiki text

@dataclass
class Config:
    num_experts: int = 64
    math_expert_ids: list = None

    def __post_init__(self):
        if self.math_expert_ids is None:
            self.math_expert_ids = [0, 1, 2, 3]

# Global for accumulation
router_probs_all = []

# Router Pre-Hook (Unchanged)
def router_calculation_hook(module, input):
    global router_probs_all
    hidden_states = input[0]
    if hidden_states.dim() == 3:
        reshaped_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
    else:
        reshaped_hidden_states = hidden_states
    router_weight = module.weight
    raw_logits = F.linear(reshaped_hidden_states, router_weight)
    probs = F.softmax(raw_logits, dim=-1, dtype=torch.float32)
    layer_probs = probs.mean(dim=0, keepdim=True).cpu()
    router_probs_all.append(layer_probs)
    return input

# Affinity Scan Function (Generic: Math or General)
def pre_scan_affinity_nonquant(model, tokenizer, config, texts, scan_type="Math", math_samples=500, k=4, batch_size=8):
    global router_probs_all
    router_probs_all = []

    # Clean texts (for general: strip wiki markup)
    if scan_type == "General":
        texts = [re.sub(r'\[.*?\]', '', text).strip() for text in texts]  # Basic cleanup

    model.eval()
    device = next(model.parameters()).device

    # Hook
    handles = []
    for layer in model.model.layers:
        if hasattr(layer.mlp, 'gate'):
            handle = layer.mlp.gate.register_forward_pre_hook(router_calculation_hook)
            handles.append(handle)

    num_moe_layers = len(handles)
    estimated_batches = math.ceil(len(texts) / batch_size)
    print(f"\n--- {scan_type} Affinity Scan ---")
    print(f"Texts: {len(texts)}, Layers: {num_moe_layers}, Batches: {estimated_batches}")

    # Forward
    with torch.no_grad():
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i + batch_size]
            batch_inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
            _ = model(**batch_inputs, use_cache=False)
            if (i // batch_size + 1) % 20 == 0:
                print(f"  Processed {i // batch_size + 1}/{estimated_batches}...")

    # Cleanup
    for h in handles:
        h.remove()

    if not router_probs_all:
        print(f"🚨 No {scan_type.lower()} data—fallback.")
        return config.math_expert_ids

    # Aggregate
    stacked_probs = torch.cat(router_probs_all, dim=0)
    avg_affinity = stacked_probs.mean(dim=0)
    baseline = 1 / config.num_experts
    relative_affinity = avg_affinity / baseline

    # Output
    print(f"\n✅ {scan_type} discovery complete! Recorded: {len(router_probs_all)} layer calls")
    print(f"Baseline: {baseline:.4f}")
    print(f"\n{scan_type} Expert affinities:")

    expert_data = [{'id': i, 'affinity': avg_affinity[i].item(), 'relative': relative_affinity[i].item()} for i in range(config.num_experts)]
    expert_data.sort(key=lambda x: x['affinity'], reverse=True)

    top_k_results = avg_affinity.topk(k)
    top_k_experts = top_k_results.indices.tolist()

    for data in expert_data[:20]:  # Top 20 for brevity
        is_selected = data['id'] in top_k_experts
        bias_tag = " ← Math-inclined!" if data['relative'] > 1.1 else ""
        prefix = "👉" if is_selected else "  "
        print(f"{prefix} Expert {data['id']:2d}: {data['affinity']:.4f} ({data['relative']:.2f}x) {bias_tag}")

    avg_selected = avg_affinity[top_k_experts].mean().item()
    print(f"\n--- {scan_type} Summary ---")
    print(f"🎯 Top-{k}: {top_k_experts}")
    print(f"Avg selected: {avg_selected:.4f} ({avg_selected / baseline:.2f}x baseline)")

    return top_k_experts, avg_selected / baseline  # Return rel avg for contrast

# --- LOAD NON-QUANT MODEL (bf16, ~32GB) ---
print("⏳ Loading non-quant DeepSeek-MoE-16B (bf16)...")
model_name = "deepseek-ai/deepseek-moe-16b-base"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
print("✅ Model & tokenizer loaded. VRAM: ~32GB (40GB ok; 80GB safer).")

# Config
config = Config()

# --- MATH SCAN ---
print("\n=== MATH SCAN ===")
gsm8k = load_dataset("gsm8k", "main", split=f"train[:500]")
math_texts = [f"Question: {ex['question']}\nLet's think step by step." for ex in gsm8k]
math_top_k, math_rel_avg = pre_scan_affinity_nonquant(model, tokenizer, config, math_texts, scan_type="Math")

# --- CONTRAST SCAN (General: WikiText) ---
print("\n=== GENERAL CONTRAST SCAN ===")
wikitext = load_dataset("wikitext", "wikitext-2-raw-v1", split=f"train[:500]")
general_texts = [ex['text'] for ex in wikitext if ex['text'].strip()]  # Plain sentences
general_top_k, general_rel_avg = pre_scan_affinity_nonquant(model, tokenizer, config, general_texts, scan_type="General")

# --- CONTRAST SUMMARY ---
print("\n=== CONTRAST SUMMARY ===")
print(f"Math Top-4: {math_top_k} | Rel Avg: {math_rel_avg:.2f}x")
print(f"General Top-4: {general_top_k} | Rel Avg: {general_rel_avg:.2f}x")
overlap = len(set(math_top_k) & set(general_top_k))
print(f"Overlap in Top-4: {overlap}/4 (Math-specific if <2)")
if math_rel_avg > general_rel_avg * 1.1:
    print("✅ Math bias confirmed (domain signal strong).")
else:
    print("⚠️ Weak domain signal—recheck prompts/data.")

print(f"Recommended experts for burst: {math_top_k}")

⏳ Loading non-quant DeepSeek-MoE-16B (bf16)...


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

✅ Model & tokenizer loaded. VRAM: ~32GB (40GB ok; 80GB safer).

=== MATH SCAN ===

--- Math Affinity Scan ---
Texts: 500, Layers: 27, Batches: 63
  Processed 20/63...
  Processed 40/63...
  Processed 60/63...

✅ Math discovery complete! Recorded: 1701 layer calls
Baseline: 0.0156

Math Expert affinities:
👉 Expert 62: 0.0325 (2.08x)  ← Math-inclined!
👉 Expert 43: 0.0312 (2.00x)  ← Math-inclined!
👉 Expert 45: 0.0240 (1.54x)  ← Math-inclined!
👉 Expert 59: 0.0237 (1.52x)  ← Math-inclined!
   Expert 10: 0.0228 (1.46x)  ← Math-inclined!
   Expert 15: 0.0227 (1.45x)  ← Math-inclined!
   Expert  2: 0.0227 (1.45x)  ← Math-inclined!
   Expert 16: 0.0221 (1.41x)  ← Math-inclined!
   Expert 17: 0.0211 (1.35x)  ← Math-inclined!
   Expert  3: 0.0207 (1.33x)  ← Math-inclined!
   Expert 24: 0.0201 (1.29x)  ← Math-inclined!
   Expert  1: 0.0195 (1.25x)  ← Math-inclined!
   Expert 37: 0.0185 (1.18x)  ← Math-inclined!
   Expert 36: 0.0184 (1.18x)  ← Math-inclined!
   Expert 49: 0.0182 (1.17x)  ← Math-inc

README.md: 0.00B [00:00, ?B/s]

wikitext-2-raw-v1/test-00000-of-00001.pa(…):   0%|          | 0.00/733k [00:00<?, ?B/s]

wikitext-2-raw-v1/train-00000-of-00001.p(…):   0%|          | 0.00/6.36M [00:00<?, ?B/s]

wikitext-2-raw-v1/validation-00000-of-00(…):   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]


--- General Affinity Scan ---
Texts: 338, Layers: 27, Batches: 43
  Processed 20/43...
  Processed 40/43...

✅ General discovery complete! Recorded: 1161 layer calls
Baseline: 0.0156

General Expert affinities:
👉 Expert 62: 0.0402 (2.57x)  ← Math-inclined!
👉 Expert 43: 0.0358 (2.29x)  ← Math-inclined!
👉 Expert 45: 0.0298 (1.91x)  ← Math-inclined!
👉 Expert 10: 0.0281 (1.80x)  ← Math-inclined!
   Expert 59: 0.0274 (1.75x)  ← Math-inclined!
   Expert 16: 0.0262 (1.68x)  ← Math-inclined!
   Expert 17: 0.0258 (1.65x)  ← Math-inclined!
   Expert 15: 0.0258 (1.65x)  ← Math-inclined!
   Expert  2: 0.0252 (1.61x)  ← Math-inclined!
   Expert  3: 0.0251 (1.61x)  ← Math-inclined!
   Expert 24: 0.0241 (1.54x)  ← Math-inclined!
   Expert 36: 0.0219 (1.40x)  ← Math-inclined!
   Expert 49: 0.0199 (1.28x)  ← Math-inclined!
   Expert 22: 0.0198 (1.27x)  ← Math-inclined!
   Expert 37: 0.0193 (1.23x)  ← Math-inclined!
   Expert  1: 0.0191 (1.22x)  ← Math-inclined!
   Expert 63: 0.0173 (1.11x)  ← Math-inc

In [None]:
hellaswag = load_dataset("hellaswag", split="validation[:500]")
general_texts = [f"Premise: {ex['ctx_a']} {ex['ctx_b']} Sentence: {ex['endings'][0]}" for ex in hellaswag]  # Common-sense, no math
general_top_k, general_rel_avg = pre_scan_affinity_nonquant(model, tokenizer, config, general_texts, scan_type="General")


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/24.4M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/6.11M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/6.32M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/39905 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10003 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10042 [00:00<?, ? examples/s]


--- General Affinity Scan ---
Texts: 500, Layers: 27, Batches: 63
  Processed 20/63...
  Processed 40/63...
  Processed 60/63...

✅ General discovery complete! Recorded: 1701 layer calls
Baseline: 0.0156

General Expert affinities:
👉 Expert 62: 0.0308 (1.97x)  ← Math-inclined!
👉 Expert 43: 0.0297 (1.90x)  ← Math-inclined!
👉 Expert 45: 0.0234 (1.50x)  ← Math-inclined!
👉 Expert  3: 0.0221 (1.41x)  ← Math-inclined!
   Expert 15: 0.0219 (1.40x)  ← Math-inclined!
   Expert 10: 0.0216 (1.38x)  ← Math-inclined!
   Expert 59: 0.0216 (1.38x)  ← Math-inclined!
   Expert 16: 0.0207 (1.32x)  ← Math-inclined!
   Expert 24: 0.0205 (1.31x)  ← Math-inclined!
   Expert 17: 0.0200 (1.28x)  ← Math-inclined!
   Expert  2: 0.0200 (1.28x)  ← Math-inclined!
   Expert 36: 0.0192 (1.23x)  ← Math-inclined!
   Expert  1: 0.0191 (1.22x)  ← Math-inclined!
   Expert 22: 0.0184 (1.18x)  ← Math-inclined!
   Expert 49: 0.0177 (1.13x)  ← Math-inclined!
   Expert 37: 0.0177 (1.13x)  ← Math-inclined!
   Expert 63: 0.017

In [None]:
# --- CONTRAST SUMMARY ---
print("\n=== CONTRAST SUMMARY ===")
print(f"Math Top-4: {math_top_k} | Rel Avg: {math_rel_avg:.2f}x")
print(f"General Top-4: {general_top_k} | Rel Avg: {general_rel_avg:.2f}x")
overlap = len(set(math_top_k) & set(general_top_k))
print(f"Overlap in Top-4: {overlap}/4 (Math-specific if <2)")
if math_rel_avg > general_rel_avg * 1.1:
    print("✅ Math bias confirmed (domain signal strong).")
else:
    print("⚠️ Weak domain signal—recheck prompts/data.")

print(f"Recommended experts for burst: {math_top_k}")


=== CONTRAST SUMMARY ===
Math Top-4: [62, 43, 45, 59] | Rel Avg: 1.78x
General Top-4: [62, 43, 45, 3] | Rel Avg: 1.70x
Overlap in Top-4: 3/4 (Math-specific if <2)
⚠️ Weak domain signal—recheck prompts/data.
Recommended experts for burst: [62, 43, 45, 59]


In [None]:
# Test Router Hook (Toy Forward)
toy_text = "Solve 2+2="  # Simple math token
inputs = tokenizer(toy_text, return_tensors="pt", max_length=20, truncation=True).to(model.device)

class DummyTrainer:
    def __init__(self, config, model):
        self.config = config
        self.router_logits_buffer = []
        self.model = model
        # FIXED: Hook on global mlp.gate_proj + sample experts
        for layer in self.model.model.layers:
            if hasattr(layer, 'mlp') and hasattr(layer.mlp, 'gate_proj'):
                gate = layer.mlp.gate_proj
                def make_hook(idx):
                    def hook(module, input, output):
                        # FIXED: Squeeze to [tokens, 1] (remove hidden dim)
                        squeezed = output.mean(dim=-1).detach().cpu() if output.dim() > 2 else output.detach().cpu()
                        self.router_logits_buffer.append(squeezed)
                    return hook
                gate.register_forward_hook(make_hook(0))
            # Sample 4 experts per layer (mem-safe)
            if hasattr(layer, 'mlp') and hasattr(layer.mlp, 'experts') and len(layer.mlp.experts) > 0:
                for e_id in [0, 1, 62, 63]:
                    if hasattr(layer.mlp.experts[e_id], 'gate_proj'):
                        expert_gate = layer.mlp.experts[e_id].gate_proj
                        expert_gate.register_forward_hook(make_hook(e_id))

dt = DummyTrainer(config, model)
model.eval()
with torch.no_grad():
    _ = model(**inputs, use_cache=False)  # FIXED: Avoid cache error

print(f"✅ Hook fired: {len(dt.router_logits_buffer)} tensors")
if dt.router_logits_buffer:
    sample = dt.router_logits_buffer[0]
    print(f"Shape: {sample.shape} (tokens, 1 gate)")
    print(f"Sample logits: {sample[0][:3]}...")  # Scalar ~ -0.5 to 0.5


✅ Hook fired: 52 tensors
Shape: torch.Size([1, 8]) (tokens, 1 gate)
Sample logits: tensor([ 0.0022,  0.0332, -0.0032], dtype=torch.bfloat16)...


In [None]:
# Fixed Mini B0 + Eval: No use_cache in Args
from peft import LoraConfig, get_peft_model, PeftModel
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import load_dataset

# B0 Config
lora_config_b0 = LoraConfig(
    r=config.lora_rank_general,
    lora_alpha=config.lora_alpha_general,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=config.lora_dropout,
    bias="none",
    task_type="CAUSAL_LM"
)

# Safe unload
if isinstance(model, PeftModel):
    model = model.unload()
model_b0 = get_peft_model(model, lora_config_b0)
model_b0.print_trainable_parameters()

# FIXED: Disable cache
model_b0.config.use_cache = False
model_b0.enable_input_require_grads()
model_b0.gradient_checkpointing_enable()

# Data (fixed padding + early clean)
gsm8k_mini = load_dataset("gsm8k", "main", split="train[:5]")
def format_mini(ex):
    return {"text": f"Question: {ex['question']}\nAnswer: {ex['answer']}<|endoftext|>"}

mini_data = gsm8k_mini.map(format_mini, remove_columns=gsm8k_mini.column_names)

def tokenize_mini(ex):
    tokenized = tokenizer(ex["text"], truncation=True, max_length=256, padding="max_length")
    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized

mini_tokenized = mini_data.map(tokenize_mini, batched=False, remove_columns=mini_data.column_names)

# Split
train_tokenized = mini_tokenized.select(range(3))
eval_tokenized = mini_tokenized.select(range(3,5))

# Args (no use_cache)
mini_args = TrainingArguments(
    output_dir="/tmp/mini_b0",
    max_steps=50,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    learning_rate=2e-4,
    bf16=True,
    logging_steps=5,
    eval_strategy="steps",
    eval_steps=25,
    dataloader_drop_last=True,
    gradient_checkpointing=True,
    dataloader_pin_memory=False,
    optim="adamw_torch",
    report_to="none",
    remove_unused_columns=False
)

collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

trainer_mini = Trainer(
    model=model_b0,
    args=mini_args,
    train_dataset=train_tokenized,
    eval_dataset=eval_tokenized,
    data_collator=collator,
)
trainer_mini.train()

print("✅ B0 Mini done.")
print("Loss history:", [log.get('train_loss', log.get('eval_loss', 'N/A')) for log in trainer_mini.state.log_history])

# Eval
def eval_gsm8k(model, tokenizer, test_data, num_samples=2):
    correct = 0
    model.eval()
    for ex in test_data:
        prompt = f"Question: {ex['question']}\nAnswer:"
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False, temperature=0.1, use_cache=False)
        gen = tokenizer.decode(outputs[0], skip_special_tokens=True)
        answer = ex['answer'].split('####')[0].strip()
        correct += 1 if answer in gen else 0
    return correct / num_samples

acc = eval_gsm8k(model_b0, tokenizer, gsm8k_mini.select(range(3,5)))
print(f"Eval Acc: {acc:.2f}")

trainable params: 300,921,856 || all params: 16,676,649,984 || trainable%: 1.8045


Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Step,Training Loss,Validation Loss
25,0.0153,No log
50,0.0095,No log


✅ B0 Mini done.
Loss history: ['N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 0.24707426130771637]


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:100001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:100001 for open-end generation.


Eval Acc: 0.00


In [None]:
# Smoke Test: B5 First (OOM Test), Then B0
import os
import torch
from peft import LoraConfig, get_peft_model, PeftModel
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import load_dataset, concatenate_datasets

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

discovered_ids = [62, 43, 45, 59]

# Data (256 len)
gsm8k = load_dataset("gsm8k", "main", split="train[:500]")
hellaswag = load_dataset("hellaswag", split="validation[:500]")

def format_math(ex):
    return {"text": f"Question: {ex['question']}\nAnswer: {ex['answer']}<|endoftext|>"}

def format_general(ex):
    return {"text": f"Premise: {ex['ctx_a']} {ex['ctx_b']} Sentence: {ex['endings'][0]}<|endoftext|>"}

math_data = gsm8k.map(format_math)
general_data = hellaswag.map(format_general)

mixed_data = concatenate_datasets([math_data.select(range(250)), general_data.select(range(250))])

def tokenize(ex):
    tokenized = tokenizer(ex["text"], truncation=True, max_length=256, padding="max_length")
    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized

mixed_tokenized = mixed_data.map(tokenize, batched=False, remove_columns=mixed_data.column_names)
math_tokenized = math_data.map(tokenize, batched=False, remove_columns=math_data.column_names)

collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

# B5 Proxy First: r=128, 100% math (OOM test)
if isinstance(model, PeftModel):
    model = model.unload()
torch.cuda.empty_cache()

lora_config_b5 = LoraConfig(r=128, lora_alpha=256, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")
model_b5 = get_peft_model(model, lora_config_b5)
model_b5.enable_input_require_grads()
model_b5.gradient_checkpointing_enable()
model_b5.config.use_cache = False
model_b5.print_trainable_parameters()

smoke_args = TrainingArguments(
    output_dir="/tmp/smoke_b5",
    num_train_epochs=2,
    per_device_train_batch_size=2,  # Low for OOM
    gradient_accumulation_steps=8,  # Effective 16
    learning_rate=2e-4,
    bf16=True,
    logging_steps=10,  # Frequent for monitor
    eval_strategy="no",
    dataloader_drop_last=True,
    gradient_checkpointing=False,  # Off for VRAM
    dataloader_pin_memory=False,
    optim="adamw_torch",
    report_to="none",
    remove_unused_columns=False,
    save_total_limit=1
)

trainer_b5 = Trainer(model=model_b5, args=smoke_args, train_dataset=math_tokenized, data_collator=collator)
trainer_b5.train()
b5_loss = trainer_b5.state.log_history[-1].get('train_loss', 'N/A')
print(f"✅ B5 Done. Final Loss: {b5_loss}")

# Del B5
trainer_b5.save_model("/tmp/b5_ckpt")
del model_b5, trainer_b5
torch.cuda.empty_cache()

# B0: Vanilla r=16, mixed data (after B5)
lora_config_b0 = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")
model_b0 = get_peft_model(model, lora_config_b0)
model_b0.enable_input_require_grads()
model_b0.gradient_checkpointing_enable()
model_b0.config.use_cache = False
model_b0.print_trainable_parameters()

smoke_args.output_dir = "/tmp/smoke_b0"
trainer_b0 = Trainer(model=model_b0, args=smoke_args, train_dataset=mixed_tokenized, data_collator=collator)
trainer_b0.train()
b0_loss = trainer_b0.state.log_history[-1].get('train_loss', 'N/A')
print(f"✅ B0 Done. Final Loss: {b0_loss}")

# Affinity Check (Load saved)
from peft import PeftModel

model_b0_loaded = PeftModel.from_pretrained(model, "/tmp/smoke_b0")
model_b5_loaded = PeftModel.from_pretrained(model, "/tmp/b5_ckpt")

math_prompts = [f"Question: {ex['question']}\nLet's think step by step." for ex in gsm8k[:50]]

def check_affinity(model, tokenizer, texts, ids, num_prompts=50):
    router_probs_all = []
    handles = []
    for layer in model.model.layers:
        if hasattr(layer.mlp, 'gate'):
            def hook_fn(module, input):
                hidden = input[0].view(-1, input[0].shape[-1])
                logits = F.linear(hidden, module.weight)
                probs = F.softmax(logits, dim=-1)
                router_probs_all.append(probs.mean(dim=0, keepdim=True).cpu())
                return input
            handle = layer.mlp.gate.register_forward_pre_hook(hook_fn)
            handles.append(handle)

    model.eval()
    with torch.no_grad():
        for text in texts[:num_prompts]:
            inputs = tokenizer(text, return_tensors="pt").to(model.device)
            _ = model(**inputs, use_cache=False)

    for h in handles:
        h.remove()

    if router_probs_all:
        stacked = torch.cat(router_probs_all, dim=0)
        avg_aff = stacked.mean(dim=0)
        aff_to_ids = (avg_aff[ids].sum() / len(ids)).item()
        return aff_to_ids
    return 0.0

aff_b0 = check_affinity(model_b0_loaded, tokenizer, math_prompts, discovered_ids)
aff_b5 = check_affinity(model_b5_loaded, tokenizer, math_prompts, discovered_ids)

print(f"B0 Affinity: {aff_b0:.1%} | B5: {aff_b5:.1%} (Target B5 > B0 +5%)")

# Clean
del model_b0_loaded, model_b5_loaded
torch.cuda.empty_cache()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

main/train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

main/test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/24.4M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/6.11M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/6.32M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/39905 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10003 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10042 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

trainable params: 2,407,374,848 || all params: 18,783,102,976 || trainable%: 12.8167


Step,Training Loss
10,1.0189
20,0.9486
30,0.9207
40,0.667
50,0.5726
60,0.582


✅ B5 Done. Final Loss: 0.769857294857502




trainable params: 300,921,856 || all params: 16,676,649,984 || trainable%: 1.8045


Step,Training Loss
10,1.8847
20,1.4426
30,1.4208
40,1.363
50,1.2066
60,1.28


✅ B0 Done. Final Loss: 1.4112832099199295


ValueError: Can't find 'adapter_config.json' at '/tmp/smoke_b0'

In [None]:
import math
from datasets import load_dataset
import torch
import torch.nn.functional as F

# --- AFFINITY FUNCTION (Final Path Fix) ---
def check_affinity(model, tokenizer, texts, ids, num_prompts=50):
    router_probs_all = []
    handles = []

    # Get the underlying DeepseekForCausalLM object
    if hasattr(model, 'base_model') and hasattr(model.base_model, 'base_model'):
        hf_model_core = model.base_model.base_model
    else:
        hf_model_core = model

    # Access the layers list
    if hasattr(hf_model_core, 'model'):
        layers = hf_model_core.model.layers
    elif hasattr(hf_model_core, 'layers'):
        layers = hf_model_core.layers
    else:
        raise AttributeError("Failed to find layers list on the underlying model object.")

    # Hooking the Gate
    for layer in layers:
        if hasattr(layer.mlp, 'gate'):
            def hook_fn(module, input):
                hidden = input[0].view(-1, input[0].shape[-1])
                logits = F.linear(hidden, module.weight)
                probs = F.softmax(logits, dim=-1)
                router_probs_all.append(probs.mean(dim=0, keepdim=True).cpu())
                return input
            handle = layer.mlp.gate.register_forward_pre_hook(hook_fn)
            handles.append(handle)

    # ... (rest of the forward pass is correct)
    model.eval()
    device = next(model.parameters()).device

    with torch.no_grad():
        for text in texts[:num_prompts]:
            inputs = tokenizer(text, return_tensors="pt").to(device)
            _ = model(**inputs, use_cache=False)

    for h in handles:
        h.remove()

    if router_probs_all:
        stacked = torch.cat(router_probs_all, dim=0)
        avg_aff = stacked.mean(dim=0)
        aff_to_ids = avg_aff[ids].sum().item()
        return aff_to_ids
    return 0.0

discovered_ids = [62, 43, 45, 59]
math_prompts = [f"Question: {ex['question']}\nLet's think step by step." for ex in load_dataset("gsm8k", "main", split="train[:50]")]

aff_b0 = check_affinity(trainer_b0.model, tokenizer, math_prompts, discovered_ids)
print(f"B0 Affinity to IDs: {aff_b0:.1%} (Baseline ~6.25%)")

# Perplexity (Sequential, B0 Only)
gsm8k_test = load_dataset("gsm8k", "main", split="test[:50]")
def format_ppl(ex):
    text = f"Question: {ex['question']}\nAnswer: {ex['answer']}<|endoftext|>"
    tokenized = tokenizer(text, truncation=True, max_length=256, padding="max_length")
    return {"input_ids": tokenized["input_ids"], "labels": tokenized["input_ids"].copy()}

gsm8k_test = gsm8k_test.map(format_ppl, batched=False, remove_columns=gsm8k_test.column_names)

# ... (Perplexity function is correct, assuming you have it defined)
def perplexity(model, tokenizer, test_data, ignore_index=-100):
    model.eval()
    total_loss = 0
    # IMPORTANT: The loss must be calculated using the same logic as the Trainer
    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)

    with torch.no_grad():
        for ex in test_data:
            inputs = {"input_ids": torch.tensor([ex["input_ids"]]).to(model.device)}
            # The labels list already has the prompt and pad tokens set to -100
            # (or should be adjusted in the mapping function).
            labels = torch.tensor([ex["labels"]]).to(model.device)

            outputs = model(**inputs, use_cache=False)

            # 1. Shift logits for causal loss (LMs predict the next token)
            shift_logits = outputs.logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            # 2. Calculate loss (ignoring -100 labels)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            total_loss += loss.item()

    # Return PPL
    return math.exp(total_loss / len(test_data))

# HellaSwag (General)
hellaswag_test = load_dataset("hellaswag", split="validation[:50]")
def format_hellaswag(ex):
    # --- FIXED: Convert string label to integer index ---
    label_index = int(ex['label'])
    text = f"Premise: {ex['ctx_a']} {ex['ctx_b']} Sentence: {ex['endings'][label_index]}<|endoftext|>"
    tokenized = tokenizer(text, truncation=True, max_length=256, padding="max_length")
    return {"input_ids": tokenized["input_ids"], "labels": tokenized["input_ids"].copy()}

hellaswag_test = hellaswag_test.map(format_hellaswag, batched=False, remove_columns=hellaswag_test.column_names)
ppl_b0_hella = perplexity(trainer_b0.model, tokenizer, hellaswag_test)
print(f"B0 HellaSwag PPL: {ppl_b0_hella:.2f}")

# Summary
print(f"B0 Baseline: Affinity {aff_b0:.1%}, Math PPL {ppl_b0_math:.2f}, General PPL {ppl_b0_hella:.2f}")

B0 Affinity to IDs: 6.7% (Baseline ~6.25%)


Map:   0%|          | 0/50 [00:00<?, ? examples/s]

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

B0 HellaSwag PPL: 1754.96
B0 Baseline: Affinity 6.7%, Math PPL 304.05, General PPL 1754.96


In [None]:
# Fixed Affinity Check: Nested PEFT Path (base_model.base_model.model.layers)
import math
from datasets import load_dataset
import torch.nn.functional as F

math_prompts = [f"Question: {ex['question']}\nLet's think step by step." for ex in load_dataset("gsm8k", "main", split="train[:50]")]

def check_affinity(model, tokenizer, texts, ids, num_prompts=50):
    router_probs_all = []
    handles = []

    # Get the underlying Hugging Face model object
    # Path: PeftModelForCausalLM -> LoraModel -> DeepseekForCausalLM
    if hasattr(model, 'base_model') and hasattr(model.base_model, 'base_model'):
        # This is the path for PEFT-wrapped models (trainer_b0/b5)
        hf_model = model.base_model.base_model
    else:
        # Fallback for the raw model object (DeepseekForCausalLM)
        hf_model = model

    # The actual layers list is typically under the 'model' attribute of the HF model wrapper
    # DeepseekForCausalLM -> model -> layers
    if hasattr(hf_model, 'model'):
        layers = hf_model.model.layers
    elif hasattr(hf_model, 'layers'):
        layers = hf_model.layers
    else:
        raise AttributeError("Could not find layers in the model structure (expected .model.layers or .layers).")

    for layer in layers:
        if hasattr(layer.mlp, 'gate'):
            def hook_fn(module, input):
                hidden = input[0].view(-1, input[0].shape[-1])
                logits = F.linear(hidden, module.weight)
                probs = F.softmax(logits, dim=-1)
                router_probs_all.append(probs.mean(dim=0, keepdim=True).cpu())
                return input
            handle = layer.mlp.gate.register_forward_pre_hook(hook_fn)
            handles.append(handle)

    model.eval()
    with torch.no_grad():
        for text in texts[:num_prompts]:
            inputs = tokenizer(text, return_tensors="pt").to(model.device)
            _ = model(**inputs, use_cache=False)

    for h in handles:
        h.remove()

    if router_probs_all:
        stacked = torch.cat(router_probs_all, dim=0)
        avg_aff = stacked.mean(dim=0)
        # Average probability across the 4 specific experts
        aff_to_ids = (avg_aff[ids].sum() / len(ids)).item()
        return aff_to_ids
    return 0.0

discovered_ids = [62, 43, 45, 59]

aff_b0 = check_affinity(trainer_b0.model, tokenizer, math_prompts, discovered_ids)
aff_b5 = check_affinity(trainer_b5.model, tokenizer, math_prompts, discovered_ids)

print(f"B0 Affinity: {aff_b0:.1%} | B5: {aff_b5:.1%} (Target B5 > B0 +5%)")

NameError: name 'trainer_b5' is not defined

In [None]:
# Debug Model Structure
print("=== Trainer Model Type ===")
print(f"Type: {type(trainer_b0.model)}")
print(f"Has base_model: {hasattr(trainer_b0.model, 'base_model')}")
if hasattr(trainer_b0.model, 'base_model'):
    print(f"base_model type: {type(trainer_b0.model.base_model)}")
    print(f"base_model has layers: {hasattr(trainer_b0.model.base_model, 'layers')}")
    if hasattr(trainer_b0.model.base_model, 'layers'):
        print(f"layers type: {type(trainer_b0.model.base_model.layers)}")
        print(f"Num layers: {len(trainer_b0.model.base_model.layers)}")
        print(f"First layer mlp gate: {hasattr(trainer_b0.model.base_model.layers[0].mlp, 'gate')}")
print("\nDir slice:", [attr for attr in dir(trainer_b0.model) if 'base' in attr or 'model' in attr][:5])

=== Trainer Model Type ===
Type: <class 'peft.peft_model.PeftModelForCausalLM'>
Has base_model: True
base_model type: <class 'peft.tuners.lora.model.LoraModel'>
base_model has layers: False

Dir slice: ['_get_base_model_class', '_get_peft_specific_model_tags', '_prepare_model_for_gradient_checkpointing', 'base_model', 'base_model_prepare_inputs_for_generation']


In [None]:
!pip install -q -U transformers peft datasets accelerate bitsandbytes wandb evaluate

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m512.3/512.3 kB[0m [31m10.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m45.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.7/47.7 MB[0m [31m54.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# The Targeted Burst Script

In [None]:
import os
import torch
import bitsandbytes as bnb
from dataclasses import dataclass, field
from peft import LoraConfig, get_peft_model, PeftModel
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling, BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from bitsandbytes.nn import Params4bit

# --- CONFIG ---
@dataclass
class Config:
    model_name: str = "deepseek-ai/deepseek-moe-16b-base"
    load_in_4bit: bool = True
    lora_rank_targeted: int = 64
    lora_alpha_targeted: int = 128
    lora_dropout: float = 0.05
    math_expert_ids: list = field(default_factory=lambda: [62, 43, 45, 59])

config = Config()

# --- PATHS ---
DRIVE_SAVE_ROOT = "/content/drive/MyDrive/DeepSeek_Model"
QUANT_MODEL_PATH = os.path.join(DRIVE_SAVE_ROOT, config.model_name.split('/')[-1] + "_4bit")
BURST_SAVE_PATH = os.path.join(DRIVE_SAVE_ROOT, "Burst_Targeted_Phase1")

os.makedirs(BURST_SAVE_PATH, exist_ok=True)

# --- LOAD MODEL ---
bnb_config = BitsAndBytesConfig(
    load_in_4bit=config.load_in_4bit,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

if os.path.exists(os.path.join(QUANT_MODEL_PATH, "config.json")):
    print("Loading model from Drive...")
    model = AutoModelForCausalLM.from_pretrained(QUANT_MODEL_PATH, quantization_config=bnb_config, device_map="auto", dtype=torch.bfloat16, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(QUANT_MODEL_PATH)
else:
    print("Downloading model...")
    model = AutoModelForCausalLM.from_pretrained(config.model_name, quantization_config=bnb_config, device_map="auto", dtype=torch.bfloat16, trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    model.save_pretrained(QUANT_MODEL_PATH)
    tokenizer.save_pretrained(QUANT_MODEL_PATH)

tokenizer.pad_token = tokenizer.eos_token
model.config.use_cache = False

# --- DATA (GSM8K) ---
gsm8k = load_dataset("gsm8k", "main", split="train[:500]")
def format_math(ex):
    return {"text": f"Question: {ex['question']}\nLet's think step by step.\nAnswer: {ex['answer']}<|endoftext|>"}
math_data = gsm8k.map(format_math).map(lambda x: tokenizer(x["text"], truncation=True, max_length=512, padding="max_length"), batched=True)
collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

# --- THE SURGICAL FIX ---
if isinstance(model, PeftModel):
    model = model.unload()
torch.cuda.empty_cache()

# 1. Apply LoRA to Experts & Attention (Targeting Linear Layers)
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]

print(f"🎯 LoRA Targets: Attention + Experts {config.math_expert_ids}")
for e_id in config.math_expert_ids:
    target_modules.append(f"experts.{e_id}.gate_proj")
    target_modules.append(f"experts.{e_id}.up_proj")
    target_modules.append(f"experts.{e_id}.down_proj")

lora_config_burst = LoraConfig(
    r=config.lora_rank_targeted,
    lora_alpha=config.lora_alpha_targeted,
    target_modules=target_modules,
    lora_dropout=config.lora_dropout,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config_burst)

# 2. HELPER: UPCAST ROUTER TO FLOAT AND UNFREEZE
def upcast_and_unfreeze_routers(model):
    print("🔓 De-quantizing and Unfreezing Routers...")
    count = 0
    base = model.base_model.model.model

    for layer in base.layers:
        if hasattr(layer.mlp, 'gate'):
            gate = layer.mlp.gate
            # Check if weight is 4-bit
            if isinstance(gate.weight, Params4bit):
                w_dequant = bnb.functional.dequantize_4bit(gate.weight.data, gate.weight.quant_state)
                gate.weight = torch.nn.Parameter(w_dequant.to(torch.bfloat16))

            # Enable Gradients
            gate.weight.requires_grad = True
            count += 1
    print(f"✅ Successfully upcasted {count} routers to bfloat16.")

upcast_and_unfreeze_routers(model)

# Verify
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
    all_param += param.numel()
    if param.requires_grad:
        trainable_params += param.numel()
print(f"Trainable params: {trainable_params:,} || All params: {all_param:,} || Trainable%: {100 * trainable_params / all_param:.4f}%")

# --- ENABLE INPUT GRADS ---
model.enable_input_require_grads()
model.gradient_checkpointing_enable()

# --- STRICTER VERIFICATION LOOP (FIXED) ---
print("🔍 Verifying Router Gradients...")
for name, param in model.named_parameters():
    # Only target 'mlp.gate.weight', NOT 'mlp.gate_proj.weight'
    if "mlp.gate.weight" in name and not param.requires_grad:
        print(f"⚠️ Warning: {name} was frozen! Unfreezing...")
        param.requires_grad = True
    # Ignore 'gate_proj' errors (they should remain frozen 4-bit)

# --- TRAINING ---
burst_args = TrainingArguments(
    output_dir=BURST_SAVE_PATH,
    num_train_epochs=2,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    bf16=True,
    logging_steps=10,
    eval_strategy="no",
    dataloader_drop_last=True,
    gradient_checkpointing=True,
    optim="paged_adamw_8bit",
    report_to="none",
    save_steps=50,
    save_total_limit=2
)

trainer = Trainer(model=model, args=burst_args, train_dataset=math_data, data_collator=collator)

print("🚀 Starting Phase 1: Hybrid Burst Training (Corrected)...")
trainer.train()

# Save
trainer.save_model(BURST_SAVE_PATH)
tokenizer.save_pretrained(BURST_SAVE_PATH)

# Manual Router Save
router_state_dict = {k: v.cpu() for k, v in model.named_parameters() if "mlp.gate" in k}
torch.save(router_state_dict, os.path.join(BURST_SAVE_PATH, "router_gates.pt"))

print(f"✅ Phase 1 Complete.")

Loading model from Drive...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

🎯 LoRA Targets: Attention + Experts [62, 43, 45, 59]
🔓 De-quantizing and Unfreezing Routers...
✅ Successfully upcasted 27 routers to bfloat16.
Trainable params: 104,562,688 || All params: 8,500,430,848 || Trainable%: 1.2301%
🔍 Verifying Router Gradients...
🚀 Starting Phase 1: Hybrid Burst Training (Corrected)...


Step,Training Loss
10,1.1001
20,0.8987
30,0.8671
40,0.7897
50,0.7717
60,0.7944


✅ Phase 1 Complete.


Step	Training Loss
10	1.024200
20	0.947100
30	0.920000
40	0.668300
50	0.571200
60	0.582400
✅ B5 Trained & Saved to Drive. Final Loss: 0.7704821899533272

In [None]:
print("Final Loss:", trainer.state.log_history[-1].get('train_loss', 'N/A'))


Final Loss: 0.8625424839556217


In [None]:
import os
import torch
import math
import numpy as np
import torch.nn.functional as F
from datasets import load_dataset
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import bitsandbytes as bnb

# --- CONFIG ---
DRIVE_SAVE_ROOT = "/content/drive/MyDrive/DeepSeek_Model"
BURST_SAVE_PATH = os.path.join(DRIVE_SAVE_ROOT, "Burst_Targeted_Phase1")
BASE_MODEL_NAME = "deepseek-ai/deepseek-moe-16b-base"
MATH_EXPERT_IDS = [62, 43, 45, 59]

print(f"📂 Loading from: {BURST_SAVE_PATH}")

# ==========================================
# 1. FORCE IMPLANT LOADER
# ==========================================
def load_hybrid_burst_model_force(base_model_name, burst_save_path):
    # 1. Load Base Model
    print(f"⏳ Loading Base Model (4-bit)...")
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"
    )
    model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        quantization_config=bnb_config,
        device_map="auto",
        dtype=torch.bfloat16,
        trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(burst_save_path)
    tokenizer.pad_token = tokenizer.eos_token

    # 2. Load LoRA Adapters
    print("🔗 Loading LoRA Adapters...")
    model = PeftModel.from_pretrained(model, burst_save_path)

    # 3. Load Saved Weights
    router_path = os.path.join(burst_save_path, "router_gates.pt")
    if not os.path.exists(router_path):
        raise FileNotFoundError(f"CRITICAL: {router_path} not found!")

    print(f"📥 Analyzing {router_path}...")
    router_state = torch.load(router_path, map_location="cpu", weights_only=False)

    # 4. IMPLANTATION LOOP
    print("🔓 Force-Implanting Router Weights...")

    # Identify valid router weights from the file
    valid_weights = []
    for k, v in router_state.items():
        # Filter out the massive junk layers (Experts are > 10M params)
        if v.numel() > 5_000_000:
            continue
        valid_weights.append(v)

    if len(valid_weights) == 0:
        print("❌ CRITICAL: No valid router weights found in file! (All were junk or empty)")
        print("Keys found:", list(router_state.keys())[:5])
        return model, tokenizer

    # Find the target modules in the model
    targets = []
    # Search recursively for anything named 'gate'
    for name, module in model.named_modules():
        if "gate" in name and not "proj" in name: # avoid gate_proj
             targets.append(module)

    print(f"   - Found {len(valid_weights)} source weights.")
    print(f"   - Found {len(targets)} target modules in model.")

    if len(targets) == 0:
        print("❌ CRITICAL: Could not find any 'gate' modules in the model!")
        return model, tokenizer

    # Match and Swap
    implant_count = 0
    # We assume the order matches (Layer 0 -> Layer N)
    # This is safe because both .items() and .named_modules() are usually ordered

    limit = min(len(valid_weights), len(targets))
    for i in range(limit):
        target_module = targets[i]
        source_weight = valid_weights[i]

        # Check shapes (Handle transpose if needed)
        # Target usually [Experts, Hidden] e.g. [64, 2048]
        # Source might be [2048, 64] depending on save

        try:
            # Prepare the new weight
            new_tensor = source_weight.to(dtype=torch.bfloat16, device=model.device)

            # Auto-Transpose if shapes are swapped
            if hasattr(target_module, "weight"):
                current_shape = target_module.weight.shape
                if current_shape != new_tensor.shape and current_shape == new_tensor.T.shape:
                    new_tensor = new_tensor.T

            # FORCE REPLACE
            # We delete the old parameter to break the bitsandbytes link
            if hasattr(target_module, "weight"):
                del target_module.weight

            # Assign new parameter
            target_module.weight = torch.nn.Parameter(new_tensor)

            # Clean up quantization state if it exists
            if hasattr(target_module, "quant_state"):
                del target_module.quant_state

            implant_count += 1

        except Exception as e:
            print(f"⚠️ Error implanting layer {i}: {e}")

    print(f"✅ Successfully Implanted {implant_count} Router Matrices.")

    return model, tokenizer

# Load it!
model, tokenizer = load_hybrid_burst_model_force(BASE_MODEL_NAME, BURST_SAVE_PATH)
model.eval()

# ==========================================
# 2. AFFINITY CHECK
# ==========================================
def check_affinity(model, tokenizer, texts, ids, num_prompts=50):
    router_probs_all = []
    handles = []

    # Simple recursive search for the gates we just modified
    for name, module in model.named_modules():
        if "gate" in name and not "proj" in name:
            def hook_fn(module, input):
                with torch.no_grad():
                    x = input[0].detach()
                    if x.dtype != module.weight.dtype: x = x.to(module.weight.dtype)
                    hidden = x.view(-1, x.shape[-1])
                    logits = F.linear(hidden, module.weight)
                    probs = F.softmax(logits, dim=-1)
                    router_probs_all.append(probs.mean(dim=0, keepdim=True).cpu())
                return input
            handle = module.register_forward_pre_hook(hook_fn)
            handles.append(handle)

    device = next(model.parameters()).device
    print(f"📊 Running Affinity on {num_prompts} prompts...")
    try:
        with torch.no_grad():
            for text in texts[:num_prompts]:
                inputs = tokenizer(text, return_tensors="pt").to(device)
                _ = model(**inputs, use_cache=False)
    except Exception as e:
        print(f"⚠️ Inference Error: {e}")
    finally:
        for h in handles: h.remove()

    if router_probs_all:
        stacked = torch.cat(router_probs_all, dim=0)
        avg_aff = stacked.mean(dim=0)
        aff_to_ids = avg_aff[ids].sum().item()
        return aff_to_ids
    return 0.0

gsm8k = load_dataset("gsm8k", "main", split="train[:50]")
math_prompts = [f"Question: {ex['question']}\nLet's think step by step." for ex in gsm8k]
aff_burst = check_affinity(model, tokenizer, math_prompts, MATH_EXPERT_IDS)

# ==========================================
# 3. PERPLEXITY
# ==========================================
def perplexity(model, tokenizer, test_data):
    model.eval()
    total_loss = 0
    n_tokens = 0
    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
    with torch.no_grad():
        for ex in test_data:
            inputs = {"input_ids": torch.tensor([ex["input_ids"]]).to(model.device)}
            outputs = model(**inputs, use_cache=False)
            shift_logits = outputs.logits[..., :-1, :].contiguous()
            shift_labels = torch.tensor([ex["labels"]])[..., 1:].contiguous().to(model.device)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            mask = (shift_labels != -100).sum().item()
            if mask > 0:
                total_loss += loss.item() * mask
                n_tokens += mask
    return math.exp(total_loss / n_tokens) if n_tokens > 0 else float('inf')

print("📊 Measuring Math PPL...")
gsm8k_test = load_dataset("gsm8k", "main", split="test[:50]")
def format_math_ppl(ex):
    text = f"Question: {ex['question']}\nAnswer: {ex['answer']}<|endoftext|>"
    tokenized = tokenizer(text, truncation=True, max_length=512, padding="max_length")
    labels = np.array(tokenized["input_ids"])
    prompt_len = len(tokenizer(f"Question: {ex['question']}\n").input_ids)
    labels[:prompt_len] = -100
    return {"input_ids": tokenized["input_ids"], "labels": labels.tolist()}

gsm8k_test = gsm8k_test.map(format_math_ppl, batched=False, remove_columns=gsm8k_test.column_names)
ppl_burst_math = perplexity(model, tokenizer, gsm8k_test)

print("\n" + "="*40)
print("🚀 PHASE 1 BURST RESULTS (FINAL)")
print("="*40)
print(f"Affinity (Goal > 8.0%):  {aff_burst:.2%}")
print(f"Math PPL (Goal < 20):    {ppl_burst_math:.2f}")
print("="*40)

📂 Loading from: /content/drive/MyDrive/DeepSeek_Model/Burst_Targeted_Phase1
⏳ Loading Base Model (4-bit)...


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

🔗 Loading LoRA Adapters...
📥 Analyzing /content/drive/MyDrive/DeepSeek_Model/Burst_Targeted_Phase1/router_gates.pt...
🔓 Force-Implanting Router Weights...
   - Found 27 source weights.
   - Found 27 target modules in model.
✅ Successfully Implanted 27 Router Matrices.
📊 Running Affinity on 50 prompts...
📊 Measuring Math PPL...


Map:   0%|          | 0/50 [00:00<?, ? examples/s]


🚀 PHASE 1 BURST RESULTS (FINAL)
Affinity (Goal > 8.0%):  7.08%
Math PPL (Goal < 20):    22895.07


# last


In [None]:
import os
import torch
import bitsandbytes as bnb
from bitsandbytes.nn import Params4bit  # <--- Import this for the check
from dataclasses import dataclass, field
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig,
    AutoModelForCausalLM,
    AutoTokenizer
)
from datasets import load_dataset

# --- CONFIG ---
@dataclass
class Config:
    model_name: str = "deepseek-ai/deepseek-moe-16b-base"
    lora_rank: int = 64
    lora_alpha: int = 128
    math_expert_ids: list = field(default_factory=lambda: [62, 43, 45, 59])

config = Config()
DRIVE_SAVE_ROOT = "/content/drive/MyDrive/DeepSeek_Model"
BURST_SAVE_PATH = os.path.join(DRIVE_SAVE_ROOT, "Burst_Phase1_Native")
os.makedirs(BURST_SAVE_PATH, exist_ok=True)

# --- 1. LOAD MODEL ---
print(f"⏳ Loading {config.model_name}...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

model = AutoModelForCausalLM.from_pretrained(
    config.model_name,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token

# Enable gradients
model = prepare_model_for_kbit_training(model)

# --- 2. THE CONDITIONAL CAST (Fixed) ---
print("🔓 checking Routers for training...")
for name, module in model.named_modules():
    if "gate" in name and "proj" not in name: # Target 'mlp.gate'
        # Check if it is actually 4-bit
        if isinstance(module.weight, Params4bit):
            print(f"   - Dequantizing 4-bit Router: {name}")
            w_dequant = bnb.functional.dequantize_4bit(module.weight.data, module.weight.quant_state)
            module.weight = torch.nn.Parameter(w_dequant.to(torch.bfloat16))
        else:
            # It's already standard, just ensure bfloat16
            if module.weight.dtype != torch.bfloat16:
                print(f"   - Casting Router to bfloat16: {name}")
                module.weight.data = module.weight.data.to(torch.bfloat16)

# --- 3. DATA PIPELINE ---
gsm8k = load_dataset("gsm8k", "main", split="train[:500]")
def format_math(ex):
    return {"text": f"Question: {ex['question']}\nLet's think step by step.\nAnswer: {ex['answer']}<|endoftext|>"}

train_data = gsm8k.map(format_math, remove_columns=gsm8k.column_names)
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=512)
tokenized_train = train_data.map(tokenize_function, batched=True)
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

# --- 4. PEFT CONFIG ---
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] # Attention
for e_id in config.math_expert_ids:
    target_modules.append(f"experts.{e_id}.gate_proj")
    target_modules.append(f"experts.{e_id}.up_proj")
    target_modules.append(f"experts.{e_id}.down_proj")

lora_config = LoraConfig(
    r=config.lora_rank,
    lora_alpha=config.lora_alpha,
    target_modules=target_modules,      # LoRA on Experts
    modules_to_save=["gate"],           # <--- Full Train the Router!
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# --- 5. TRAINING ---
training_args = TrainingArguments(
    output_dir=BURST_SAVE_PATH,
    num_train_epochs=2,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    bf16=True,
    logging_steps=10,
    save_strategy="steps",
    save_steps=50,
    save_total_limit=1,
    gradient_checkpointing=True,
    optim="paged_adamw_8bit",
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    data_collator=collator
)

print("🚀 Starting Native Hybrid Training...")
trainer.train()

# --- 6. SAVE ---
trainer.save_model(BURST_SAVE_PATH)
tokenizer.save_pretrained(BURST_SAVE_PATH)
print(f"✅ Saved Native Hybrid model to {BURST_SAVE_PATH}")

⏳ Loading deepseek-ai/deepseek-moe-16b-base...


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

🔓 checking Routers for training...
   - Casting Router to bfloat16: model.layers.1.mlp.gate
   - Casting Router to bfloat16: model.layers.2.mlp.gate
   - Casting Router to bfloat16: model.layers.3.mlp.gate
   - Casting Router to bfloat16: model.layers.4.mlp.gate
   - Casting Router to bfloat16: model.layers.5.mlp.gate
   - Casting Router to bfloat16: model.layers.6.mlp.gate
   - Casting Router to bfloat16: model.layers.7.mlp.gate
   - Casting Router to bfloat16: model.layers.8.mlp.gate
   - Casting Router to bfloat16: model.layers.9.mlp.gate
   - Casting Router to bfloat16: model.layers.10.mlp.gate
   - Casting Router to bfloat16: model.layers.11.mlp.gate
   - Casting Router to bfloat16: model.layers.12.mlp.gate
   - Casting Router to bfloat16: model.layers.13.mlp.gate
   - Casting Router to bfloat16: model.layers.14.mlp.gate
   - Casting Router to bfloat16: model.layers.15.mlp.gate
   - Casting Router to bfloat16: model.layers.16.mlp.gate
   - Casting Router to bfloat16: model.layers.

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

trainable params: 104,562,688 || all params: 16,480,290,816 || trainable%: 0.6345
🚀 Starting Native Hybrid Training...




Step,Training Loss
10,1.1012
20,0.8986
30,0.867
40,0.7901
50,0.7714
60,0.7937


✅ Saved Native Hybrid model to /content/drive/MyDrive/DeepSeek_Model/Burst_Phase1_Native


In [None]:
print("Final Loss:", trainer.state.log_history[-1].get('train_loss', 'N/A'))

Final Loss: 0.8626297377049923


In [None]:
import torch
import math
import numpy as np
import os
from datasets import load_dataset
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import bitsandbytes as bnb

# --- CONFIG ---
DRIVE_SAVE_ROOT = "/content/drive/MyDrive/DeepSeek_Model"
BURST_SAVE_PATH = os.path.join(DRIVE_SAVE_ROOT, "Burst_Phase1_Native")
BASE_MODEL_NAME = "deepseek-ai/deepseek-moe-16b-base"
MATH_EXPERT_IDS = [62, 43, 45, 59]

print(f"📂 Loading from: {BURST_SAVE_PATH}")

# ==========================================
# 1. LOAD BASE MODEL
# ==========================================
print(f"⏳ Loading Base Model (4-bit)...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)

# ==========================================
# 2. DIAGNOSTIC: FIND THE ROUTER
# ==========================================
print("\n🔍 DIAGNOSTIC: Printing Model Structure (First Layer Only)...")
found_gate = False
target_module_name = ""

for name, module in model.named_modules():
    # We look for ANY linear layer that is a gate
    if "gate" in name and "proj" not in name:
        print(f"   ✅ FOUND CANDIDATE: {name} -> Type: {type(module)}")
        found_gate = True
        target_module_name = name.split(".")[-1] # likely 'gate'
        break

if not found_gate:
    print("❌ CRITICAL: Could not find any router layer! Checking all modules...")
    for name, module in list(model.named_modules())[:20]:
        print(f"   - {name}")

# ==========================================
# 3. THE FIX: CAST ROUTER TO FLOAT
# ==========================================
print(f"\n🔓 Casting Routers to bfloat16 (Targeting '{target_module_name}')...")
upcast_count = 0

for name, module in model.named_modules():
    # Use the name we just found
    if name.endswith(f".{target_module_name}"):
        # Force conversion
        if hasattr(module, "weight"):
             # If it is 4-bit (Params4bit), dequantize
             if "Params4bit" in str(type(module.weight)):
                 w_dequant = bnb.functional.dequantize_4bit(module.weight.data, module.weight.quant_state)
                 module.weight = torch.nn.Parameter(w_dequant.to(torch.bfloat16))
             # If it's standard, just ensure dtype
             else:
                 module.weight.data = module.weight.data.to(torch.bfloat16)

             upcast_count += 1

print(f"   - Successfully Upcasted {upcast_count} routers.")

# ==========================================
# 4. LOAD ADAPTER
# ==========================================
print("🔗 Loading Adapter...")
model = PeftModel.from_pretrained(model, BURST_SAVE_PATH)
model.eval()

# ==========================================
# 5. RUN CHECKS
# ==========================================
# A. Affinity
def check_affinity(model, tokenizer, texts, target_ids, num_prompts=50):
    router_probs_all = []
    handles = []

    for name, module in model.named_modules():
        if name.endswith(target_module_name): # Use the found name
            def hook_fn(module, input):
                with torch.no_grad():
                    x = input[0].detach()
                    if x.dtype != module.weight.dtype: x = x.to(module.weight.dtype)
                    logits = torch.nn.functional.linear(x.view(-1, x.shape[-1]), module.weight)
                    probs = torch.nn.functional.softmax(logits, dim=-1)
                    router_probs_all.append(probs.mean(dim=0, keepdim=True).cpu())
                return input
            handles.append(module.register_forward_pre_hook(hook_fn))

    device = next(model.parameters()).device
    with torch.no_grad():
        for i, text in enumerate(texts[:num_prompts]):
            inputs = tokenizer(text, return_tensors="pt").to(device)
            _ = model(**inputs, use_cache=False)
    for h in handles: h.remove()

    if router_probs_all:
        stacked = torch.cat(router_probs_all, dim=0)
        avg_aff = stacked.mean(dim=0)
        return avg_aff[target_ids].sum().item()
    return 0.0

gsm8k = load_dataset("gsm8k", "main", split="train[:50]")
math_prompts = [f"Question: {ex['question']}\nLet's think step by step." for ex in gsm8k]
print("📊 Measuring Affinity...")
aff = check_affinity(model, tokenizer, math_prompts, MATH_EXPERT_IDS)

# B. PPL
def perplexity(model, tokenizer, test_data):
    model.eval()
    total_loss = 0
    n_tokens = 0
    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
    with torch.no_grad():
        for ex in test_data:
            inputs = {"input_ids": torch.tensor([ex["input_ids"]]).to(model.device)}
            outputs = model(**inputs, use_cache=False)
            shift_logits = outputs.logits[..., :-1, :].contiguous()
            shift_labels = torch.tensor([ex["labels"]])[..., 1:].contiguous().to(model.device)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            mask = (shift_labels != -100).sum().item()
            if mask > 0:
                total_loss += loss.item() * mask
                n_tokens += mask
    return math.exp(total_loss / n_tokens) if n_tokens > 0 else float('inf')

print("📊 Measuring PPL...")
gsm8k_test = load_dataset("gsm8k", "main", split="test[:50]")
def format_math_ppl(ex):
    text = f"Question: {ex['question']}\nLet's think step by step.\nAnswer: {ex['answer']}<|endoftext|>"
    tokenized = tokenizer(text, truncation=True, max_length=512, padding="max_length")
    labels = np.array(tokenized["input_ids"])
    prompt_len = len(tokenizer(f"Question: {ex['question']}\nLet's think step by step.\nAnswer:").input_ids)
    labels[:prompt_len] = -100
    return {"input_ids": tokenized["input_ids"], "labels": labels.tolist()}

gsm8k_ppl_data = gsm8k_test.map(format_math_ppl, batched=False, remove_columns=gsm8k_test.column_names)
ppl = perplexity(model, tokenizer, gsm8k_ppl_data)

print("\n" + "="*40)
print(f"Affinity: {aff:.2%}")
print(f"PPL:      {ppl:.2f}")
print("="*40)

📂 Loading from: /content/drive/MyDrive/DeepSeek_Model/Burst_Phase1_Native
⏳ Loading Base Model (4-bit)...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json: 0.00B [00:00, ?B/s]

configuration_deepseek.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/deepseek-ai/deepseek-moe-16b-base:
- configuration_deepseek.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
`torch_dtype` is deprecated! Use `dtype` instead!


modeling_deepseek.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/deepseek-ai/deepseek-moe-16b-base:
- modeling_deepseek.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

model-00003-of-00007.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00006-of-00007.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00002-of-00007.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00007-of-00007.safetensors:   0%|          | 0.00/2.77G [00:00<?, ?B/s]

model-00001-of-00007.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00004-of-00007.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00005-of-00007.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/121 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/793 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]


🔍 DIAGNOSTIC: Printing Model Structure (First Layer Only)...
   ✅ FOUND CANDIDATE: model.layers.1.mlp.gate -> Type: <class 'transformers_modules.deepseek_hyphen_ai.deepseek_hyphen_moe_hyphen_16b_hyphen_base.521d2bc4fb69a3f3ae565310fcc3b65f97af2580.modeling_deepseek.MoEGate'>

🔓 Casting Routers to bfloat16 (Targeting 'gate')...
   - Successfully Upcasted 27 routers.
🔗 Loading Adapter...


README.md: 0.00B [00:00, ?B/s]

main/train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

main/test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

📊 Measuring Affinity...
📊 Measuring PPL...


Map:   0%|          | 0/50 [00:00<?, ? examples/s]


Affinity: 7.13%
PPL:      4097.70


In [None]:
im doing a sort of project like this, MoE Burst Training: Project Summary
The Core Idea
Problem: When you fine-tune Mixture-of-Experts (MoE) models like Mixtral-8x7B on specialized domains (e.g., math), standard LoRA applies the same small rank (r=16) to ALL experts uniformly. This creates two issues:

1. Dilution: Math-specialized data gets spread across all 8 experts, preventing deep specialization

2. Router Inertia: The pretrained router already has strong habits about which experts to use, and resists learning new routing patterns during fine-tuning

The Insight: Recent 2025 papers on "inverse scaling laws" proved that optimal MoE architectures should pre-allocate different expert sizes based on domain data volume (e.g., if you have 50B math tokens, the math expert should be proportionally larger). But these papers assume you're training from scratch with $10M budgets.
Our Innovation: What if we could retrofit existing pretrained models (like Mixtral or DeepSeek-MoE) to approximate these optimal architectures using clever training techniques, for just $40-200?

The Method: "Burst Upcycling"
We combine FOUR components that work synergistically:
0. Expert Discovery & Capacity Planning (Preprocessing Phase)
NEW INSIGHT: Don't randomly pick which experts to enhance—discover which ones already handle your domain!

* Step 1: Profile the pretrained model on a small sample of domain data (~500 examples)

* Step 2: Measure which experts naturally activate most on domain tokens

* Step 3: Select top-k experts that are already domain-inclined (e.g., Expert 3 and 5 activate 1.8× more on math)

* Step 4: Calculate optimal LoRA ranks using inverse scaling formula: rank ≈ domain_tokens / 20,000

Why this matters:

* ✅ Works WITH pretrained knowledge, not against it

* ✅ Lower router inertia (router already sends some math there—just need to strengthen it)

* ✅ Faster convergence (building on existing patterns vs. fighting them)

* ✅ Better results (+2-3% extra gain vs. random expert selection)

Example output:
🔍 Discovering math experts...
Expert affinity scores:
  Expert 0: 0.098
  Expert 1: 0.102
👉 Expert 3: 0.187 ← Already math-inclined!
  Expert 4: 0.095
👉 Expert 5: 0.156 ← Already math-inclined!

Selected: [3, 5]
Calculated ranks: math_rank=128, general_rank=16
Cost: 15-20 minutes, runs once before training
1. Heterogeneous LoRA (Capacity Allocation)

* Apply high LoRA rank (r=128) to discovered math experts [3, 5]

* Apply low LoRA rank (r=16) to remaining general experts

* This mimics the "inverse scaling law" optimal architecture without changing the base model

2. Burst Curriculum (Data Sequencing)
Instead of mixing 50% math + 50% general data randomly throughout training:

* Phase I (Epochs 0-1): 80% general, 20% math (warmup - stabilize baseline)

* Phase II (Epochs 2-3): 100% math (affinity warmup - teach router where math experts are)

* Phase III (Epochs 4-9): 100% math with 1.5x loss weight (capacity burst - aggressively specialize)

* Phase IV (Epochs 10-11): 50/50 mix (stabilization - prevent catastrophic forgetting)

The "burst" forces the router to pay attention to the math experts instead of ignoring them.
3. Router Teacher Forcing (Overcoming Inertia)

* Add an auxiliary loss that explicitly penalizes the router when it fails to send math tokens to discovered math experts

* Start strong (coefficient 0.05) then anneal down (0.05 → 0.001) so the model learns to route correctly on its own

* This breaks the pretrained router's stubborn habits


Why This Matters
Academic Contribution:

* Bridges the gap between theory (2025 inverse scaling papers) and practice (existing open-source models)

* Novel training methodology: "Burst Curriculum" + "Affinity Warmup" for pretrained MoE adaptation

* Practical protocol engineers can use today

Community Value:

* Most practitioners can't afford to pretrain custom MoE architectures ($10M+)

* Everyone has access to Mixtral/DeepSeek checkpoints

* If this works, people can specialize existing models cheaply ($50 vs $10M)

Research Question:
"Can sequential burst training with heterogeneous LoRA enable pretrained MoE models to achieve the capacity distributions predicted by inverse scaling laws, at a fraction of the retraining cost?"

Where the Idea Came From
Initial Inspiration:

1. DeepSeek-V3 (2025): Showed semantic experts with affinity-based routing gave +12% GSM8K gains

2. Progressive MoE papers (ReXMoE, 2025): Used curriculum-based "expert unlocking" for +6% efficiency

3. Inverse Scaling Laws (arXiv 2509.07909, 2025): Proved optimal expert sizing based on data volume (e.g., 20 tokens/param for 80% saturation)

Evolution Through Discussion:

* Initial idea: "Can I use bursts to specialize experts better than standard LoRA?"

* Concern discovered: Wait, 2025 papers already did inverse scaling for architecture design—am I too late?

* Reframing: No! Those papers solve the "build from scratch" problem. We solve the "retrofit existing models" problem—completely different contribution

* Key pivot: Position this as "Practical Methodology for Upcycling Legacy MoE Models" not "New Scaling Law Discovery"


Validation Strategy: Phased Approach
We designed a risk-managed, incremental validation plan to avoid wasting compute on a failed hypothesis.
Phase 1: Smoke Test ($5, 1-2 hours)
Goal: Quick validation that the core mechanism works
What we test:

* B0 (Baseline): Vanilla LoRA, r=16, mixed data

* B5 (Full method): High-rank LoRA r=128, burst curriculum, router forcing

Success criteria:

* ✅ Router affinity increases from ~12.5% (random) to >20% (specialized)

* ✅ Code runs without crashes

* ✅ Training loss decreases

Decision point:

* Strong signal (+5% affinity): Proceed to Phase 2

* Weak signal (+2-5%): Tweak hyperparameters (increase affinity coefficient, burst weight)

* No signal (<2%): Debug before proceeding

Phase 2: Core Ablation ($50, 2-3 days)
Goal: Prove synergy—all three components together beat any individual component
What we test (4 baselines):

* B0: Vanilla LoRA (industry standard)

* B1: Heterogeneous LoRA ONLY (capacity without bursts) - tests if sizing alone helps

* B3: Heterogeneous + Bursts, NO router forcing - tests if inertia is real

* B5: Full method (all three components)

Success criteria:

* ✅ B5 beats B0 by +3-5% GSM8K accuracy

* ✅ B5 beats B1 (proves bursts add value beyond just sizing)

* ✅ B5 beats B3 (proves router forcing is necessary)

* ✅ Math expert affinity reaches 70%+

* ✅ General performance doesn't collapse (HellaSwag <5% drop)

Decision point:

* Success: Write up results, share code, consider workshop paper

* Partial success (B5 ≈ B1): Bursts don't add much—pivot to just heterogeneous LoRA

* Failure: Method doesn't work—learn from it, try alternative approaches

Phase 3: Full Study ($150, 1 week)
Goal: Statistical rigor + generalization validation
What we add:

* Run ALL 5 baselines (B0, B1, B2, B3, B5) with 3 random seeds each

* Add B2 (burst-only, no heterogeneous ranks) to complete ablation

* Validate on second model (Mixtral-8x7B) to prove it's not DeepSeek-specific

* Full evaluation: GSM8K, MATH, HellaSwag, router diagnostics

Deliverable:

* arXiv preprint or workshop paper

* Open-source code repository

* Training recipes for practitioners


Key Design Choices & Rationale
Choice 1: Model Selection
Decision: Start with DeepSeek-MoE-16B-base, validate on Mixtral-8x7B later
Why:

* DeepSeek advantages: 3x faster, 3x cheaper, 64 experts (clearer specialization signal)

* Mixtral validation: Industry recognition, broader applicability

* Base over chat: Cleaner routing behavior, no instruction-following interference

Choice 2: LoRA Over Full Fine-Tuning
Decision: Use QLoRA (4-bit quantization + LoRA)
Why:

* Full fine-tuning Mixtral requires 8× A100 80GB ($800+ per run) = economically infeasible

* LoRA is what practitioners actually use in the real world

* Our contribution is about training methodology, not just "more compute wins"

* Fair comparison: all baselines use same memory/compute budget

Choice 3: Burst Schedule Design
Decision: 4-phase schedule with annealed router forcing
Why each phase:

* Phase I warmup: Prevents router from diverging wildly at start

* Phase II affinity warmup: Teaches router new patterns with low expert capacity (prevents overfitting)

* Phase III capacity unlock: Now that router knows the pattern, train expert parameters aggressively

* Phase IV stabilization: Prevents catastrophic forgetting of general capabilities

Key insight: You can't just throw math data at frozen experts—the router won't route to them. You need staged unlocking.
Choice 4: Heterogeneous Ranks
Decision: r=128 for discovered math experts, r=16 for general experts
Rationale:

* Based on inverse scaling papers: "20 tokens/param for 80% saturation"

* If we have ~50k math tokens in burst, we need ~2500 params = r≈128 in LoRA

* 8x ratio (128 vs 16) reflects data volume imbalance (bursts are 20-30% of training)

Discovery process:

1. Profile pretrained model on 500 math examples

2. Measure which experts naturally activate most

3. Select top-k (e.g., experts [3, 5] if they show 1.7-1.8× higher activation)

4. Apply high rank to those specific experts

Why discovery matters:

* Discovered experts have 18% baseline affinity vs. random experts at 12.5%

* Overcoming 18% → 70% inertia is easier than 12.5% → 70%

* Expected improvement: +2-3% extra gain from working with natural patterns

Note: For smoke test, we use uniform r=128 as a proxy (true per-expert ranks require custom PEFT code)
Choice 5: Router Forcing Strategy
Decision: KL divergence loss from router probs to target distribution, with linear annealing
Why KL divergence:

* Soft constraint (doesn't force 100% routing to math experts)

* Allows model to learn when it's truly ambiguous

* Better than hard constraints (which break training)

Why anneal from 0.05 → 0.001:

* Start strong to overcome inertia

* Gradually fade so model learns natural routing

* By end, model routes correctly without artificial push

Choice 6: Evaluation Metrics
Primary metrics:

1. Math expert affinity: % of math tokens routed to designated math experts (should reach 70%+)

2. GSM8K accuracy: Absolute task performance improvement (+3-5% target)

3. HellaSwag retention: General capability preservation (should not drop >5%)

Secondary metrics:

* Router entropy (prevent collapse)

* Per-expert gradient norms (verify specialization)

* Expert utilization variance (balance check)

Why affinity is critical: It's the mechanistic proof that bursts work. If affinity doesn't increase, nothing else matters.

What Makes This Novel
Not Novel (Already Known):

* ❌ Inverse scaling laws for MoE sizing

* ❌ Curriculum learning for training

* ❌ LoRA for parameter efficiency

* ❌ Router auxiliary losses

Novel (Our Contribution):

* ✅ Burst curriculum specifically designed to overcome router inertia in pretrained models

* ✅ Staged capacity unlocking (affinity warmup → burst → stabilization)

* ✅ Heterogeneous LoRA as a proxy for architectural retrofitting

* ✅ Practical protocol bridging 2025 theory to 2024 checkpoints

The synthesis is novel, not the components.

Potential Pitfalls & Mitigations
Pitfall 1: Router Inertia Too Strong
Symptom: Affinity stays at ~12-15% despite forcing
Mitigation:

* Increase affinity coefficient (0.05 → 0.10)

* Extend Phase II (affinity warmup) from 2 → 4 epochs

* Add positive bias to math expert logits during bursts

Pitfall 2: Catastrophic Forgetting
Symptom: HellaSwag drops >10% while GSM8K improves
Mitigation:

* Shorten burst phases (reduce 100% math to 80% math)

* Add longer stabilization phase (Phase IV: 4 epochs instead of 2)

* Use lower burst weight (1.3x instead of 1.5x)

Pitfall 3: Expert Collapse
Symptom: All math tokens go to ONE expert, not distributed across 2-4
Mitigation:

* Add entropy regularization (already in design: -0.01 * p*log(p))

* Use load balance loss (variance penalty)

* Target more than 2 experts for specialization

Pitfall 4: Noise in Small-Scale Experiments
Symptom: Results vary wildly between runs (±3-5% swings)
Mitigation:

* Run 3 seeds per experiment (report mean ± std)

* Use larger eval sets (100+ GSM8K problems, not 50)

* Focus on affinity as primary metric (less noisy than accuracy)

Pitfall 5: Heterogeneous LoRA Implementation Bugs
Symptom: Code crashes or LoRA ranks not actually different
Mitigation:

* For smoke test: use uniform high rank as proxy (acceptable)

* For full ablation: manually verify per-expert ranks via model inspection

* Add preflight checks to validate architecture before training


Success Criteria Summary
Minimum Success (Still Valuable):

* B5 beats B0 by +3% GSM8K

* Affinity reaches 50%+

* Value: Proof of concept, shareable code/blog post

Good Success (Workshop Paper):

* B5 beats all baselines by +4-5%

* Affinity reaches 70%+

* Works on both DeepSeek and Mixtral

* Value: Workshop publication, community adoption

Excellent Success (Main Conference Paper):

* B5 beats all baselines by +7%+

* Generalizes across 3+ MoE models

* Shows cost efficiency: 0.1% retraining cost for 85% of optimal gains

* Value: EMNLP/ICLR submission, high citation potential

