In [2]:
import torch
import numpy as np
import plotly.graph_objects as go
from collections import defaultdict
import json
import os
from typing import Dict, List, Tuple, Optional
# from tqdm.auto import tqdm
import datetime
import gc
from transformers import AutoModelForCausalLM, AutoTokenizer
import traceback

In [3]:
def get_device():
    """Get the optimal available device"""
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        # Enable TF32 for better performance on Ampere GPUs (A100, A6000, etc)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        # Set memory allocation settings
        torch.cuda.empty_cache()
        # Enable CUDNN benchmarking for better performance
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")
    return device

In [4]:
# model = AutoModelForCausalLM.from_pretrained('deepseek-ai/deepseek-moe-16b-base',
#                                             trust_remote_code=True,
#                                             torch_dtype=torch.float16)
model = None
tokenizer = AutoTokenizer.from_pretrained('deepseek-ai/deepseek-moe-16b-base',
                                            trust_remote_code=True)

# device = get_device()
# model = model.to(device)
# model.eval()

In [5]:
def prepare_text_chunks(file_path: str, chunk_size: int = 2048, tokenizer = None) -> List[str]:
    """
    Split input text file into chunks of specified token length with memory efficiency
    """
    chunks = []
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            if tokenizer:
                # Process file in chunks to avoid loading entire file
                chunk_text = ""
                buffer = []
                
                for line in f:
                    buffer.append(line)
                    # Join buffer and check token length
                    test_text = "".join(buffer)
                    tokens = tokenizer.encode(test_text)
                    
                    # If we exceed max length, process the buffer
                    if len(tokens) >= chunk_size:
                        # Take only chunk_size tokens
                        chunk_tokens = tokens[:chunk_size]
                        chunks.append(tokenizer.decode(chunk_tokens))
                        
                        # Keep remainder for next chunk
                        remainder_tokens = tokens[chunk_size:]
                        buffer = [tokenizer.decode(remainder_tokens)]
                
                # Add remaining text if any
                if buffer:
                    final_text = "".join(buffer)
                    final_tokens = tokenizer.encode(final_text)
                    if len(final_tokens) > chunk_size:
                        # Split remaining text into chunks of chunk_size
                        for i in range(0, len(final_tokens), chunk_size):
                            chunk_tokens = final_tokens[i:i + chunk_size]
                            chunks.append(tokenizer.decode(chunk_tokens))
                    else:
                        chunks.append(final_text)
            else:
                words = []
                for line in f:
                    words.extend(line.split())
                    while len(words) >= chunk_size:
                        chunk = ' '.join(words[:chunk_size])
                        chunks.append(chunk)
                        words = words[chunk_size:]
                if words:
                    chunks.append(' '.join(words))
                    
        print(f"Created {len(chunks)} chunks from the input file")
        return chunks
        
    except Exception as e:
        print(f"Error reading file: {e}")
        return []

In [6]:
@torch.no_grad()
def get_expert_distribution(model, input_text: str, tokenizer, device: torch.device, 
                          num_shared_experts: int = 0, num_topk: int = 8) -> Dict[int, List[Tuple[str, int, float]]]:
    """
    Get expert routing probabilities with proper handling of shared experts and top-k tracking
    """
    try:
        model = model.to(device)
        
        inputs = tokenizer(input_text, 
                         return_tensors="pt",
                         padding=True,
                         truncation=True,
                         max_length=2048)
        
        inputs = {k: v.to(device) for k, v in inputs.items()} 
        
        outputs = model(**inputs, 
                       output_attentions=False, 
                       output_hidden_states=True, 
                       return_dict=True)
        
        hidden_states = outputs.hidden_states
        layer_distributions = {}

        # Process each MoE layer
        for layer_idx, layer in enumerate(model.model.layers[1:], 1):
            moe_layer = layer.mlp
            layer_hidden = hidden_states[layer_idx]
            expert_distributions = []
            
            tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
            
            # Get router logits for dynamic experts
            router_logits = torch.matmul(
                layer_hidden.float(),
                moe_layer.gate.weight.t().float()
            )
            
            # Get top-k routed experts (only from routed experts, not shared)
            router_probs = torch.softmax(router_logits, dim=-1)
            top_probs, top_indices = torch.topk(router_probs, k=min(num_topk, router_probs.size(-1)), dim=-1)

            # Process results for each token
            for i in range(len(tokens)):
                if tokens[i] == tokenizer.pad_token:
                    continue
                    
                if tokens[i].startswith('<') and tokens[i].endswith('>'):
                    continue
                    
                clean_token = tokens[i].replace('Ġ', '')
                
                # Add shared experts first (0 to num_shared_experts-1)
                for shared_idx in range(num_shared_experts):
                    expert_distributions.append(
                        (clean_token, f"shared_{shared_idx}", 1.0)  # Mark as shared expert
                    )
                
                # Add routed experts separately (starting from 0)
                for j in range(top_indices.size(-1)):
                    expert_distributions.append(
                        (clean_token, 
                         f"routed_{int(top_indices[0,i,j].item())}", # Mark as routed expert
                         float(top_probs[0,i,j].item()))
                    )
            
            layer_distributions[layer_idx] = expert_distributions

        return layer_distributions

    except Exception as e:
        print(f"Error in get_expert_distribution: {e}")
        return {}

In [7]:
def save_expert_stats(distributions: Dict[int, List[Tuple[str, int, float]]], 
                     domain: str,
                     output_dir: str = 'expert_stats'):
    """Save expert routing statistics with error handling"""
    try:
        os.makedirs(output_dir, exist_ok=True)
        output_path = os.path.join(output_dir, f'{domain}_expert_stats.json')
        
        stats = {}
        for layer_idx, layer_data in distributions.items():
            expert_counts = defaultdict(int)
            total = len(layer_data)
            
            if total == 0:
                continue
                
            for _, expert_idx, _ in layer_data:
                expert_counts[expert_idx] += 1
                
            expert_probs = {
                str(expert): count/total 
                for expert, count in expert_counts.items()
            }
            stats[str(layer_idx)] = expert_probs
        
        with open(output_path, 'w') as f:
            json.dump(stats, f, indent=2)
            
    except Exception as e:
        print(f"Error saving expert stats: {e}")

In [8]:
def analyze_token_distributions(expert_distributions):
    """
    Convert expert distributions to token counts per expert per layer, separating shared and routed experts
    """
    token_counts = {}
    
    # Process each layer
    for layer_id, distributions in expert_distributions.items():
        shared_counts = defaultdict(int)
        routed_counts = defaultdict(int)
        
        # Count tokens per expert in this layer
        for token, expert_id, prob in distributions:
            if expert_id.startswith('shared_'):
                expert_num = int(expert_id.split('_')[1])
                shared_counts[expert_num] += 1
            else:  # routed expert
                expert_num = int(expert_id.split('_')[1])
                routed_counts[expert_num] += 1
            
        token_counts[layer_id] = {
            'shared': dict(shared_counts),
            'routed': dict(routed_counts)
        }
        
    return token_counts

In [9]:
def plot_layer_distribution(token_counts, layer_id: int, domain: str, dark_mode: bool = True, tokenizer=tokenizer):
    """
    Create a bar graph showing token distribution across experts for a specific layer
    """
    try:
        # Get expert counts for specified layer
        layer_counts = token_counts.get(layer_id, {})
        if not layer_counts:
            raise ValueError(f"No data found for layer {layer_id}")
            
        # Create lists for x and y values
        expert_ids = sorted(layer_counts.keys())
        token_counts_list = [layer_counts[expert_id] for expert_id in expert_ids]
        
        # Create labels for experts
        expert_labels = [f"{expert_id}" for expert_id in expert_ids]
        
        # Calculate totals
        total_routing_decisions = sum(token_counts_list)
        
        file_path = f"dataset/{domain}.txt"
        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()
        tokens = tokenizer(text)
        actual_tokens = len(tokens['input_ids'])

        # Create hover text
        hover_text = [f"Expert {expert_id}<br>Tokens: {count}<br>Percentage: {(count/total_routing_decisions*100):.1f}%" 
                     for expert_id, count in zip(expert_ids, token_counts_list)]

        # Create bar graph
        fig = go.Figure(data=[
            go.Bar(
                x=expert_labels,
                y=token_counts_list,
                hovertext=hover_text,
                hoverinfo='text',
                marker_color='#636EFA'
            )
        ])
        
        # Update layout
        template = 'plotly_dark' if dark_mode else 'plotly_white'
        fig.update_layout(
            title=f'Token Distribution Across Experts in Layer {layer_id} ({domain})',
            xaxis_title='Expert',
            yaxis_title='Number of Routing Decisions',
            width=1200,
            height=600,
            template=template,
            showlegend=False,
            xaxis={'tickangle': -45},
            yaxis=dict(range=[0, actual_tokens])  # Set max on y-axis to actual tokens
        )
        
        # Add annotations for both total routing decisions and actual tokens
        fig.add_annotation(
            text=f'Total Routing Decisions: {total_routing_decisions:,}<br>Actual Tokens: {actual_tokens:,}',
            xref='paper', yref='paper',
            x=1, y=1,
            xanchor='right', yanchor='top',
            showarrow=False,
            font=dict(size=12)
        )
        
        return fig
        
    except Exception as e:
        print(f"Error creating bar plot: {e}")
        return None

In [19]:
def plot_token_heatmap(token_counts, domain, n_experts=64, n_layers=27, tokenizer=None):
    """
    Create heatmap of token percentages per expert per layer (routed experts only)
    """
    # Initialize matrices for counts and percentages
    count_matrix = np.zeros((n_layers, n_experts))
    percentage_matrix = np.zeros((n_layers, n_experts))
    
    # Fill matrices with token counts and calculate percentages per layer
    for layer_id, routed_counts in token_counts.items():
        # Process routed experts (0-63)
        for expert_id, count in routed_counts.items():
            count_matrix[layer_id-1, expert_id] = count
        
        # Calculate percentages for this layer
        layer_total = count_matrix[layer_id-1].sum()
        if layer_total > 0:
            percentage_matrix[layer_id-1] = (count_matrix[layer_id-1] / layer_total) * 100
    
    # Calculate total routing decisions
    total_routing_decisions = int(count_matrix.sum())
    
    # Get actual tokens count
    file_path = f"dataset/{domain}.txt"
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read()
    tokens = tokenizer(text)
    actual_tokens = len(tokens['input_ids'])
    
    # Create heatmap
    fig = go.Figure(data=go.Heatmap(
        z=percentage_matrix,
        x=[f'{i}' for i in range(n_experts)],  # 0-63 for all routed experts
        y=[f'{i+1}' for i in range(n_layers)],
        colorscale='Viridis',
        hoverongaps=False,
        hovertemplate='Layer %{y}<br>Expert %{x}<br>' +
                      'Percentage: %{z:.1f}%<br>' +
                      'Tokens: %{customdata:d}<extra></extra>',
        customdata=count_matrix.astype(int)
    ))
    
    fig.update_layout(
        title=f'Token Distribution Heatmap for {domain} (Percentages per Layer)',
        xaxis_title='Expert Index (0-63)',
        yaxis_title='Layer',
        width=1200,
        height=800,
        template='plotly_dark',
        margin=dict(t=100),
        coloraxis_colorbar=dict(title="%age")
    )
    
    fig.add_annotation(
        text=f'Total Routing Decisions: {total_routing_decisions:,}<br>Actual Tokens: {actual_tokens:,}',
        xref='paper', yref='paper',
        x=1.0,
        y=1.1,
        xanchor='right',
        yanchor='top',
        showarrow=False,
        font=dict(size=12)
    )
    
    return fig, count_matrix

In [11]:
def plot_expert_distribution(domain, device='cpu'):
    """    
    args :
        json_path: Path to the JSON file containing expert counts
        device: Device to use for tensor operations ('cuda' or 'cpu')
    output : heatmap showing distribution of tokens across experts and layers as percentages
    """
    json_path = f'plots-deepseek/{domain}_count_matrix.npy'

    # Load the matrix directly from npy file 
    count_matrix = np.load(json_path)
    
    # Convert raw counts to percentages per layer
    percentage_matrix = np.zeros_like(count_matrix, dtype=float)
    for layer in range(count_matrix.shape[0]):
        layer_total = count_matrix[layer].sum()
        if layer_total > 0:  # Avoid division by zero
            percentage_matrix[layer] = (count_matrix[layer] / layer_total) * 100
    
    # Create heatmap with adjusted expert indices
    fig = go.Figure(data=go.Heatmap(
        z=percentage_matrix,
        x=[f'{i+1}' for i in range(percentage_matrix.shape[1])],
        y=[f'{i+1}' for i in range(percentage_matrix.shape[0])],
        colorscale='Viridis',
        hoverongaps=False,
        hovertemplate='Layer %{y}<br>Expert %{x}<br>Percentage: %{z:.1f}%<extra></extra>'
    ))
    
    fig.update_layout(
        title=f'Percentage Distribution of Tokens Across Experts and Layers ({domain})',
        xaxis_title='Expert Index',
        yaxis_title='Layer',
        width=1200,
        height=800,
        template='plotly_dark'
    )
    
    return fig, count_matrix
def plot_layer_distribution(token_counts, layer_id: int, domain: str, dark_mode: bool = True, tokenizer=None):
    """Create a bar graph showing percentage distribution of tokens across experts for a specific layer"""
    try:
        # Get expert counts for specified layer
        layer_counts = token_counts.get(layer_id, {})
        if not layer_counts:
            raise ValueError(f"No data found for layer {layer_id}")
            
        # Create lists for x and y values
        expert_ids = sorted(layer_counts.keys())
        token_counts_list = [layer_counts[expert_id] for expert_id in expert_ids]
        
        # Calculate percentages
        total_routing_decisions = sum(token_counts_list)
        percentages = [(count/total_routing_decisions * 100) for count in token_counts_list]
        
        # Create labels for experts
        expert_labels = [f"{expert_id}" for expert_id in expert_ids]

        # Get actual tokens for reference
        file_path = f"dataset/{domain}.txt"
        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()
        tokens = tokenizer(text)
        actual_tokens = len(tokens['input_ids'])

        # Create hover text with both counts and percentages
        hover_text = [f"Expert {expert_id}<br>Tokens: {count}<br>Percentage: {percentage:.1f}%" 
                     for expert_id, count, percentage in zip(expert_ids, token_counts_list, percentages)]

        # Create bar graph
        fig = go.Figure(data=[
            go.Bar(
                x=expert_labels,
                y=percentages,  # Now plotting percentages instead of raw counts
                hovertext=hover_text,
                hoverinfo='text',
                marker_color='#636EFA'
            )
        ])
        
        # Update layout
        template = 'plotly_dark' if dark_mode else 'plotly_white'
        fig.update_layout(
            title=f'Token Distribution Across Experts in Layer {layer_id} ({domain})',
            xaxis_title='Expert',
            yaxis_title='Percentage of Total Tokens (%)',
            width=1200,
            height=600,
            template=template,
            showlegend=False,
            xaxis={'tickangle': -45},
            yaxis=dict(range=[0, 100])  # Set y-axis range to 0-100%
        )
        
        # Add annotations for totals
        fig.add_annotation(
            text=f'Total Routing Decisions: {total_routing_decisions:,}<br>Actual Tokens: {actual_tokens:,}',
            xref='paper', yref='paper',
            x=1, y=1,
            xanchor='right', yanchor='top',
            showarrow=False,
            font=dict(size=12)
        )
        
        return fig
        
    except Exception as e:
        print(f"Error creating bar plot: {e}")
        return None

In [12]:
def plot_layer_distribution(token_counts, layer_id: int, domain: str, dark_mode: bool = True, tokenizer=None):
    """Create a bar graph showing percentage distribution of tokens across routed experts (0-63) for a specific layer"""
    try:
        # Get expert counts for specified layer - token_counts is already the routed experts dict
        if not token_counts:
            raise ValueError(f"No data found for layer {layer_id}")
            
        # Create lists for x and y values - ensure we handle all experts 0-63
        token_counts_list = [token_counts.get(i, 0) for i in range(64)]  # Initialize all experts 0-63
        
        # Calculate percentages
        total_routing_decisions = sum(token_counts_list)
        if total_routing_decisions == 0:
            raise ValueError(f"No routing decisions found for layer {layer_id}")
        percentages = [(count/total_routing_decisions * 100) for count in token_counts_list]
        
        # Create labels for experts
        expert_labels = [f"{i}" for i in range(64)]  # Labels 0-63

        # Get actual tokens for reference
        file_path = f"dataset/{domain}.txt"
        with open(file_path, 'r', encoding='utf-8') as f:
            text = f.read()
        tokens = tokenizer(text)
        actual_tokens = len(tokens['input_ids'])

        # Create hover text with both counts and percentages
        hover_text = [f"Expert {expert_id}<br>Tokens: {count}<br>Percentage: {percentage:.1f}%" 
                     for expert_id, count, percentage in zip(expert_labels, token_counts_list, percentages)]

        # Create bar graph
        fig = go.Figure(data=[
            go.Bar(
                x=expert_labels,
                y=percentages,
                hovertext=hover_text,
                hoverinfo='text',
                marker_color='#636EFA'
            )
        ])
        
        # Update layout
        template = 'plotly_dark' if dark_mode else 'plotly_white'
        fig.update_layout(
            title=f'Token Distribution Across Experts in Layer {layer_id} ({domain})',
            xaxis_title='Expert Index (0-63)',
            yaxis_title='Percentage of Total Tokens (%)',
            width=1200,
            height=600,
            template=template,
            showlegend=False,
            xaxis={'tickangle': -45},
            yaxis=dict(range=[0, 100])  # Set y-axis range to 0-100%
        )
        
        # Add annotations for totals
        fig.add_annotation(
            text=f'Total Routing Decisions: {total_routing_decisions:,}<br>Actual Tokens: {actual_tokens:,}',
            xref='paper', yref='paper',
            x=1, y=1,
            xanchor='right', yanchor='top',
            showarrow=False,
            font=dict(size=12)
        )
        
        return fig
        
    except Exception as e:
        print(f"Error creating bar plot: {e}")
        traceback.print_exc()  # Add this to get more detailed error information
        return None

In [21]:
def analyze_dataset(
    plot_type: str = 'heatmap',
    file_path: str = None,
    model = None,
    tokenizer = None,
    domain: str = 'test1',
    chunk_size: int = 2048,
    num_shared_experts: int = 2,  
    num_topk: int = 6,
    layer_id: int = 1,
    device: torch.device = None,
    force_recompute: bool = False
    ):
    """
    Main analysis function with caching of computed results
    """
    try:
        # Create cache filename that includes configuration
        cache_name = f'{domain}_s{num_shared_experts}_k{num_topk}'
        npy_path = f'plots-deepseek/{cache_name}_count_matrix.npy'
        
        # Check for existing results with matching parameters
        if os.path.exists(npy_path) and not force_recompute:
            print(f"Loading existing results for {domain} with {num_shared_experts} shared experts and top-k={num_topk}")
            data = np.load(npy_path, allow_pickle=True).item()
            routed_matrix = data['routed']  # We only need routed matrix for visualization
            
            # Reconstruct token_counts for routed experts only
            routed_counts = {}
            for layer in range(routed_matrix.shape[0]):
                routed_counts[layer + 1] = {
                    i: int(routed_matrix[layer, i]) 
                    for i in range(routed_matrix.shape[1]) 
                    if routed_matrix[layer, i] > 0
                }
            
            # Create visualization
            if plot_type == 'heatmap':
                fig, matrix = plot_token_heatmap(
                    token_counts=routed_counts,
                    domain=domain,
                    n_experts=64,  # Fixed for DeepSeek
                    n_layers=27,
                    tokenizer=tokenizer
                )
            else:
                fig = plot_layer_distribution(
                    token_counts=routed_counts[layer_id],
                    layer_id=layer_id,
                    domain=domain,
                    tokenizer=tokenizer
                )
                
            return fig, routed_matrix, routed_counts
            
        print(f"Computing new results for {domain} with {num_shared_experts} shared experts and top-k={num_topk}")
        
        if device is None:
            device = get_device()
        print(f"Using device: {device}")
        
        chunks = prepare_text_chunks(file_path, chunk_size, tokenizer)
        if not chunks:
            raise ValueError("No text chunks generated")
            
        # Process chunks
        all_distributions = defaultdict(list)
        total_chunks = len(chunks)
        
        for i, chunk in enumerate(chunks):
            torch.cuda.empty_cache()
            print(f"Processing chunk {i+1}/{total_chunks}")
            
            chunk_dist = get_expert_distribution(
                model, 
                chunk,
                tokenizer,
                device,
                num_shared_experts=num_shared_experts,
                num_topk=num_topk
            )
            
            for layer, layer_data in chunk_dist.items():
                all_distributions[layer].extend(layer_data)
                
        token_counts = analyze_token_distributions(all_distributions)
        
        # Create matrix for saving with correct dimensions
        routed_matrix = np.zeros((27, 64))  # Changed to 64 to include all experts
        
        for layer in range(27):
            layer_counts = token_counts.get(layer + 1, {'shared': {}, 'routed': {}})
            for i, count in layer_counts['routed'].items():
                if i < 64:  # Changed to ensure we capture all experts 0-63
                    routed_matrix[layer, i] = count
        
        # Save results
        os.makedirs('plots-deepseek', exist_ok=True)
        np.save(npy_path, {'routed': routed_matrix})
        
        # Create visualization
        if plot_type == 'heatmap':
            # Extract routed experts data from the nested structure
            routed_counts = {}
            for layer_id, layer_data in token_counts.items():
                routed_counts[layer_id] = layer_data['routed']
                
            fig, matrix = plot_token_heatmap(  
                token_counts=routed_counts,  # Pass extracted routed experts data
                domain=domain,
                n_experts=64,  # Fixed number of experts for DeepSeek
                n_layers=27,   # Fixed number of layers
                tokenizer=tokenizer
            )
        else:
            fig = plot_layer_distribution(
                token_counts=token_counts[layer_id]['routed'],  # Pass routed experts for specific layer
                layer_id=layer_id,
                domain=domain,
                tokenizer=tokenizer
            )
            
        # Save plots with configuration in filename
        if plot_type == 'heatmap':
            fig.write_html(f'plots-deepseek/{cache_name}_heatmap.html')
            fig.write_image(f'plots-deepseek/{cache_name}_heatmap.png')
        else:
            # fig.write_html(f'plots-deepseek/{cache_name}_layer{layer_id}_bargraph.html')
            # fig.write_image(f'plots-deepseek/{cache_name}_layer{layer_id}_bargraph.png')
            pass
            
        return fig, routed_matrix, routed_counts
        
    except Exception as e:
        print(f"Error in analyze_dataset: {e}")
        traceback.print_exc()
        return None, None, None

### plotting

In [13]:
# Analyze with proper bounds checking
fig, count_matrix, token_counts = analyze_dataset(
    plot_type='heatmap', # or 'bar'
    file_path='dataset/humaneval.txt',
    model= model,
    tokenizer= tokenizer,
    domain='humaneval',
    chunk_size=1024,
    num_shared_experts=1,
    num_topk=7,
    force_recompute=False
    # layer_id=1 # only for bar plot
)

fig.show()

Computing new results for humaneval with 1 shared experts and top-k=7
Using device: cpu
Created 25 chunks from the input file
Processing chunk 1/25
Processing chunk 2/25
Processing chunk 3/25
Processing chunk 4/25
Processing chunk 5/25
Processing chunk 6/25
Processing chunk 7/25
Processing chunk 8/25
Processing chunk 9/25
Processing chunk 10/25
Processing chunk 11/25
Processing chunk 12/25
Processing chunk 13/25
Processing chunk 14/25
Processing chunk 15/25
Processing chunk 16/25
Processing chunk 17/25
Processing chunk 18/25
Processing chunk 19/25
Processing chunk 20/25
Processing chunk 21/25
Processing chunk 22/25
Processing chunk 23/25
Processing chunk 24/25
Processing chunk 25/25


Token indices sequence length is longer than the specified maximum sequence length for this model (24627 > 16384). Running this sequence through the model will result in indexing errors
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [14]:
fig, count_matrix, token_counts = analyze_dataset(
    plot_type='heatmap', # or 'bar'
    file_path='dataset/piqa.txt',
    model= model,
    tokenizer= tokenizer,
    domain='piqa',
    chunk_size=1024,
    num_shared_experts=1, 
    num_topk=7,
    force_recompute=False
    # layer_id=1 # only for bar plot
)

fig.show()

Computing new results for piqa with 1 shared experts and top-k=7
Using device: cpu
Created 1111 chunks from the input file
Processing chunk 1/1111
Processing chunk 2/1111
Processing chunk 3/1111
Processing chunk 4/1111
Processing chunk 5/1111
Processing chunk 6/1111
Processing chunk 7/1111
Processing chunk 8/1111
Processing chunk 9/1111
Processing chunk 10/1111
Processing chunk 11/1111
Processing chunk 12/1111
Processing chunk 13/1111
Processing chunk 14/1111
Processing chunk 15/1111
Processing chunk 16/1111
Processing chunk 17/1111
Processing chunk 18/1111
Processing chunk 19/1111
Processing chunk 20/1111
Processing chunk 21/1111
Processing chunk 22/1111
Processing chunk 23/1111
Processing chunk 24/1111
Processing chunk 25/1111
Processing chunk 26/1111
Processing chunk 27/1111
Processing chunk 28/1111
Processing chunk 29/1111
Processing chunk 30/1111
Processing chunk 31/1111
Processing chunk 32/1111
Processing chunk 33/1111
Processing chunk 34/1111
Processing chunk 35/1111
Processing 

In [15]:
# Analyze with proper bounds checking
fig, count_matrix, token_counts = analyze_dataset(
    plot_type='heatmap', # or 'bar'
    file_path='dataset/gsm8k.txt',
    model= model,
    tokenizer= tokenizer,
    domain='gsm8k',
    chunk_size=1024,
    num_shared_experts=1, 
    num_topk=7,
    force_recompute=False
    # layer_id=1 # only for bar plot
)

fig.show()

Computing new results for gsm8k with 1 shared experts and top-k=7
Using device: cpu
Created 517 chunks from the input file
Processing chunk 1/517
Processing chunk 2/517
Processing chunk 3/517
Processing chunk 4/517
Processing chunk 5/517
Processing chunk 6/517
Processing chunk 7/517
Processing chunk 8/517
Processing chunk 9/517
Processing chunk 10/517
Processing chunk 11/517
Processing chunk 12/517
Processing chunk 13/517
Processing chunk 14/517
Processing chunk 15/517
Processing chunk 16/517
Processing chunk 17/517
Processing chunk 18/517
Processing chunk 19/517
Processing chunk 20/517
Processing chunk 21/517
Processing chunk 22/517
Processing chunk 23/517
Processing chunk 24/517
Processing chunk 25/517
Processing chunk 26/517
Processing chunk 27/517
Processing chunk 28/517
Processing chunk 29/517
Processing chunk 30/517
Processing chunk 31/517
Processing chunk 32/517
Processing chunk 33/517
Processing chunk 34/517
Processing chunk 35/517
Processing chunk 36/517
Processing chunk 37/51

In [17]:
fig, count_matrix, token_counts = analyze_dataset(
    plot_type='bar',
    file_path='dataset/arc_easy.txt',
    model= model,
    tokenizer= tokenizer,
    domain='arc_easy',
    chunk_size=1024,
    num_shared_experts=1, 
    num_topk=7,
    force_recompute=False,
    layer_id=1 # only for bar plot
)

fig.show()

Loading existing results for arc_easy with 1 shared experts and top-k=7
