In [1]:
import torch
import numpy as np
import plotly.graph_objects as go
from collections import defaultdict
import torch.nn.functional as F
import json
import os
from typing import Dict, Tuple, List, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import math
from pathlib import Path
import datetime
from pathlib import Path
import pickle

In [2]:
def get_device():
    """Get the optimal available device"""
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        # Enable TF32 for better performance on Ampere GPUs (A100, A6000, etc)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        # Set memory allocation settings
        torch.cuda.empty_cache()
        # Enable CUDNN benchmarking for better performance
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")
    return device

In [3]:
model = AutoModelForCausalLM.from_pretrained(
    "deepseek-ai/deepseek-moe-16b-base",
    trust_remote_code=True,
    torch_dtype=torch.float16
)

tokenizer = AutoTokenizer.from_pretrained(
    "deepseek-ai/deepseek-moe-16b-base",
    trust_remote_code=True
)

# Verify the configuration
print(f"Number of shared experts: {model.config.n_shared_experts}")
print(f"Number of routed experts: {model.config.n_routed_experts}")
print(f"Number of experts per token: {model.config.num_experts_per_tok}")

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

Number of shared experts: 2
Number of routed experts: 64
Number of experts per token: 6


----

In [67]:
class MOELens:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.activations = defaultdict(dict)
        self.hook_handles = []
        self.setup_hooks()

    def setup_hooks(self):
        def get_gate_hook(layer_idx):
            def hook(module, inp, out):
                if isinstance(out, tuple):
                    topk_idx, topk_weight, _ = out
                    
                    # Get hidden states from input
                    hidden_states = inp[0]
                    batch_size, seq_len, hidden_dim = hidden_states.shape
                    
                    # Project to vocab space to get token predictions
                    with torch.no_grad():
                        expert_outputs = {}
                        # Only process the expert indices that were actually selected
                        unique_experts = torch.unique(topk_idx)
                        for expert_idx in unique_experts:
                            expert_idx = expert_idx.item()  # Convert to int
                            if hasattr(self.model.model.layers[layer_idx].mlp, 'experts'):
                                expert = self.model.model.layers[layer_idx].mlp.experts[expert_idx]
                                expert_output = expert(hidden_states.view(-1, hidden_dim))
                                logits = self.model.lm_head(expert_output)
                                top_tokens = torch.topk(logits, k=5, dim=-1)
                                expert_outputs[expert_idx] = {
                                    'token_ids': top_tokens.indices,
                                    'probs': torch.softmax(top_tokens.values, dim=-1)
                                }

                    self.activations[f'layer_{layer_idx}'] = {
                        'router_weights': topk_weight.detach(),
                        'router_indices': topk_idx.detach(),
                        'expert_outputs': expert_outputs
                    }
            return hook

        def get_shared_expert_hook(layer_idx):
            def hook(module, inp, out):
                x = inp[0]
                batch_size, seq_len, hidden_dim = x.shape

                # Compute gate and up projections
                gate_proj = module.gate_proj(x)  # [batch, seq_len, 2816]
                up_proj = module.up_proj(x)      # [batch, seq_len, 2816]
                act = module.act_fn(gate_proj) * up_proj

                # Split into two experts (2816 = 2*1408)
                expert0_act = act[..., :1408]
                expert1_act = act[..., 1408:]

                # Project to vocabulary
                with torch.no_grad():
                    # Process first expert
                    expert0_out = module.down_proj(
                        F.pad(expert0_act, (0, 1408))  # Pad to match full width
                    )
                    logits0 = self.model.lm_head(expert0_out)
                    top_tokens0 = torch.topk(logits0, k=5, dim=-1)

                    # Process second expert
                    expert1_out = module.down_proj(
                        F.pad(expert1_act, (1408, 0))  # Pad to match full width
                    )
                    logits1 = self.model.lm_head(expert1_out)
                    top_tokens1 = torch.topk(logits1, k=5, dim=-1)

                self.activations[f'layer_{layer_idx}']['shared_experts'] = {
                    'expert0': {
                        'token_ids': top_tokens0.indices,
                        'probs': torch.softmax(top_tokens0.values, dim=-1),
                        'weight': 1.0
                    },
                    'expert1': {
                        'token_ids': top_tokens1.indices,
                        'probs': torch.softmax(top_tokens1.values, dim=-1),
                        'weight': 1.0
                    }
                }
            return hook

        for i, layer in enumerate(self.model.model.layers):
            if hasattr(layer.mlp, 'gate'):
                handle = layer.mlp.gate.register_forward_hook(get_gate_hook(i))
                self.hook_handles.append(handle)
            if hasattr(layer.mlp, 'shared_experts'):
                handle = layer.mlp.shared_experts.register_forward_hook(get_shared_expert_hook(i))
                self.hook_handles.append(handle)

    def analyze_text(self, input_ids: torch.Tensor) -> dict:
        """
        Create a combined analysis showing both expert contributions and final predictions.
        """
        self.activations.clear()
        final_logits = None
        
        def get_final_output_hook(module, inp, out):
            nonlocal final_logits
            final_logits = out

        # Add hook for final output
        final_hook = self.model.lm_head.register_forward_hook(get_final_output_hook)
        self.hook_handles.append(final_hook)

        # Forward pass
        with torch.no_grad():
            outputs = self.model(input_ids)

        # Process the outputs for each layer
        for layer_key, layer_data in self.activations.items():
            if not layer_key.startswith('layer_'):
                continue
                
            tokens_data = {}
            for pos in range(input_ids.shape[1]):
                expert_outputs = []
                
                # Process routed experts
                if 'expert_outputs' in layer_data:
                    router_indices = layer_data['router_indices'][pos]
                    router_weights = layer_data['router_weights'][pos]
                    
                    for idx, expert_idx in enumerate(router_indices):
                        expert_idx = expert_idx.item()
                        expert_data = layer_data['expert_outputs'].get(expert_idx)
                        if expert_data is not None:
                            weight = router_weights[idx].item()
                            token_ids = expert_data['token_ids'][pos]
                            probs = expert_data['probs'][pos]
                            
                            expert_outputs.append({
                                "expert_id": expert_idx,
                                "weight": weight,
                                "top_tokens": [
                                    (self.tokenizer.decode([tid.item()]), prob.item())
                                    for tid, prob in zip(token_ids, probs)
                                ],
                                "expert_type": "routed"
                            })
                
                # Process shared experts
                if 'shared_experts' in layer_data:
                    for expert_key in ['expert0', 'expert1']:
                        expert_data = layer_data['shared_experts'][expert_key]
                        token_ids = expert_data['token_ids'][0, pos]
                        probs = expert_data['probs'][0, pos]
                        
                        expert_outputs.append({
                            'expert_id': f'S{expert_key[-1]}',
                            'weight': expert_data['weight'],
                            'top_tokens': [
                                (self.tokenizer.decode([tid.item()]), prob.item())
                                for tid, prob in zip(token_ids, probs)
                            ],
                            'expert_type': 'shared'
                        })
                
                tokens_data[f"token_{pos}"] = {
                    "position": pos,
                    "expert_outputs": expert_outputs
                }
                
            self.activations[layer_key]["tokens"] = tokens_data

        # Add final predictions
        if final_logits is not None:
            final_predictions = {}
            for pos in range(input_ids.shape[1]):
                logits = final_logits[0, pos]
                top_values, top_indices = torch.topk(logits, k=5)
                probs = torch.softmax(top_values, dim=-1)
                
                final_predictions[pos] = {
                    'token_ids': top_indices,
                    'probs': probs,
                    'tokens': [self.tokenizer.decode([idx.item()]) for idx in top_indices]
                }
            self.activations['final_predictions'] = final_predictions

        return dict(self.activations)

    def remove_hooks(self):
        for handle in self.hook_handles:
            handle.remove()
        self.hook_handles = []

fixed for shared experts and routed exp

In [69]:
def visualize_layer_analysis(tokenizer, results: Dict, token_position: int, input_text: str):
    """
    Modified version that includes final hidden state layer in visualization 
    """
    layer_nums = []
    expert_ids = []
    weights = []
    hover_texts = []
    colors = []  # Add colors array for marker colors
    
    n_routed_experts = 6  # Number of routed experts
    n_shared_experts = 2  # Number of shared experts
    
    # Extract token we're visualizing
    tokens = tokenizer.encode(input_text)
    token = tokenizer.decode([tokens[token_position]])
    print(f"Visualizing token: {token}")
    
    # Add final hidden layer to layer range
    layers = list(range(1, 28)) + ['hidden']  # Add 'hidden' as final layer
        
    for layer_idx in layers:
        layer_data = results.get(f"layer_{layer_idx}", {})
        if not layer_data:
            # For hidden state layer, get the final hidden state before lm_head
            if layer_idx == 'hidden':
                # Get hidden state representation
                hidden_data = results.get("final_hidden", {})
                if hidden_data:
                    layer_nums.append('hidden')
                    expert_ids.append("H")  # Mark as hidden layer
                    weight = hidden_data.get('weight', 0)
                    weights.append(weight)
                    colors.append('rgba(253,231,37,0.8)')  # Yellow color for hidden
                    
                    top_tokens_text = "<br>".join([
                        f"{token}: {prob:.3f}" 
                        for token, prob in hidden_data["top_tokens"][:5]
                    ])
                    hover_text = f"Final Hidden State<br>Weight: {weight:.3f}<br>Top tokens:<br>{top_tokens_text}"
                    hover_texts.append(hover_text)
            continue

        token_data = None
        for data in layer_data.get("tokens", {}).values():
            if data["position"] == token_position:
                token_data = data
                break
                
        if not token_data:
            continue

        # Process routed experts
        expert_map = {
            exp["expert_id"]: exp 
            for exp in token_data["expert_outputs"] 
            if exp["expert_type"] == "routed"
        }
        
        # Go through routed experts
        for expert_id in range(n_routed_experts):
            layer_nums.append(layer_idx)
            expert_ids.append(f"R{expert_id}")  # Prefix with R for routed
            
            if expert_id in expert_map:
                expert_data = expert_map[expert_id]
                weight = expert_data["weight"]
                top_tokens_text = "<br>".join([
                    f"{token}: {prob:.3f}" 
                    for token, prob in expert_data["top_tokens"][:5]
                ])
                hover_text = f"Layer: {layer_idx}<br>Routed Expert: {expert_id}<br>Weight: {weight:.3f}<br>Top tokens:<br>{top_tokens_text}"
                hover_texts.append(hover_text)
            else:
                weight = 0
                hover_texts.append(None)
            
            weights.append(weight)
            colors.append('rgba(253,231,37,0.8)')  # Yellow color for routed
            
        # Process shared experts
        shared_experts = [
            exp for exp in token_data["expert_outputs"]
            if exp.get("expert_type") == "shared"
        ]
        for exp in shared_experts:
            expert_id = exp['expert_id']  # S0 or S1
            weight = exp['weight']
            top_tokens = exp['top_tokens']
            
            layer_nums.append(layer_idx)
            expert_ids.append(expert_id)
            weights.append(weight)
            colors.append('rgba(0,176,246,0.8)')  # Blue color for shared experts
            
            top_tokens_text = "<br>".join([
                f"{token}: {prob:.3f}" 
                for token, prob in top_tokens[:5]
            ])
            hover_text = f"Layer: {layer_idx}<br>Shared Expert: {expert_id}<br>Weight: {weight:.3f}<br>Top tokens:<br>{top_tokens_text}"
            hover_texts.append(hover_text)
    
    # Create plotly scatter plot
    fig = go.Figure(data=go.Scatter(
        x=expert_ids,
        y=layer_nums,
        mode='markers',
        marker=dict(
            size=9,
            color=colors,  # Use the colors array instead of colorscale
            showscale=True,
            colorbar=dict(
                title='Weight',
                tickmode='linear',
                tick0=0,
                dtick=0.2
            ),
        ),
        text=hover_texts,
        hoverinfo='text',
        hovertemplate='%{text}<extra></extra>',
    ))
    
    # Update layout
    fig.update_layout(
        template='plotly_dark',
        title=f'Expert Activations for Token "{token}" at Position {token_position}',
        xaxis_title='Expert ID (R: Routed, S: Shared, H: Hidden)',
        yaxis_title='Layer',
        yaxis=dict(autorange='reversed'),
        width=1200,
        height=850,  # Increased height to accommodate extra row
        showlegend=False,
        plot_bgcolor='black',
        paper_bgcolor='black'
    )
    
    # Add grid lines
    fig.update_xaxes(
        showgrid=True,
        gridwidth=1,
        gridcolor='rgba(128, 128, 128, 0.2)',
        tickmode='array',
        ticktext=[f"R{i}" for i in range(n_routed_experts)] + ["S0", "S1", "H"],
        tickvals=list(range(n_routed_experts + 3))
    )
    fig.update_yaxes(
        showgrid=True,
        gridwidth=1,
        gridcolor='rgba(128, 128, 128, 0.2)'
    )
    
    return fig

def plot_enhanced_logit_lens(model_outputs, tokenizer, input_text, position):
    """
    Creates an enhanced logit lens visualization showing expert activations and their token predictions.
    Handles both routed and shared experts.
    """
    all_tokens = set()
    token_probs = defaultdict(lambda: defaultdict(float))
    layers = []
    
    layer_keys = sorted(
        [k for k in model_outputs.keys() if k.startswith('layer_')],
        key=lambda x: int(x.split('_')[1])
    )
    
    for layer_key in layer_keys:
        layer_idx = int(layer_key.split('_')[1])
        layer_data = model_outputs[layer_key]
        
        token_data = None
        for data in layer_data.get("tokens", {}).values():
            if data["position"] == position:
                token_data = data
                break
                
        if not token_data:
            continue
            
        layers.append(layer_idx)
        
        # Process both routed and shared experts
        for expert_output in token_data["expert_outputs"]:
            weight = expert_output["weight"]
            expert_type = expert_output["expert_type"]
            
            for token, prob in expert_output["top_tokens"]:
                all_tokens.add(token)
                # For shared experts, we might want to weight differently
                if expert_type == "shared":
                    token_probs[layer_idx][token] += prob * 0.5  # Adjust weight for shared experts
                else:
                    token_probs[layer_idx][token] += prob * weight

    tokens = sorted(list(all_tokens))
    prob_matrix = []
    
    for layer_idx in layers:
        layer_probs = []
        for token in tokens:
            layer_probs.append(token_probs[layer_idx].get(token, 0.0))
        prob_matrix.append(layer_probs)

    fig = go.Figure(data=go.Heatmap(
        z=prob_matrix,
        x=tokens,
        y=[f'Layer {idx}' for idx in layers],
        colorscale='Viridis',
        showscale=True
    ))

    fig.update_layout(
        title=f'Token Probabilities Across Layers - Position {position}',
        xaxis_title='Predicted Tokens',
        yaxis_title='Layer',
        yaxis={'autorange': 'reversed'},
        width=1200,
        height=800,
        plot_bgcolor='black',
        paper_bgcolor='black',
        font=dict(color='white')
    )

    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128, 128, 128, 0.2)')
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128, 128, 128, 0.2)')

    return fig

In [70]:
def apply_moe_logit_lens_with_active_expert_outputs(model, inputs, tokenizer, num_active_experts=7):
    model.eval()

    # Initial embedding
    x = model.model.embed_tokens(inputs)

    # Process each layer
    for layer_idx, layer in enumerate(model.model.layers):
        print(f"Layer {layer_idx + 1}")
        x = layer.input_layernorm(x)

        # Self-attention output
        # Remove the position_ids argument here, let the model handle it internally
        attn_output = layer.self_attn(x, x, x)
        x = attn_output + x  # Residual connection
        x = layer.post_attention_layernorm(x)

        # MoE Layer
        moe_output = layer.mlp(x)
        gate_values = moe_output["gate_values"]
        expert_outputs = moe_output["expert_outputs"]

        # Extract active experts
        for batch_idx, gates in enumerate(gate_values):
            active_experts = gates.argsort(descending=True)[:num_active_experts]
            print(f"  Batch {batch_idx + 1}:")
            for expert_idx in active_experts:
                expert_weight = gates[expert_idx].item()
                expert_output = expert_outputs[batch_idx, :, expert_idx]
                top_tokens = expert_output.topk(5, dim=-1)
                top_indices = top_tokens.indices
                top_scores = top_tokens.values
                decoded_tokens = tokenizer.decode(top_indices.tolist())
                print(f"    Expert {expert_idx}: Weight: {expert_weight:.4f}, Tokens: {decoded_tokens}, Scores: {top_scores.tolist()}")

    # Final logits
    logits = model.lm_head(x)
    top_tokens = logits.topk(5, dim=-1)
    decoded_final_tokens = tokenizer.decode(top_tokens.indices.tolist())
    print(f"Final Layer: Tokens: {decoded_final_tokens}, Scores: {top_tokens.values.tolist()}")

In [71]:
def get_device():
    """Get the optimal available device"""
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        # Enable TF32 for better performance on Ampere GPUs (A100, A6000, etc)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        # Set memory allocation settings
        torch.cuda.empty_cache()
        # Enable CUDNN benchmarking for better performance
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")
    return device

In [72]:
def apply_moe_logit_lens_with_active_expert_outputs(model, inputs, tokenizer, num_active_experts=7):
    model.eval()

    # Initial embedding
    x = model.model.embed_tokens(inputs)

    # Process each layer
    for layer_idx, layer in enumerate(model.model.layers):
        print(f"Layer {layer_idx + 1}")
        x = layer.input_layernorm(x)

        # Self-attention output
        # Remove the position_ids argument here, let the model handle it internally
        attn_output = layer.self_attn(x, x, x)
        x = attn_output + x  # Residual connection
        x = layer.post_attention_layernorm(x)

        # MoE Layer
        moe_output = layer.mlp(x)
        gate_values = moe_output["gate_values"]
        expert_outputs = moe_output["expert_outputs"]

        # Extract active experts
        for batch_idx, gates in enumerate(gate_values):
            active_experts = gates.argsort(descending=True)[:num_active_experts]
            print(f"  Batch {batch_idx + 1}:")
            for expert_idx in active_experts:
                expert_weight = gates[expert_idx].item()
                expert_output = expert_outputs[batch_idx, :, expert_idx]
                top_tokens = expert_output.topk(5, dim=-1)
                top_indices = top_tokens.indices
                top_scores = top_tokens.values
                decoded_tokens = tokenizer.decode(top_indices.tolist())
                print(f"    Expert {expert_idx}: Weight: {expert_weight:.4f}, Tokens: {decoded_tokens}, Scores: {top_scores.tolist()}")

    # Final logits
    logits = model.lm_head(x)
    top_tokens = logits.topk(5, dim=-1)
    decoded_final_tokens = tokenizer.decode(top_tokens.indices.tolist())
    print(f"Final Layer: Tokens: {decoded_final_tokens}, Scores: {top_tokens.values.tolist()}")

In [73]:
def logit_lens(model, input_tokens):
    """
    Applies a logit lens to each layer of a mixture of experts (MoE) model for specific input tokens.

    Args:
        model (torch.nn.Module): The MoE model.
        input_tokens (torch.Tensor): Input tensor with token embeddings, shape [batch_size, seq_len].

    Returns:
        dict: A dictionary containing the logit lens outputs for each token at each layer.
    """
    logit_outputs = {}

    # Pass the input tokens through the embedding layer
    embedding_output = model.embed_tokens(input_tokens)  # Shape: [batch_size, seq_len, embedding_dim]

    # Process tokens through each layer
    for layer_idx, layer in enumerate(model.layers):
        # Apply the layer and get its output
        layer_outputs = layer(embedding_output)  # Shape: [batch_size, seq_len, feature_dim]

        # Extract the mixture of experts (MoE) components for this layer
        if hasattr(layer.mlp, "experts"):
            gate_outputs = layer.mlp.gate(embedding_output)  # Shape: [batch_size, seq_len, num_experts]
            expert_outputs = []

            # Process each token separately to extract expert-specific outputs
            for token_idx in range(input_tokens.shape[1]):
                token_expert_outputs = []

                for expert_idx, expert in enumerate(layer.mlp.experts):
                    token_input = embedding_output[:, token_idx, :]  # Shape: [batch_size, embedding_dim]
                    expert_output = expert(token_input)  # Shape: [batch_size, feature_dim]
                    token_expert_outputs.append(expert_output)

                # Stack expert outputs and apply gating
                token_expert_outputs = torch.stack(token_expert_outputs, dim=1)  # Shape: [batch_size, num_experts, feature_dim]
                token_gate_weights = gate_outputs[:, token_idx, :].unsqueeze(-1)  # Shape: [batch_size, num_experts, 1]
                activated_experts_output = torch.sum(token_expert_outputs * token_gate_weights, dim=1)  # Shape: [batch_size, feature_dim]

                # Save activated expert outputs for this token
                logit_outputs[f"layer_{layer_idx}_token_{token_idx}_activated_experts"] = activated_experts_output

        # Update the embedding for the next layer
        embedding_output = layer_outputs

    return logit_outputs

In [74]:
def save_prompt_cache(domain: str, data_dict: dict, input_text: str, cache_dir='plots-deepseek') -> int:
    """Save prompt results with domain-based caching."""
    os.makedirs(cache_dir, exist_ok=True)
    cache_info_path = Path(cache_dir) / 'domain_cache_info.json'
    
    # Load or create cache info
    if cache_info_path.exists():
        with open(cache_info_path, 'r') as f:
            cache_info = json.load(f)
    else:
        cache_info = {}
    
    # Get prompt number from existing domain or create new one
    if domain in cache_info:
        prompt_number = cache_info[domain]['prompt_number']
    else:
        existing_numbers = [info['prompt_number'] for info in cache_info.values()]
        prompt_number = max(existing_numbers, default=0) + 1
    
    # Update cache info with prompt text
    cache_info[domain] = {
        'prompt_number': prompt_number,
        'timestamp': datetime.datetime.now().isoformat(),
        'input_text': input_text  # Add the prompt text
    }
    
    # Save cache info
    with open(cache_info_path, 'w') as f:
        json.dump(cache_info, f, indent=2)
    
    # Save data using pickle
    data_path = Path(cache_dir) / f'prompt_{prompt_number}.pkl'
    with open(data_path, 'wb') as f:
        pickle.dump(data_dict, f)
        
    return prompt_number

def load_prompt_cache(prompt_number, cache_dir='plots-deepseek'):
    """Load a specific prompt's results by number"""
    matrix_path = f'{cache_dir}/prompt_{prompt_number}_matrix.npy'
    meta_path = f'{cache_dir}/prompt_{prompt_number}_meta.json'
    
    if not (os.path.exists(matrix_path) and os.path.exists(meta_path)):
        raise ValueError(f"Prompt {prompt_number} not found in cache")
        
    count_matrix = np.load(matrix_path)
    with open(meta_path, 'r') as f:
        metadata = json.load(f)
        
    return count_matrix, metadata

def list_cached_prompts(cache_dir='plots-deepseek'):
    """List all cached prompts and their metadata"""
    prompts = []
    for f in os.listdir(cache_dir):
        if f.endswith('_meta.json'):
            with open(os.path.join(cache_dir, f), 'r') as meta_file:
                prompts.append(json.load(meta_file))
    return sorted(prompts, key=lambda x: x['prompt_number'])

def plot_cached_prompt(prompt_number, plot_type='heatmap', layer_id=None, tokenizer=None, cache_dir='plots-deepseek'):
    """Plot results from a cached prompt"""
    count_matrix, metadata = load_prompt_cache(prompt_number, cache_dir)
    domain = metadata['domain']
    
    if plot_type == 'heatmap':
        fig = visualize_layer_analysis(
            tokenizer,
            count_matrix,
            domain=domain
        )
    elif plot_type == 'logit_lens':
        if layer_id is None:
            raise ValueError("layer_id is required for logit lens plots")
        fig = plot_enhanced_logit_lens(
            count_matrix,
            tokenizer,
            domain=domain,
            layer_id=layer_id
        )
    else:
        raise ValueError(f"Unknown plot type: {plot_type}")
        
    return fig

In [75]:
def migrate_cache_structure(old_cache_info):
    """
    Migrate old cache structure to new format.
    """
    new_cache = {
        'domains': {},
        'next_id': 1
    }
    
    # Handle existing entries if any
    for domain, info in old_cache_info.items():
        if domain not in ('domains', 'next_id'):  # Skip if these are already in new format
            if isinstance(info, dict):
                # Convert old format to new
                new_cache['domains'][domain] = [{
                    'prompt_number': info.get('prompt_number', new_cache['next_id']),
                    'timestamp': info.get('timestamp', datetime.datetime.now().isoformat()),
                    'input_text': info.get('input_text', 'No text stored')
                }]
                new_cache['next_id'] = max(new_cache['next_id'], info['prompt_number'] + 1)
    
    return new_cache

In [76]:
def save_prompt_cache(domain: str, data: dict, cache_dir='plots-deepseek') -> int:
    """
    Save prompt results with domain-based caching.
    """
    os.makedirs(cache_dir, exist_ok=True)
    cache_info_path = Path(cache_dir) / 'domain_cache_info.json'
    
    # Initialize or load cache info
    if cache_info_path.exists():
        try:
            with open(cache_info_path, 'r') as f:
                cache_info = json.load(f)
                # Migrate if needed
                if 'domains' not in cache_info:
                    cache_info = migrate_cache_structure(cache_info)
        except Exception as e:
            print(f"Error reading cache file, creating new: {e}")
            cache_info = {'domains': {}, 'next_id': 1}
    else:
        cache_info = {'domains': {}, 'next_id': 1}
    
    # Get next available ID
    prompt_number = cache_info['next_id']
    
    # Initialize domain if needed
    if domain not in cache_info['domains']:
        cache_info['domains'][domain] = []
    
    # Add new entry
    entry = {
        'prompt_number': prompt_number,
        'timestamp': datetime.datetime.now().isoformat(),
        'input_text': data.get('input_text', 'No text stored')
    }
    cache_info['domains'][domain].append(entry)
    cache_info['next_id'] = prompt_number + 1
    
    # Save cache info
    with open(cache_info_path, 'w') as f:
        json.dump(cache_info, f, indent=2)
    
    # Save data
    data_path = Path(cache_dir) / f'prompt_{prompt_number}.pkl'
    with open(data_path, 'wb') as f:
        pickle.dump(data, f)
        
    return prompt_number


In [77]:

def load_prompt_cache(domain: str, input_text: str, cache_dir='plots-deepseek'):
    """
    Load cached results for a domain and input text.
    """
    cache_info_path = Path(cache_dir) / 'domain_cache_info.json'
    
    if not cache_info_path.exists():
        raise ValueError(f"No cache found in {cache_dir}")
    
    try:
        with open(cache_info_path, 'r') as f:
            cache_info = json.load(f)
            # Migrate if needed
            if 'domains' not in cache_info:
                cache_info = migrate_cache_structure(cache_info)
                # Save migrated structure
                with open(cache_info_path, 'w') as f:
                    json.dump(cache_info, f, indent=2)
    except Exception as e:
        raise ValueError(f"Error reading cache file: {e}")
    
    if domain not in cache_info['domains']:
        raise ValueError(f"No cache found for domain '{domain}'")
    
    # Find matching entry
    matching_entries = [
        entry for entry in cache_info['domains'][domain]
        if entry['input_text'] == input_text
    ]
    
    if not matching_entries:
        raise ValueError(f"No matching cache entry found for text: {input_text}")
    
    # Use most recent
    latest_entry = max(matching_entries, key=lambda x: x['timestamp'])
    prompt_number = latest_entry['prompt_number']
    
    data_path = Path(cache_dir) / f'prompt_{prompt_number}.pkl'
    if not data_path.exists():
        raise ValueError(f"Cache data file missing for prompt #{prompt_number}")
    
    with open(data_path, 'rb') as f:
        data_dict = pickle.load(f)
        
    return data_dict, latest_entry


In [78]:
def list_cached_prompts(cache_dir='plots-deepseek'):
    """List all cached prompts and their metadata."""
    cache_info_path = Path(cache_dir) / 'domain_cache_info.json'
    if not cache_info_path.exists():
        return []
    
    try:
        with open(cache_info_path, 'r') as f:
            cache_info = json.load(f)
            # Migrate if needed
            if 'domains' not in cache_info:
                cache_info = migrate_cache_structure(cache_info)
    except Exception:
        return []
    
    results = []
    for domain, entries in cache_info.get('domains', {}).items():
        for entry in entries:
            results.append({
                'domain': domain,
                'prompt_number': entry['prompt_number'],
                'timestamp': entry['timestamp'],
                'input_text': entry.get('input_text', 'No text stored')
            })
    
    return sorted(results, key=lambda x: x['prompt_number'])


In [79]:
def plot_enhanced_logit_lens(model_outputs, tokenizer, input_text, position):
    """
    Creates an enhanced logit lens visualization showing expert activations and their token predictions.
    """
    # Get all unique tokens predicted by experts across layers
    all_tokens = set()
    token_probs = defaultdict(lambda: defaultdict(float))
    layers = []
    
    # Process each layer's outputs in order
    layer_keys = sorted([k for k in model_outputs.keys() if k.startswith('layer_')], 
                       key=lambda x: int(x.split('_')[1]))
                       
    for layer_key in layer_keys:
        layer_idx = int(layer_key.split('_')[1])
        layer_data = model_outputs[layer_key]
        if 'tokens' not in layer_data:
            continue
            
        token_data = None
        # Find the token data for the specified position
        for data in layer_data['tokens'].values():
            if data['position'] == position:
                token_data = data
                break
                
        if token_data is None:
            continue
            
        layers.append(layer_idx)
        
        # Collect tokens and their probabilities from expert outputs
        for expert_output in token_data['expert_outputs']:
            expert_weight = expert_output['weight']
            for token, prob in expert_output['top_tokens']:
                all_tokens.add(token)
                token_probs[layer_idx][token] += prob * expert_weight

    # Convert to sorted lists for consistent ordering
    tokens = sorted(list(all_tokens))
    
    # Create matrix of probabilities
    prob_matrix = []
    for layer_idx in layers:
        layer_probs = []
        for token in tokens:
            layer_probs.append(token_probs[layer_idx].get(token, 0.0))
        prob_matrix.append(layer_probs)

    # Create heatmap
    fig = go.Figure(data=go.Heatmap(
        z=prob_matrix,
        x=tokens,
        y=[f'Layer {idx}' for idx in layers],
        colorscale='Viridis',
        showscale=True
    ))

    # Update layout
    fig.update_layout(
        title=f'Token Probabilities Across Layers - Position {position}',
        xaxis_title='Predicted Tokens',
        yaxis_title='Layer',
        yaxis={'autorange': 'reversed'},  # Display layer 1 at top
        width=1200,
        height=800,
        plot_bgcolor='black',
        paper_bgcolor='black',
        font=dict(color='white')
    )

    # Add grid lines
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128, 128, 128, 0.2)')
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128, 128, 128, 0.2)')

    return fig

In [119]:
def plot_expert_and_final_predictions(model_outputs, tokenizer, position=None, title=None):
    """
    Enhanced visualization showing both expert contributions and final predictions with improved readability.
    """
    layers = []
    expert_ids = []
    expert_tokens = defaultdict(dict)
    
    # Process layer outputs
    layer_keys = sorted([k for k in model_outputs.keys() if k.startswith('layer_')],
                       key=lambda x: int(x.split('_')[1]))
    
    for layer_key in layer_keys:
        layer_num = int(layer_key.split('_')[1])
        layer_data = model_outputs[layer_key]
        
        # Find relevant token data
        token_data = None
        for data in layer_data.get("tokens", {}).values():
            if data["position"] == position:
                token_data = data
                break
                
        if not token_data:
            continue
            
        layers.append(layer_num)
        
        # Process experts
        for expert_output in token_data["expert_outputs"]:
            expert_type = expert_output["expert_type"]
            if expert_type == "routed":
                expert_id = f"R{expert_output['expert_id']}"
            else:
                expert_id = expert_output["expert_id"]  # Already has 'S0' or 'S1' format
                
            weight = expert_output["weight"]
            tokens = expert_output["top_tokens"]
            
            expert_tokens[(layer_num, expert_id)] = {
                'weight': weight,
                'tokens': tokens
            }
            if expert_id not in expert_ids:
                expert_ids.append(expert_id)
    
    # Add final predictions
    if 'final_predictions' in model_outputs and position in model_outputs['final_predictions']:
        final_preds = model_outputs['final_predictions'][position]
        final_layer = max(layers) + 1
        layers.append(final_layer)
        
        for i, (token_id, prob) in enumerate(zip(final_preds['token_ids'], final_preds['probs'])):
            expert_id = f'Final_{i+1}'
            if expert_id not in expert_ids:
                expert_ids.append(expert_id)
            token = final_preds['tokens'][i]
            expert_tokens[(final_layer, expert_id)] = {
                'weight': prob.item(),
                'tokens': [(token, prob.item())]
            }
    
    # Sort expert IDs
    expert_ids = sorted(list(set(expert_ids)), key=str)
    
    # Create matrices for visualization
    weight_matrix = np.zeros((len(layers), len(expert_ids)))
    hover_text = [['' for _ in range(len(expert_ids))] for _ in range(len(layers))]
    text_matrix = [['' for _ in range(len(expert_ids))] for _ in range(len(layers))]
    
    # Color scheme
    colors = []
    for expert_id in expert_ids:
        if expert_id.startswith('R'):
            colors.append('rgba(253,231,37,0.9)')  # Yellow for routed
        elif expert_id.startswith('S'):
            colors.append('rgba(0,176,246,0.9)')   # Blue for shared
        else:
            colors.append('rgba(94,201,98,0.9)')   # Green for final predictions
    
    # Fill matrices
    for i, layer in enumerate(layers):
        for j, expert_id in enumerate(expert_ids):
            info = expert_tokens.get((layer, expert_id), None)
            if info and info['tokens']:
                weight = info['weight']
                weight_matrix[i, j] = weight
                
                top_token, _ = info['tokens'][0]
                text_matrix[i][j] = f'{top_token}<br>{weight:.3f}'
                
                hover_lines = [
                    f"Layer: {'Final' if 'Final' in expert_id else layer}",
                    f"Expert: {expert_id}",
                    f"Weight: {weight:.3f}",
                    "Top tokens:"
                ]
                hover_lines.extend(f"{token}: {prob:.3f}" for token, prob in info['tokens'])
                hover_text[i][j] = "<br>".join(hover_lines)
    
    # Create figure with increased size and improved readability
    fig = go.Figure(data=go.Heatmap(
        z=weight_matrix,
        x=expert_ids,
        y=[f"Layer {layer}" if layer != max(layers) else "Final Layer" for layer in layers],
        colorscale=[
            [0, 'rgba(24, 21, 23, 0.9)'],
            [0.0001, 'rgb(68,1,84)'],
            [1, 'rgb(242, 121, 53)'] #for shared experts
        ],
        text=text_matrix,
        texttemplate="%{text}",
        textfont={"size": 16, "color": "white", "family": "Arial"},
        hoverongaps=False,
        hoverinfo="text",
        hovertext=hover_text,
        showscale=True,
    ))
    
    # Update layout with improved readability
    title = title or f"Expert Contributions and Final Predictions" + (f" - Position {position}" if position else "")
    fig.update_layout(
        title=dict(
            text=title,
            font=dict(size=24)
        ),
        xaxis_title=dict(
            text="Experts",
            font=dict(size=20)
        ),
        yaxis_title=dict(
            text="Layers",
            font=dict(size=20)
        ),
        height=2000,  # Increased height
        width=4000,   # Increased width
        plot_bgcolor='black',
        paper_bgcolor='black',
        font=dict(
            color='white',
            size=16
        ),
        xaxis=dict(
            tickangle=45,
            tickfont=dict(size=16),
            showgrid=True,
            gridwidth=1,
            gridcolor='rgba(128, 128, 128, 0.2)',
            showline=True,
            linewidth=1,
            linecolor='rgba(128, 128, 128, 0.2)'
        ),
        yaxis=dict(
            autorange='reversed',
            tickfont=dict(size=16),
            showgrid=True,
            gridwidth=1,
            gridcolor='rgba(128, 128, 128, 0.2)',
            showline=True,
            linewidth=1,
            linecolor='rgba(128, 128, 128, 0.2)'
        ),
        margin=dict(l=120, r=120, t=160, b=120)  # Increased margins
    )
    
    return fig

In [82]:
def analyze_dataset(
    text: str,
    model,
    tokenizer,
    token_position: int = None,
    domain: str = 'default',
    plot_type: str = 'both',
    force_recompute: bool = False,
    cache_dir: str = 'plots-deepseek',
    display_plot: bool = True
):
    """
    Analyze text through the model and create visualizations.
    """
    try:
        figures = {}
        
        # Check cache if not forcing recompute
        if not force_recompute:
            try:
                cached_data, cache_info = load_prompt_cache(domain, text, cache_dir)
                print(f"Using cached results from prompt #{cache_info['prompt_number']}")
                
                # Extract results from cached data
                results = cached_data['results']
                
                # Create visualizations from cached data
                if plot_type in ['heatmap', 'both']:
                    figures['heatmap'] = visualize_layer_analysis(
                        tokenizer=tokenizer,
                        results=results,
                        token_position=token_position,
                        input_text=text
                    )
                    if display_plot:
                        figures['heatmap'].show()
                
                if plot_type in ['logit_lens', 'both']:
                    figures['logit_lens'] = plot_enhanced_logit_lens(
                        model_outputs=results,
                        tokenizer=tokenizer,
                        input_text=text,
                        position=token_position
                    )
                    if display_plot:
                        figures['logit_lens'].show()
                    
                if plot_type in ['expert', 'both']:
                    figures['expert'] = plot_expert_and_final_predictions(
                        results,
                        tokenizer=tokenizer,
                        position=token_position,
                        title=f"Expert Contributions and Predictions for: '{text}'"
                    )
                    if display_plot:
                        figures['expert'].show()
                    
                return figures, results, None
            
            except ValueError as e:
                print(f"Cache miss: {str(e)}")
            except Exception as e:
                print(f"Error loading cache: {str(e)}")

        # Initialize the lens analyzer
        lens = MOELens(model, tokenizer)
        
        # Process the input text
        input_ids = tokenizer(text, return_tensors="pt").input_ids.to(model.device)
        results = lens.analyze_text(input_ids)
        
        # Create visualizations based on plot_type
        if plot_type in ['heatmap', 'both']:
            figures['heatmap'] = visualize_layer_analysis(
                tokenizer=tokenizer,
                results=results,
                token_position=token_position,
                input_text=text
            )
            if display_plot:
                figures['heatmap'].show()
            
        if plot_type in ['logit_lens', 'both']:
            figures['logit_lens'] = plot_enhanced_logit_lens(
                model_outputs=results,
                tokenizer=tokenizer,
                input_text=text,
                position=token_position
            )
            if display_plot:
                figures['logit_lens'].show()
            
        if plot_type in ['expert', 'both']:
            figures['expert'] = plot_expert_and_final_predictions(
                results,
                tokenizer=tokenizer,
                position=token_position,
                title=f"Expert Contributions and Predictions for: '{text}'"
            )
            if display_plot:
                figures['expert'].show()
            
        # Cache the results with metadata
        cache_data = {
            'results': results,
            'input_text': text
        }
        prompt_number = save_prompt_cache(domain, cache_data, cache_dir)
        print(f"Results cached as prompt #{prompt_number}")
        
        return figures, results, None
        
    except Exception as e:
        print(f"Error in analyze_dataset: {str(e)}")
        import traceback
        traceback.print_exc()
        return None, None, str(e)

    finally:
        # Clean up
        if 'lens' in locals():
            lens.remove_hooks()

----

In [123]:
text = "capital of japan is"

fig, result, _ = analyze_dataset(
    text=text,
    model=model,
    tokenizer=tokenizer,
    token_position=5,  # Analyze full sequence
    domain="test1",
    force_recompute=False,
    plot_type='expert',
    display_plot=True,
)

Using cached results from prompt #1


In [46]:
txt = "capital of usa is"

inputs = tokenizer(txt, return_tensors="pt")
outputs = model.forward(**inputs.to(model.device))

print(outputs.logits.shape)


i = 4

print(f"i: {i}")
x = outputs.logits[0, i]
# so basically over hee gpt/these langauge models put a prob over the whole vocab
print(x) 
y = x.argmax()
# to get the highest prob token
print(f"y: {y}") 

print(tokenizer.decode(y))




print(model.model.layers[3].mlp.shared_experts.gate_proj.weight.shape)
print(model.model.layers[3].mlp.experts[63].gate_proj.weight.shape)


torch.Size([1, 5, 102400])
i: 4
tensor([18.1094, 19.5781, 16.3906,  ...,  2.3262,  2.4473,  2.3262],
       grad_fn=<SelectBackward0>)
y: 8196
 Washington
torch.Size([2816, 2048])
torch.Size([1408, 2048])
