In [2]:
from optim_hunter.llama_model import load_llama_model
from optim_hunter.model_utils import get_numerical_tokens
import torch
from einops import einsum
import seaborn as sns
import matplotlib.pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [None]:
model = load_llama_model()

In [None]:
def analyze_number_probabilities(model, prompt, numerical_tokens):
    # Convert numerical_tokens to a tensor if it's a dictionary
    if isinstance(numerical_tokens, dict):
        # Create a mapping from token indices to actual numbers
        token_to_number = {v: k for k, v in numerical_tokens.items()}
        
        # Sort tokens by their numerical values
        sorted_tokens = sorted(token_to_number.items(), key=lambda x: float(x[1]))
        token_indices = torch.tensor([idx for idx, _ in sorted_tokens], device=device)
        token_to_number = dict(sorted_tokens)
    else:
        token_indices = torch.tensor(numerical_tokens, device=device)
    
    # Run model with cache to get activations at each layer
    logits, cache = model.run_with_cache(prompt)
    
    # Move cache to the correct device
    cache.to("cuda:0")
    
    # Get accumulated residual streams for each layer with LayerNorm applied
    accumulated_resid, labels = cache.accumulated_resid(return_labels=True, apply_ln=True)
    
    # Convert accumulated_resid to half precision
    accumulated_resid = accumulated_resid.half()
    
    # Get probabilities for numerical tokens at each layer
    layer_logits = einsum(
        accumulated_resid,  # [layer, batch, pos, d_model]
        model.W_U.to(device),  # [d_model, vocab_size]
        "layer batch pos d_model, d_model vocab -> layer batch pos vocab"
    )
    
    # Convert to probabilities with softmax
    layer_probs = torch.softmax(layer_logits, dim=-1)
    # Extract probabilities just for numerical tokens
    number_probs = layer_probs[..., token_indices]
    
    # Average across batch and position dimensions
    number_probs = number_probs.mean(dim=(1, 2))  # This will give shape [layer, num_tokens]
    
    return number_probs, labels, token_to_number, token_indices

# Example usage:
prompt = "2 + 2 = "
numerical_tokens = get_numerical_tokens(model)

probs, layer_labels, token_to_number, token_indices = analyze_number_probabilities(model, prompt, numerical_tokens)

# Plotting
plt.figure(figsize=(12, 8))
sns.heatmap(probs.cpu().numpy(),
            xticklabels=[token_to_number[idx.item()] for idx in token_indices],
            yticklabels=layer_labels,
            cmap='viridis')
plt.xlabel('Numbers')
plt.ylabel('Layer')
plt.title('Average Probability Distribution over Numbers Across Layers')
plt.show()

In [None]:
def analyze_number_probabilities(model, prompt, numerical_tokens):
    # Convert numerical_tokens to a tensor if it's a dictionary
    if isinstance(numerical_tokens, dict):
        # Create a mapping from token indices to actual numbers
        token_to_number = {v: k for k, v in numerical_tokens.items()}
        
        # Filter tokens to include only numbers between 0 and 10
        filtered_tokens = {idx: num for idx, num in token_to_number.items() 
                         if num.strip().isdigit() and 0 <= float(num) <= 10}
        
        # Sort tokens by their numerical values
        sorted_tokens = sorted(filtered_tokens.items(), key=lambda x: float(x[1]))
        token_indices = torch.tensor([idx for idx, _ in sorted_tokens], device=device)
        token_to_number = dict(sorted_tokens)
    else:
        token_indices = torch.tensor(numerical_tokens, device=device)
    
    # Rest of the function remains the same...
    logits, cache = model.run_with_cache(prompt)
    cache.to("cuda:0")
    accumulated_resid, labels = cache.accumulated_resid(return_labels=True, apply_ln=True)
    accumulated_resid = accumulated_resid.half()
    
    layer_logits = einsum(
        accumulated_resid,
        model.W_U.to(device),
        "layer batch pos d_model, d_model vocab -> layer batch pos vocab"
    )
    
    layer_probs = torch.softmax(layer_logits, dim=-1)
    number_probs = layer_probs[..., token_indices]
    number_probs = number_probs.mean(dim=(1, 2))
    
    return number_probs, labels, token_to_number, token_indices

# Example usage:
prompt = "2 + 2 = "
numerical_tokens = get_numerical_tokens(model)

probs, layer_labels, token_to_number, token_indices = analyze_number_probabilities(model, prompt, numerical_tokens)

# Plotting
plt.figure(figsize=(12, 8))
sns.heatmap(probs.cpu().numpy(),
            xticklabels=[token_to_number[idx.item()] for idx in token_indices],
            yticklabels=layer_labels,
            cmap='viridis')
plt.xlabel('Numbers')
plt.ylabel('Layer')
plt.title('Average Probability Distribution over Numbers Across Layers')
plt.xticks(rotation=45)  # Rotate x-axis labels for better readability
plt.tight_layout()  # Adjust layout to prevent label cutoff
plt.show()

In [None]:
def analyze_number_probabilities(model, prompt, numerical_tokens):
    # Convert numerical_tokens to a tensor if it's a dictionary
    if isinstance(numerical_tokens, dict):
        # Create a mapping from token indices to actual numbers
        token_to_number = {v: k for k, v in numerical_tokens.items()}
        
        # Filter tokens to include only numbers between 0 and 10
        filtered_tokens = {idx: num for idx, num in token_to_number.items() 
                         if num.strip().isdigit() and 0 <= float(num) <= 10}
        
        # Sort tokens by their numerical values
        sorted_tokens = sorted(filtered_tokens.items(), key=lambda x: float(x[1]))
        token_indices = torch.tensor([idx for idx, _ in sorted_tokens], device=device)
        token_to_number = dict(sorted_tokens)
    else:
        token_indices = torch.tensor(numerical_tokens, device=device)
    
    # Run model with cache to get activations at each layer
    logits, cache = model.run_with_cache(prompt)
    cache.to("cuda:0")
    
    # Get residual stream at each layer
    resid_post = torch.stack([cache["resid_post", i] for i in range(model.cfg.n_layers)])
    
    # Apply layer norm
    resid_post = model.ln_final(resid_post)
    
    # Get probabilities for numerical tokens at each layer
    layer_logits = einsum(
        resid_post,  # [layer, batch, pos, d_model]
        model.W_U.to(device),  # [d_model, vocab_size]
        "layer batch pos d_model, d_model vocab -> layer batch pos vocab"
    )
    
    # Convert to probabilities with softmax
    layer_probs = torch.softmax(layer_logits, dim=-1)
    # Extract probabilities just for numerical tokens
    number_probs = layer_probs[..., token_indices]
    
    # Average across batch and position dimensions
    number_probs = number_probs.mean(dim=(1, 2))  # This will give shape [layer, num_tokens]
    
    # Create layer labels
    layer_labels = [f"Layer {i}" for i in range(model.cfg.n_layers)]
    
    return number_probs, layer_labels, token_to_number, token_indices

# Example usage:
prompt = "2 + 2 = "
numerical_tokens = get_numerical_tokens(model)

probs, layer_labels, token_to_number, token_indices = analyze_number_probabilities(model, prompt, numerical_tokens)

# Plotting
plt.figure(figsize=(12, 8))
sns.heatmap(probs.cpu().numpy(),
            xticklabels=[token_to_number[idx.item()] for idx in token_indices],
            yticklabels=layer_labels,
            cmap='viridis')
plt.xlabel('Numbers')
plt.ylabel('Layer')
plt.title('Average Probability Distribution over Numbers Across Layers')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

In [None]:
def plot_layer_group(probs, layer_labels, token_to_number, token_indices, start_layer, end_layer, title_suffix=""):
    plt.figure(figsize=(12, 2))
    sns.heatmap(probs[start_layer:end_layer].cpu().numpy(),
                xticklabels=[token_to_number[idx.item()] for idx in token_indices],
                yticklabels=layer_labels[start_layer:end_layer],
                cmap='viridis')
    plt.xlabel('Numbers')
    plt.ylabel('Layer')
    plt.title(f'Probability Distribution over Numbers - Layers {start_layer}-{end_layer-1} {title_suffix}')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

# Get the probabilities as before
prompt = "2 + 2 = "
numerical_tokens = get_numerical_tokens(model)
probs, layer_labels, token_to_number, token_indices = analyze_number_probabilities(model, prompt, numerical_tokens)

# Plot in groups of 8 layers
for i in range(0, 32, 1):
    plot_layer_group(probs, layer_labels, token_to_number, token_indices, i, i+1)