In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from sklearn.decomposition import PCA
import plotly.express as px
import plotly.graph_objects as go
import json
import numpy as np
import matplotlib.pyplot as plt
import os
from collections import defaultdict

In [2]:
def load_model(model_name):
    # print(f"Default torch dtype: {torch.get_default_dtype()}")
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        # use_flash_attention_2=True,
    )
    # Print model parameters dtype
    # first_param = next(model.parameters())
    # print(f"Model parameters dtype: {first_param.dtype}")
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model("deepseek-ai/deepseek-moe-16b-base")

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

In [3]:
def generate_text(model, tokenizer, input_text, max_length=70):
    input_ids = tokenizer.encode(input_text, return_tensors="pt")
    
    # outputs = model(input_ids, output_router_logits=True)
    outputs = model(input_ids,)
    
    output = model.generate(input_ids, max_length=max_length, use_cache=True, pad_token_id=tokenizer.pad_token_id)
    print("generated text :")
    for token in output[0]:
        print(tokenizer.decode(token, skip_special_tokens=True), end='', flush=True)
    print()
    
    return outputs

In [5]:
def get_router_logits(model, input_ids):
    """get router logits from model forward pass"""
    # Get hidden states for each layer
    outputs = model(input_ids, output_hidden_states=True)
    hidden_states = outputs.hidden_states
    
    # Initialize tensor to store all router logits
    num_layers = len(model.model.layers)
    first_layer = next(layer for layer in model.model.layers if hasattr(layer.mlp, 'gate'))
    num_experts = first_layer.mlp.gate.weight.shape[0]
    bsz, seq_len, hidden_dim = hidden_states[0].shape
    router_logits = torch.zeros((num_layers, bsz * seq_len, num_experts))
    
    # For each MoE layer, calculate router logits
    layer_count = 0
    for layer_idx, layer in enumerate(model.model.layers):
        # Check if this layer uses MoE
        if hasattr(layer.mlp, 'gate'):
            # Get hidden states before MoE layer
            layer_hidden = hidden_states[layer_idx]
            
            # Calculate router logits using layer's gate
            flat_hidden = layer_hidden.view(-1, hidden_dim)
            logits = torch.nn.functional.linear(flat_hidden, layer.mlp.gate.weight)
            
            # Store logits in the tensor
            router_logits[layer_count] = logits
            layer_count += 1
                
    router_logits = router_logits[:layer_count]  
    print(f"router logits shape : {router_logits.shape}")
    return router_logits

def get_last_token_router_probs(router_logits, layer_idx):
    """Get router probabilities for the last token in a specified layer"""
    layer_logits = router_logits[layer_idx]  # Shape: [sequence_length, num_experts]
    last_token_logits = layer_logits[-1]  # get last token logits
    routing_probs = torch.nn.functional.softmax(last_token_logits, dim=-1)
    return routing_probs

def topk(router_probs, k):
    """zero out all components except top k router probabilities"""
    values, indices = torch.topk(router_probs, k)
    zeroed_probs = torch.zeros_like(router_probs)
    zeroed_probs[indices] = values
    return zeroed_probs

In [3]:
def get_router_probs_matrix(model, prompts, k=8):
    """get router probs matrix for all tokens and last token for multiple inputs across all layers
    shape of all_router_logits : [num_prompts, num_layers-1, max_seq_len, num_experts]
    shape of last_token_prob_matrix : [num_layers-1, num_prompts, num_experts]
    """
    num_prompts = len(prompts)
    num_experts = 64
    num_layers = 28
    last_token_prob_matrix = torch.zeros((num_layers-1, num_prompts, num_experts))
    
    # Store router logits for all tokens
    all_router_logits = []
    max_seq_len = 0
    
    # First pass to get max sequence length
    for prompt in prompts:
        input_ids = tokenizer.encode(prompt, return_tensors="pt")
        seq_len = input_ids.shape[1]
        max_seq_len = max(max_seq_len, seq_len)
    
    # Get router probs for each prompt
    for i, prompt in enumerate(prompts):
        input_ids = tokenizer.encode(prompt, return_tensors="pt")
        router_logits = get_router_logits(model, input_ids)
        
        # Pad router logits to max sequence length
        curr_seq_len = router_logits.shape[1]
        if curr_seq_len < max_seq_len:
            padding = torch.zeros((router_logits.shape[0], max_seq_len - curr_seq_len, num_experts))
            router_logits = torch.cat([router_logits, padding], dim=1)
            
        all_router_logits.append(router_logits)
        
        # Get probs for each layer's last token
        for layer_idx in range(1,num_layers):
            probs = get_last_token_router_probs(router_logits, layer_idx-1)
            top_probs = topk(probs, k=k)
            last_token_prob_matrix[layer_idx-1, i] = top_probs
            
    # Stack all router logits into single tensor
    all_router_logits = torch.stack(all_router_logits)

    return last_token_prob_matrix, all_router_logits


def get_last_token(prompt):
    """get the last token of a prompt using the tokenizer"""
    tokens = tokenizer.encode(prompt)
    last_token = tokenizer.decode([tokens[-1]])
    return last_token

In [7]:
def pca_visualize(prompts):
    """perform PCA visualization on router probabilities for a list of prompts"""
    
    # Get router probability matrix using helper functions
    last_token_prob_matrix, all_router_logits = get_router_probs_matrix(model, prompts, k=64)
    router_prob_matrix_np = last_token_prob_matrix.detach().numpy()

    # Perform PCA
    pca = PCA(n_components=3)
    pca_result = pca.fit_transform(router_prob_matrix_np)

    print("\nPCA results:")
    print(f"explained variance ratio: {pca.explained_variance_ratio_}")
    print(f"cumulative explained variance: {pca.explained_variance_ratio_.sum():.3f}")
    print("\nPCA transformed data shape:", pca_result.shape)

    # Get last token of each prompt
    last_tokens = [get_last_token(prompt) for prompt in prompts]

    # Create 3D scatter plot
    fig = go.Figure(data=[go.Scatter3d(
        x=pca_result[:, 0],
        y=pca_result[:, 1],
        z=pca_result[:, 2],
        mode='markers+text',
        text=last_tokens,
        textposition="top center",
        marker=dict(
            size=10,
            opacity=0.8
        )
    )])

    # Update layout for 3D
    fig.update_layout(
        title='3D PCA of Router Probabilities',
        scene=dict(
            xaxis_title='first principal component',
            yaxis_title='second principal component', 
            zaxis_title='third principal component',
            xaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
            yaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
            zaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray')
        ),
        width=1000,
        height=800,
        showlegend=False
    )

    fig.show()

In [8]:
def compute_cosine_similarity(router_prob_matrix, idx1, idx2):

    # Get the probability vectors for the two tokens
    vec1 = router_prob_matrix[idx1]
    vec2 = router_prob_matrix[idx2]
    
    # Check if inputs are already torch tensors
    if not isinstance(vec1, torch.Tensor):
        vec1 = torch.from_numpy(vec1).float()
    if not isinstance(vec2, torch.Tensor):
        vec2 = torch.from_numpy(vec2).float()
    
    # Compute cosine similarity using torch.nn.functional
    cos_sim = torch.nn.functional.cosine_similarity(vec1.unsqueeze(0), vec2.unsqueeze(0))
    
    return cos_sim.item()

# print(f"Cosine similarity between tokens 2 and 3: {compute_cosine_similarity(router_prob_matrix, 2, 3):.4f}")
# print(f"Cosine similarity between tokens 0 and 5: {compute_cosine_similarity(router_prob_matrix, 0, 5):.4f}")
# print(f"Cosine similarity between tokens 10 and 15: {compute_cosine_similarity(router_prob_matrix, 10, 15):.4f}")


In [4]:
def prepare_prompts_from_txt(txt_file_path,  domain = 'english', output_path=f'english.json'):
    """ read prompts from a txt file and save them in json format. """
    
    with open(txt_file_path, 'r', encoding='utf-8') as f:
        prompts = [line.strip() for line in f.readlines() if line.strip()]
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump({f"{domain}": prompts}, f, indent=4)
        
    print(f"{domain} prompts saved to {output_path}")
    return prompts

def parse_code_blocks(txt_file_path, output_path='code.json', domain='code'):
    """parse code blocks between ``` markers from a text file and save them in json format."""
    code_blocks = []
    current_block = []
    in_block = False
    
    with open(txt_file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip().startswith('```'):
                if in_block:
                    # Current block is complete, save it and start new block
                    if current_block:
                        code_blocks.append('\n'.join(current_block))
                    current_block = []
                # Always start a new block since ``` only indicates start
                in_block = True
                current_block = []
            elif in_block:
                # Add line to current block
                current_block.append(line.rstrip())
    
    # Save final block if exists
    if current_block:
        code_blocks.append('\n'.join(current_block))
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump({domain: code_blocks}, f, indent=4)
        
    print(f"code blocks saved to {output_path}")
    return code_blocks

In [5]:
def load_prompts_from_json(json_file_path):
    """load prompts from a json file and return them as a list."""
    with open(json_file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    # get the first (and only) value from the dictionary
    # since the json structure is {"domain": [prompts]}
    prompts = list(data.values())[0]
    return prompts

def prepare_multi_domain_prompts(domain_files, output_path='all_domain_prompts.json'):
    """
    prepare a json file containing prompts from multiple domains.
    
    args:
        domain_files: Dict mapping domain names to lists of tuples (file_path, parser_func)
            where parser_func is a function that takes a file path and returns a list of prompts
            
    example:
        domain_files = {
            'code': [('code.txt', parse_code_blocks)], 
            'english': [('english.txt', prepare_prompts_from_txt)]
        }
    """
    all_prompts = {}
    
    for domain, file_list in domain_files.items():
        domain_prompts = []
        for _, prompts in file_list:
            # Use load_prompts_from_json if prompts is a dict
            if not isinstance(prompts, list):
                prompts = load_prompts_from_json(prompts)
            domain_prompts.extend(prompts)
                
        all_prompts[domain] = domain_prompts
        
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(all_prompts, f, indent=4)
        
    print(f"all domain prompts saved to {output_path}")
    return all_prompts

In [6]:
def convert_all_to_list(all_prompts):
    """
    combines prompts from all domains into a single list of domain-specific prompt lists.
    returns a list in the format [[domain1_prompts], [domain2_prompts], ...].
    """
    # Create list of domain-specific prompt lists
    combined_prompts = [
        prompts for prompts in all_prompts.values()
    ]
        
    total_prompts = sum(len(prompts) for prompts in combined_prompts)
    print(f"total prompts: {total_prompts}")
    print(f"prompts per domain:")
    for domain, prompts in zip(all_prompts.keys(), combined_prompts):
        print(f"  {domain}: {len(prompts)}")
        
    return combined_prompts

In [10]:
def pca_visualize_all_domains(router_prob_matrix,combined_prompts, layer_id=None):
    """
    perform PCA visualization on router probability matrix for all domains together.
    colors points based on which domain list they came from.
    
    args:
        combined_prompts: List of lists of prompts, where each inner list represents a domain
                         [code_prompts, english_prompts, french_prompts, ...]
        layer_idx: if None, visualizes all layers but has to be >0 ALWAYS if used
    """
    # Flatten prompts list while tracking domain indices
    all_prompts = []
    domain_colors = []
    
    # Generate enough distinct colors for all domains
    num_domains = len(combined_prompts)
    colors = plt.cm.rainbow(np.linspace(0, 1, num_domains))
    
    for domain_idx, domain_prompts in enumerate(combined_prompts):
        all_prompts.extend(domain_prompts)
        domain_colors.extend([colors[domain_idx]] * len(domain_prompts))
    
    # If layer_idx is None, visualize all layers
    if layer_id is None:
        # num_layers = router_prob_matrix.shape[0]
        # print(f"num layers : {num_layers}")
        
        for layer_idx in range(1,28):
            print(f"layer {layer_idx} router probs shape : {router_prob_matrix[layer_idx-1].shape}")
            # Perform PCA for current layer
            pca = PCA(n_components=3)
            layer_probs = router_prob_matrix[layer_idx-1].detach().numpy()
            pca_result = pca.fit_transform(layer_probs)

            print(f"\nPCA results for layer {layer_idx}:")
            print(f"explained variance ratio: {pca.explained_variance_ratio_}")
            print(f"cumulative explained variance: {pca.explained_variance_ratio_.sum():.3f}")
            print("\nPCA transformed data shape:", pca_result.shape)

            # Get last tokens for each prompt
            last_tokens = [get_last_token(prompt) for prompt in all_prompts]

            # Create 3D scatter plot with domain-specific colors
            fig = go.Figure(data=[go.Scatter3d(
                x=pca_result[:, 0],
                y=pca_result[:, 1], 
                z=pca_result[:, 2],
                mode='markers+text',
                text=last_tokens,
                textposition="top center",
                marker=dict(
                    size=10,
                    opacity=0.8,
                    color=[f'rgb({int(c[0]*255)},{int(c[1]*255)},{int(c[2]*255)})' for c in domain_colors]
                )
            )])

            # Update layout for better 3D visualization
            fig.update_layout(
                title=f'PCA visualization of router probabilities across domains for layer {layer_idx}',
                scene=dict(
                    xaxis_title='PC1',
                    yaxis_title='PC2', 
                    zaxis_title='PC3'
                ),
                width=1000,
                height=800,
                showlegend=False
            )

            # Save figure as HTML and show plot
            os.makedirs('pca', exist_ok=True)
            fig.write_html(f'pca/pca_visualization_layer_{layer_idx}.html')
            fig.show()
    else:
        # Original single layer visualization
        layer_probs = router_prob_matrix[layer_id-1].detach().numpy()
        pca = PCA(n_components=3)
        pca_result = pca.fit_transform(layer_probs)

        print("\nPCA results:")
        print(f"explained variance ratio: {pca.explained_variance_ratio_}")
        print(f"cumulative explained variance: {pca.explained_variance_ratio_.sum():.3f}")
        print("\nPCA transformed data shape:", pca_result.shape)

        # Get last tokens for each prompt
        last_tokens = [get_last_token(prompt) for prompt in all_prompts]

        # Create 3D scatter plot with domain-specific colors
        fig = go.Figure(data=[go.Scatter3d(
            x=pca_result[:, 0],
            y=pca_result[:, 1], 
            z=pca_result[:, 2],
            mode='markers+text',
            text=last_tokens,
            textposition="top center",
            marker=dict(
                size=10,
                opacity=0.8,
                color=[f'rgb({int(c[0]*255)},{int(c[1]*255)},{int(c[2]*255)})' for c in domain_colors]
            )
        )])

        # Update layout for better 3D visualization
        fig.update_layout(
            title=f'PCA visualization of router probabilities across domains for layer {layer_id}',
            scene=dict(
                xaxis_title='PC1',
                yaxis_title='PC2', 
                zaxis_title='PC3'
            ),
            width=1000,
            height=800,
            showlegend=False
        )

        fig.show()
        return fig

In [None]:
def bar_graph_visualize(router_probs_matrix, layer_id=0, domain=0):
    total_tokens = router_probs_matrix.shape[1]  # 600 tokens total
    tokens_per_domain = total_tokens // 3  # 200 tokens per domain
    num_experts = router_probs_matrix.shape[2]  # 64 experts
    
    # Calculate start and end indices for this domain's tokens
    start_idx = domain * tokens_per_domain  # 0 for code, 200 for English, 400 for French
    end_idx = start_idx + tokens_per_domain
    # Get probabilities for specified layer and domain tokens
    domain_probs = router_probs_matrix[layer_id, start_idx:end_idx, :]

    # Initialize expert counts
    expert_counts = defaultdict(int)
    
    # For each token in this domain, get top 7 experts and increment their counts
    for token_probs in domain_probs:
        # print(f"token probs shape : {token_probs.shape}")
        top_experts = torch.topk(token_probs, k=7).indices
        for expert in top_experts:
            expert_counts[int(expert)] += 1
            
    # Calculate percentage of domain tokens that were routed to each expert
    expert_percentages = [expert_counts[i]/tokens_per_domain * 100 for i in range(num_experts)]
    
    # Create bar plot using plotly
    fig = go.Figure()
    
    fig.add_trace(go.Bar(
        x=list(range(num_experts)),
        y=expert_percentages,
        marker_color='red'
    ))
    
    domain_name = 'code' if domain==0 else 'English' if domain==1 else 'French'
    
    # Update layout
    fig.update_layout(
        title=f'Percentage of total tokens from {domain_name} domain routed to each expert (top 7 per token) for layer {layer_id}',
        xaxis_title='Expert',
        yaxis_title='% of domain tokens',
        yaxis=dict(range=[0, 100]),  # Set y-axis range from 0 to 100%
        xaxis_tickangle=-45,
        bargap=0.2,
        plot_bgcolor='white',
        showlegend=False
    )
    
    # Add grid
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128, 128, 128, 0.2)')
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128, 128, 128, 0.2)')
    
    return fig

In [8]:
prepare_prompts_from_txt('interp-data/engl-lit.txt', domain='english', output_path='english.json')
prepare_prompts_from_txt('interp-data/french.txt', domain='french', output_path='french.json')
parse_code_blocks('interp-data/code.txt', 'code.json', domain='code')

code_prompts = load_prompts_from_json(json_file_path='code.json')
print(f"total code prompts : {len(code_prompts)}")
english_prompts = load_prompts_from_json(json_file_path='english.json')
print(f'total english prompts : {len(english_prompts)}')


domain_files = {
    'code': [('interp-data/code.txt', parse_code_blocks(txt_file_path='interp-data/code.txt', output_path='code.json', domain='code'))],
    'english': [('interp-data/engl-lit.txt', prepare_prompts_from_txt(txt_file_path='interp-data/engl-lit.txt', output_path='english.json', domain='english'))],
    'french': [('interp-data/french.txt', prepare_prompts_from_txt(txt_file_path='interp-data/french.txt', output_path='french.json', domain='french'))]
}
all_prompts = prepare_multi_domain_prompts(domain_files, output_path='all_prompts.json')
print(f"total domains : {len(all_prompts)}")

# convert all prompts to a single list of domain-specific prompt lists
combined_prompts = convert_all_to_list(all_prompts)

# # # Create test version with only first 5 prompts per domain
# test_combined_prompts = [
#     domain_prompts[:2] for domain_prompts in combined_prompts
# ]

# print(test_combined_prompts[0]) # Print first 5 prompts from first domain
# print(f' total prompts in first domain : {len(test_combined_prompts[0])}')

english prompts saved to english.json
french prompts saved to french.json
code blocks saved to code.json
total code prompts : 200
total english prompts : 200
code blocks saved to code.json
english prompts saved to english.json
french prompts saved to french.json
all domain prompts saved to all_prompts.json
total domains : 3
total prompts: 600
prompts per domain:
  code: 200
  english: 200
  french: 200


### in PCA visualizations layers are 0 indexed but in the code they are 1 indexed because of the router probs matrix !!

In [18]:
# Get router probabilities for all prompts
combined_prompts = combined_prompts
k = 64
last_token_prob_matrix, all_router_logits = get_router_probs_matrix(model, prompts=[p for domain in combined_prompts for p in domain],k=k)

In [None]:
# input_text = "what is principal component analysis ?"
# outputs = generate_text(model, tokenizer, input_text)

# test_prompts = [
#     "The quick brown fox",
#     "1+1=",
#     "the grey cat",
#     "the grey elephant",
#     "2*8",
#     "def hello_world() : \n    print('hello world')",
#     "what is principal component analysis",
#     'what is capital of india',
#     'sqrt 16',
#     'void bubbleSort(int arr[], int n) {',
#     'def is_prime(n):',
#     'if n <= 1:',
#     'return False',
#     'for i in range(2, int(n**0.5) + 1):',
#     'if n % i == 0:',
#     'return False',
#     'return True',
#     "china",
#     "the united states of america",
#     "london",
#     "tokyo",
#     'paris'
# ]

# layer_idx = 0
# k = 64
# router_logits = get_router_logits(model, input_ids=tokenizer.encode(input_text, return_tensors="pt"))
# probs = get_last_token_router_probs(router_logits, layer_idx=layer_idx)
# print(f"router probs shape: {probs.shape}, sum: {probs.sum():.2f}")
# print(f'router probs : {probs}')
# top_probs = topk(probs, k=k) 
# print(f"top {k} probs : {top_probs}")
# print(f"top {k} probs sum : {top_probs.sum():.2f}")


# router_prob_matrix = get_router_probs_matrix(model, prompts = test_prompts,k=k)
# print(f"router probability matrix shape : {router_prob_matrix.shape}")
# print("\nrouter probs matrix :")
# print(router_prob_matrix[:])
# print(f"\nverify each row sums to 1 : {router_prob_matrix.sum(dim=1)}")


# last_tokens = [get_last_token(prompt) for prompt in test_prompts]
# print(last_tokens)