In [1]:
!pip install -q -U transformers peft datasets accelerate bitsandbytes wandb evaluate

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.0/12.0 MB[0m [31m132.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m511.6/511.6 kB[0m [31m43.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m44.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.9/22.9 MB[0m [31m108.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m47.7/47.7 MB[0m [31m54.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
#@title Clear temp files { display-mode: "form" }
import shutil
import os
from pathlib import Path

# Clear old local Hugging Face cache
cache_dir = Path.home() / ".cache" / "huggingface"
if cache_dir.exists():
    shutil.rmtree(cache_dir, ignore_errors=True)
    print("✅ Local HF cache cleared.")

# Clear general temp files
shutil.rmtree("/tmp", ignore_errors=True)
print("✅ Local temp files cleared.")

✅ Local temp files cleared.


In [3]:
import os
import json
from typing import Dict, Optional, Union, List
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling
)
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from dataclasses import dataclass
from transformers.trainer_utils import EvalPrediction
from evaluate import load

In [4]:
@dataclass
class SmokeTestConfig:
    """Minimal config for smoke test."""
    model_name: str = "deepseek-ai/deepseek-moe-16b-base"
    load_in_4bit: bool = True
    lora_rank_general: int = 16
    lora_rank_math: int = 128
    lora_alpha_general: int = 32
    lora_alpha_math: int = 256
    lora_dropout: float = 0.05
    burst_weight: float = 1.5
    affinity_coef: float = 0.05
    math_ratio: float = 0.5
    num_experts: int = 64
    math_expert_ids: list = None

    def __post_init__(self):
        if self.math_expert_ids is None:
            self.math_expert_ids = [60, 61, 62, 63]

config = SmokeTestConfig()

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
# Model persistence path on Drive
DRIVE_SAVE_ROOT = "/content/drive/MyDrive/DeepSeek_Model"
MODEL_SAVE_PATH = os.path.join(DRIVE_SAVE_ROOT, config.model_name.split('/')[-1] + "_4bit")

os.makedirs(MODEL_SAVE_PATH, exist_ok=True)
print(f"Model persistence path set to: {MODEL_SAVE_PATH}")

Model persistence path set to: /content/drive/MyDrive/DeepSeek_Model/deepseek-moe-16b-base_4bit


In [7]:
# BITS AND BYTES CONFIGURATION
bnb_config = BitsAndBytesConfig(
    load_in_4bit=config.load_in_4bit,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

if os.path.exists(os.path.join(MODEL_SAVE_PATH, "config.json")):
    print("\nPERSISTENT CACHE FOUND. Loading from Google Drive...")

    # Load from Drive
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_SAVE_PATH,
        quantization_config=bnb_config,
        device_map="auto",
        dtype=torch.bfloat16,
        trust_remote_code=True,
        resume_download=True
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_SAVE_PATH)

else:
    print(f"NO CACHE. Starting initial download of {config.model_name}...")

    # Download from HF
    model = AutoModelForCausalLM.from_pretrained(
        config.model_name,
        quantization_config=bnb_config,
        device_map="auto",
        dtype=torch.bfloat16,
        trust_remote_code=True,
        resume_download=True
    )
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)

    # Save to Drive for next times
    print(f"⏳ Initial load complete. Saving quantized model to Drive at: {MODEL_SAVE_PATH}")
    model.save_pretrained(MODEL_SAVE_PATH)
    tokenizer.save_pretrained(MODEL_SAVE_PATH)
    print("✅ Model successfully saved to persistent Drive cache.")

tokenizer.pad_token = tokenizer.eos_token

print(f"✅ Model loaded: {model.config.model_type}, Total Parameters: {model.num_parameters():,}")
print(f"   Model dtype: {model.dtype}")


PERSISTENT CACHE FOUND. Loading from Google Drive...




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✅ Model loaded: deepseek, Total Parameters: 16,375,728,128
   Model dtype: torch.bfloat16


In [9]:
# Search for 'mlp.experts' or 'gate_proj' for MoE/router
moe_paths = [name for name, _ in model.named_modules() if 'mlp.experts' in name or 'gate_proj' in name]
print("\nMoE Paths (first 5):", moe_paths[:5])  # Expect mlp.experts.0.gate_proj etc.
gate_paths = [name for name, _ in model.named_modules() if 'gate_proj' in name]
print("Gate Paths (sample):", gate_paths[:3])  # Expect mlp.experts.X.gate_proj

✅ Model loaded: deepseek, Total Parameters: 16,375,728,128
   Model dtype: torch.bfloat16

MoE Paths (first 5): ['model.layers.0.mlp.gate_proj', 'model.layers.1.mlp.experts', 'model.layers.1.mlp.experts.0', 'model.layers.1.mlp.experts.0.gate_proj', 'model.layers.1.mlp.experts.0.up_proj']
Gate Paths (sample): ['model.layers.0.mlp.gate_proj', 'model.layers.1.mlp.experts.0.gate_proj', 'model.layers.1.mlp.experts.1.gate_proj']


In [20]:
import inspect  # FIXED: Import for source

layer0 = model.model.layers[0]
print("Layers:", len(model.model.layers))  # 28

print("\nDeepseekMLP Attributes (dir):")
print([attr for attr in dir(layer0.mlp) if not attr.startswith('_')])

print("\nMLP Source (inspect snippet):")
print(inspect.getsource(layer0.mlp.__class__.forward)[:500] + "...")  # Forward logic (routed creation)

print("\nLayer 0 Named Children:")
for name, child in layer0.named_children():
    print(f"  - {name}: {child.__class__.__name__} (params: {sum(p.numel() for p in child.parameters()):,})")

# Simulate forward to instantiate routed (dummy hidden state)
dummy_input = torch.zeros(1, 1, 2048).to(model.device)  # hidden_size=2048
with torch.no_grad():
    _ = layer0.mlp(dummy_input)

print("\nAfter Forward - Routed Experts:")
if hasattr(layer0.mlp, 'routed_experts'):
    print("Number:", len(layer0.mlp.routed_experts))  # 64
    first_routed = layer0.mlp.routed_experts[0]
    print("First Routed Components (dir):", [attr for attr in dir(first_routed) if not attr.startswith('_')])
    print("First Routed Params:", sum(p.numel() for p in first_routed.parameters()))
    last_routed = layer0.mlp.routed_experts[-1]
    print("Last Routed Components (dir):", [attr for attr in dir(last_routed) if not attr.startswith('_')])
else:
    print("Routed experts dynamic—no attr post-forward. Check source for topk route.")

print("\nShared Experts:")
if hasattr(layer0.mlp, 'shared_experts'):
    print("Number:", len(layer0.mlp.shared_experts))  # 2
    print("First Shared Components (dir):", [attr for attr in dir(layer0.mlp.shared_experts[0]) if not attr.startswith('_')])
else:
    print("Shared experts dynamic in forward.")

print("\nMLP Param Counts:")
print(f"Total MLP Params: {sum(p.numel() for p in layer0.mlp.parameters()):,}")

# FIXED MoE Sample Names (Broader Search)
moe_sample = [name for name, _ in model.named_modules() if 'routed_experts' in name or 'shared_experts' in name or 'gate_proj' in name][:10]
print("\nMoE Sample Names:", moe_sample)

# Full Layer 0 Breakdown (Recursive, Depth 3)
def print_structure(module, indent=0, max_depth=3):
    if indent > max_depth:
        return
    print("  " * indent + f"{module.__class__.__name__} (params: {sum(p.numel() for p in module.parameters()):,})")
    for name, child in module.named_children():
        print_structure(child, indent + 1, max_depth)

print("\nLayer 0 Full Structure (Depth 3):")
print_structure(layer0, 0, 3)

Layers: 28

DeepseekMLP Attributes (dir):
['T_destination', 'act_fn', 'add_module', 'apply', 'bfloat16', 'buffers', 'call_super_init', 'children', 'compile', 'config', 'cpu', 'cuda', 'double', 'down_proj', 'dump_patches', 'eval', 'extra_repr', 'float', 'forward', 'gate_proj', 'get_buffer', 'get_extra_state', 'get_parameter', 'get_submodule', 'half', 'hidden_size', 'intermediate_size', 'ipu', 'load_state_dict', 'modules', 'mtia', 'named_buffers', 'named_children', 'named_modules', 'named_parameters', 'parameters', 'register_backward_hook', 'register_buffer', 'register_forward_hook', 'register_forward_pre_hook', 'register_full_backward_hook', 'register_full_backward_pre_hook', 'register_load_state_dict_post_hook', 'register_load_state_dict_pre_hook', 'register_module', 'register_parameter', 'register_state_dict_post_hook', 'register_state_dict_pre_hook', 'requires_grad_', 'set_extra_state', 'set_submodule', 'share_memory', 'smart_apply', 'state_dict', 'to', 'to_empty', 'train', 'training

In [41]:
import torch
import torch.nn.functional as F
from datasets import load_dataset
from dataclasses import dataclass

# Configuration placeholder
class TempConfig:
    num_experts = 64
    math_expert_ids = [0, 1, 2, 3]

router_probs_all = []
config = TempConfig()

# --- 1. The Debug Hook (Modified to run on MLP Output) ---
def final_router_debug_hook_on_mlp(module, input, output):
    """Prints the full structure of the MLP output to identify the logits tensor."""
    global router_probs_all # Still need this global for the final working version

    print(f"\n--- DEBUG HOOK ON {module._get_name()} FIRED ---")

    # 🚨 CRITICAL: Print the output structure causing the error
    if isinstance(output, tuple):
        print(f"  Output: Tuple with {len(output)} elements.")
        for idx, item in enumerate(output):
            item_type = type(item)
            shape_info = f", Shape={item.shape}" if isinstance(item, torch.Tensor) else ""
            print(f"    Output[{idx}]: Type={item_type}{shape_info}")

            # Check for the expected Logits Shape [..., 64]
            # Logits can be at index 0, 1, or 2, but we need the tensor with 64 experts.
            if isinstance(item, torch.Tensor) and item.dim() >= 2 and item.shape[-1] == config.num_experts:
                 print(f"    ^^^^^ 🚨 LOGITS TENSOR FOUND AT INDEX {idx} 🚨 ^^^^^")
                 # We will NOT process it here to prevent the crash.
                 # Just printing is enough to give us the index.

    elif isinstance(output, torch.Tensor):
        print(f"  Output: Single Tensor. Shape={output.shape}")
        if output.dim() >= 2 and output.shape[-1] == config.num_experts:
             print(f"    ^^^^^ 🚨 LOGITS TENSOR FOUND (Single Output) 🚨 ^^^^^")
    else:
        print(f"  Output: Unexpected non-tensor/non-tuple type: {type(output)}")

    print("--- END DEBUG HOOK ---")

    # Return the original output immediately to minimize interference.
    return output


# --- 2. Print and Find Logits Function (Using MLP Hook) ---
def print_and_find_logits_mlp(model, tokenizer, config, math_samples=10, batch_size=2):

    global router_probs_all
    router_probs_all = []

    # --- Data Loading (Required to run model) ---
    try:
        gsm8k = load_dataset("gsm8k", "main", split=f"train[:{math_samples}]")
        math_texts = [f"Question: {ex['question']}" for ex in gsm8k]
    except Exception as e:
        print(f"🚨 Error loading dataset: {e}")
        return config.math_expert_ids

    model.eval()
    device = model.device

    # --- Hooking Setup ---
    handles = []

    # Target the entire MLP/MoE module: model.layers.X.mlp
    for layer in model.model.layers:
        if hasattr(layer, 'mlp'):
            # Hook the output of the entire MLP module
            handle = layer.mlp.register_forward_hook(final_router_debug_hook_on_mlp)
            handles.append(handle)

    if not handles:
        print("🚨 CRITICAL FAILURE: Cannot find 'layer.mlp' module.")
        return config.math_expert_ids

    # --- Forward Pass (Triggers Hooks) ---
    print(f"\n--- Running Forward Pass (Batch size {batch_size}, Total Layers: {len(handles)}) ---")
    with torch.no_grad():
        batch_texts = math_texts[:batch_size]
        batch_inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)

        # We only need one small pass to trigger the hooks and print the output
        try:
            _ = model(**batch_inputs, use_cache=False)
        except Exception as e:
            # Catch the error here to ensure all hook prints have time to execute before the crash
            print(f"\n--- Model Forward Pass Crashed ---")
            print(f"Error: {e}")
            print("--- Look Above for Logits Index Printout! ---")
            # Proceed to remove hooks and return, relying on the prints that ran before the crash.


    # Remove hooks
    for h in handles:
        h.remove()

    print("\n✅ Debug scan complete. If the model crashed, the critical information (Logits Index) should be printed above the crash message.")
    return config.math_expert_ids

# --- Execution ---
config = TempConfig()
config.math_expert_ids = print_and_find_logits_mlp(model, tokenizer, config)
print(f"Updated config.math_expert_ids: {config.math_expert_ids}")


--- Running Forward Pass (Batch size 2, Total Layers: 28) ---
🚨 Router Logits not found in MoE output tuple.

--- DEBUG HOOK ON DeepseekMLP FIRED ---
  Output: Single Tensor. Shape=torch.Size([2, 41, 2048])
--- END DEBUG HOOK ---

--- Model Forward Pass Crashed ---
Error: 'tuple' object has no attribute 'softmax'
--- Look Above for Logits Index Printout! ---

✅ Debug scan complete. If the model crashed, the critical information (Logits Index) should be printed above the crash message.
Updated config.math_expert_ids: [0, 1, 2, 3]


In [2]:
# Single-Cell: Non-Quant Expert Affinity Discovery + Contrast Scan (Math vs. General)
# Colab: A100 40GB+ (~32GB VRAM; 80GB safer). !pip install transformers datasets torch accelerate

import torch
import torch.nn.functional as F
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from dataclasses import dataclass
import math
import re  # For cleaning wiki text

@dataclass
class Config:
    num_experts: int = 64
    math_expert_ids: list = None

    def __post_init__(self):
        if self.math_expert_ids is None:
            self.math_expert_ids = [0, 1, 2, 3]

# Global for accumulation
router_probs_all = []

# Router Pre-Hook (Unchanged)
def router_calculation_hook(module, input):
    global router_probs_all
    hidden_states = input[0]
    if hidden_states.dim() == 3:
        reshaped_hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
    else:
        reshaped_hidden_states = hidden_states
    router_weight = module.weight
    raw_logits = F.linear(reshaped_hidden_states, router_weight)
    probs = F.softmax(raw_logits, dim=-1, dtype=torch.float32)
    layer_probs = probs.mean(dim=0, keepdim=True).cpu()
    router_probs_all.append(layer_probs)
    return input

# Affinity Scan Function (Generic: Math or General)
def pre_scan_affinity_nonquant(model, tokenizer, config, texts, scan_type="Math", math_samples=500, k=4, batch_size=8):
    global router_probs_all
    router_probs_all = []

    # Clean texts (for general: strip wiki markup)
    if scan_type == "General":
        texts = [re.sub(r'\[.*?\]', '', text).strip() for text in texts]  # Basic cleanup

    model.eval()
    device = next(model.parameters()).device

    # Hook
    handles = []
    for layer in model.model.layers:
        if hasattr(layer.mlp, 'gate'):
            handle = layer.mlp.gate.register_forward_pre_hook(router_calculation_hook)
            handles.append(handle)

    num_moe_layers = len(handles)
    estimated_batches = math.ceil(len(texts) / batch_size)
    print(f"\n--- {scan_type} Affinity Scan ---")
    print(f"Texts: {len(texts)}, Layers: {num_moe_layers}, Batches: {estimated_batches}")

    # Forward
    with torch.no_grad():
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i:i + batch_size]
            batch_inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
            _ = model(**batch_inputs, use_cache=False)
            if (i // batch_size + 1) % 20 == 0:
                print(f"  Processed {i // batch_size + 1}/{estimated_batches}...")

    # Cleanup
    for h in handles:
        h.remove()

    if not router_probs_all:
        print(f"🚨 No {scan_type.lower()} data—fallback.")
        return config.math_expert_ids

    # Aggregate
    stacked_probs = torch.cat(router_probs_all, dim=0)
    avg_affinity = stacked_probs.mean(dim=0)
    baseline = 1 / config.num_experts
    relative_affinity = avg_affinity / baseline

    # Output
    print(f"\n✅ {scan_type} discovery complete! Recorded: {len(router_probs_all)} layer calls")
    print(f"Baseline: {baseline:.4f}")
    print(f"\n{scan_type} Expert affinities:")

    expert_data = [{'id': i, 'affinity': avg_affinity[i].item(), 'relative': relative_affinity[i].item()} for i in range(config.num_experts)]
    expert_data.sort(key=lambda x: x['affinity'], reverse=True)

    top_k_results = avg_affinity.topk(k)
    top_k_experts = top_k_results.indices.tolist()

    for data in expert_data[:20]:  # Top 20 for brevity
        is_selected = data['id'] in top_k_experts
        bias_tag = " ← Math-inclined!" if data['relative'] > 1.1 else ""
        prefix = "👉" if is_selected else "  "
        print(f"{prefix} Expert {data['id']:2d}: {data['affinity']:.4f} ({data['relative']:.2f}x) {bias_tag}")

    avg_selected = avg_affinity[top_k_experts].mean().item()
    print(f"\n--- {scan_type} Summary ---")
    print(f"🎯 Top-{k}: {top_k_experts}")
    print(f"Avg selected: {avg_selected:.4f} ({avg_selected / baseline:.2f}x baseline)")

    return top_k_experts, avg_selected / baseline  # Return rel avg for contrast

# --- LOAD NON-QUANT MODEL (bf16, ~32GB) ---
print("⏳ Loading non-quant DeepSeek-MoE-16B (bf16)...")
model_name = "deepseek-ai/deepseek-moe-16b-base"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
print("✅ Model & tokenizer loaded. VRAM: ~32GB (40GB ok; 80GB safer).")

# Config
config = Config()

# --- MATH SCAN ---
print("\n=== MATH SCAN ===")
gsm8k = load_dataset("gsm8k", "main", split=f"train[:500]")
math_texts = [f"Question: {ex['question']}\nLet's think step by step." for ex in gsm8k]
math_top_k, math_rel_avg = pre_scan_affinity_nonquant(model, tokenizer, config, math_texts, scan_type="Math")

# --- CONTRAST SCAN (General: WikiText) ---
print("\n=== GENERAL CONTRAST SCAN ===")
wikitext = load_dataset("wikitext", "wikitext-2-raw-v1", split=f"train[:500]")
general_texts = [ex['text'] for ex in wikitext if ex['text'].strip()]  # Plain sentences
general_top_k, general_rel_avg = pre_scan_affinity_nonquant(model, tokenizer, config, general_texts, scan_type="General")

# --- CONTRAST SUMMARY ---
print("\n=== CONTRAST SUMMARY ===")
print(f"Math Top-4: {math_top_k} | Rel Avg: {math_rel_avg:.2f}x")
print(f"General Top-4: {general_top_k} | Rel Avg: {general_rel_avg:.2f}x")
overlap = len(set(math_top_k) & set(general_top_k))
print(f"Overlap in Top-4: {overlap}/4 (Math-specific if <2)")
if math_rel_avg > general_rel_avg * 1.1:
    print("✅ Math bias confirmed (domain signal strong).")
else:
    print("⚠️ Weak domain signal—recheck prompts/data.")

print(f"Recommended experts for burst: {math_top_k}")

⏳ Loading non-quant DeepSeek-MoE-16B (bf16)...


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

✅ Model & tokenizer loaded. VRAM: ~32GB (40GB ok; 80GB safer).

=== MATH SCAN ===

--- Math Affinity Scan ---
Texts: 500, Layers: 27, Batches: 63
  Processed 20/63...
  Processed 40/63...
  Processed 60/63...

✅ Math discovery complete! Recorded: 1701 layer calls
Baseline: 0.0156

Math Expert affinities:
👉 Expert 62: 0.0325 (2.08x)  ← Math-inclined!
👉 Expert 43: 0.0312 (2.00x)  ← Math-inclined!
👉 Expert 45: 0.0240 (1.54x)  ← Math-inclined!
👉 Expert 59: 0.0237 (1.52x)  ← Math-inclined!
   Expert 10: 0.0228 (1.46x)  ← Math-inclined!
   Expert 15: 0.0227 (1.45x)  ← Math-inclined!
   Expert  2: 0.0227 (1.45x)  ← Math-inclined!
   Expert 16: 0.0221 (1.41x)  ← Math-inclined!
   Expert 17: 0.0211 (1.35x)  ← Math-inclined!
   Expert  3: 0.0207 (1.33x)  ← Math-inclined!
   Expert 24: 0.0201 (1.29x)  ← Math-inclined!
   Expert  1: 0.0195 (1.25x)  ← Math-inclined!
   Expert 37: 0.0185 (1.18x)  ← Math-inclined!
   Expert 36: 0.0184 (1.18x)  ← Math-inclined!
   Expert 49: 0.0182 (1.17x)  ← Math-inc

README.md: 0.00B [00:00, ?B/s]

wikitext-2-raw-v1/test-00000-of-00001.pa(…):   0%|          | 0.00/733k [00:00<?, ?B/s]

wikitext-2-raw-v1/train-00000-of-00001.p(…):   0%|          | 0.00/6.36M [00:00<?, ?B/s]

wikitext-2-raw-v1/validation-00000-of-00(…):   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]


--- General Affinity Scan ---
Texts: 338, Layers: 27, Batches: 43
  Processed 20/43...
  Processed 40/43...

✅ General discovery complete! Recorded: 1161 layer calls
Baseline: 0.0156

General Expert affinities:
👉 Expert 62: 0.0402 (2.57x)  ← Math-inclined!
👉 Expert 43: 0.0358 (2.29x)  ← Math-inclined!
👉 Expert 45: 0.0298 (1.91x)  ← Math-inclined!
👉 Expert 10: 0.0281 (1.80x)  ← Math-inclined!
   Expert 59: 0.0274 (1.75x)  ← Math-inclined!
   Expert 16: 0.0262 (1.68x)  ← Math-inclined!
   Expert 17: 0.0258 (1.65x)  ← Math-inclined!
   Expert 15: 0.0258 (1.65x)  ← Math-inclined!
   Expert  2: 0.0252 (1.61x)  ← Math-inclined!
   Expert  3: 0.0251 (1.61x)  ← Math-inclined!
   Expert 24: 0.0241 (1.54x)  ← Math-inclined!
   Expert 36: 0.0219 (1.40x)  ← Math-inclined!
   Expert 49: 0.0199 (1.28x)  ← Math-inclined!
   Expert 22: 0.0198 (1.27x)  ← Math-inclined!
   Expert 37: 0.0193 (1.23x)  ← Math-inclined!
   Expert  1: 0.0191 (1.22x)  ← Math-inclined!
   Expert 63: 0.0173 (1.11x)  ← Math-inc

In [3]:
hellaswag = load_dataset("hellaswag", split="validation[:500]")
general_texts = [f"Premise: {ex['ctx_a']} {ex['ctx_b']} Sentence: {ex['endings'][0]}" for ex in hellaswag]  # Common-sense, no math
general_top_k, general_rel_avg = pre_scan_affinity_nonquant(model, tokenizer, config, general_texts, scan_type="General")


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/24.4M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/6.11M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/6.32M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/39905 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10003 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10042 [00:00<?, ? examples/s]


--- General Affinity Scan ---
Texts: 500, Layers: 27, Batches: 63
  Processed 20/63...
  Processed 40/63...
  Processed 60/63...

✅ General discovery complete! Recorded: 1701 layer calls
Baseline: 0.0156

General Expert affinities:
👉 Expert 62: 0.0308 (1.97x)  ← Math-inclined!
👉 Expert 43: 0.0297 (1.90x)  ← Math-inclined!
👉 Expert 45: 0.0234 (1.50x)  ← Math-inclined!
👉 Expert  3: 0.0221 (1.41x)  ← Math-inclined!
   Expert 15: 0.0219 (1.40x)  ← Math-inclined!
   Expert 10: 0.0216 (1.38x)  ← Math-inclined!
   Expert 59: 0.0216 (1.38x)  ← Math-inclined!
   Expert 16: 0.0207 (1.32x)  ← Math-inclined!
   Expert 24: 0.0205 (1.31x)  ← Math-inclined!
   Expert 17: 0.0200 (1.28x)  ← Math-inclined!
   Expert  2: 0.0200 (1.28x)  ← Math-inclined!
   Expert 36: 0.0192 (1.23x)  ← Math-inclined!
   Expert  1: 0.0191 (1.22x)  ← Math-inclined!
   Expert 22: 0.0184 (1.18x)  ← Math-inclined!
   Expert 49: 0.0177 (1.13x)  ← Math-inclined!
   Expert 37: 0.0177 (1.13x)  ← Math-inclined!
   Expert 63: 0.017

In [4]:
# --- CONTRAST SUMMARY ---
print("\n=== CONTRAST SUMMARY ===")
print(f"Math Top-4: {math_top_k} | Rel Avg: {math_rel_avg:.2f}x")
print(f"General Top-4: {general_top_k} | Rel Avg: {general_rel_avg:.2f}x")
overlap = len(set(math_top_k) & set(general_top_k))
print(f"Overlap in Top-4: {overlap}/4 (Math-specific if <2)")
if math_rel_avg > general_rel_avg * 1.1:
    print("✅ Math bias confirmed (domain signal strong).")
else:
    print("⚠️ Weak domain signal—recheck prompts/data.")

print(f"Recommended experts for burst: {math_top_k}")


=== CONTRAST SUMMARY ===
Math Top-4: [62, 43, 45, 59] | Rel Avg: 1.78x
General Top-4: [62, 43, 45, 3] | Rel Avg: 1.70x
Overlap in Top-4: 3/4 (Math-specific if <2)
⚠️ Weak domain signal—recheck prompts/data.
Recommended experts for burst: [62, 43, 45, 59]


In [12]:
# Test Router Hook (Toy Forward)
toy_text = "Solve 2+2="  # Simple math token
inputs = tokenizer(toy_text, return_tensors="pt", max_length=20, truncation=True).to(model.device)

class DummyTrainer:
    def __init__(self, config, model):
        self.config = config
        self.router_logits_buffer = []
        self.model = model
        # FIXED: Hook on global mlp.gate_proj + sample experts
        for layer in self.model.model.layers:
            if hasattr(layer, 'mlp') and hasattr(layer.mlp, 'gate_proj'):
                gate = layer.mlp.gate_proj
                def make_hook(idx):
                    def hook(module, input, output):
                        # FIXED: Squeeze to [tokens, 1] (remove hidden dim)
                        squeezed = output.mean(dim=-1).detach().cpu() if output.dim() > 2 else output.detach().cpu()
                        self.router_logits_buffer.append(squeezed)
                    return hook
                gate.register_forward_hook(make_hook(0))
            # Sample 4 experts per layer (mem-safe)
            if hasattr(layer, 'mlp') and hasattr(layer.mlp, 'experts') and len(layer.mlp.experts) > 0:
                for e_id in [0, 1, 62, 63]:
                    if hasattr(layer.mlp.experts[e_id], 'gate_proj'):
                        expert_gate = layer.mlp.experts[e_id].gate_proj
                        expert_gate.register_forward_hook(make_hook(e_id))

dt = DummyTrainer(config, model)
model.eval()
with torch.no_grad():
    _ = model(**inputs, use_cache=False)  # FIXED: Avoid cache error

print(f"✅ Hook fired: {len(dt.router_logits_buffer)} tensors")
if dt.router_logits_buffer:
    sample = dt.router_logits_buffer[0]
    print(f"Shape: {sample.shape} (tokens, 1 gate)")
    print(f"Sample logits: {sample[0][:3]}...")  # Scalar ~ -0.5 to 0.5


✅ Hook fired: 52 tensors
Shape: torch.Size([1, 8]) (tokens, 1 gate)
Sample logits: tensor([ 0.0022,  0.0332, -0.0032], dtype=torch.bfloat16)...


In [20]:
# Fixed Mini B0 + Eval: No use_cache in Args
from peft import LoraConfig, get_peft_model, PeftModel
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import load_dataset

# B0 Config
lora_config_b0 = LoraConfig(
    r=config.lora_rank_general,
    lora_alpha=config.lora_alpha_general,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=config.lora_dropout,
    bias="none",
    task_type="CAUSAL_LM"
)

# Safe unload
if isinstance(model, PeftModel):
    model = model.unload()
model_b0 = get_peft_model(model, lora_config_b0)
model_b0.print_trainable_parameters()

# FIXED: Disable cache
model_b0.config.use_cache = False
model_b0.enable_input_require_grads()
model_b0.gradient_checkpointing_enable()

# Data (fixed padding + early clean)
gsm8k_mini = load_dataset("gsm8k", "main", split="train[:5]")
def format_mini(ex):
    return {"text": f"Question: {ex['question']}\nAnswer: {ex['answer']}<|endoftext|>"}

mini_data = gsm8k_mini.map(format_mini, remove_columns=gsm8k_mini.column_names)

def tokenize_mini(ex):
    tokenized = tokenizer(ex["text"], truncation=True, max_length=256, padding="max_length")
    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized

mini_tokenized = mini_data.map(tokenize_mini, batched=False, remove_columns=mini_data.column_names)

# Split
train_tokenized = mini_tokenized.select(range(3))
eval_tokenized = mini_tokenized.select(range(3,5))

# Args (no use_cache)
mini_args = TrainingArguments(
    output_dir="/tmp/mini_b0",
    max_steps=50,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    learning_rate=2e-4,
    bf16=True,
    logging_steps=5,
    eval_strategy="steps",
    eval_steps=25,
    dataloader_drop_last=True,
    gradient_checkpointing=True,
    dataloader_pin_memory=False,
    optim="adamw_torch",
    report_to="none",
    remove_unused_columns=False
)

collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

trainer_mini = Trainer(
    model=model_b0,
    args=mini_args,
    train_dataset=train_tokenized,
    eval_dataset=eval_tokenized,
    data_collator=collator,
)
trainer_mini.train()

print("✅ B0 Mini done.")
print("Loss history:", [log.get('train_loss', log.get('eval_loss', 'N/A')) for log in trainer_mini.state.log_history])

# Eval
def eval_gsm8k(model, tokenizer, test_data, num_samples=2):
    correct = 0
    model.eval()
    for ex in test_data:
        prompt = f"Question: {ex['question']}\nAnswer:"
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False, temperature=0.1, use_cache=False)
        gen = tokenizer.decode(outputs[0], skip_special_tokens=True)
        answer = ex['answer'].split('####')[0].strip()
        correct += 1 if answer in gen else 0
    return correct / num_samples

acc = eval_gsm8k(model_b0, tokenizer, gsm8k_mini.select(range(3,5)))
print(f"Eval Acc: {acc:.2f}")

trainable params: 300,921,856 || all params: 16,676,649,984 || trainable%: 1.8045


Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Map:   0%|          | 0/5 [00:00<?, ? examples/s]

Step,Training Loss,Validation Loss
25,0.0153,No log
50,0.0095,No log


✅ B0 Mini done.
Loss history: ['N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 'N/A', 0.24707426130771637]


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Setting `pad_token_id` to `eos_token_id`:100001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:100001 for open-end generation.


Eval Acc: 0.00


In [8]:
# Smoke Test: B5 First (OOM Test), Then B0
import os
import torch
from peft import LoraConfig, get_peft_model, PeftModel
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import load_dataset, concatenate_datasets

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

discovered_ids = [62, 43, 45, 59]

# Data (256 len)
gsm8k = load_dataset("gsm8k", "main", split="train[:500]")
hellaswag = load_dataset("hellaswag", split="validation[:500]")

def format_math(ex):
    return {"text": f"Question: {ex['question']}\nAnswer: {ex['answer']}<|endoftext|>"}

def format_general(ex):
    return {"text": f"Premise: {ex['ctx_a']} {ex['ctx_b']} Sentence: {ex['endings'][0]}<|endoftext|>"}

math_data = gsm8k.map(format_math)
general_data = hellaswag.map(format_general)

mixed_data = concatenate_datasets([math_data.select(range(250)), general_data.select(range(250))])

def tokenize(ex):
    tokenized = tokenizer(ex["text"], truncation=True, max_length=256, padding="max_length")
    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized

mixed_tokenized = mixed_data.map(tokenize, batched=False, remove_columns=mixed_data.column_names)
math_tokenized = math_data.map(tokenize, batched=False, remove_columns=math_data.column_names)

collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

# B5 Proxy First: r=128, 100% math (OOM test)
if isinstance(model, PeftModel):
    model = model.unload()
torch.cuda.empty_cache()

lora_config_b5 = LoraConfig(r=128, lora_alpha=256, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")
model_b5 = get_peft_model(model, lora_config_b5)
model_b5.enable_input_require_grads()
model_b5.gradient_checkpointing_enable()
model_b5.config.use_cache = False
model_b5.print_trainable_parameters()

smoke_args = TrainingArguments(
    output_dir="/tmp/smoke_b5",
    num_train_epochs=2,
    per_device_train_batch_size=2,  # Low for OOM
    gradient_accumulation_steps=8,  # Effective 16
    learning_rate=2e-4,
    bf16=True,
    logging_steps=10,  # Frequent for monitor
    eval_strategy="no",
    dataloader_drop_last=True,
    gradient_checkpointing=False,  # Off for VRAM
    dataloader_pin_memory=False,
    optim="adamw_torch",
    report_to="none",
    remove_unused_columns=False,
    save_total_limit=1
)

trainer_b5 = Trainer(model=model_b5, args=smoke_args, train_dataset=math_tokenized, data_collator=collator)
trainer_b5.train()
b5_loss = trainer_b5.state.log_history[-1].get('train_loss', 'N/A')
print(f"✅ B5 Done. Final Loss: {b5_loss}")

# Del B5
trainer_b5.save_model("/tmp/b5_ckpt")
del model_b5, trainer_b5
torch.cuda.empty_cache()

# B0: Vanilla r=16, mixed data (after B5)
lora_config_b0 = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")
model_b0 = get_peft_model(model, lora_config_b0)
model_b0.enable_input_require_grads()
model_b0.gradient_checkpointing_enable()
model_b0.config.use_cache = False
model_b0.print_trainable_parameters()

smoke_args.output_dir = "/tmp/smoke_b0"
trainer_b0 = Trainer(model=model_b0, args=smoke_args, train_dataset=mixed_tokenized, data_collator=collator)
trainer_b0.train()
b0_loss = trainer_b0.state.log_history[-1].get('train_loss', 'N/A')
print(f"✅ B0 Done. Final Loss: {b0_loss}")

# Affinity Check (Load saved)
from peft import PeftModel

model_b0_loaded = PeftModel.from_pretrained(model, "/tmp/smoke_b0")
model_b5_loaded = PeftModel.from_pretrained(model, "/tmp/b5_ckpt")

math_prompts = [f"Question: {ex['question']}\nLet's think step by step." for ex in gsm8k[:50]]

def check_affinity(model, tokenizer, texts, ids, num_prompts=50):
    router_probs_all = []
    handles = []
    for layer in model.model.layers:
        if hasattr(layer.mlp, 'gate'):
            def hook_fn(module, input):
                hidden = input[0].view(-1, input[0].shape[-1])
                logits = F.linear(hidden, module.weight)
                probs = F.softmax(logits, dim=-1)
                router_probs_all.append(probs.mean(dim=0, keepdim=True).cpu())
                return input
            handle = layer.mlp.gate.register_forward_pre_hook(hook_fn)
            handles.append(handle)

    model.eval()
    with torch.no_grad():
        for text in texts[:num_prompts]:
            inputs = tokenizer(text, return_tensors="pt").to(model.device)
            _ = model(**inputs, use_cache=False)

    for h in handles:
        h.remove()

    if router_probs_all:
        stacked = torch.cat(router_probs_all, dim=0)
        avg_aff = stacked.mean(dim=0)
        aff_to_ids = (avg_aff[ids].sum() / len(ids)).item()
        return aff_to_ids
    return 0.0

aff_b0 = check_affinity(model_b0_loaded, tokenizer, math_prompts, discovered_ids)
aff_b5 = check_affinity(model_b5_loaded, tokenizer, math_prompts, discovered_ids)

print(f"B0 Affinity: {aff_b0:.1%} | B5: {aff_b5:.1%} (Target B5 > B0 +5%)")

# Clean
del model_b0_loaded, model_b5_loaded
torch.cuda.empty_cache()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

main/train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

main/test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/24.4M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/6.11M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/6.32M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/39905 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/10003 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10042 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

trainable params: 2,407,374,848 || all params: 18,783,102,976 || trainable%: 12.8167


Step,Training Loss
10,1.0189
20,0.9486
30,0.9207
40,0.667
50,0.5726
60,0.582


✅ B5 Done. Final Loss: 0.769857294857502




trainable params: 300,921,856 || all params: 16,676,649,984 || trainable%: 1.8045


Step,Training Loss
10,1.8847
20,1.4426
30,1.4208
40,1.363
50,1.2066
60,1.28


✅ B0 Done. Final Loss: 1.4112832099199295


ValueError: Can't find 'adapter_config.json' at '/tmp/smoke_b0'

In [33]:
import math
from datasets import load_dataset
import torch
import torch.nn.functional as F

# --- AFFINITY FUNCTION (Final Path Fix) ---
def check_affinity(model, tokenizer, texts, ids, num_prompts=50):
    router_probs_all = []
    handles = []

    # Get the underlying DeepseekForCausalLM object
    if hasattr(model, 'base_model') and hasattr(model.base_model, 'base_model'):
        hf_model_core = model.base_model.base_model
    else:
        hf_model_core = model

    # Access the layers list
    if hasattr(hf_model_core, 'model'):
        layers = hf_model_core.model.layers
    elif hasattr(hf_model_core, 'layers'):
        layers = hf_model_core.layers
    else:
        raise AttributeError("Failed to find layers list on the underlying model object.")

    # Hooking the Gate
    for layer in layers:
        if hasattr(layer.mlp, 'gate'):
            def hook_fn(module, input):
                hidden = input[0].view(-1, input[0].shape[-1])
                logits = F.linear(hidden, module.weight)
                probs = F.softmax(logits, dim=-1)
                router_probs_all.append(probs.mean(dim=0, keepdim=True).cpu())
                return input
            handle = layer.mlp.gate.register_forward_pre_hook(hook_fn)
            handles.append(handle)

    # ... (rest of the forward pass is correct)
    model.eval()
    device = next(model.parameters()).device

    with torch.no_grad():
        for text in texts[:num_prompts]:
            inputs = tokenizer(text, return_tensors="pt").to(device)
            _ = model(**inputs, use_cache=False)

    for h in handles:
        h.remove()

    if router_probs_all:
        stacked = torch.cat(router_probs_all, dim=0)
        avg_aff = stacked.mean(dim=0)
        aff_to_ids = avg_aff[ids].sum().item()
        return aff_to_ids
    return 0.0

discovered_ids = [62, 43, 45, 59]
math_prompts = [f"Question: {ex['question']}\nLet's think step by step." for ex in load_dataset("gsm8k", "main", split="train[:50]")]

aff_b0 = check_affinity(trainer_b0.model, tokenizer, math_prompts, discovered_ids)
print(f"B0 Affinity to IDs: {aff_b0:.1%} (Baseline ~6.25%)")

# Perplexity (Sequential, B0 Only)
gsm8k_test = load_dataset("gsm8k", "main", split="test[:50]")
def format_ppl(ex):
    text = f"Question: {ex['question']}\nAnswer: {ex['answer']}<|endoftext|>"
    tokenized = tokenizer(text, truncation=True, max_length=256, padding="max_length")
    return {"input_ids": tokenized["input_ids"], "labels": tokenized["input_ids"].copy()}

gsm8k_test = gsm8k_test.map(format_ppl, batched=False, remove_columns=gsm8k_test.column_names)

# ... (Perplexity function is correct, assuming you have it defined)
def perplexity(model, tokenizer, test_data, ignore_index=-100):
    model.eval()
    total_loss = 0
    # IMPORTANT: The loss must be calculated using the same logic as the Trainer
    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignore_index)

    with torch.no_grad():
        for ex in test_data:
            inputs = {"input_ids": torch.tensor([ex["input_ids"]]).to(model.device)}
            # The labels list already has the prompt and pad tokens set to -100
            # (or should be adjusted in the mapping function).
            labels = torch.tensor([ex["labels"]]).to(model.device)

            outputs = model(**inputs, use_cache=False)

            # 1. Shift logits for causal loss (LMs predict the next token)
            shift_logits = outputs.logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            # 2. Calculate loss (ignoring -100 labels)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            total_loss += loss.item()

    # Return PPL
    return math.exp(total_loss / len(test_data))

# HellaSwag (General)
hellaswag_test = load_dataset("hellaswag", split="validation[:50]")
def format_hellaswag(ex):
    # --- FIXED: Convert string label to integer index ---
    label_index = int(ex['label'])
    text = f"Premise: {ex['ctx_a']} {ex['ctx_b']} Sentence: {ex['endings'][label_index]}<|endoftext|>"
    tokenized = tokenizer(text, truncation=True, max_length=256, padding="max_length")
    return {"input_ids": tokenized["input_ids"], "labels": tokenized["input_ids"].copy()}

hellaswag_test = hellaswag_test.map(format_hellaswag, batched=False, remove_columns=hellaswag_test.column_names)
ppl_b0_hella = perplexity(trainer_b0.model, tokenizer, hellaswag_test)
print(f"B0 HellaSwag PPL: {ppl_b0_hella:.2f}")

# Summary
print(f"B0 Baseline: Affinity {aff_b0:.1%}, Math PPL {ppl_b0_math:.2f}, General PPL {ppl_b0_hella:.2f}")

B0 Affinity to IDs: 6.7% (Baseline ~6.25%)


Map:   0%|          | 0/50 [00:00<?, ? examples/s]

Map:   0%|          | 0/50 [00:00<?, ? examples/s]

B0 HellaSwag PPL: 1754.96
B0 Baseline: Affinity 6.7%, Math PPL 304.05, General PPL 1754.96


In [27]:
# Fixed Affinity Check: Nested PEFT Path (base_model.base_model.model.layers)
import math
from datasets import load_dataset
import torch.nn.functional as F

math_prompts = [f"Question: {ex['question']}\nLet's think step by step." for ex in load_dataset("gsm8k", "main", split="train[:50]")]

def check_affinity(model, tokenizer, texts, ids, num_prompts=50):
    router_probs_all = []
    handles = []

    # Get the underlying Hugging Face model object
    # Path: PeftModelForCausalLM -> LoraModel -> DeepseekForCausalLM
    if hasattr(model, 'base_model') and hasattr(model.base_model, 'base_model'):
        # This is the path for PEFT-wrapped models (trainer_b0/b5)
        hf_model = model.base_model.base_model
    else:
        # Fallback for the raw model object (DeepseekForCausalLM)
        hf_model = model

    # The actual layers list is typically under the 'model' attribute of the HF model wrapper
    # DeepseekForCausalLM -> model -> layers
    if hasattr(hf_model, 'model'):
        layers = hf_model.model.layers
    elif hasattr(hf_model, 'layers'):
        layers = hf_model.layers
    else:
        raise AttributeError("Could not find layers in the model structure (expected .model.layers or .layers).")

    for layer in layers:
        if hasattr(layer.mlp, 'gate'):
            def hook_fn(module, input):
                hidden = input[0].view(-1, input[0].shape[-1])
                logits = F.linear(hidden, module.weight)
                probs = F.softmax(logits, dim=-1)
                router_probs_all.append(probs.mean(dim=0, keepdim=True).cpu())
                return input
            handle = layer.mlp.gate.register_forward_pre_hook(hook_fn)
            handles.append(handle)

    model.eval()
    with torch.no_grad():
        for text in texts[:num_prompts]:
            inputs = tokenizer(text, return_tensors="pt").to(model.device)
            _ = model(**inputs, use_cache=False)

    for h in handles:
        h.remove()

    if router_probs_all:
        stacked = torch.cat(router_probs_all, dim=0)
        avg_aff = stacked.mean(dim=0)
        # Average probability across the 4 specific experts
        aff_to_ids = (avg_aff[ids].sum() / len(ids)).item()
        return aff_to_ids
    return 0.0

discovered_ids = [62, 43, 45, 59]

aff_b0 = check_affinity(trainer_b0.model, tokenizer, math_prompts, discovered_ids)
aff_b5 = check_affinity(trainer_b5.model, tokenizer, math_prompts, discovered_ids)

print(f"B0 Affinity: {aff_b0:.1%} | B5: {aff_b5:.1%} (Target B5 > B0 +5%)")

NameError: name 'trainer_b5' is not defined

In [24]:
# Debug Model Structure
print("=== Trainer Model Type ===")
print(f"Type: {type(trainer_b0.model)}")
print(f"Has base_model: {hasattr(trainer_b0.model, 'base_model')}")
if hasattr(trainer_b0.model, 'base_model'):
    print(f"base_model type: {type(trainer_b0.model.base_model)}")
    print(f"base_model has layers: {hasattr(trainer_b0.model.base_model, 'layers')}")
    if hasattr(trainer_b0.model.base_model, 'layers'):
        print(f"layers type: {type(trainer_b0.model.base_model.layers)}")
        print(f"Num layers: {len(trainer_b0.model.base_model.layers)}")
        print(f"First layer mlp gate: {hasattr(trainer_b0.model.base_model.layers[0].mlp, 'gate')}")
print("\nDir slice:", [attr for attr in dir(trainer_b0.model) if 'base' in attr or 'model' in attr][:5])

=== Trainer Model Type ===
Type: <class 'peft.peft_model.PeftModelForCausalLM'>
Has base_model: True
base_model type: <class 'peft.tuners.lora.model.LoraModel'>
base_model has layers: False

Dir slice: ['_get_base_model_class', '_get_peft_specific_model_tags', '_prepare_model_for_gradient_checkpointing', 'base_model', 'base_model_prepare_inputs_for_generation']


In [None]:
# Full B0 Smoke (100 Steps, PPL)
from transformers import DataCollatorForLanguageModeling
from evaluate import load

def prepare_smoke_data(config):
    gsm8k = load_dataset("gsm8k", "main", split="train[:200]")
    hellaswag = load_dataset("hellaswag", split="train[:200]")
    def format_math(ex):
        return {"text": f"Question: {ex['question']}\nAnswer: {ex['answer']}", "domain_id": 1}
    def format_general(ex):
        return {"text": f"{ex['ctx_a']} {ex['endings'][0]}", "domain_id": 0}
    math_data = gsm8k.map(format_math)
    general_data = hellaswag.map(format_general)
    from datasets import concatenate_datasets
    train_data = concatenate_datasets([math_data, general_data]).shuffle(seed=42)
    eval_data = gsm8k.map(format_math, split="test[:50]")
    return train_data, eval_data

def tokenize_data(dataset, tokenizer, config):
    def tokenize(ex):
        out = tokenizer(ex["text"], truncation=True, max_length=config.max_seq_length, padding="max_length")
        out["domain_id"] = ex["domain_id"]
        out["labels"] = out["input_ids"].copy()
        return out
    return dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)

# Data
train_data, eval_data = prepare_smoke_data(config)
train_tokenized = tokenize_data(train_data, tokenizer, config)
eval_tokenized = tokenize_data(eval_data, tokenizer, config)

# B0 Full
args_b0 = TrainingArguments(
    output_dir="/tmp/b0_smoke",
    max_steps=config.max_steps,  # 100
    per_device_train_batch_size=config.batch_size,  # 2
    gradient_accumulation_steps=config.gradient_accumulation_steps,  # 4
    learning_rate=config.learning_rate,
    bf16=True,
    logging_steps=10,
    eval_steps=50,
    gradient_checkpointing=True,
    remove_unused_columns=False
)

class B0Trainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        domain_ids = inputs.pop("domain_id", None)
        return super().compute_loss(model, inputs, return_outputs)

trainer_b0 = B0Trainer(
    model=model_b0,
    args=args_b0,
    train_dataset=train_tokenized,
    eval_dataset=eval_tokenized,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
    compute_metrics=lambda p: {"ppl": np.exp(p.metrics.get('eval_loss', 0.0))}
)
trainer_b0.train()
eval_b0 = trainer_b0.evaluate()
affinity_b0 = 12.5  # Baseline proxy (no bursts)
ppl_b0 = eval_b0['eval_perplexity']
print(f"✅ B0 Smoke: PPL {ppl_b0:.2f}, Affinity ~{affinity_b0}%")

In [None]:
# B5 Smoke (Proxy Hetero + Bursts)
# Proxy hetero: High rank uniform (Option A)
lora_config_b5 = LoraConfig(
    r=config.lora_rank_math,  # High rank proxy
    lora_alpha=config.lora_alpha_math,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "w1", "w2", "w3"],
    lora_dropout=config.lora_dropout,
    bias="none",
    task_type="CAUSAL_LM"
)
model_b5 = get_peft_model(model, lora_config_b5)
model_b5.print_trainable_parameters()

# B5 Trainer (with bursts/affinity)
class B5Trainer(B0Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.router_logits_buffer = []
        # Hook
        base_model = self.model.base_model.model if hasattr(self.model, 'base_model') else self.model.model
        for layer in base_model.layers:
            if hasattr(layer, 'block_sparse_moe'):
                gate = layer.block_sparse_moe.gate
                def make_hook(self=self):
                    def hook(module, input, output):
                        self.router_logits_buffer.append(output.detach().cpu())
                    return hook
                gate.register_forward_hook(make_hook())

    def compute_loss(self, model, inputs, return_outputs=False):
        domain_ids = inputs.pop("domain_id", None)
        self.router_logits_buffer = []  # Clear
        outputs = model(**inputs)
        loss = outputs.loss

        if self.router_logits_buffer:
            try:
                router_logits = torch.stack(self.router_logits_buffer).mean(0)
                batch_size, seq_len = domain_ids.shape
                router_logits = router_logits.view(batch_size, seq_len, -1)
                # Affinity
                expert_choices = router_logits.argmax(-1)
                is_math = (domain_ids == 1)
                went_to_math = torch.isin(expert_choices, torch.tensor(self.config.math_expert_ids))
                affinity = (went_to_math[is_math].float().mean().item() * 100)
                self.affinity_history = getattr(self, 'affinity_history', []) + [affinity]

                # Affinity loss
                is_math_flat = is_math.view(-1)
                if is_math_flat.any():
                    target = torch.zeros_like(router_logits.view(-1, router_logits.size(-1)))
                    for eid in self.config.math_expert_ids:
                        target[:, eid] = 1.0 / len(self.config.math_expert_ids)
                    router_probs = F.softmax(router_logits.view(-1, router_logits.size(-1))[is_math_flat], dim=-1)
                    target_probs = target[is_math_flat]
                    affinity_loss = F.kl_div(router_probs.log(), target_probs, reduction='batchmean')

                    # Masked burst
                    boost_mask = torch.where(is_math & went_to_math, self.config.burst_weight, 1.0)
                    masked_loss = loss * boost_mask.mean()

                    loss = masked_loss + self.config.affinity_coef * affinity_loss

                if self.state.global_step % 10 == 0:
                    self.log({'affinity': affinity, 'masked_loss': masked_loss.item() if 'masked_loss' in locals() else loss.item()})
            except Exception as e:
                print(f"Router skip: {e}")

        return (loss, outputs) if return_outputs else loss

trainer_b5 = B5Trainer(
    model=model_b5,
    args=args_b0,  # Reuse args
    train_dataset=train_tokenized,
    eval_dataset=eval_tokenized,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
    compute_metrics=lambda p: {"ppl": np.exp(p.metrics.get('eval_loss', 0.0))}
)
trainer_b5.train()
eval_b5 = trainer_b5.evaluate()
affinity_b5 = trainer_b5.affinity_history[-1] if trainer_b5.affinity_history else 12.5
ppl_b5 = eval_b5['eval_perplexity']
print(f"✅ B5 Smoke: PPL {ppl_b5:.2f}, Affinity {affinity_b5:.1f}%")

In [None]:
# Compare Verdict
improvement = affinity_b5 - affinity_b0
ppl_drop = ppl_b0 - ppl_b5
print(f"\nVERDICT:")
if improvement > 2.0:
    print(f"✅ SIGNAL: Affinity +{improvement:.1f}%, PPL drop {ppl_drop:.2f} → Proceed!")
else:
    print(f"⚠️ WEAK: Affinity +{improvement:.1f}%, PPL drop {ppl_drop:.2f} → Tweak λ")