In [1]:
import sys
import torch as t
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import einops
from jaxtyping import Int, Float
import functools
from tqdm import tqdm
from IPython.display import display
from transformer_lens.hook_points import HookPoint
from transformer_lens import (
    utils,
    HookedTransformer,
    HookedTransformerConfig,
    FactoredMatrix,
    ActivationCache,
)
import circuitsvis as cv

from optim_hunter.plotly_utils import imshow, hist, plot_comp_scores, plot_logit_attribution, plot_loss_difference, line
from optim_hunter.utils import prepare_prompt, slice_dataset
from optim_hunter.sklearn_regressors import linear_regression, knn_regression, random_forest, baseline_average, baseline_last, baseline_random
from optim_hunter.datasets import get_dataset_friedman_2
from optim_hunter.data_model import create_comparison_data
from optim_hunter.model_utils import get_numerical_tokens, generate_linreg_tokens, run_and_cache_model_linreg_tokens_batched, run_and_cache_model_linreg_tokens
from optim_hunter.llama_model import load_llama_model
import logging
from typing import List, Tuple


# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

# Saves computation time, since we don't need it for the contents of this notebook
t.set_grad_enabled(False)

#device = t.device("cuda:0,1" if t.cuda.is_available() else "cpu")
device = t.device("cuda:0" if t.cuda.is_available() else "cpu")
# device = t.device("cpu")

MAIN = __name__ == "__main__"

In [2]:
model = load_llama_model()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Loaded pretrained model meta-llama/Meta-Llama-3-8B-Instruct into HookedTransformer


In [3]:
from optim_hunter.model_utils import get_numerical_tokens

numerical_tokens = get_numerical_tokens(model)

print("Found numerical tokens:", numerical_tokens)
print("Number of numerical tokens:", len(numerical_tokens))


# Check if digit 0 is in numerical tokens
has_zero = '0' in numerical_tokens
print("Zero token present:", has_zero)
if has_zero:
    print("Token ID for zero:", numerical_tokens['29'])


Found numerical tokens: {'399': 18572, '276': 16660, '192': 5926, '12': 717, '033': 13103, '83': 6069, '880': 19272, '861': 24963, '550': 13506, '736': 23969, '714': 23193, '87': 4044, '070': 17819, '698': 25169, '304': 12166, '010': 7755, '309': 15500, '808': 11770, '375': 12935, '271': 15828, '607': 21996, '947': 26511, '19': 777, '263': 15666, '856': 25505, '996': 23031, '32': 843, '848': 24951, '647': 22644, '94': 6281, '431': 19852, '072': 23439, '275': 14417, '214': 11584, '22': 1313, '979': 25476, '170': 8258, '114': 8011, '695': 24394, '138': 10350, '26': 1627, '364': 15951, '155': 9992, '656': 20744, '500': 2636, '664': 23888, '852': 24571, '458': 21209, '437': 18318, '452': 21098, '322': 15805, '030': 14649, '860': 18670, '287': 17897, '357': 18520, '831': 25009, '238': 13895, '969': 24792, '909': 21278, '887': 26058, '606': 20213, '682': 25178, '207': 12060, '844': 24344, '109': 7743, '735': 24939, '218': 13302, '893': 26088, '174': 11771, '733': 24865, '491': 21824, '391': 

In [4]:
if t.cuda.is_available():
    t.cuda.empty_cache()
# Example usage
(linreg_tokens, linreg_logits, linreg_caches, linreg_data_store) = run_and_cache_model_linreg_tokens_batched(
    model,
    seq_len=25,
    total_batch=25
)

model.clear_contexts()

# Move all tokens and logits to CPU
linreg_tokens = [tokens.to('cpu') for tokens in linreg_tokens]
linreg_logits = [logits.to('cpu') for logits in linreg_logits]

# Verify all datasets have the same comparison names
base_comparison_names = linreg_data_store[0]["comparison_names"]
all_match = all(dataset["comparison_names"] == base_comparison_names for dataset in linreg_data_store[1:])
assert all_match, "Mismatch in comparison names across datasets."

# Extract comparison names from the first dataset
token_pairs_names = base_comparison_names.copy()

# Extract token pairs across all datasets for each comparison
token_pairs = [
    t.stack([dataset["token_pairs"][i] for dataset in linreg_data_store])[0]
    for i in range(len(token_pairs_names))
]

logger.info(f"Number of comparisons: {len(token_pairs_names)}")
logger.info(f"Number of token_pairs: {len(token_pairs)}")

# Iterate over token pairs and generate plots
for i, token_pair in enumerate(token_pairs):
    logger.info(f"Processing comparison {i}: {token_pairs_names[i]}")
    token_pair = token_pair.to('cpu')

    print(f"Token pair shape: {token_pair.shape}")

    def logits_to_ave_logit_diff(
        logits_list: List[Float[Tensor, "batch seq d_vocab"]],
        answer_tokens: Float[Tensor, "batch 2"] = token_pair,
        per_prompt: bool = False
    ) -> Float[Tensor, "*batch"]:
        # Process each batch separately
        all_logit_diffs = []
        
        for logits in logits_list:
            final_logits = logits[:, -1, :]  # Take final position from each batch
            
            correct = answer_tokens[:, 0]
            incorrect = answer_tokens[:, 1]

            correct_logits = final_logits[t.arange(final_logits.size(0)), correct]
            incorrect_logits = final_logits[t.arange(final_logits.size(0)), incorrect]

            logit_diff = correct_logits - incorrect_logits
            all_logit_diffs.append(logit_diff)
        
        # Combine results
        combined_logit_diffs = t.cat(all_logit_diffs)
        
        if per_prompt:
            return combined_logit_diffs
        else:
            return combined_logit_diffs.mean()

    original_per_prompt_diff = logits_to_ave_logit_diff(linreg_logits, token_pair, per_prompt=True)
    original_average_logit_diff = logits_to_ave_logit_diff(linreg_logits, token_pair)

    # Initialize lists to store results from all caches
    all_logit_lens_diffs = []
    all_per_layer_diffs = []
    all_per_head_diffs = []

    # Process each cache
    for cache_idx, current_cache in enumerate(linreg_caches):
        logger.info(f"Processing cache {cache_idx}")
        
        # Move cache to GPU
        current_cache = current_cache

        # Retrieve final residual stream
        final_residual_stream = current_cache["resid_post", -1]
        final_token_residual_stream = final_residual_stream[:, -1, :]

        # Compute residual directions
        pair_residual_directions = model.tokens_to_residual_directions(token_pair.to('cpu'))
        correct_residual_directions, incorrect_residual_directions = pair_residual_directions.unbind(dim=1)
        logit_diff_directions = correct_residual_directions - incorrect_residual_directions

        def residual_stack_to_logit_diff(
            residual_stack: Float[Tensor, "... batch d_model"],
            cache: ActivationCache,
            logit_diff_directions: Float[Tensor, "batch d_model"] = logit_diff_directions,
        ) -> Float[Tensor, "..."]:
            scaled_residual_stream = cache.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1)
            logit_diff_directions = logit_diff_directions.to(dtype=scaled_residual_stream.dtype).to('cpu')

            batch_size = residual_stack.size(-2)
            avg_logit_diff = einops.einsum(
                scaled_residual_stream,
                logit_diff_directions,
                "... batch d_model, batch d_model -> ..."
            ) / batch_size
            return avg_logit_diff

        # Accumulate residuals
        accumulated_residual, labels = current_cache.accumulated_resid(
            layer=-1, incl_mid=True, pos_slice=-1, return_labels=True
        )
        logit_lens_logit_diffs = residual_stack_to_logit_diff(accumulated_residual, current_cache).half()
        all_logit_lens_diffs.append(logit_lens_logit_diffs)

        # Per layer analysis
        per_layer_residual, _ = current_cache.decompose_resid(layer=-1, pos_slice=-1, return_labels=True)
        per_layer_logit_diffs = residual_stack_to_logit_diff(per_layer_residual, current_cache)
        all_per_layer_diffs.append(per_layer_logit_diffs)


        # model = model.to('cpu')
        # if t.cuda.is_available():
        #     t.cuda.empty_cache()

        # # Per head analysis
        # model = model.to("cpu")
        # current_cache = current_cache.to("cuda:1")
        # per_head_residual, _ = current_cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)
        # per_head_residual = per_head_residual.to("cuda")
        # per_head_residual = einops.rearrange(
        #     per_head_residual,
        #     "(layer head) ... -> layer head ...",
        #     layer=model.cfg.n_layers
        # )
        # per_head_logit_diffs = residual_stack_to_logit_diff(per_head_residual, current_cache)
        # all_per_head_diffs.append(per_head_logit_diffs)

        # Clear GPU memory
        current_cache = current_cache.to('cpu')
        if t.cuda.is_available():
            t.cuda.empty_cache()

    # Average results across all caches
    avg_logit_lens_diffs = t.stack(all_logit_lens_diffs).mean(dim=0)
    avg_per_layer_diffs = t.stack(all_per_layer_diffs).mean(dim=0)
    # avg_per_head_diffs = t.stack(all_per_head_diffs).mean(dim=0)

    # Generate plots with averaged results
    line(
        avg_logit_lens_diffs,
        hovermode="x unified",
        title=f"Average Logit Difference From Accumulated Residual Stream for {token_pairs_names[i]}",
        labels={"x": "Layer", "y": "Logit Diff"},
        xaxis_tickvals=labels,
        width=800
    )

    line(
        avg_per_layer_diffs,
        hovermode="x unified",
        title=f"Average Per Layer Logit Difference for {token_pairs_names[i]}",
        labels={"x": "Layer", "y": "Logit Diff"},
        xaxis_tickvals=labels,
        width=800
    )

    # imshow(
    #     avg_per_head_diffs,
    #     labels={"x": "Head", "y": "Layer"},
    #     title=f"Average Logit Difference From Each Head for {token_pairs_names[i]}",
    #     width=600
    # )

: 

In [None]:
from optim_hunter.model_utils import get_numerical_tokens

numerical_tokens = get_numerical_tokens(model)

model.clear_contexts()

# Extract predictions across all datasets and convert to tensors
predictions_store = {
    predictor: t.tensor([
        dataset["predictions"][predictor] 
        for dataset in linreg_data_store
    ], dtype=t.float32)
    for predictor in linreg_data_store[0]["predictions"].keys()
}

# Print predictions
print("\nPredictions:")
print("-" * 50)
for predictor_name, predictor_tensor in predictions_store.items():
    print(f"\n{predictor_name}:")
    for i, pred in enumerate(predictor_tensor):
        print(f"  Sample {i}: {pred:.2f}")

# Process each predictor
for predictor_name, predictor_tensor in predictions_store.items():
    print(f"\nPredictor: {predictor_name}")
    print(f"Shape: {predictor_tensor.shape}")
    print(f"Values: {predictor_tensor}")

    numerical_token_ids = t.tensor(list(numerical_tokens.values()), device="cpu").expand(predictor_tensor.shape[0], -1)
    print(f"Numerical token IDs shape: {numerical_token_ids.shape}")


    def logits_to_numeric_mse(
        logits_list: List[Float[Tensor, "batch seq d_vocab"]],
        predicted_values: Float[Tensor, "batch"],
        numerical_tokens: dict,
        per_prompt: bool = False
    ) -> Float[Tensor, "*batch"]:
        '''
        Handles logits with potentially different sequence lengths
        '''
        # Get final token logits from each batch
        final_logits_list = [logits[:, -1, :] for logits in logits_list]
        # Combine final logits
        combined_final_logits = t.cat(final_logits_list, dim=0)
        
        numeric_ids = []
        numeric_values = []
        for digit, token_id in numerical_tokens.items():
            try:
                value = float(digit)
                numeric_ids.append(token_id)
                numeric_values.append(value)
            except ValueError:
                continue
                
        numeric_ids = t.tensor(numeric_ids, device=combined_final_logits.device)
        numeric_values = t.tensor(numeric_values, device=combined_final_logits.device)
        
        numeric_logits = combined_final_logits[:, numeric_ids]
        numeric_probs = t.softmax(numeric_logits, dim=-1)
        
        expected_values = (numeric_probs * numeric_values.unsqueeze(0)).sum(dim=-1)
        
        mse = (expected_values - predicted_values) ** 2
        
        if per_prompt:
            return mse
        else:
            return mse.mean()


    original_per_prompt_mse = logits_to_numeric_mse(linreg_logits, predictor_tensor, numerical_tokens, per_prompt=True)
    original_average_mse = logits_to_numeric_mse(linreg_logits, predictor_tensor, numerical_tokens)

    # Initialize lists to store results from all caches
    all_layerwise_mse = []
    all_expected_values = []

    # Process each cache
    for cache_idx, current_cache in enumerate(linreg_caches):
        logger.info(f"Processing cache {cache_idx}")
        
        # Move cache to GPU
        # current_cache = current_cache.to('cuda')

        # Compute residual directions for numeric MSE
        numeric_residual_directions = model.tokens_to_residual_directions(numerical_token_ids).to("cpu")

        def get_numeric_logits_per_layer(
            residual_stack: Float[Tensor, "layers batch d_model"],
            cache: ActivationCache,
            numeric_residual_directions: Float[Tensor, "batch num_tokens d_model"]
        ) -> Float[Tensor, "layers batch num_tokens"]:
            scaled_residual_stream = cache.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1)
            scaled_residual_stream = scaled_residual_stream.to("cpu")
            numeric_residual_directions = numeric_residual_directions.to(dtype=scaled_residual_stream.dtype)

            numeric_logits = einops.einsum(
                scaled_residual_stream,
                numeric_residual_directions,
                "layer b d, b t d -> layer b t"
            )
            return numeric_logits

        # Get accumulated residuals
        accumulated_residual, labels = current_cache.accumulated_resid(
            layer=-1, 
            incl_mid=True,
            pos_slice=-1,
            return_labels=True
        )

        numeric_logits_per_layer = get_numeric_logits_per_layer(
            accumulated_residual,
            current_cache,
            numeric_residual_directions
        )

        numeric_probs_per_layer = t.softmax(numeric_logits_per_layer, dim=-1)

        # Build numeric values tensor
        numeric_values_list = []
        for digit in numerical_tokens.keys():
            try:
                numeric_values_list.append(float(digit))
            except ValueError:
                pass
        numeric_values_tensor = t.tensor(numeric_values_list, device=numeric_probs_per_layer.device)
        numeric_values_reshaped = numeric_values_tensor.view(1, 1, -1)

        # Calculate expected values
        expected_values_per_layer = (numeric_probs_per_layer * numeric_values_reshaped).sum(dim=-1)
        all_expected_values.append(expected_values_per_layer)

        # Calculate MSE
        gold_values = predictor_tensor
        gold_values_expanded = gold_values.unsqueeze(0)
        layerwise_mse = (expected_values_per_layer - gold_values_expanded) ** 2
        layerwise_mse_mean = layerwise_mse.mean(dim=-1)
        all_layerwise_mse.append(layerwise_mse_mean)

        # Clear GPU memory
        # current_cache = current_cache.to('cpu')
        if t.cuda.is_available():
            t.cuda.empty_cache()

    # Average results across all caches
    avg_layerwise_mse = t.stack(all_layerwise_mse).mean(dim=0)
    avg_expected_values = t.stack(all_expected_values).mean(dim=0)

    # Generate plots with averaged results
    line(
        avg_layerwise_mse,
        hovermode="x unified",
        title=f"Average MSE vs. Gold Across Layers for {predictor_name}",
        labels={"x": "Layer", "y": "MSE"},
        xaxis_tickvals=labels,
        width=800
    )

    line(
        avg_expected_values.mean(dim=-1),  # Average across batch dimension
        hovermode="x unified",
        title=f"Average Expected Numeric Prediction Across Layers for {predictor_name}",
        labels={"x": "Layer", "y": "Predicted Value"},
        xaxis_tickvals=labels,
        width=800
    )

In [None]:
from optim_hunter.model_utils import get_numerical_tokens
import plotly.express as px
import pandas as pd

if t.cuda.is_available():
    t.cuda.empty_cache()
# Example usage
(linreg_tokens, linreg_logits, linreg_caches, linreg_data_store) = run_and_cache_model_linreg_tokens_batched(
    model,
    seq_len=25,
    total_batch=1
)


numerical_tokens = get_numerical_tokens(model)

model.clear_contexts()

# Extract predictions across all datasets and convert to tensors
predictions_store = {
    predictor: t.tensor([
        dataset["predictions"][predictor] 
        for dataset in linreg_data_store
    ], dtype=t.float32)
    for predictor in linreg_data_store[0]["predictions"].keys()
}

numeric_ids = []
numeric_values = []
numeric_labels = []
for digit, token_id in numerical_tokens.items():
    try:
        value = float(digit)
        numeric_ids.append(token_id)
        numeric_values.append(value)
        numeric_labels.append(str(value))
    except ValueError:
        continue

# Print predictions
print("\nPredictions:")
print("-" * 50)
for predictor_name, predictor_tensor in predictions_store.items():
    print(f"\n{predictor_name}:")
    for i, pred in enumerate(predictor_tensor):
        print(f"  Sample {i}: {pred:.2f}")

# Process each predictor
for predictor_name, predictor_tensor in predictions_store.items():
    print(f"\nPredictor: {predictor_name}")
    print(f"Shape: {predictor_tensor.shape}")
    print(f"Values: {predictor_tensor}")

    numerical_token_ids = t.tensor(list(numerical_tokens.values()), device="cpu").expand(predictor_tensor.shape[0], -1)
    print(f"Numerical token IDs shape: {numerical_token_ids.shape}")


    def logits_to_numeric_mse(
        logits_list: List[Float[Tensor, "batch seq d_vocab"]],
        predicted_values: Float[Tensor, "batch"],
        numerical_tokens: dict,
        per_prompt: bool = False
    ) -> Float[Tensor, "*batch"]:
        '''
        Handles logits with potentially different sequence lengths
        '''
        # Get final token logits from each batch
        final_logits_list = [logits[:, -1, :] for logits in logits_list]
        # Combine final logits
        combined_final_logits = t.cat(final_logits_list, dim=0)
        
        numeric_ids = []
        numeric_values = []
        numeric_labels = []
        for digit, token_id in numerical_tokens.items():
            try:
                value = float(digit)
                numeric_ids.append(token_id)
                numeric_values.append(value)
                numeric_labels.append(str(value))
            except ValueError:
                continue
                
        numeric_ids = t.tensor(numeric_ids, device=combined_final_logits.device)
        numeric_values = t.tensor(numeric_values, device=combined_final_logits.device)
        
        numeric_logits = combined_final_logits[:, numeric_ids]
        numeric_probs = t.softmax(numeric_logits, dim=-1)
        
        expected_values = (numeric_probs * numeric_values.unsqueeze(0)).sum(dim=-1)
        
        mse = (expected_values - predicted_values) ** 2
        
        if per_prompt:
            return mse
        else:
            return mse.mean()


    original_per_prompt_mse = logits_to_numeric_mse(linreg_logits, predictor_tensor, numerical_tokens, per_prompt=True)
    original_average_mse = logits_to_numeric_mse(linreg_logits, predictor_tensor, numerical_tokens)

    # Initialize lists to store results from all caches
    all_layerwise_mse = []
    all_expected_values = []
    all_probs_per_layer = []  # New list to store probabilities

    # Process each cache
    for cache_idx, current_cache in enumerate(linreg_caches):
        logger.info(f"Processing cache {cache_idx}")
        
        # Move cache to GPU
        # current_cache = current_cache.to('cuda')

        # Compute residual directions for numeric MSE
        numeric_residual_directions = model.tokens_to_residual_directions(numerical_token_ids).to("cpu")

        # Get accumulated residuals
        accumulated_residual, labels = current_cache.accumulated_resid(
            layer=-1, 
            incl_mid=True,
            pos_slice=-1,
            return_labels=True
        )

        numeric_logits_per_layer = get_numeric_logits_per_layer(
            accumulated_residual,
            current_cache,
            numeric_residual_directions
        )

        numeric_probs_per_layer = t.softmax(numeric_logits_per_layer, dim=-1)
        all_probs_per_layer.append(numeric_probs_per_layer)

        def get_numeric_logits_per_layer(
            residual_stack: Float[Tensor, "layers batch d_model"],
            cache: ActivationCache,
            numeric_residual_directions: Float[Tensor, "batch num_tokens d_model"]
        ) -> Float[Tensor, "layers batch num_tokens"]:
            scaled_residual_stream = cache.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1)
            scaled_residual_stream = scaled_residual_stream.to("cpu")
            numeric_residual_directions = numeric_residual_directions.to(dtype=scaled_residual_stream.dtype)

            numeric_logits = einops.einsum(
                scaled_residual_stream,
                numeric_residual_directions,
                "layer b d, b t d -> layer b t"
            )
            return numeric_logits

        # Get accumulated residuals
        accumulated_residual, labels = current_cache.accumulated_resid(
            layer=-1, 
            incl_mid=True,
            pos_slice=-1,
            return_labels=True
        )

        numeric_logits_per_layer = get_numeric_logits_per_layer(
            accumulated_residual,
            current_cache,
            numeric_residual_directions
        )

        numeric_probs_per_layer = t.softmax(numeric_logits_per_layer, dim=-1)

        # Build numeric values tensor
        numeric_values_list = []
        for digit in numerical_tokens.keys():
            try:
                numeric_values_list.append(float(digit))
            except ValueError:
                pass
        numeric_values_tensor = t.tensor(numeric_values_list, device=numeric_probs_per_layer.device)
        numeric_values_reshaped = numeric_values_tensor.view(1, 1, -1)

        # Calculate expected values
        expected_values_per_layer = (numeric_probs_per_layer * numeric_values_reshaped).sum(dim=-1)
        all_expected_values.append(expected_values_per_layer)

        # Calculate MSE
        gold_values = predictor_tensor
        gold_values_expanded = gold_values.unsqueeze(0)
        layerwise_mse = (expected_values_per_layer - gold_values_expanded) ** 2
        layerwise_mse_mean = layerwise_mse.mean(dim=-1)
        all_layerwise_mse.append(layerwise_mse_mean)

        # Clear GPU memory
        # current_cache = current_cache.to('cpu')
        if t.cuda.is_available():
            t.cuda.empty_cache()

    # Average results across all caches
    avg_layerwise_mse = t.stack(all_layerwise_mse).mean(dim=0)
    avg_expected_values = t.stack(all_expected_values).mean(dim=0)
    avg_probs_per_layer = t.stack(all_probs_per_layer).mean(dim=0)

# Create probability distribution heatmap
    avg_probs = avg_probs_per_layer.mean(dim=1)  # Average across batches
    df = pd.DataFrame(
        avg_probs.cpu().numpy(),
        columns=numeric_labels,
        index=labels
    )

    # Heatmap visualization
    fig = px.imshow(
        df,
        title=f"Token Probability Distribution Across Layers for {predictor_name}",
        labels=dict(x="Numeric Token", y="Layer", color="Probability"),
        aspect="auto",
        color_continuous_scale="viridis"
    )

    # Update layout
    fig.update_layout(
        width=1000,
        height=600,
        xaxis_tickangle=-45,
    )

    # Add target value marker
    target_value = predictor_tensor.mean().item()
    closest_token_idx = min(range(len(numeric_values_list)), 
                          key=lambda i: abs(numeric_values_list[i] - target_value))
    
    fig.add_annotation(
        x=closest_token_idx,
        y=-0.5,
        text=f"Target ≈ {target_value:.2f}",
        showarrow=True,
        arrowhead=1,
        yanchor="bottom"
    )

    fig.show()

# Calculate and plot entropy
    entropy = -(avg_probs_per_layer * t.log(avg_probs_per_layer + 1e-10)).sum(dim=-1)
    avg_entropy = entropy.mean(dim=-1)

    line(
        avg_entropy,
        hovermode="x unified",
        title=f"Distribution Entropy Across Layers for {predictor_name}",
        labels={"x": "Layer", "y": "Entropy"},
        xaxis_tickvals=labels,
        width=800
    )

    # Print top-k predictions per layer
    k = 5
    for layer_idx, layer_name in enumerate(labels):
        probs = avg_probs[layer_idx]
        values = t.tensor(numeric_values_list)
        
        top_probs, top_indices = t.topk(probs, k)
        
        print(f"\nLayer {layer_name} top {k} predictions:")
        for prob, idx in zip(top_probs, top_indices):
            print(f"Value: {numeric_values_list[idx]:.2f}, Probability: {prob:.3f}")

    # Original MSE and expected value plots
    line(
        avg_layerwise_mse,
        hovermode="x unified",
        title=f"Average MSE vs. Gold Across Layers for {predictor_name}",
        labels={"x": "Layer", "y": "MSE"},
        xaxis_tickvals=labels,
        width=800
    )

    line(
        avg_expected_values.mean(dim=-1),
        hovermode="x unified",
        title=f"Average Expected Numeric Prediction Across Layers for {predictor_name}",
        labels={"x": "Layer", "y": "Predicted Value"},
        xaxis_tickvals=labels,
        width=800
    )

In [None]:
# Move the probability distribution visualization outside the predictor loop
# First collect all target values
target_values = {}
for predictor_name, predictor_tensor in predictions_store.items():
    target_values[predictor_name] = predictor_tensor.mean().item()

# Sort numeric values and labels
sorted_indices = np.argsort(numeric_values_list)
sorted_numeric_values = [numeric_values_list[i] for i in sorted_indices]
sorted_numeric_labels = [numeric_labels[i] for i in sorted_indices]

# After the predictor loop ends, create the consolidated visualization
avg_probs = avg_probs_per_layer.mean(dim=1)  # Average across batches
# Reorder the columns according to sorted indices
avg_probs = avg_probs[:, sorted_indices]

df = pd.DataFrame(
    avg_probs.cpu().numpy(),
    columns=sorted_numeric_labels,
    index=labels
)

# Create heatmap
fig = px.imshow(
    df,
    title="Token Probability Distribution Across Layers",
    labels=dict(x="Numeric Token", y="Layer", color="Probability"),
    aspect="auto",
    color_continuous_scale="viridis"
)

# Update layout
fig.update_layout(
    width=1000,
    height=600,
    xaxis_tickangle=-45,
)

# Add target value markers for all predictors
colors = px.colors.qualitative.Set1  # Different colors for different predictors
for i, (predictor_name, target_value) in enumerate(target_values.items()):
    closest_token_idx = min(range(len(sorted_numeric_values)), 
                          key=lambda i: abs(sorted_numeric_values[i] - target_value))
    
    fig.add_annotation(
        x=closest_token_idx,
        y=-0.5 - (i * 0.5),  # Stack annotations vertically
        text=f"{predictor_name} ≈ {target_value:.2f}",
        showarrow=True,
        arrowhead=1,
        yanchor="bottom",
        arrowcolor=colors[i % len(colors)],
        font=dict(color=colors[i % len(colors)])
    )

fig.show()

In [None]:
# Move the probability distribution visualization outside the predictor loop
# First collect all target values
target_values = {}
for predictor_name, predictor_tensor in predictions_store.items():
    target_values[predictor_name] = predictor_tensor.mean().item()

# After the predictor loop ends, create the consolidated visualization
avg_probs = avg_probs_per_layer.mean(dim=1)  # Average across batches
# Select only layers 0 to 15
avg_probs = avg_probs[:16]  # Since indexing is 0-based
labels = labels[:16]  # Adjust labels accordingly

df = pd.DataFrame(
    avg_probs.cpu().numpy(),
    columns=numeric_labels,
    index=labels
)

# Create heatmap
fig = px.imshow(
    df,
    title="Token Probability Distribution Across Layers",
    labels=dict(x="Numeric Token", y="Layer", color="Probability"),
    aspect="auto",
    color_continuous_scale="viridis"
)

# Update layout
fig.update_layout(
    width=1000,
    height=600,
    xaxis_tickangle=-45,
)

# Add target value markers for all predictors
colors = px.colors.qualitative.Set1  # Different colors for different predictors
for i, (predictor_name, target_value) in enumerate(target_values.items()):
    closest_token_idx = min(range(len(numeric_values_list)), 
                          key=lambda i: abs(numeric_values_list[i] - target_value))
    
    fig.add_annotation(
        x=closest_token_idx,
        y=-0.5 - (i * 0.5),  # Stack annotations vertically
        text=f"{predictor_name} ≈ {target_value:.2f}",
        showarrow=True,
        arrowhead=1,
        yanchor="bottom",
        arrowcolor=colors[i % len(colors)],
        font=dict(color=colors[i % len(colors)])
    )

fig.show()

# Move the probability distribution visualization outside the predictor loop
# First collect all target values
target_values = {}
for predictor_name, predictor_tensor in predictions_store.items():
    target_values[predictor_name] = predictor_tensor.mean().item()

# After the predictor loop ends, create the consolidated visualization
avg_probs = avg_probs_per_layer.mean(dim=1)  # Average across batches
# Select only layers 0 to 15
avg_probs = avg_probs[16:]  # Since indexing is 0-based
labels = labels[16:]  # Adjust labels accordingly

df = pd.DataFrame(
    avg_probs.cpu().numpy(),
    columns=numeric_labels,
    index=labels
)

# Create heatmap
fig = px.imshow(
    df,
    title="Token Probability Distribution Across Layers",
    labels=dict(x="Numeric Token", y="Layer", color="Probability"),
    aspect="auto",
    color_continuous_scale="viridis"
)

# Update layout
fig.update_layout(
    width=1000,
    height=600,
    xaxis_tickangle=-45,
)

# Add target value markers for all predictors
colors = px.colors.qualitative.Set1  # Different colors for different predictors
for i, (predictor_name, target_value) in enumerate(target_values.items()):
    closest_token_idx = min(range(len(numeric_values_list)), 
                          key=lambda i: abs(numeric_values_list[i] - target_value))
    
    fig.add_annotation(
        x=closest_token_idx,
        y=-0.5 - (i * 0.5),  # Stack annotations vertically
        text=f"{predictor_name} ≈ {target_value:.2f}",
        showarrow=True,
        arrowhead=1,
        yanchor="bottom",
        arrowcolor=colors[i % len(colors)],
        font=dict(color=colors[i % len(colors)])
    )

fig.show()

In [None]:
# Move the probability distribution visualization outside the predictor loop
# First collect all target values
target_values = {}
for predictor_name, predictor_tensor in predictions_store.items():
    target_values[predictor_name] = predictor_tensor.mean().item()

# After the predictor loop ends, create the consolidated visualization
avg_probs = avg_probs_per_layer.mean(dim=1)  # Average across batches
# Select only layers 0 to 15
avg_probs = avg_probs[:]  # Since indexing is 0-based
labels = labels[:]  # Adjust labels accordingly

# Sort numeric values and create sorted indices
numeric_values_array = np.array([float(label) for label in numeric_labels])
sorted_indices = np.argsort(numeric_values_array)
numeric_labels_sorted = [numeric_labels[i] for i in sorted_indices]

# Reorder the probabilities according to sorted tokens
avg_probs_sorted = avg_probs[:, sorted_indices]

df = pd.DataFrame(
    avg_probs_sorted.cpu().numpy(),
    columns=numeric_labels_sorted,
    index=labels
)

# Create heatmap
fig = px.imshow(
    df,
    title="Token Probability Distribution Across Layers",
    labels=dict(x="Numeric Token", y="Layer", color="Probability"),
    aspect="auto",
    color_continuous_scale="viridis"
)

# Update layout
fig.update_layout(
    width=1000,
    height=600,
    xaxis_tickangle=-45,
)

# Add target value markers for all predictors
colors = px.colors.qualitative.Set1  # Different colors for different predictors
for i, (predictor_name, target_value) in enumerate(target_values.items()):
    closest_token_idx = min(range(len(numeric_values_array)), 
                          key=lambda i: abs(float(numeric_labels_sorted[i]) - target_value))
    
    fig.add_annotation(
        x=closest_token_idx,
        y=-0.5 - (i * 0.5),  # Stack annotations vertically
        text=f"{predictor_name} ≈ {target_value:.2f}",
        showarrow=True,
        arrowhead=1,
        yanchor="bottom",
        arrowcolor=colors[i % len(colors)],
        font=dict(color=colors[i % len(colors)])
    )

fig.show()

In [39]:
from transformer_lens import HookedTransformer
import torch
from einops import einsum
import matplotlib.pyplot as plt
import seaborn as sns

# Load model
# model = HookedTransformer.from_pretrained("gpt2-small")

# Create input with numerical tokens
prompt = "The number is 42. The next number is"
logits, cache = model.run_with_cache(prompt)

# Get accumulated residual streams for each layer
accum_resid, labels = cache.accumulated_resid(return_labels=True, apply_ln=True)

# Get the last token position
last_pos = -1
last_token_accum = accum_resid[:, 0, last_pos, :]  # [layer, d_model]

# Project into logit space
W_U = model.W_U
logit_by_layer = einsum(
    last_token_accum,
    W_U,
    "layer d_model, d_model d_vocab -> layer d_vocab"
)

# Get probabilities using softmax
probs_by_layer = torch.softmax(logit_by_layer, dim=-1)

# Get indices of numerical tokens (0-9)
number_tokens = [model.to_single_token(str(i)) for i in range(10)]

# Extract probabilities for numerical tokens
number_probs = probs_by_layer[:, number_tokens]

# Plot heatmap
plt.figure(figsize=(12, 8))
sns.heatmap(
    number_probs.cpu().numpy(),
    xticklabels=range(10),
    yticklabels=labels,
    cmap='viridis'
)
plt.xlabel('Digit')
plt.ylabel('Layer')
plt.title('Probability Distribution over Numerical Tokens by Layer')
plt.show()

# Print top predictions per layer
for layer, label in enumerate(labels):
    top_numbers = torch.topk(number_probs[layer], k=3)
    print(f"\nLayer {label}:")
    for prob, idx in zip(top_numbers.values, top_numbers.indices):
        print(f"Number {idx.item()}: {prob.item():.3f}")

ModuleNotFoundError: No module named 'matplotlib'