In [None]:
import torch
import numpy as np
import plotly.graph_objects as go
from collections import defaultdict
import torch.nn.functional as F
import json
import os
from typing import Dict, List, Tuple
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
def get_device():
    """Get the optimal available device"""
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        # Enable TF32 for better performance on Ampere GPUs (A100, A6000, etc)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        # Set memory allocation settings
        torch.cuda.empty_cache()
        # Enable CUDNN benchmarking for better performance
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")
    return device

In [None]:
class MOEExpertLens:
    def __init__(self, model_name: str, device=None):
        """Initialize the MoE Expert analyzer.
        
        Args:
            model_name: Name/path of the DeepSeek MoE model
            device: Optional device to use. If None, will use get_device()
        """
        self.device = device if device is not None else get_device()
        
        # Load model and move to device
        self.model = AutoModelForCausalLM.from_pretrained(model_name)
        self.model.to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model.eval()
        
        # Get the embedding matrix for reverse projection
        self.embed_matrix = self.model.get_input_embeddings().weight
        
        # Store expert outputs and routing decisions
        self.expert_outputs = []  # Will store (layer_idx, token_idx, expert_idx, output)
        self.selected_experts = []  # Will store (layer_idx, token_idx, [expert_indices])
        self.register_hooks()

    def register_hooks(self):
        """Register forward hooks to capture expert outputs and routing decisions."""
        def expert_output_hook(module, input, output, expert_idx):
            """Hook to capture individual expert outputs."""
            # Store the raw expert output before any aggregation
            self.expert_outputs.append((
                self.current_layer,
                self.current_token,
                expert_idx,
                output.detach().cpu()  # Move to CPU to save GPU memory
            ))
            return output

        def router_hook(module, input, output):
            """Hook to capture router decisions."""
            # Get top-7 expert selections for each token
            gate_logits = output
            top_k_experts = torch.topk(gate_logits, k=7, dim=-1)
            
            for token_idx, token_experts in enumerate(top_k_experts.indices):
                self.selected_experts.append((
                    self.current_layer,
                    token_idx,
                    token_experts.cpu().tolist()  # Move to CPU to save GPU memory
                ))

        # Register hooks for each layer and expert
        for layer_idx, layer in enumerate(self.model.model.layers):
            # Hook for router
            layer.mlp.gate.register_forward_hook(router_hook)
            
            # Hook for each expert
            for expert_idx in range(63):  # Excluding shared expert
                expert = layer.mlp.experts[expert_idx]
                expert.register_forward_hook(
                    lambda mod, inp, out, ei=expert_idx: expert_output_hook(mod, inp, out, ei)
                )

    def _project_to_vocab(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """Project hidden states to vocabulary space using embedding matrix."""
        hidden_states = hidden_states.to(self.device)
        return F.linear(hidden_states, self.embed_matrix)

    def get_top_k_tokens(self, logits: torch.Tensor, k: int = 5) -> List[Tuple[str, float]]:
        """Convert logits to top-k tokens and their probabilities."""
        probs = F.softmax(logits, dim=-1)
        top_k_values, top_k_indices = torch.topk(probs, k, dim=-1)
        
        results = []
        for val, idx in zip(top_k_values.cpu(), top_k_indices.cpu()):
            token = self.tokenizer.decode([idx])
            results.append((token, val.item()))
        return results

    @torch.cuda.amp.autocast()  # Enable automatic mixed precision
    def analyze_text(self, text: str, k: int = 5) -> Dict:
        """Perform expert analysis on input text.
        
        Args:
            text: Input text to analyze
            k: Number of top tokens to return for each expert
            
        Returns:
            Dictionary containing analysis results for each layer
        """
        self.expert_outputs = []
        self.selected_experts = []
        
        # Process input
        inputs = self.tokenizer(text, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        # Run inference while collecting expert outputs
        with torch.no_grad():
            for layer_idx in range(len(self.model.model.layers)):
                self.current_layer = layer_idx
                for token_idx in range(inputs["input_ids"].size(1)):
                    self.current_token = token_idx
                    _ = self.model(**inputs)

        results = {
            "layers": {}
        }

        # Analyze expert outputs for each layer
        for layer_idx in range(len(self.model.model.layers)):
            layer_results = {
                "tokens": {}
            }
            
            # For each token position
            for token_idx in range(inputs["input_ids"].size(1)):
                token = self.tokenizer.decode([inputs["input_ids"][0][token_idx].cpu()])
                token_experts = []
                
                # Get selected experts for this token
                selected = [x for x in self.selected_experts 
                          if x[0] == layer_idx and x[1] == token_idx][0][2]
                
                # Analyze output of each selected expert
                for expert_idx in selected:
                    expert_output = [x for x in self.expert_outputs 
                                   if x[0] == layer_idx and 
                                   x[1] == token_idx and 
                                   x[2] == expert_idx]
                    
                    if expert_output:
                        # Project expert output to vocab space
                        logits = self._project_to_vocab(expert_output[0][3])
                        top_tokens = self.get_top_k_tokens(logits[0], k)
                        
                        token_experts.append({
                            "expert_id": expert_idx,
                            "top_tokens": top_tokens
                        })
                
                layer_results["tokens"][token] = {
                    "position": token_idx,
                    "selected_experts": selected,
                    "expert_interpretations": token_experts
                }
            
            results["layers"][f"layer_{layer_idx}"] = layer_results
        
        # Clear GPU memory
        torch.cuda.empty_cache()
        return results

def visualize_expert_analysis(results: Dict, layer_idx: int, token_position: int):
    """Visualize expert interpretations for a specific layer and token."""
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    layer_data = results["layers"][f"layer_{layer_idx}"]
    token_data = [data for data in layer_data["tokens"].values() 
                 if data["position"] == token_position][0]
    
    plt.figure(figsize=(15, 10))
    
    for idx, expert_data in enumerate(token_data["expert_interpretations"]):
        tokens, probs = zip(*expert_data["top_tokens"])
        
        plt.subplot(len(token_data["expert_interpretations"]), 1, idx + 1)
        sns.barplot(x=list(tokens), y=list(probs))
        plt.title(f"Expert {expert_data['expert_id']} Interpretations")
        plt.xticks(rotation=45)
    
    plt.tight_layout()
    plt.show()

In [None]:
analyzer = MOEExpertLens("deepseek-ai/deepseek-moe-16b-base")

# Analyze text
text = "The quick brown fox"
results = analyzer.analyze_text(text)

# Visualize expert interpretations for specific layer and token
visualize_expert_analysis(results, layer_idx=0, token_position=1)