In [37]:
import torch
import numpy as np
import plotly.graph_objects as go
from collections import defaultdict
import torch.nn.functional as F
import json
import os
from typing import Dict, Tuple, List, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer
import math
from pathlib import Path
import datetime
from pathlib import Path
import pickle

In [2]:
def get_device():
    """Get the optimal available device"""
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        # Enable TF32 for better performance on Ampere GPUs (A100, A6000, etc)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        # Set memory allocation settings
        torch.cuda.empty_cache()
        # Enable CUDNN benchmarking for better performance
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")
    return device

In [3]:
model = AutoModelForCausalLM.from_pretrained("deepseek-ai/deepseek-moe-16b-base",
                                             trust_remote_code=True,
                                             torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
model.eval()

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

DeepseekForCausalLM(
  (model): DeepseekModel(
    (embed_tokens): Embedding(102400, 2048)
    (layers): ModuleList(
      (0): DeepseekDecoderLayer(
        (self_attn): DeepseekSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): DeepseekRotaryEmbedding()
        )
        (mlp): DeepseekMLP(
          (gate_proj): Linear(in_features=2048, out_features=10944, bias=False)
          (up_proj): Linear(in_features=2048, out_features=10944, bias=False)
          (down_proj): Linear(in_features=10944, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): DeepseekRMSNorm()
        (post_attention_layernorm): DeepseekRMSNorm()
      )
      (1-27): 27 x DeepseekDecod

----

In [5]:
class MOELens:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.activations = defaultdict(dict)
        self.hook_handles = []
        self.setup_hooks()

    def setup_hooks(self):
        def get_gate_hook(layer_idx):
            def hook(module, inp, out):
                if isinstance(out, tuple):
                    topk_idx, topk_weight, _ = out
                    
                    # Get hidden states from input
                    hidden_states = inp[0]
                    batch_size, seq_len, hidden_dim = hidden_states.shape
                    
                    # Project to vocab space to get token predictions
                    with torch.no_grad():
                        # Get expert outputs
                        expert_outputs = {}
                        for expert_idx in range(module.n_routed_experts):
                            if hasattr(self.model.model.layers[layer_idx].mlp, 'experts'):
                                expert = self.model.model.layers[layer_idx].mlp.experts[expert_idx]
                                expert_output = expert(hidden_states.view(-1, hidden_dim))
                                # Project to vocabulary space
                                logits = self.model.lm_head(expert_output)
                                top_tokens = torch.topk(logits, k=5, dim=-1)
                                expert_outputs[expert_idx] = {
                                    'token_ids': top_tokens.indices,
                                    'probs': torch.softmax(top_tokens.values, dim=-1)
                                }

                    self.activations[f'layer_{layer_idx}'] = {
                        'router_weights': topk_weight.detach(),
                        'router_indices': topk_idx.detach(),
                        'expert_outputs': expert_outputs
                    }
            return hook

        for i, layer in enumerate(self.model.model.layers):
            if hasattr(layer.mlp, 'gate'):
                handle = layer.mlp.gate.register_forward_hook(get_gate_hook(i))
                self.hook_handles.append(handle)

    def analyze_text(self, input_ids: torch.Tensor) -> dict:
        self.activations.clear()
        
        with torch.no_grad():
            outputs = self.model(input_ids)

        results = {}
        for layer_name, acts in self.activations.items():
            layer_results = {"tokens": {}}
            
            for pos in range(input_ids.shape[1]):
                token = self.tokenizer.decode([input_ids[0, pos].item()])
                expert_info = []
                
                if 'router_indices' in acts and 'router_weights' in acts:
                    indices = acts['router_indices'][pos]
                    weights = acts['router_weights'][pos]
                    expert_outputs = acts['expert_outputs']
                    
                    for idx, weight in zip(indices, weights):
                        expert_id = idx.item()
                        if expert_id in expert_outputs:
                            expert_output = expert_outputs[expert_id]
                            top_tokens = [
                                (self.tokenizer.decode([token_id.item()]), prob.item())
                                for token_id, prob in zip(
                                    expert_output['token_ids'][pos],
                                    expert_output['probs'][pos]
                                )
                            ]
                        else:
                            top_tokens = []
                            
                        expert_info.append({
                            "expert_id": expert_id,
                            "weight": weight.item(),
                            "top_tokens": top_tokens
                        })
                
                layer_results["tokens"][token] = {
                    "position": pos,
                    "expert_outputs": expert_info
                }
            
            results[layer_name] = layer_results

        return results

    def remove_hooks(self):
        for handle in self.hook_handles:
            handle.remove()
        self.hook_handles = []

In [6]:
def visualize_layer_analysis(tokenizer, results: Dict, token_position: int, input_text: str):
    """
    Creates a plotly visualization of expert activations across layers for a specific token.
    Shows all 64 experts with zero weights for non-selected experts.
    """
    # Create lists to store data 
    layer_nums = []
    expert_ids = []
    weights = []
    hover_texts = []
    
    total_experts = 64  # Total number of experts in the model
    
    # Extract token we're visualizing by tokenizing input text first
    tokens = tokenizer.encode(input_text)
    token = tokenizer.decode([tokens[token_position]])  # Get tokenized token
    print(f"Visualizing token: {token}")
        
    for layer_idx in range(1, 28):  # Layers 1-27
        layer_data = results[f"layer_{layer_idx}"]
        token_data = [data for data in layer_data["tokens"].values() 
                     if data["position"] == token_position][0]
        
        # Create a mapping of expert_id to its data for this layer
        expert_map = {exp["expert_id"]: exp for exp in token_data["expert_outputs"]}
        
        # Go through all possible experts
        for expert_id in range(total_experts):
            layer_nums.append(layer_idx)
            expert_ids.append(expert_id)
            
            if expert_id in expert_map:
                # Expert was selected
                expert_data = expert_map[expert_id]
                weight = expert_data["weight"]
                top_tokens_text = "<br>".join([
                    f"{token}: {prob:.3f}" 
                    for token, prob in expert_data["top_tokens"][:5]
                ])
                hover_text = f"Layer: {layer_idx}<br>Expert: {expert_id}<br>Weight: {weight:.3f}<br>Top tokens:<br>{top_tokens_text}"
                hover_texts.append(hover_text)
            else:
                # Expert was not selected
                weight = 0
                hover_texts.append(None)  # No hover text for unselected experts
            
            weights.append(weight)
    
    # Create plotly heatmap
    fig = go.Figure(data=go.Scatter(
        x=expert_ids,
        y=layer_nums,
        mode='markers',
        marker=dict(
            size=9,
            color=weights,
            colorscale=[
                [0, 'rgba(24, 21, 23, 0.8)'],  # Very dark/transparent for zero weights
                [0.0001, 'rgb(68,1,84)'],      # Start of Viridis colorscale
                [1, 'rgb(253,231,37)']         # End of Viridis colorscale
            ],
            cmin=0,  # Set minimum of color scale to 0
            cmax=1,  # Set maximum of color scale to 1
            showscale=True,
            colorbar=dict(
                title='Weight',
                tickmode='linear',
                tick0=0,
                dtick=0.2
            ),
        ),
        text=hover_texts,
        hoverinfo='text',
        hovertemplate='%{text}<extra></extra>',  # Only show hover when text exists
    ))
    
    # Update layout with dark theme
    fig.update_layout(
        template='plotly_dark',
        title=f'Expert Activations for Token "{token}" at Position {token_position}',
        xaxis_title='Expert ID',
        yaxis_title='Layer',
        yaxis=dict(autorange='reversed'),  # Reverse y-axis to have layer 1 at top
        width=1200,
        height=800,
        showlegend=False,
        plot_bgcolor='black',
        paper_bgcolor='black'
    )
    
    # Add grid lines
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128, 128, 128, 0.2)', 
                     range=[-1, total_experts])
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128, 128, 128, 0.2)')
    
    fig.show()

In [81]:
def plot_enhanced_logit_lens(model_outputs, tokenizer, input_text, position):
    """
    Creates an enhanced logit lens visualization showing expert activations and their token predictions.
    """
    # Get all unique tokens predicted by experts across layers
    all_tokens = set()
    token_probs = defaultdict(lambda: defaultdict(float))
    layers = []
    
    # Process each layer's outputs
    for layer_idx in range(1, 28):  # Layers 1-27
        layer_key = f'layer_{layer_idx}'
        if layer_key not in model_outputs:
            continue
            
        layer_data = model_outputs[layer_key]
        if 'tokens' not in layer_data:
            continue
            
        token_data = None
        # Find the token data for the specified position
        for data in layer_data['tokens'].values():
            if data['position'] == position:
                token_data = data
                break
                
        if token_data is None:
            continue
            
        layers.append(layer_idx)
        
        # Collect tokens and their probabilities from expert outputs
        for expert_output in token_data['expert_outputs']:
            expert_weight = expert_output['weight']
            for token, prob in expert_output['top_tokens']:
                all_tokens.add(token)
                token_probs[layer_idx][token] += prob * expert_weight

    # Convert to sorted lists for consistent ordering
    tokens = sorted(list(all_tokens))
    
    # Create matrix of probabilities
    prob_matrix = []
    for layer_idx in layers:
        layer_probs = []
        for token in tokens:
            layer_probs.append(token_probs[layer_idx].get(token, 0.0))
        prob_matrix.append(layer_probs)

    # Create heatmap
    fig = go.Figure(data=go.Heatmap(
        z=prob_matrix,
        x=tokens,
        y=[f'Layer {idx}' for idx in layers],
        colorscale='Viridis',
        showscale=True
    ))

    # Update layout
    fig.update_layout(
        title=f'Token Probabilities Across Layers - Position {position}',
        xaxis_title='Predicted Tokens',
        yaxis_title='Layer',
        yaxis={'autorange': 'reversed'},  # Display layer 1 at top
        width=1200,
        height=800,
        plot_bgcolor='black',
        paper_bgcolor='black',
        font=dict(color='white')
    )

    # Add grid lines
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128, 128, 128, 0.2)')
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128, 128, 128, 0.2)')

    return fig

def plot_expert_contributions(model_outputs, position=None):
    """
    Creates visualization of expert contributions at each layer.
    
    Args:
        model_outputs: Dictionary containing layer outputs and expert information
        position: Optional specific position to analyze
    """
    import plotly.graph_objects as go
    import numpy as np
    
    # Extract expert data
    layers = []
    expert_weights = []
    expert_ids = []
    
    for layer_idx, layer_data in sorted(model_outputs.items()):
        if not layer_idx.startswith('layer_'):
            continue
            
        layer_num = int(layer_idx.split('_')[1])
        if position is not None:
            token_data = [data for data in layer_data["tokens"].values() 
                         if data["position"] == position][0]
        else:
            token_data = next(iter(layer_data["tokens"].values()))
            
        layers.append(f"Layer {layer_num}")
        
        # Get expert weights and ids
        layer_weights = []
        layer_ids = []
        for exp_output in token_data["expert_outputs"]:
            layer_weights.append(exp_output["weight"])
            layer_ids.append(exp_output["expert_id"])
            
        expert_weights.append(layer_weights)
        expert_ids.extend([id for id in layer_ids if id not in expert_ids])
        
    # Create matrix of expert weights
    weight_matrix = np.zeros((len(layers), len(expert_ids)))
    for i, weights in enumerate(expert_weights):
        for w, id in zip(weights, expert_ids[:len(weights)]):
            weight_matrix[i, expert_ids.index(id)] = w
            
    # Create figure
    fig = go.Figure(data=go.Heatmap(
        z=weight_matrix,
        x=[f"Expert {id}" for id in expert_ids],
        y=layers,
        colorscale='Viridis',
        text=[[f"{val:.3f}" if val > 0 else "" for val in row] for row in weight_matrix],
        texttemplate="%{text}",
        textfont={"size":10},
        showscale=True,
    ))
    
    # Update layout
    fig.update_layout(
        title=f"Expert Contributions by Layer" + (f" - Position {position}" if position else ""),
        xaxis_title="Experts",
        yaxis_title="Layers",
        height=800,
        width=1200,
    )
    
    return fig

In [8]:
def apply_moe_logit_lens_with_active_expert_outputs(model, inputs, tokenizer, num_active_experts=7):
    model.eval()

    # Initial embedding
    x = model.model.embed_tokens(inputs)

    # Process each layer
    for layer_idx, layer in enumerate(model.model.layers):
        print(f"Layer {layer_idx + 1}")
        x = layer.input_layernorm(x)

        # Self-attention output
        # Remove the position_ids argument here, let the model handle it internally
        attn_output = layer.self_attn(x, x, x)
        x = attn_output + x  # Residual connection
        x = layer.post_attention_layernorm(x)

        # MoE Layer
        moe_output = layer.mlp(x)
        gate_values = moe_output["gate_values"]
        expert_outputs = moe_output["expert_outputs"]

        # Extract active experts
        for batch_idx, gates in enumerate(gate_values):
            active_experts = gates.argsort(descending=True)[:num_active_experts]
            print(f"  Batch {batch_idx + 1}:")
            for expert_idx in active_experts:
                expert_weight = gates[expert_idx].item()
                expert_output = expert_outputs[batch_idx, :, expert_idx]
                top_tokens = expert_output.topk(5, dim=-1)
                top_indices = top_tokens.indices
                top_scores = top_tokens.values
                decoded_tokens = tokenizer.decode(top_indices.tolist())
                print(f"    Expert {expert_idx}: Weight: {expert_weight:.4f}, Tokens: {decoded_tokens}, Scores: {top_scores.tolist()}")

    # Final logits
    logits = model.lm_head(x)
    top_tokens = logits.topk(5, dim=-1)
    decoded_final_tokens = tokenizer.decode(top_tokens.indices.tolist())
    print(f"Final Layer: Tokens: {decoded_final_tokens}, Scores: {top_tokens.values.tolist()}")

In [9]:
def get_device():
    """Get the optimal available device"""
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        # Enable TF32 for better performance on Ampere GPUs (A100, A6000, etc)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        # Set memory allocation settings
        torch.cuda.empty_cache()
        # Enable CUDNN benchmarking for better performance
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")
    return device

In [10]:
def apply_moe_logit_lens_with_active_expert_outputs(model, inputs, tokenizer, num_active_experts=7):
    model.eval()

    # Initial embedding
    x = model.model.embed_tokens(inputs)

    # Process each layer
    for layer_idx, layer in enumerate(model.model.layers):
        print(f"Layer {layer_idx + 1}")
        x = layer.input_layernorm(x)

        # Self-attention output
        # Remove the position_ids argument here, let the model handle it internally
        attn_output = layer.self_attn(x, x, x)
        x = attn_output + x  # Residual connection
        x = layer.post_attention_layernorm(x)

        # MoE Layer
        moe_output = layer.mlp(x)
        gate_values = moe_output["gate_values"]
        expert_outputs = moe_output["expert_outputs"]

        # Extract active experts
        for batch_idx, gates in enumerate(gate_values):
            active_experts = gates.argsort(descending=True)[:num_active_experts]
            print(f"  Batch {batch_idx + 1}:")
            for expert_idx in active_experts:
                expert_weight = gates[expert_idx].item()
                expert_output = expert_outputs[batch_idx, :, expert_idx]
                top_tokens = expert_output.topk(5, dim=-1)
                top_indices = top_tokens.indices
                top_scores = top_tokens.values
                decoded_tokens = tokenizer.decode(top_indices.tolist())
                print(f"    Expert {expert_idx}: Weight: {expert_weight:.4f}, Tokens: {decoded_tokens}, Scores: {top_scores.tolist()}")

    # Final logits
    logits = model.lm_head(x)
    top_tokens = logits.topk(5, dim=-1)
    decoded_final_tokens = tokenizer.decode(top_tokens.indices.tolist())
    print(f"Final Layer: Tokens: {decoded_final_tokens}, Scores: {top_tokens.values.tolist()}")

In [11]:
def logit_lens(model, input_tokens):
    """
    Applies a logit lens to each layer of a mixture of experts (MoE) model for specific input tokens.

    Args:
        model (torch.nn.Module): The MoE model.
        input_tokens (torch.Tensor): Input tensor with token embeddings, shape [batch_size, seq_len].

    Returns:
        dict: A dictionary containing the logit lens outputs for each token at each layer.
    """
    logit_outputs = {}

    # Pass the input tokens through the embedding layer
    embedding_output = model.embed_tokens(input_tokens)  # Shape: [batch_size, seq_len, embedding_dim]

    # Process tokens through each layer
    for layer_idx, layer in enumerate(model.layers):
        # Apply the layer and get its output
        layer_outputs = layer(embedding_output)  # Shape: [batch_size, seq_len, feature_dim]

        # Extract the mixture of experts (MoE) components for this layer
        if hasattr(layer.mlp, "experts"):
            gate_outputs = layer.mlp.gate(embedding_output)  # Shape: [batch_size, seq_len, num_experts]
            expert_outputs = []

            # Process each token separately to extract expert-specific outputs
            for token_idx in range(input_tokens.shape[1]):
                token_expert_outputs = []

                for expert_idx, expert in enumerate(layer.mlp.experts):
                    token_input = embedding_output[:, token_idx, :]  # Shape: [batch_size, embedding_dim]
                    expert_output = expert(token_input)  # Shape: [batch_size, feature_dim]
                    token_expert_outputs.append(expert_output)

                # Stack expert outputs and apply gating
                token_expert_outputs = torch.stack(token_expert_outputs, dim=1)  # Shape: [batch_size, num_experts, feature_dim]
                token_gate_weights = gate_outputs[:, token_idx, :].unsqueeze(-1)  # Shape: [batch_size, num_experts, 1]
                activated_experts_output = torch.sum(token_expert_outputs * token_gate_weights, dim=1)  # Shape: [batch_size, feature_dim]

                # Save activated expert outputs for this token
                logit_outputs[f"layer_{layer_idx}_token_{token_idx}_activated_experts"] = activated_experts_output

        # Update the embedding for the next layer
        embedding_output = layer_outputs

    return logit_outputs

In [118]:
def save_prompt_cache(domain: str, data_dict: dict, input_text: str, cache_dir='plots-deepseek') -> int:
    """Save prompt results with domain-based caching."""
    os.makedirs(cache_dir, exist_ok=True)
    cache_info_path = Path(cache_dir) / 'domain_cache_info.json'
    
    # Load or create cache info
    if cache_info_path.exists():
        with open(cache_info_path, 'r') as f:
            cache_info = json.load(f)
    else:
        cache_info = {}
    
    # Get prompt number from existing domain or create new one
    if domain in cache_info:
        prompt_number = cache_info[domain]['prompt_number']
    else:
        existing_numbers = [info['prompt_number'] for info in cache_info.values()]
        prompt_number = max(existing_numbers, default=0) + 1
    
    # Update cache info with prompt text
    cache_info[domain] = {
        'prompt_number': prompt_number,
        'timestamp': datetime.datetime.now().isoformat(),
        'input_text': input_text  # Add the prompt text
    }
    
    # Save cache info
    with open(cache_info_path, 'w') as f:
        json.dump(cache_info, f, indent=2)
    
    # Save data using pickle
    data_path = Path(cache_dir) / f'prompt_{prompt_number}.pkl'
    with open(data_path, 'wb') as f:
        pickle.dump(data_dict, f)
        
    return prompt_number

def load_prompt_cache(prompt_number, cache_dir='plots-deepseek'):
    """Load a specific prompt's results by number"""
    matrix_path = f'{cache_dir}/prompt_{prompt_number}_matrix.npy'
    meta_path = f'{cache_dir}/prompt_{prompt_number}_meta.json'
    
    if not (os.path.exists(matrix_path) and os.path.exists(meta_path)):
        raise ValueError(f"Prompt {prompt_number} not found in cache")
        
    count_matrix = np.load(matrix_path)
    with open(meta_path, 'r') as f:
        metadata = json.load(f)
        
    return count_matrix, metadata

def list_cached_prompts(cache_dir='plots-deepseek'):
    """List all cached prompts and their metadata"""
    prompts = []
    for f in os.listdir(cache_dir):
        if f.endswith('_meta.json'):
            with open(os.path.join(cache_dir, f), 'r') as meta_file:
                prompts.append(json.load(meta_file))
    return sorted(prompts, key=lambda x: x['prompt_number'])

def plot_cached_prompt(prompt_number, plot_type='heatmap', layer_id=None, tokenizer=None, cache_dir='plots-deepseek'):
    """Plot results from a cached prompt"""
    count_matrix, metadata = load_prompt_cache(prompt_number, cache_dir)
    domain = metadata['domain']
    
    if plot_type == 'heatmap':
        fig = visualize_layer_analysis(
            tokenizer,
            count_matrix,
            domain=domain
        )
    elif plot_type == 'logit_lens':
        if layer_id is None:
            raise ValueError("layer_id is required for logit lens plots")
        fig = plot_enhanced_logit_lens(
            count_matrix,
            tokenizer,
            domain=domain,
            layer_id=layer_id
        )
    else:
        raise ValueError(f"Unknown plot type: {plot_type}")
        
    return fig

In [None]:
def migrate_cache_structure(old_cache_info):
    """
    Migrate old cache structure to new format.
    """
    new_cache = {
        'domains': {},
        'next_id': 1
    }
    
    # Handle existing entries if any
    for domain, info in old_cache_info.items():
        if domain not in ('domains', 'next_id'):  # Skip if these are already in new format
            if isinstance(info, dict):
                # Convert old format to new
                new_cache['domains'][domain] = [{
                    'prompt_number': info.get('prompt_number', new_cache['next_id']),
                    'timestamp': info.get('timestamp', datetime.datetime.now().isoformat()),
                    'input_text': info.get('input_text', 'No text stored')
                }]
                new_cache['next_id'] = max(new_cache['next_id'], info['prompt_number'] + 1)
    
    return new_cache

In [None]:
def save_prompt_cache(domain: str, data: dict, cache_dir='plots-deepseek') -> int:
    """
    Save prompt results with domain-based caching.
    """
    os.makedirs(cache_dir, exist_ok=True)
    cache_info_path = Path(cache_dir) / 'domain_cache_info.json'
    
    # Initialize or load cache info
    if cache_info_path.exists():
        try:
            with open(cache_info_path, 'r') as f:
                cache_info = json.load(f)
                # Migrate if needed
                if 'domains' not in cache_info:
                    cache_info = migrate_cache_structure(cache_info)
        except Exception as e:
            print(f"Error reading cache file, creating new: {e}")
            cache_info = {'domains': {}, 'next_id': 1}
    else:
        cache_info = {'domains': {}, 'next_id': 1}
    
    # Get next available ID
    prompt_number = cache_info['next_id']
    
    # Initialize domain if needed
    if domain not in cache_info['domains']:
        cache_info['domains'][domain] = []
    
    # Add new entry
    entry = {
        'prompt_number': prompt_number,
        'timestamp': datetime.datetime.now().isoformat(),
        'input_text': data.get('input_text', 'No text stored')
    }
    cache_info['domains'][domain].append(entry)
    cache_info['next_id'] = prompt_number + 1
    
    # Save cache info
    with open(cache_info_path, 'w') as f:
        json.dump(cache_info, f, indent=2)
    
    # Save data
    data_path = Path(cache_dir) / f'prompt_{prompt_number}.pkl'
    with open(data_path, 'wb') as f:
        pickle.dump(data, f)
        
    return prompt_number


In [None]:

def load_prompt_cache(domain: str, input_text: str, cache_dir='plots-deepseek'):
    """
    Load cached results for a domain and input text.
    """
    cache_info_path = Path(cache_dir) / 'domain_cache_info.json'
    
    if not cache_info_path.exists():
        raise ValueError(f"No cache found in {cache_dir}")
    
    try:
        with open(cache_info_path, 'r') as f:
            cache_info = json.load(f)
            # Migrate if needed
            if 'domains' not in cache_info:
                cache_info = migrate_cache_structure(cache_info)
                # Save migrated structure
                with open(cache_info_path, 'w') as f:
                    json.dump(cache_info, f, indent=2)
    except Exception as e:
        raise ValueError(f"Error reading cache file: {e}")
    
    if domain not in cache_info['domains']:
        raise ValueError(f"No cache found for domain '{domain}'")
    
    # Find matching entry
    matching_entries = [
        entry for entry in cache_info['domains'][domain]
        if entry['input_text'] == input_text
    ]
    
    if not matching_entries:
        raise ValueError(f"No matching cache entry found for text: {input_text}")
    
    # Use most recent
    latest_entry = max(matching_entries, key=lambda x: x['timestamp'])
    prompt_number = latest_entry['prompt_number']
    
    data_path = Path(cache_dir) / f'prompt_{prompt_number}.pkl'
    if not data_path.exists():
        raise ValueError(f"Cache data file missing for prompt #{prompt_number}")
    
    with open(data_path, 'rb') as f:
        data_dict = pickle.load(f)
        
    return data_dict, latest_entry


In [None]:

def list_cached_prompts(cache_dir='plots-deepseek'):
    """List all cached prompts and their metadata."""
    cache_info_path = Path(cache_dir) / 'domain_cache_info.json'
    if not cache_info_path.exists():
        return []
    
    try:
        with open(cache_info_path, 'r') as f:
            cache_info = json.load(f)
            # Migrate if needed
            if 'domains' not in cache_info:
                cache_info = migrate_cache_structure(cache_info)
    except Exception:
        return []
    
    results = []
    for domain, entries in cache_info.get('domains', {}).items():
        for entry in entries:
            results.append({
                'domain': domain,
                'prompt_number': entry['prompt_number'],
                'timestamp': entry['timestamp'],
                'input_text': entry.get('input_text', 'No text stored')
            })
    
    return sorted(results, key=lambda x: x['prompt_number'])


In [None]:
def analyze_dataset(
    text: str,
    model,
    tokenizer,
    token_position: int = None,
    domain: str = 'default',
    plot_type: str = 'both',
    force_recompute: bool = False,
    cache_dir: str = 'plots-deepseek'
):
    """
    Analyze text through the model and create visualizations.
    """
    try:
        figures = {}
        
        # Check cache if not forcing recompute
        if not force_recompute:
            try:
                cached_data, cache_info = load_prompt_cache(domain, text, cache_dir)
                print(f"Using cached results from prompt #{cache_info['prompt_number']}")
                
                # Extract results from cached data
                results = cached_data['results']
                
                # Create visualizations from cached data
                if plot_type in ['heatmap', 'both']:
                    figures['heatmap'] = visualize_layer_analysis(
                        tokenizer=tokenizer,
                        results=results,
                        token_position=token_position,
                        input_text=text
                    )
                
                if plot_type in ['logit_lens', 'both']:
                    figures['logit_lens'] = plot_enhanced_logit_lens(
                        model_outputs=results,
                        tokenizer=tokenizer,
                        input_text=text,
                        position=token_position
                    )
                    
                if plot_type in ['expert', 'both']:
                    figures['expert'] = plot_expert_contributions(
                        results,
                        position=token_position
                    )
                    
                return figures, results, None
            
            except ValueError as e:
                print(f"Cache miss: {str(e)}")
            except Exception as e:
                print(f"Error loading cache: {str(e)}")

        # Initialize the lens analyzer
        lens = MOELens(model, tokenizer)
        
        # Process the input text
        input_ids = tokenizer(text, return_tensors="pt").input_ids.to(model.device)
        results = lens.analyze_text(input_ids)
        
        # Create visualizations based on plot_type
        if plot_type in ['heatmap', 'both']:
            figures['heatmap'] = visualize_layer_analysis(
                tokenizer=tokenizer,
                results=results,
                token_position=token_position,
                input_text=text
            )
            
        if plot_type in ['logit_lens', 'both']:
            figures['logit_lens'] = plot_enhanced_logit_lens(
                model_outputs=results,
                tokenizer=tokenizer,
                input_text=text,
                position=token_position
            )
            
        if plot_type in ['expert', 'both']:
            figures['expert'] = plot_expert_contributions(
                results,
                position=token_position
            )
            
        # Cache the results with metadata
        cache_data = {
            'results': results,
            'input_text': text
        }
        prompt_number = save_prompt_cache(domain, cache_data, cache_dir)
        print(f"Results cached as prompt #{prompt_number}")
        
        return figures, results, None
        
    except Exception as e:
        print(f"Error in analyze_dataset: {str(e)}")
        import traceback
        traceback.print_exc()
        return None, None, str(e)

    finally:
        # Clean up
        if 'lens' in locals():
            lens.remove_hooks()

In [None]:
def plot_enhanced_logit_lens(model_outputs, tokenizer, input_text, position):
    """
    Creates an enhanced logit lens visualization showing expert activations and their token predictions.
    """
    # Get all unique tokens predicted by experts across layers
    all_tokens = set()
    token_probs = defaultdict(lambda: defaultdict(float))
    layers = []
    
    # Process each layer's outputs in order
    layer_keys = sorted([k for k in model_outputs.keys() if k.startswith('layer_')], 
                       key=lambda x: int(x.split('_')[1]))
                       
    for layer_key in layer_keys:
        layer_idx = int(layer_key.split('_')[1])
        layer_data = model_outputs[layer_key]
        if 'tokens' not in layer_data:
            continue
            
        token_data = None
        # Find the token data for the specified position
        for data in layer_data['tokens'].values():
            if data['position'] == position:
                token_data = data
                break
                
        if token_data is None:
            continue
            
        layers.append(layer_idx)
        
        # Collect tokens and their probabilities from expert outputs
        for expert_output in token_data['expert_outputs']:
            expert_weight = expert_output['weight']
            for token, prob in expert_output['top_tokens']:
                all_tokens.add(token)
                token_probs[layer_idx][token] += prob * expert_weight

    # Convert to sorted lists for consistent ordering
    tokens = sorted(list(all_tokens))
    
    # Create matrix of probabilities
    prob_matrix = []
    for layer_idx in layers:
        layer_probs = []
        for token in tokens:
            layer_probs.append(token_probs[layer_idx].get(token, 0.0))
        prob_matrix.append(layer_probs)

    # Create heatmap
    fig = go.Figure(data=go.Heatmap(
        z=prob_matrix,
        x=tokens,
        y=[f'Layer {idx}' for idx in layers],
        colorscale='Viridis',
        showscale=True
    ))

    # Update layout
    fig.update_layout(
        title=f'Token Probabilities Across Layers - Position {position}',
        xaxis_title='Predicted Tokens',
        yaxis_title='Layer',
        yaxis={'autorange': 'reversed'},  # Display layer 1 at top
        width=1200,
        height=800,
        plot_bgcolor='black',
        paper_bgcolor='black',
        font=dict(color='white')
    )

    # Add grid lines
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128, 128, 128, 0.2)')
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128, 128, 128, 0.2)')

    return fig

In [130]:
def plot_expert_contributions(model_outputs, position=None):
    """
    Creates detailed visualization of expert contributions with token predictions at each layer.
    Shows top token and expert weight in cells, all tokens on hover.
    """
    import plotly.graph_objects as go
    import numpy as np
    from collections import defaultdict
    
    # Extract expert data
    layers = []
    expert_weights = []
    expert_ids = []
    expert_tokens = defaultdict(dict)
    
    # Sort layer keys numerically
    layer_keys = sorted([k for k in model_outputs.keys() if k.startswith('layer_')],
                       key=lambda x: int(x.split('_')[1]))
                       
    for layer_key in layer_keys:
        layer_num = int(layer_key.split('_')[1])
        layer_data = model_outputs[layer_key]
            
        if position is not None:
            token_data = [data for data in layer_data["tokens"].values() 
                         if data["position"] == position][0]
        else:
            token_data = next(iter(layer_data["tokens"].values()))
            
        layers.append(layer_num)
        
        # Get expert weights, ids, and token predictions
        layer_weights = []
        layer_ids = []
        for exp_output in token_data["expert_outputs"]:
            layer_weights.append(exp_output["weight"])
            expert_id = exp_output["expert_id"]
            layer_ids.append(expert_id)
            
            expert_tokens[(layer_num, expert_id)] = {
                'weight': exp_output["weight"],
                'tokens': exp_output["top_tokens"]
            }
            
        expert_weights.append(layer_weights)
        expert_ids.extend([id for id in layer_ids if id not in expert_ids])
    
    expert_ids = sorted(list(set(expert_ids)))
    
    # Create matrices for weights, hover text, and display text
    weight_matrix = np.zeros((len(layers), len(expert_ids)))
    hover_text = [['' for _ in range(len(expert_ids))] for _ in range(len(layers))]
    text_matrix = [['' for _ in range(len(expert_ids))] for _ in range(len(layers))]
    
    for i, layer in enumerate(layers):
        for j, expert_id in enumerate(expert_ids):
            info = expert_tokens.get((layer, expert_id), None)
            if info and info['tokens']:
                weight = info['weight']
                weight_matrix[i, j] = weight
                
                # Get the top token
                top_token, _ = info['tokens'][0]
                
                # Create cell text: show token first, centered, then weight below
                # Using HTML for better text control
                text_matrix[i][j] = f'<span style="text-align: center">{top_token}<br>{weight:.3f}</span>'
                
                # Create hover text with all tokens and their probabilities
                hover_lines = [
                    f"Layer: {layer}",
                    f"Expert: {expert_id}",
                    f"Weight: {weight:.3f}",
                    "Top tokens:"
                ]
                hover_lines.extend(f"{token}: {prob:.3f}" for token, prob in info['tokens'])
                hover_text[i][j] = "<br>".join(hover_lines)
    
    # Calculate dimensions to make cells square
    cell_size = 50  # Base cell size in pixels
    width = cell_size * len(expert_ids)
    height = cell_size * len(layers)
    
    # Create figure
    fig = go.Figure(data=go.Heatmap(
        z=weight_matrix,
        x=[f"Expert {id}" for id in expert_ids],
        y=[f"Layer {layer}" for layer in layers],
        colorscale=[
            [0, 'rgba(24, 21, 23, 0.8)'],
            [0.0001, 'rgb(68,1,84)'],
            [1, 'rgb(253,231,37)']
        ],
        text=text_matrix,
        texttemplate="%{text}",
        textfont={"size": 9, "color": "white", "family": "monospace"},
        hoverongaps=False,
        hoverinfo="text",
        hovertext=hover_text,
        showscale=True,
    ))
    
    # Update layout with equal width and height
    fig.update_layout(
        title=f"Expert Contributions by Layer" + (f" - Position {position}" if position else ""),
        xaxis_title="Experts",
        yaxis_title="Layers",
        height=height,
        width=width,
        plot_bgcolor='black',
        paper_bgcolor='black',
        font=dict(color='white'),
        xaxis=dict(
            tickangle=45,  # Angle the expert labels
            tickfont=dict(size=9),
            tickmode='array',
            ticktext=[f"Expert {id}" for id in expert_ids],
            tickvals=list(range(len(expert_ids))),
            dtick=1,
            showgrid=True,
            gridwidth=1,
            gridcolor='rgba(128, 128, 128, 0.2)',
            showline=True,
            linewidth=1,
            linecolor='rgba(128, 128, 128, 0.2)',
            scaleanchor="y",  # This makes the x and y scales equal
            scaleratio=1
        ),
        yaxis=dict(
            autorange='reversed',  # Keep layer 1 at top
            tickfont=dict(size=9),
            dtick=1,
            showgrid=True,
            gridwidth=1,
            gridcolor='rgba(128, 128, 128, 0.2)',
            showline=True,
            linewidth=1,
            linecolor='rgba(128, 128, 128, 0.2)',
            scaleanchor="x",  # This makes the x and y scales equal
            scaleratio=1
        ),
        margin=dict(l=50, r=50, t=100, b=100)
    )
    
    return fig

----

In [136]:
# First run - analyze and cache results
text = "the quick brown fox"
figures, results, _ = analyze_dataset(
    text=text,
    model=model,
    tokenizer=tokenizer,
    token_position=4,  # To visualize "quick"
    domain="test1",
    force_recompute=False,
    plot_type='expert'
)

if figures:
    if 'heatmap' in figures: #kinda useless one but keeping it for the time being
        figures['heatmap'].show()
    if 'logit_lens' in figures:
        figures['logit_lens'].show()
    if 'expert' in figures:
        figures['expert'].show()

Cache miss: No cache found in plots-deepseek
Results cached as prompt #1
