In [1]:
import torch
import numpy as np
import plotly.graph_objects as go
from collections import defaultdict
import json
import os
from typing import Dict, List, Tuple, Optional
from tqdm.auto import tqdm
import gc
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
def get_device():
    """Get the optimal available device"""
    if torch.cuda.is_available():
        # Set device to first available GPU
        device = torch.device(f"cuda:0")
        # Enable TF32 for better performance on Ampere GPUs
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
    else:
        device = torch.device("cpu")
    return device

In [3]:
def load_model():
    model = AutoModelForCausalLM.from_pretrained('deepseek-ai/deepseek-moe-16b-base', trust_remote_code=True)
    tokenizer = AutoTokenizer.from_pretrained('deepseek-ai/deepseek-moe-16b-base', trust_remote_code=True)
    return model, tokenizer

model, tokenizer = load_model()
model.eval()


Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

DeepseekForCausalLM(
  (model): DeepseekModel(
    (embed_tokens): Embedding(102400, 2048)
    (layers): ModuleList(
      (0): DeepseekDecoderLayer(
        (self_attn): DeepseekSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): DeepseekRotaryEmbedding()
        )
        (mlp): DeepseekMLP(
          (gate_proj): Linear(in_features=2048, out_features=10944, bias=False)
          (up_proj): Linear(in_features=2048, out_features=10944, bias=False)
          (down_proj): Linear(in_features=10944, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): DeepseekRMSNorm()
        (post_attention_layernorm): DeepseekRMSNorm()
      )
      (1-27): 27 x DeepseekDecod

In [4]:
def prepare_text_chunks(file_path: str, chunk_size: int = 2048, tokenizer = None) -> List[str]:
    """
    Split input text file into chunks of specified token length with memory efficiency
    """
    chunks = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            if tokenizer:
                # Process file in chunks to avoid loading entire file
                chunk_text = ""
                for line in f:
                    chunk_text += line
                    tokens = tokenizer.encode(chunk_text)
                    
                    while len(tokens) >= chunk_size:
                        # Extract chunk_size tokens
                        chunk_tokens = tokens[:chunk_size]
                        chunks.append(tokenizer.decode(chunk_tokens))
                        
                        # Keep remaining tokens
                        tokens = tokens[chunk_size:]
                        chunk_text = tokenizer.decode(tokens)
                
                # Add remaining text if any
                if tokens:
                    chunks.append(tokenizer.decode(tokens))
            else:
                words = []
                for line in f:
                    words.extend(line.split())
                    while len(words) >= chunk_size:
                        chunk = ' '.join(words[:chunk_size])
                        chunks.append(chunk)
                        words = words[chunk_size:]
                if words:
                    chunks.append(' '.join(words))
    except Exception as e:
        print(f"Error reading file: {e}")
        return []
        
    return chunks

In [5]:
@torch.no_grad()  # Disable gradient computation
def get_expert_distribution(model, input_text: str, tokenizer, device: torch.device) -> Dict[int, List[Tuple[str, int, float]]]:
    """
    Get expert routing probabilities with CUDA optimization
    """
    try:
        # Move model to specified device if not already there
        model = model.to(device)
        
        # Tokenize input with padding for batch processing
        inputs = tokenizer(input_text, 
                         return_tensors="pt",
                         padding=True,
                         truncation=True,
                         max_length=2048)
        
        # Move inputs to device
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        # Get model outputs
        outputs = model(**inputs, 
                       output_attentions=False, 
                       output_hidden_states=True, 
                       return_dict=True)
        
        hidden_states = outputs.hidden_states
        layer_distributions = {}

        # Process each MoE layer
        for layer_idx, layer in enumerate(model.model.layers[1:], 1):
            moe_layer = layer.mlp

            # Get layer hidden states
            layer_hidden = hidden_states[layer_idx]

            # Optimize matrix multiplication for GPU
            if device.type == 'cuda':
                # Use torch.cuda.amp for mixed precision
                with torch.cuda.amp.autocast():
                    router_logits = torch.matmul(
                        layer_hidden.float(), 
                        moe_layer.experts[0].gate_proj.weight.t().float()
                    )
            else:
                router_logits = torch.matmul(
                    layer_hidden,
                    moe_layer.experts[0].gate_proj.weight.t()
                )

            # Compute probabilities efficiently
            router_probs = torch.softmax(router_logits, dim=-1)

            # Get top expert for each token
            top_probs, top_indices = torch.max(router_probs, dim=-1)

            # Convert token IDs to text
            tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

            # Process results efficiently
            layer_results = []
            for i in range(len(tokens)):
                if tokens[i] == tokenizer.pad_token:
                    continue
                clean_token = tokens[i].replace('Ġ', '')
                layer_results.append(
                    (clean_token,
                     top_indices[0,i].item(),
                     top_probs[0,i].item())
                )

            layer_distributions[layer_idx] = layer_results

            # Clear GPU memory if needed
            if device.type == 'cuda':
                del router_logits, router_probs
                torch.cuda.empty_cache()

        return layer_distributions

    except Exception as e:
        print(f"Error in get_expert_distribution: {e}")
        return {}

In [6]:
def save_expert_stats(distributions: Dict[int, List[Tuple[str, int, float]]], 
                     domain: str,
                     output_dir: str = 'expert_stats'):
    """Save expert routing statistics with error handling"""
    try:
        os.makedirs(output_dir, exist_ok=True)
        output_path = os.path.join(output_dir, f'{domain}_expert_stats.json')
        
        stats = {}
        for layer_idx, layer_data in distributions.items():
            expert_counts = defaultdict(int)
            total = len(layer_data)
            
            if total == 0:
                continue
                
            for _, expert_idx, _ in layer_data:
                expert_counts[expert_idx] += 1
                
            expert_probs = {
                str(expert): count/total 
                for expert, count in expert_counts.items()
            }
            stats[str(layer_idx)] = expert_probs
        
        with open(output_path, 'w') as f:
            json.dump(stats, f, indent=2)
            
    except Exception as e:
        print(f"Error saving expert stats: {e}")

In [7]:
def plot_expert_heatmap(domain: str,
                       output_dir: str = 'expert_stats',
                       n_experts: int = 64,
                       n_layers: int = 27):
    """Create heatmap visualization with improved color scaling"""
    try:
        stats_path = os.path.join(output_dir, f'{domain}_expert_stats.json')
        with open(stats_path) as f:
            stats = json.load(f)
            
        prob_matrix = np.zeros((n_layers, n_experts))
        
        for layer in range(1, n_layers+1):
            layer_stats = stats.get(str(layer), {})
            for expert in range(n_experts):
                prob_matrix[layer-1, expert] = layer_stats.get(str(expert), 0)
                
        # Normalize for better visualization
        max_prob = np.max(prob_matrix)
        if max_prob > 0:
            prob_matrix = prob_matrix / max_prob
            
        fig = go.Figure(data=go.Heatmap(
            z=prob_matrix,
            x=[f'Expert {i}' for i in range(n_experts)],
            y=[f'Layer {i}' for i in range(1, n_layers+1)],
            colorscale='Viridis',
            hoverongaps=False,
            hovertemplate='Layer %{y}<br>%{x}<br>Usage: %{z:.2%}<extra></extra>'
        ))
        
        fig.update_layout(
            title=f'Expert Utilization Heatmap for {domain} Domain',
            xaxis_title='Expert Index',
            yaxis_title='Layer',
            width=1000,
            height=800,
            template='plotly_dark'  # Better for visualization
        )
        # save the figure
        fig.write_image(f'plots-deepseek/{domain}_expert_heatmap.png')
        fig.write_html(f'plots-deepseek/{domain}_expert_heatmap.html')
        return fig
    
    except Exception as e:
        print(f"Error creating heatmap: {e}")
        return None

In [8]:
def analyze_dataset(file_path: str,
                   model,
                   tokenizer,
                   domain: str,
                   chunk_size: int = 2048,
                   output_dir: str = 'expert_stats',
                   batch_size: int = 1):
    """
    Analyze expert routing patterns with CUDA optimization
    """
    try:
        device = get_device()
        print(f"Using device: {device}")
        
        # Split into chunks
        chunks = prepare_text_chunks(file_path, chunk_size, tokenizer, chunk_size)
        if not chunks:
            raise ValueError("No text chunks generated")
            
        # Process chunks with progress bar
        all_distributions = defaultdict(list)
        for chunk in tqdm(chunks, desc="Processing chunks"):
            chunk_dist = get_expert_distribution(model, chunk, tokenizer, device)
            
            for layer, layer_data in chunk_dist.items():
                all_distributions[layer].extend(layer_data)
            
            # Clear memory between chunks
            if device.type == 'cuda':
                torch.cuda.empty_cache()
                gc.collect()
                
        # Save statistics
        save_expert_stats(all_distributions, domain, output_dir)
        
        # Create visualization
        fig = plot_expert_heatmap(domain, output_dir)
        
        return fig
    
    except Exception as e:
        print(f"Error in analyze_dataset: {e}")
        return None

In [10]:
fig = analyze_dataset(
    file_path='dataset/test1.txt',
    model=model,
    tokenizer=tokenizer, 
    domain='test1',
    chunk_size=2048
)
fig.show()

Using device: cpu


Processing chunks:   0%|          | 0/1 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [11]:
# # Compare multiple domains
domains = ['test1', 'test2']
figs = []
for domain in domains:
    fig = analyze_dataset(
        file_path=f'dataset/{domain}.txt',
        model=model,
        tokenizer=tokenizer,
        domain=domain
    )
    figs.append(fig)
    fig.show()

Using device: cpu


Processing chunks:   0%|          | 0/1 [00:00<?, ?it/s]

Using device: cpu


Processing chunks:   0%|          | 0/1 [00:00<?, ?it/s]