this notebook is to analyze the perplexity of the next token prediction for datasets in `data/` and how it changes with different number ($\text{top-k}= 1 \rightarrow 6$) of active experts to observe whether using the single most-activated expert maintains comparable loss and prediction performance to using all 6 routed experts.

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.nn import CrossEntropyLoss
import math
from tqdm import tqdm
import csv
import pandas as pd
import plotly.express as px
import numpy as np


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
def load_model(model_name):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        # use_flash_attention_2=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model("deepseek-ai/deepseek-moe-16b-base")
model.eval()
model.to(device)

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

DeepseekForCausalLM(
  (model): DeepseekModel(
    (embed_tokens): Embedding(102400, 2048)
    (layers): ModuleList(
      (0): DeepseekDecoderLayer(
        (self_attn): DeepseekSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): DeepseekRotaryEmbedding()
        )
        (mlp): DeepseekMLP(
          (gate_proj): Linear(in_features=2048, out_features=10944, bias=False)
          (up_proj): Linear(in_features=2048, out_features=10944, bias=False)
          (down_proj): Linear(in_features=10944, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): DeepseekRMSNorm()
        (post_attention_layernorm): DeepseekRMSNorm()
      )
      (1-27): 27 x DeepseekDecod

In [15]:
def calculate_perplexity(model, tokenizer, txt_file_path, device=device, domain="code", top_k=2):
    """
    Calculate perplexity using the DeepSeek model with configurable number of experts per token.
    
    Args:
        model: The DeepSeek model
        tokenizer: The tokenizer to use
        txt_file_path: Path to the input text file containing samples
        device: The device to run on (cuda/cpu)
        domain: Domain name for output file ("code", "text", etc)
        top_k: Number of experts to select per token (1-6)
    """
    # Validate top_k parameter
    if not 1 <= top_k <= 6:
        raise ValueError("top_k must be between 1 and 6")
    
    # Configure MoE layers to use specified number of experts
    for layer in model.model.layers:
        if hasattr(layer.mlp, 'experts'):  # Check if it's an MoE layer
            layer.mlp.num_experts_per_tok = top_k
            if hasattr(layer.mlp, 'gate'):
                layer.mlp.gate.top_k = top_k

    # Read text file and get samples
    with open(txt_file_path, 'r', encoding='utf-8') as f:
        content = f.read()
    
    # Process the text file based on its content
    if 'github.txt' in txt_file_path:
        import re
        # Find all code blocks using the file pattern
        file_pattern = re.compile(r'.*\b\w+\.(js|py|c|cpp|java|ts|rb|go|rs|cs|swift|kt|php)$', re.MULTILINE)
        
        # Find all matches (file headers)
        matches = list(file_pattern.finditer(content))
        
        # Extract code blocks between file headers
        samples = []
        for i in range(len(matches)):
            start_pos = matches[i].start()
            # If this is the last match, go to the end of the file
            if i == len(matches) - 1:
                end_pos = len(content)
            else:
                end_pos = matches[i+1].start()
            
            # Extract the code block including the file header
            code_block = content[start_pos:end_pos].strip()
            samples.append(code_block)
    else:
        # Regular text file processing (one prompt per line)
        samples = [line.strip() for line in content.split('\n') if line.strip()]

    # Calculate perplexity for each sample
    perplexities = []
    file_name = f"{domain}_perplexity_top{top_k}.csv"
    
    # Count total chunks for progress bar
    total_chunks = 0
    for sample in samples:
        if not sample.strip():
            continue
        encodings = tokenizer(sample.strip(), return_tensors='pt')
        seq_len = encodings.input_ids.size(1)
        chunks = (seq_len + 2047) // 2048  # Ceiling division
        total_chunks += max(1, chunks)
    
    progress_bar = tqdm(total=total_chunks, desc=f"Processing {domain} samples (top-k={top_k})")
    
    with open(file_name, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        if csvfile.tell() == 0:
            writer.writerow([f'{domain}_num', 'perplexity', 'chunks'])
            
        for i, sample in enumerate(samples):
            if not sample.strip():
                continue
                
            # Tokenize sample
            encodings = tokenizer(sample.strip(), return_tensors='pt')
            input_ids = encodings.input_ids.to(device)
            
            # Check if we need to chunk the sequence
            seq_len = input_ids.size(1)
            if seq_len > 2048:
                # Split into chunks of 2048 tokens
                chunks = []
                for start_idx in range(0, seq_len, 2048):
                    end_idx = min(start_idx + 2048, seq_len)
                    chunks.append(input_ids[:, start_idx:end_idx])
                
                # Calculate perplexity for each chunk and average
                chunk_losses = []
                for chunk_idx, chunk in enumerate(chunks):
                    target_ids = chunk.clone()
                    
                    # Initialize loss function
                    loss_fct = CrossEntropyLoss(reduction='none')
                    
                    with torch.no_grad():
                        outputs = model(chunk)
                        logits = outputs.logits
                        
                        # Clean up CUDA memory
                        del outputs
                        torch.cuda.empty_cache()
                        
                        # Shift logits and target_ids for next-token prediction
                        shift_logits = logits[..., :-1, :].contiguous()
                        shift_target_ids = target_ids[..., 1:].contiguous()
                        
                        # Calculate loss
                        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), 
                                      shift_target_ids.view(-1)).cpu()
                        
                        chunk_losses.append(loss)
                        del shift_logits, shift_target_ids, loss
                        
                    progress_bar.update(1)
                
                # Combine losses from all chunks
                combined_loss = torch.cat(chunk_losses)
                avg_nll = combined_loss.mean()
                ppl = torch.exp(avg_nll).item()
                del combined_loss, avg_nll, chunk_losses
                
                progress_bar.set_postfix({'Perplexity': f'{ppl:.2f}', 'Chunks': len(chunks)})
                perplexities.append((i+1, ppl))
                
                writer.writerow([i+1, ppl, len(chunks)])
                
            else:
                # Process as a single chunk
                target_ids = input_ids.clone()
                
                # Initialize loss function
                loss_fct = CrossEntropyLoss(reduction='none')
                
                with torch.no_grad():
                    outputs = model(input_ids)
                    logits = outputs.logits
                    
                    # Clean up CUDA memory
                    del outputs
                    torch.cuda.empty_cache()
                    
                    # Shift logits and target_ids for next-token prediction
                    shift_logits = logits[..., :-1, :].contiguous()
                    shift_target_ids = target_ids[..., 1:].contiguous()
                    
                    # Calculate loss
                    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), 
                                  shift_target_ids.view(-1)).cpu()
                    del shift_logits, shift_target_ids
                    
                    # Calculate perplexity
                    avg_nll = loss.mean()
                    ppl = torch.exp(avg_nll).item()
                    del loss, avg_nll
                    
                    progress_bar.update(1)
                    progress_bar.set_postfix({'Perplexity': f'{ppl:.2f}', 'Chunks': 1})
                    perplexities.append((i+1, ppl))
                    
                    writer.writerow([i+1, ppl, 1])
    
    return perplexities

In [None]:
text_file_path = "data-ext/gsm8k.txt"

for k in range(6, 0, -1):
    ppl = calculate_perplexity(model=model,
                              tokenizer=tokenizer,
                              txt_file_path=text_file_path,
                              domain="gsm8k",
                              top_k=k)
    print(f"Perplexity (top-{k}): {ppl}")


In [None]:
import pandas as pd
import plotly.express as px

# Read the CSV file
df = pd.read_csv('code_perplexity_top6.csv')

# Print column names to verify
print("Available columns:", df.columns.tolist())

# Create the line plot using plotly
fig = px.line(df, x='code_num', y='perplexity', 
              markers=True,
              title='Perplexity by Code Number')

# Customize the plot
fig.update_layout(
    xaxis_title="Code Number",
    yaxis_title="Perplexity",
    xaxis=dict(showgrid=True),
    yaxis=dict(showgrid=True)
)

# Display the plot
fig.show()


In [None]:
import pandas as pd
import plotly.express as px
import numpy as np

# Read all CSV files and calculate means of log perplexities
files = [f'code_perplexity_top{k}.csv' for k in range(1,7)]
labels = [f'Top {k}' for k in range(1,7)]
log_means = []
perplexities = []

for file in files:
    # Read CSV file
    df = pd.read_csv(file)
    # Calculate mean of log perplexities
    log_perplexity = np.log(df['perplexity'])
    log_mean = log_perplexity.mean()
    log_means.append(log_mean)
    perplexities.append(np.exp(log_mean))
    
    # Print both log and perplexity space results for reference
    k = file.split('top')[1].split('.')[0]
    print(f"{k}:")
    print(f"  Mean log perplexity: {log_mean:.3f}")
    print(f"  Equivalent perplexity: {np.exp(log_mean):.3f}")

# Create DataFrame for saving
results_df = pd.DataFrame({
    'Selection': labels,
    'Log_Perplexity': log_means,
    'Perplexity': perplexities
})

# Save to CSV
results_df.to_csv('mean_perplexities.csv', index=False)

# Create line plot
fig = px.line(results_df, 
              x='Selection', 
              y='Log_Perplexity',
              markers=True,
              text=[f'{v:.3f}' for v in log_means],
              title='Mean Log Perplexity by Top-k Selection')
# Customize the plot
fig.update_layout(
    xaxis_title="Selection Method",
    yaxis_title="Mean Log Perplexity", 
    xaxis=dict(showgrid=True)
)

# Update marker and text positions
fig.update_traces(textposition="middle right")

# Display the plot
fig.show()

In [12]:
def analyze_perplexity_by_domain(domains):
    """
    Analyze perplexity across different domains and plot them together.
    
    Args:
        domains: List of domain names to analyze
    """
    all_results = []
    
    for domain in domains:
        # Read all CSV files and calculate means of log perplexities
        files = [f'data/pplx/{domain}_perplexity_top{k}.csv' for k in range(1,7)]
        labels = [f'Top {k}' for k in range(1,7)]
        log_means = []
        perplexities = []

        for file in files:
            # Read CSV file
            df = pd.read_csv(file)
            # Calculate mean of log perplexities
            log_perplexity = np.log(df['perplexity'])
            log_mean = log_perplexity.mean()
            log_means.append(log_mean)
            perplexities.append(np.exp(log_mean))
            
            # Print both log and perplexity space results for reference
            k = file.split('top')[1].split('.')[0]
            print(f"{domain} - {k}:")
            print(f"  Mean log perplexity: {log_mean:.3f}")
            print(f"  Equivalent perplexity: {np.exp(log_mean):.3f}")

        # Create DataFrame for this domain
        results_df = pd.DataFrame({
            'Selection': labels,
            'Log_Perplexity': log_means,
            'Perplexity': perplexities,
            'Domain': domain
        })
        
        all_results.append(results_df)
        
        # Save individual domain results to CSV
        results_df.to_csv(f'data/pplx-data/{domain}_mean_perplexities.csv', index=False)
    
    # Combine all domain results
    combined_results = pd.concat(all_results)
    
    # Save combined results
    combined_results.to_csv('data/pplx/all_domains_mean_perplexities.csv', index=False)
    
    # Normalize log perplexities to 0-1 scale for each domain
    normalized_results = combined_results.copy()
    
    # Group by domain and normalize log perplexity within each domain
    for domain in domains:
        domain_data = normalized_results[normalized_results['Domain'] == domain]
        min_val = domain_data['Log_Perplexity'].min()
        max_val = domain_data['Log_Perplexity'].max()
        
        # Avoid division by zero if min and max are the same
        if max_val > min_val:
            normalized_results.loc[normalized_results['Domain'] == domain, 'Normalized_Log_Perplexity'] = (
                (normalized_results.loc[normalized_results['Domain'] == domain, 'Log_Perplexity'] - min_val) / 
                (max_val - min_val)
            )
        else:
            normalized_results.loc[normalized_results['Domain'] == domain, 'Normalized_Log_Perplexity'] = 0.5
    
    # Create line plot with all domains using normalized values
    fig = px.line(normalized_results, 
                  x='Selection', 
                  y='Normalized_Log_Perplexity',
                  color='Domain',
                  markers=True,
                  text='Normalized_Log_Perplexity',
                  title='Normalized Log Perplexity by Top-k Selection Across Domains (0-1 Scale)')
    
    # Customize the plot
    fig.update_layout(
        xaxis_title="Selection Method",
        yaxis_title="Normalized Log Perplexity (0-1 Scale)", 
        xaxis=dict(showgrid=True),
        yaxis=dict(range=[-0.1, 1.1])
    )
    
    # Update marker and text positions
    fig.update_traces(textposition="middle right", texttemplate='%{text:.3f}')
    
    # Display the plot
    fig.show()


In [13]:
analyze_perplexity_by_domain(['french-qa', 'english', 'arxiv', 'aime-math', 'chinese', 'github', 'gsm8k'])

french-qa - 1:
  Mean log perplexity: 2.535
  Equivalent perplexity: 12.612
french-qa - 2:
  Mean log perplexity: 1.918
  Equivalent perplexity: 6.810
french-qa - 3:
  Mean log perplexity: 1.814
  Equivalent perplexity: 6.133
french-qa - 4:
  Mean log perplexity: 1.782
  Equivalent perplexity: 5.944
french-qa - 5:
  Mean log perplexity: 1.768
  Equivalent perplexity: 5.860
french-qa - 6:
  Mean log perplexity: 1.763
  Equivalent perplexity: 5.829
english - 1:
  Mean log perplexity: 1.522
  Equivalent perplexity: 4.579
english - 2:
  Mean log perplexity: 1.286
  Equivalent perplexity: 3.620
english - 3:
  Mean log perplexity: 1.231
  Equivalent perplexity: 3.426
english - 4:
  Mean log perplexity: 1.211
  Equivalent perplexity: 3.356
english - 5:
  Mean log perplexity: 1.202
  Equivalent perplexity: 3.328
english - 6:
  Mean log perplexity: 1.199
  Equivalent perplexity: 3.316
arxiv - 1:
  Mean log perplexity: 3.000
  Equivalent perplexity: 20.085
arxiv - 2:
  Mean log perplexity: 2.657