In [1]:
import torch
import numpy as np
import plotly.graph_objects as go
from collections import defaultdict
import torch.nn.functional as F
from typing import Dict, Tuple, List, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
from pathlib import Path
from pathlib import Path

In [2]:
def get_device():
    """Get the optimal available device"""
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        # Enable TF32 for better performance on Ampere GPUs (A100, A6000, etc)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        # Set memory allocation settings
        torch.cuda.empty_cache()
        # Enable CUDNN benchmarking for better performance
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")
    return device

In [3]:
model = AutoModelForCausalLM.from_pretrained(
    "deepseek-ai/deepseek-moe-16b-base",
    trust_remote_code=True,
    torch_dtype=torch.float16
)

tokenizer = AutoTokenizer.from_pretrained(
    "deepseek-ai/deepseek-moe-16b-base",
    trust_remote_code=True
)

# Verify the configuration
print(f"Number of shared experts: {model.config.n_shared_experts}")
print(f"Number of routed experts: {model.config.n_routed_experts}")
print(f"Number of experts per token: {model.config.num_experts_per_tok}")

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

Number of shared experts: 2
Number of routed experts: 64
Number of experts per token: 6


----

In [4]:

class MOELens:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.activations = defaultdict(dict)
        self.hook_handles = []
        self.setup_hooks()

    def setup_hooks(self):
        def get_gate_hook(layer_idx):
            def hook(module, inp, out):
                if isinstance(out, tuple):
                    topk_idx, topk_weight, _ = out
                    
                    # Get hidden states from input
                    hidden_states = inp[0]
                    batch_size, seq_len, hidden_dim = hidden_states.shape
                    
                    # Project to vocab space to get token predictions
                    with torch.no_grad():
                        expert_outputs = {}
                        # Only process the expert indices that were actually selected
                        unique_experts = torch.unique(topk_idx)
                        for expert_idx in unique_experts:
                            expert_idx = expert_idx.item()
                            if hasattr(self.model.model.layers[layer_idx].mlp, 'experts'):
                                expert = self.model.model.layers[layer_idx].mlp.experts[expert_idx]
                                expert_output = expert(hidden_states.view(-1, hidden_dim))
                                logits = self.model.lm_head(expert_output)
                                top_tokens = torch.topk(logits, k=5, dim=-1)
                                expert_outputs[expert_idx] = {
                                    'token_ids': top_tokens.indices,
                                    'probs': torch.softmax(top_tokens.values, dim=-1)
                                }

                    self.activations[f'layer_{layer_idx}'] = {
                        'router_weights': topk_weight.detach(),
                        'router_indices': topk_idx.detach(),
                        'expert_outputs': expert_outputs
                    }
            return hook

        def get_shared_expert_hook(layer_idx):
            def hook(module, inp, out):
                x = inp[0]
                batch_size, seq_len, hidden_dim = x.shape

                # Compute gate and up projections
                gate_proj = module.gate_proj(x)  # [batch, seq_len, 2816]
                up_proj = module.up_proj(x)      # [batch, seq_len, 2816]
                act = module.act_fn(gate_proj) * up_proj

                # Split into two experts
                expert0_act = act[..., :1408]
                expert1_act = act[..., 1408:]

                # Project to vocabulary
                with torch.no_grad():
                    # Process first expert
                    expert0_out = module.down_proj(
                        F.pad(expert0_act, (0, 1408))  # Pad to match full width
                    )
                    logits0 = self.model.lm_head(expert0_out)
                    top_tokens0 = torch.topk(logits0, k=5, dim=-1)

                    # Process second expert
                    expert1_out = module.down_proj(
                        F.pad(expert1_act, (1408, 0))  # Pad to match full width
                    )
                    logits1 = self.model.lm_head(expert1_out)
                    top_tokens1 = torch.topk(logits1, k=5, dim=-1)

                self.activations[f'layer_{layer_idx}']['shared_experts'] = {
                    'expert0': {
                        'token_ids': top_tokens0.indices,
                        'probs': torch.softmax(top_tokens0.values, dim=-1),
                        'weight': 1.0
                    },
                    'expert1': {
                        'token_ids': top_tokens1.indices,
                        'probs': torch.softmax(top_tokens1.values, dim=-1),
                        'weight': 1.0
                    }
                }
            return hook

        # Add hooks for gates and shared experts
        for i, layer in enumerate(self.model.model.layers):
            if hasattr(layer.mlp, 'gate'):
                handle = layer.mlp.gate.register_forward_hook(get_gate_hook(i))
                self.hook_handles.append(handle)
            if hasattr(layer.mlp, 'shared_experts'):
                handle = layer.mlp.shared_experts.register_forward_hook(get_shared_expert_hook(i))
                self.hook_handles.append(handle)

        # Add hook to capture layer outputs and compute logits
        for layer_idx, layer in enumerate(self.model.model.layers):
            def make_layer_hook(layer_idx):
                def hook(module, inputs, outputs):
                    hidden_states = outputs[0] if isinstance(outputs, tuple) else outputs
                    
                    # Apply layer normalization
                    if hasattr(self.model.model, 'norm'):
                        hidden_states = self.model.model.norm(hidden_states)
                    
                    # Compute logits and top predictions for this layer
                    with torch.no_grad():
                        logits = self.model.lm_head(hidden_states)
                        top_values, top_indices = torch.topk(logits, k=5, dim=-1)
                        probs = torch.softmax(top_values, dim=-1)
                        
                        layer_predictions = {}
                        for pos in range(logits.shape[1]):
                            pos_indices = top_indices[0, pos]  # Take first batch
                            pos_probs = probs[0, pos]         # Take first batch
                            
                            layer_predictions[pos] = {
                                'token_ids': [idx.item() for idx in pos_indices],
                                'probs': [prob.item() for prob in pos_probs],
                                'tokens': [self.tokenizer.decode([idx.item()]) for idx in pos_indices]
                            }
                            
                        self.activations[f'layer_{layer_idx}']['layer_predictions'] = layer_predictions
                        
                return hook
            
            handle = layer.register_forward_hook(make_layer_hook(layer_idx))
            self.hook_handles.append(handle)

    def analyze_text(self, input_ids: torch.Tensor) -> dict:
        """
        Create a combined analysis showing expert contributions and predictions at each layer.
        """
        self.activations.clear()
        final_logits = None
        
        def get_final_output_hook(module, inp, out):
            nonlocal final_logits
            final_logits = out

        # Add hook for final output
        final_hook = self.model.lm_head.register_forward_hook(get_final_output_hook)
        self.hook_handles.append(final_hook)

        # Forward pass
        with torch.no_grad():
            outputs = self.model(input_ids)

        # Process the outputs for each layer
        for layer_key, layer_data in self.activations.items():
            if not layer_key.startswith('layer_'):
                continue
                
            tokens_data = {}
            for pos in range(input_ids.shape[1]):
                expert_outputs = []
                
                # Process routed experts
                if 'expert_outputs' in layer_data:
                    router_indices = layer_data['router_indices'][pos]
                    router_weights = layer_data['router_weights'][pos]
                    
                    for idx, expert_idx in enumerate(router_indices):
                        expert_idx = expert_idx.item()
                        expert_data = layer_data['expert_outputs'].get(expert_idx)
                        if expert_data is not None:
                            weight = router_weights[idx].item()
                            token_ids = expert_data['token_ids'][pos]
                            probs = expert_data['probs'][pos]
                            
                            expert_outputs.append({
                                "expert_id": expert_idx,
                                "weight": weight,
                                "top_tokens": [
                                    (self.tokenizer.decode([tid.item()]), prob.item())
                                    for tid, prob in zip(token_ids, probs)
                                ],
                                "expert_type": "routed"
                            })
                
                # Process shared experts
                if 'shared_experts' in layer_data:
                    for expert_key in ['expert0', 'expert1']:
                        expert_data = layer_data['shared_experts'][expert_key]
                        token_ids = expert_data['token_ids'][0, pos]  # Take first batch
                        probs = expert_data['probs'][0, pos]         # Take first batch
                        
                        expert_outputs.append({
                            'expert_id': f'S{expert_key[-1]}',
                            'weight': expert_data['weight'],
                            'top_tokens': [
                                (self.tokenizer.decode([tid.item()]), prob.item())
                                for tid, prob in zip(token_ids, probs)
                            ],
                            'expert_type': 'shared'
                        })
                
                tokens_data[f"token_{pos}"] = {
                    "position": pos,
                    "expert_outputs": expert_outputs,
                    "layer_predictions": layer_data.get('layer_predictions', {}).get(pos, {})
                }
                
            self.activations[layer_key]["tokens"] = tokens_data

        # Add final predictions
        if final_logits is not None:
            final_predictions = {}
            for pos in range(input_ids.shape[1]):
                logits = final_logits[0, pos]  # Take first batch
                top_values, top_indices = torch.topk(logits, k=5)
                probs = torch.softmax(top_values, dim=-1)
                
                final_predictions[pos] = {
                    'token_ids': [idx.item() for idx in top_indices],
                    'probs': [prob.item() for prob in probs],
                    'tokens': [self.tokenizer.decode([idx.item()]) for idx in top_indices]
                }
            self.activations['final_predictions'] = final_predictions

        return dict(self.activations)

    def remove_hooks(self):
        for handle in self.hook_handles:
            handle.remove()
        self.hook_handles = []

In [5]:
def plot_expert_and_final_predictions(model_outputs, tokenizer, position=None, title=None):
    """
    Enhanced visualization with exact column alignment matching the original layout.
    """
    # Define column groups
    prediction_keys = [f'Prediction-{i+1}' for i in range(5)]  # 5 prediction columns
    blank_column = [' ']  
    blank_row = [' ']
    routed_expert_ids = [f'R{i}' for i in range(64)]          # 64 routed expert columns
    shared_expert_ids = ['S0', 'S1']                          # 2 shared expert columns
    
    # Original positioning with blank column after predictions
    all_expert_ids = prediction_keys + blank_column + routed_expert_ids + shared_expert_ids
    expert_tokens = defaultdict(dict)
    
    # Process layers in order
    layer_keys = sorted([k for k in model_outputs.keys() if k.startswith('layer_')],
                       key=lambda x: int(x.split('_')[1]))
    layers = []

    # Process data (same as before)
    for layer_key in layer_keys:
        layer_num = int(layer_key.split('_')[1])
        layer_data = model_outputs[layer_key]
        
        token_data = None
        for data in layer_data.get("tokens", {}).values():
            if data["position"] == position:
                token_data = data
                break
                
        if not token_data:
            continue
            
        layers.append(layer_num)

        # Process experts and predictions
        for expert_id in all_expert_ids:
            if expert_id == 'blank':
                continue  # Skip processing for blank column
            elif expert_id.startswith('R'):
                expert_num = int(expert_id[1:])
                for expert_output in token_data.get("expert_outputs", []):
                    if expert_output["expert_type"] == "routed" and expert_output["expert_id"] == expert_num:
                        expert_tokens[(layer_num, expert_id)] = {
                            'weight': expert_output["weight"],
                            'tokens': expert_output["top_tokens"]
                        }
            elif expert_id.startswith('S'):
                for expert_output in token_data.get("expert_outputs", []):
                    if expert_output["expert_type"] == "shared" and expert_output["expert_id"] == expert_id:
                        expert_tokens[(layer_num, expert_id)] = {
                            'weight': expert_output["weight"],
                            'tokens': expert_output["top_tokens"]
                        }
            elif expert_id.startswith('Prediction'):
                if 'layer_predictions' in token_data:
                    pred_idx = int(expert_id.split('-')[1]) - 1
                    if pred_idx < len(token_data['layer_predictions'].get('tokens', [])):
                        token = token_data['layer_predictions']['tokens'][pred_idx]
                        prob = token_data['layer_predictions']['probs'][pred_idx]
                        expert_tokens[(layer_num, expert_id)] = {
                            'weight': prob,
                            'tokens': [(token, prob)]
                        }

    # Add final predictions
    if 'final_predictions' in model_outputs and position in model_outputs['final_predictions']:
        final_preds = model_outputs['final_predictions'][position]
        final_layer = max(layers) + 1
        layers.append(final_layer)
        
        for i, (token, prob) in enumerate(zip(final_preds['tokens'], final_preds['probs'])):
            pred_key = f'Prediction-{i+1}'
            expert_tokens[(final_layer, pred_key)] = {
                'weight': prob,
                'tokens': [(token, prob)]
            }

    # Create visualization matrices
    layers = sorted(layers)
    weight_matrix = np.zeros((len(layers), len(all_expert_ids)))
    hover_text = [['' for _ in range(len(all_expert_ids))] for _ in range(len(layers))]
    text_matrix = [['' for _ in range(len(all_expert_ids))] for _ in range(len(layers))]

    # Fill matrices with data
    for i, layer in enumerate(layers):
        for j, expert_id in enumerate(all_expert_ids):
            if expert_id == 'blank':
                continue  # Skip blank column
            info = expert_tokens.get((layer, expert_id))
            if info and info['tokens']:
                weight = info['weight']
                weight_matrix[i, j] = weight
                
                top_token, top_prob = info['tokens'][0]
                text_matrix[i][j] = f'{top_token}<br>{top_prob:.3f}'
                
                hover_lines = [
                    f"Layer: {'Final Prediction' if layer == max(layers) else layer}",
                    f"{'Prediction' if 'Prediction' in expert_id else 'Expert'}: {expert_id}",
                    f"Weight: {weight:.3f}",
                    "Top tokens:"
                ]
                hover_lines.extend(f"{token}: {prob:.3f}" for token, prob in info['tokens'])
                hover_text[i][j] = "<br>".join(hover_lines)

    # Create figure with exact positioning
    fig = go.Figure(data=go.Heatmap(
        z=weight_matrix,
        x=all_expert_ids,
        y=[f"{'Final Prediction' if layer == max(layers) else f'Layer {layer}'}" for layer in layers],
        colorscale=[
            [0, 'rgba(0,0,0,0)'],
            [0.0001, 'rgb(68,1,84)'],
            [1, 'rgb(242, 121, 53)']
        ],
        text=text_matrix,
        texttemplate="%{text}",
        textfont={"size": 14, "color": "white"},
        hoverongaps=False,
        hoverinfo="text",
        hovertext=hover_text,
        showscale=True
    ))

    # Update layout with exact separator positions
    title = title or f"Expert Contributions and Predictions for Position {position}"
    fig.update_layout(
        title=dict(
            text=title,
            font=dict(size=24)
        ),
        xaxis_title=dict(
            text="Predictions | Routed Experts | Shared Experts",
            font=dict(size=20)
        ),
        yaxis_title=dict(
            text="Layers",
            font=dict(size=20)
        ),
        height=2000,
        width=5000,
        plot_bgcolor='black',
        paper_bgcolor='black',
        font=dict(
            color='white',
            size=16
        ),
        xaxis=dict(
            tickangle=45,
            tickfont=dict(size=16),
            showgrid=True,
            gridwidth=1,
            gridcolor='rgba(128, 128, 128, 0.2)',
            showline=True,
            linewidth=1,
            linecolor='rgba(128, 128, 128, 0.2)',
        ),
        yaxis=dict(
            autorange='reversed',
            tickfont=dict(size=16),
            showgrid=True,
            gridwidth=1,
            gridcolor='rgba(128, 128, 128, 0.2)',
            showline=True,
            linewidth=1,
            linecolor='rgba(128, 128, 128, 0.2)'
        ),
        margin=dict(l=120, r=120, t=160, b=120)
    )

    # Add visual separators at exact positions
    prediction_end = len(prediction_keys) - 0.5
    blank_end = prediction_end + 1
    routed_end = blank_end + len(routed_expert_ids)

    # Add separator lines
    fig.add_vline(x=prediction_end, line_width=2, line_color="white", line_dash="dash")
    fig.add_vline(x=routed_end, line_width=2, line_color="white", line_dash="dash")

    return fig

In [6]:
def analyze_dataset(
    text: str,
    model,
    tokenizer,
    token_position: int = None,
    display_plot: bool = True
):
    """
    Analyze text through the model and create visualizations.
    """
    try:
        figures = {}
        
        # Initialize the lens analyzer
        lens = MOELens(model, tokenizer)
        
        # Process the input text
        input_ids = tokenizer(text, return_tensors="pt").input_ids.to(model.device)
        results = lens.analyze_text(input_ids)
        
        # Create expert visualization
        figures['expert'] = plot_expert_and_final_predictions(
            results,
            tokenizer=tokenizer,
            position=token_position,
            title=f"Expert Contributions and Predictions for: '{text}'"
        )
        if display_plot:
            figures['expert'].show()
            
        return figures, results, None
        
    except Exception as e:
        print(f"Error in analyze_dataset: {str(e)}")
        import traceback
        traceback.print_exc()
        return None, None, str(e)

    finally:
        # Clean up
        if 'lens' in locals():
            lens.remove_hooks()

----

In [7]:
text = "The tide came in, washing away the footprints left behind in the"
x = len(tokenizer.encode(text))

fig, result, _ = analyze_dataset(
    text=text,
    model=model,
    tokenizer=tokenizer,
    token_position=x-1,
    display_plot=True,
)

mostly all i notice is some related tokens in the shared experts thats it. very less in the routed experts

In [8]:
txt = "A gentle breeze carried the scent of flowers and the faintest sound of a"

inputs = tokenizer(txt, return_tensors="pt")
outputs = model.forward(**inputs.to(model.device))

print(outputs.logits.shape)



for i in range(20):
    print(f"i: {i}")
    x = outputs.logits[0, i]
    # so basically over hee gpt/these langauge models put a prob over the whole vocab
    print(x) 
    y = x.argmax()
    # to get the highest prob token
    print(f"y: {y}") 

    print(tokenizer.decode(y))




print(model.model.layers[3].mlp.shared_experts.gate_proj.weight.shape)
print(model.model.layers[3].mlp.experts[63].gate_proj.weight.shape)


torch.Size([1, 15, 102400])
i: 0
tensor([ 6.9414,  9.6016, 12.0547,  ..., -5.5430, -5.7617, -5.3477],
       grad_fn=<SelectBackward0>)
y: 549
The
i: 1
tensor([ 9.2812, 10.0859,  8.0391,  ..., -2.4512, -2.1680, -3.0547],
       grad_fn=<SelectBackward0>)
y: 761
 new
i: 2
tensor([ 5.4336,  7.2422,  2.7422,  ..., -3.1406, -2.7793, -3.3203],
       grad_fn=<SelectBackward0>)
y: 29780
 reminder
i: 3
tensor([11.4531,  6.2773,  1.0264,  ..., -7.6445, -7.6875, -7.7695],
       grad_fn=<SelectBackward0>)
y: 33515
 blew
i: 4
tensor([11.6797, 10.1328,  5.5938,  ..., -1.5410, -1.7988, -1.9258],
       grad_fn=<SelectBackward0>)
y: 254
 the
i: 5
tensor([ 1.5850,  0.2852, -3.8477,  ..., -6.0820, -6.6914, -6.7461],
       grad_fn=<SelectBackward0>)
y: 31420
 scent
i: 6
tensor([18.8594, 16.0156, 11.1797,  ...,  1.3691,  1.0322,  0.8413],
       grad_fn=<SelectBackward0>)
y: 280
 of
i: 7
tensor([ 2.4883,  2.2871, -1.9854,  ..., -5.4648, -5.5391, -5.6758],
       grad_fn=<SelectBackward0>)
y: 254
 the


IndexError: index 15 is out of bounds for dimension 1 with size 15