In [1]:
import torch
import numpy as np
import plotly.graph_objects as go
from collections import defaultdict
import json
import os
from typing import Dict, List, Tuple, Optional
# from tqdm.auto import tqdm
import datetime
import gc
from transformers import AutoModelForCausalLM, AutoTokenizer
import traceback

In [2]:
def get_device():
    """Get the optimal available device"""
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        # Enable TF32 for better performance on Ampere GPUs (A100, A6000, etc)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        # Set memory allocation settings
        torch.cuda.empty_cache()
        # Enable CUDNN benchmarking for better performance
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")
    return device

In [3]:
def load_model():
    device = get_device()
    model = AutoModelForCausalLM.from_pretrained('deepseek-ai/deepseek-moe-16b-base',
                                                trust_remote_code=True,
                                                torch_dtype=torch.float16)
    model = model.to(device)
    tokenizer = AutoTokenizer.from_pretrained('deepseek-ai/deepseek-moe-16b-base',
                                                trust_remote_code=True)
    return model, tokenizer

model, tokenizer = load_model()

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

In [4]:
def prepare_text_chunks(file_path: str, chunk_size: int = 2048, tokenizer = None) -> List[str]:
    """
    Split input text file into chunks of specified token length with memory efficiency
    """
    chunks = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            if tokenizer:
                # Process file in chunks to avoid loading entire file
                chunk_text = ""
                for line in f:
                    chunk_text += line
                    tokens = tokenizer.encode(chunk_text)
                    
                    while len(tokens) >= chunk_size:
                        # Extract chunk_size tokens
                        chunk_tokens = tokens[:chunk_size]
                        chunks.append(tokenizer.decode(chunk_tokens))
                        
                        # Keep remaining tokens
                        tokens = tokens[chunk_size:]
                        chunk_text = tokenizer.decode(tokens)
                
                # Add remaining text if any
                if tokens:
                    chunks.append(tokenizer.decode(tokens))
            else:
                words = []
                for line in f:
                    words.extend(line.split())
                    while len(words) >= chunk_size:
                        chunk = ' '.join(words[:chunk_size])
                        chunks.append(chunk)
                        words = words[chunk_size:]
                if words:
                    chunks.append(' '.join(words))
    except Exception as e:
        print(f"Error reading file: {e}")
        return []
        
    return chunks

In [5]:
@torch.no_grad()
def get_expert_distribution(model, input_text: str, tokenizer, device: torch.device, 
                          num_shared_experts: int = 0, num_topk: int = 8) -> Dict[int, List[Tuple[str, int, float]]]:
    """
    Get expert routing probabilities with proper handling of shared experts and top-k tracking
    """
    try:
        model = model.to(device)
        
        inputs = tokenizer(input_text, 
                         return_tensors="pt",
                         padding=True,
                         truncation=True,
                         max_length=2048)
        
        inputs = {k: v.to(device) for k, v in inputs.items()} 
        
        outputs = model(**inputs, 
                       output_attentions=False, 
                       output_hidden_states=True, 
                       return_dict=True)
        
        hidden_states = outputs.hidden_states
        layer_distributions = {}

        # Process each MoE layer
        for layer_idx, layer in enumerate(model.model.layers[1:], 1):
            moe_layer = layer.mlp
            layer_hidden = hidden_states[layer_idx]

            # Initialize distributions
            expert_distributions = []
            
            # Convert token IDs to text
            tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
            
            # Get router logits for dynamic experts
            router_logits = torch.matmul(
                layer_hidden.float(),
                moe_layer.gate.weight.t().float()
            )
            
            # Get top-k routed experts
            router_probs = torch.softmax(router_logits, dim=-1)
            top_probs, top_indices = torch.topk(router_probs, k=num_topk, dim=-1)

            # Process results for each token
            for i in range(len(tokens)):
                if tokens[i] == tokenizer.pad_token:
                    continue
                    
                # Skip special tokens like begin/end of sentence
                if tokens[i].startswith('<') and tokens[i].endswith('>'):
                    continue
                    
                clean_token = tokens[i].replace('Ġ', '')
                
                # Add shared experts for each token (they process all tokens)
                for shared_idx in range(num_shared_experts):
                    expert_distributions.append(
                        (clean_token, shared_idx, 1.0)  # Shared experts always have weight 1.0
                    )
                
                # Add routed experts
                for j in range(num_topk):
                    expert_distributions.append(
                        (clean_token, 
                         int(top_indices[0,i,j].item()) + num_shared_experts,  # Offset expert indices by num shared experts
                         float(top_probs[0,i,j].item()))
                    )
            
            layer_distributions[layer_idx] = expert_distributions

        return layer_distributions

    except Exception as e:
        print(f"Error in get_expert_distribution: {e}")
        return {}

In [6]:
def save_expert_stats(distributions: Dict[int, List[Tuple[str, int, float]]], 
                     domain: str,
                     output_dir: str = 'expert_stats'):
    """Save expert routing statistics with error handling"""
    try:
        os.makedirs(output_dir, exist_ok=True)
        output_path = os.path.join(output_dir, f'{domain}_expert_stats.json')
        
        stats = {}
        for layer_idx, layer_data in distributions.items():
            expert_counts = defaultdict(int)
            total = len(layer_data)
            
            if total == 0:
                continue
                
            for _, expert_idx, _ in layer_data:
                expert_counts[expert_idx] += 1
                
            expert_probs = {
                str(expert): count/total 
                for expert, count in expert_counts.items()
            }
            stats[str(layer_idx)] = expert_probs
        
        with open(output_path, 'w') as f:
            json.dump(stats, f, indent=2)
            
    except Exception as e:
        print(f"Error saving expert stats: {e}")

In [7]:
def plot_layer_distribution(token_counts, layer_id: int, domain: str, dark_mode: bool = True):
    """
    Create a bar graph showing token distribution across experts for a specific layer
    """
    try:
        # Get expert counts for specified layer
        layer_counts = token_counts.get(layer_id, {})
        if not layer_counts:
            raise ValueError(f"No data found for layer {layer_id}")
            
        # Create lists for x and y values
        expert_ids = sorted(layer_counts.keys())
        token_counts_list = [layer_counts[expert_id] for expert_id in expert_ids]
        
        # Create labels for experts
        expert_labels = [f"Expert {expert_id}" for expert_id in expert_ids]
        
        # Calculate total tokens for percentage
        total_tokens = sum(token_counts_list)
        percentages = [count/total_tokens * 100 for count in token_counts_list]
        
        # Create hover text
        hover_text = [f"Expert {expert_id}<br>Tokens: {count}<br>Percentage: {percentage:.1f}%" 
                     for expert_id, count, percentage in zip(expert_ids, token_counts_list, percentages)]
        # Create bar graph
        fig = go.Figure(data=[
            go.Bar(
                x=expert_labels,
                y=token_counts_list,
                hovertext=hover_text,
                hoverinfo='text',
                marker_color='#636EFA'
            )
        ])
        
        # Update layout
        template = 'plotly_dark' if dark_mode else 'plotly_white'
        fig.update_layout(
            title=f'Token Distribution Across Experts for Layer {layer_id} ({domain})',
            xaxis_title='Expert',
            yaxis_title='Number of Tokens',
            width=1200,
            height=600,
            template=template,
            showlegend=False,
            xaxis={'tickangle': -45}
        )
        
        # Add total tokens annotation
        fig.add_annotation(
            text=f'Total Tokens: {total_tokens}',
            xref='paper', yref='paper',
            x=1, y=1,
            xanchor='right', yanchor='top',
            showarrow=False
        )
        
        return fig
        
    except Exception as e:
        print(f"Error creating bar plot: {e}")
        return None

In [8]:
def analyze_token_distributions(expert_distributions):
    """
    Convert expert distributions to token counts per expert per layer
    
    Args:
        expert_distributions: Dictionary mapping layer_id to list of (token, expert_id, probability) tuples
        
    Returns:
        Dictionary mapping layer_id to dict of expert_id: token_count
    """
    token_counts = {}
    
    # Process each layer
    for layer_id, distributions in expert_distributions.items():
        expert_counts = defaultdict(int)
        
        # Count tokens per expert in this layer
        for token, expert_id, prob in distributions:
            # Skip shared experts that use string IDs
            if isinstance(expert_id, int):
                expert_counts[expert_id] += 1
            
        token_counts[layer_id] = dict(expert_counts)
        
    return token_counts


In [9]:
def plot_token_heatmap(token_counts, domain, n_experts=64, n_layers=27):
    """
    Create heatmap of token counts per expert per layer
    
    Args:
        token_counts: Dictionary mapping layer_id to dict of expert_id: token_count
        n_experts: Total number of experts
        n_layers: Total number of layers 
        domain: domain name for heatmap title
    """
    # Initialize count matrix
    count_matrix = np.zeros((n_layers, n_experts))
    
    # Fill matrix with token counts
    for layer_id, expert_counts in token_counts.items():
        for expert_id, count in expert_counts.items():
            # Ensure expert_id is within bounds
            if 0 <= expert_id < n_experts:
                count_matrix[layer_id-1, expert_id] = count
    
    # Create heatmap
    fig = go.Figure(data=go.Heatmap(
        z=count_matrix,
        x=[f'Expert {i}' for i in range(n_experts)],
        y=[f'Layer {i}' for i in range(1, n_layers+1)],
        colorscale='Viridis',
        hoverongaps=False,
        hovertemplate='Layer %{y}<br>Expert %{x}<br>Tokens: %{z:d}<extra></extra>'
    ))
    
    fig.update_layout(
        title=f'token distribution heatmap for {domain}',
        xaxis_title='Expert Index',
        yaxis_title='Layer',
        width=1200,
        height=800,
        template='plotly_dark'
    )
    
    return fig, count_matrix

In [10]:
def analyze_dataset(
    plot_type: str = 'heatmap',  # New parameter
    file_path: str = None,
    model = None,
    tokenizer = None,
    domain: str = 'test1',
    chunk_size: int = 2048,
    num_shared_experts: int = 2,  
    num_topk: int = 6,
    layer_id: int = 1, #only for bar plot
    device: torch.device = None
    ):
    """
    Main analysis function with flexible plotting options
    
    Args:
        plot_type: Either 'heatmap' or 'bar'
        layer_id: Which layer to plot if plot_type is 'bar'
        device: Optional torch device to use
        (... other args remain same)
    """
    try:
        if device is None:
            device = get_device()
        print(f"Using device: {device}")
        
        # Move model to device if not already done    
        model = model.to(device)
        
        # Validate plot_type
        if plot_type not in ['heatmap', 'bar']:
            raise ValueError("plot_type must be either 'heatmap' or 'bar'")
        
        # Process in smaller chunks to avoid OOM
        max_chunk_tokens = 1024  # Reduce if running into memory issues
        chunks = prepare_text_chunks(file_path, max_chunk_tokens, tokenizer)
        if not chunks:
            raise ValueError("No text chunks generated")
            
        # Process chunks
        all_distributions = defaultdict(list)
        total_chunks = len(chunks)
        
        # Clear GPU memory between chunks
        for i, chunk in enumerate(chunks):
            torch.cuda.empty_cache()
            print(f"Processing chunk {i+1}/{total_chunks}")
            
            chunk_dist = get_expert_distribution(
                model, 
                chunk,
                tokenizer,
                device,
                num_shared_experts=num_shared_experts,
                num_topk=num_topk
            )
            
            # Move results to CPU to free GPU memory
            for layer, layer_data in chunk_dist.items():
                all_distributions[layer].extend(
                    [(t, e, float(p)) for t, e, p in layer_data]
                )
                
        # Count tokens per expert
        token_counts = analyze_token_distributions(all_distributions)
        
        # Create visualization based on plot_type
        print("Generating visualization...")
        if plot_type == 'heatmap':
            fig, count_matrix = plot_token_heatmap(
                token_counts,
                n_experts=64,  # Maximum number of experts
                n_layers=27,    # Number of layers
                domain=domain
            )
        else:  # plot_type == 'bar'
            fig = plot_layer_distribution(
                token_counts=token_counts,
                layer_id=layer_id,
                domain=domain
            )
            count_matrix = None  # No count matrix for bar plot
        
        # Save results
        os.makedirs('plots-deepseek', exist_ok=True)
        if plot_type == 'heatmap':
            np.save(f'plots-deepseek/{domain}_count_matrix.npy', count_matrix)
            fig.write_html(f'plots-deepseek/{domain}_heatmap.html')
            fig.write_image(f'plots-deepseek/{domain}_heatmap.png')
        else:
            fig.write_html(f'plots-deepseek/{domain}_layer{layer_id}_bargraph.html')
            fig.write_image(f'plots-deepseek/{domain}_layer{layer_id}_bargraph.png')
        
        print("Analysis complete!")
        return fig, count_matrix, token_counts
        
    except Exception as e:
        print(f"Error in analyze_dataset: {e}")
        traceback.print_exc()
        return None, None, None


In [11]:
# Analyze with proper bounds checking
fig, count_matrix, token_counts = analyze_dataset(
    plot_type='heatmap', # or 'bar'
    file_path='dataset/test1.txt',
    model= model,
    tokenizer= tokenizer,
    domain='test1',
    chunk_size=1024,
    num_shared_experts=1, 
    num_topk=7,
    # layer_id=1 # only for bar plot
)

fig.show()

Using device: cpu
Processing chunk 1/1
Generating visualization...


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Analysis complete!


In [12]:
# Analyze with proper bounds checking
fig, count_matrix, token_counts = analyze_dataset(
    plot_type='bar', # or 'bar'
    file_path='dataset/test1.txt',
    model=model,
    tokenizer=tokenizer, 
    domain='test1',
    chunk_size=2048,
    num_shared_experts=1, 
    num_topk=7,
    layer_id=1 # only for bar plot
)

fig.show()

Using device: cpu
Processing chunk 1/1
Generating visualization...
Analysis complete!
