In [1]:
import sys
import torch as t
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import einops
from jaxtyping import Int, Float
import functools
from tqdm import tqdm
from IPython.display import display
from transformer_lens.hook_points import HookPoint
from transformer_lens import (
    utils,
    HookedTransformer,
    HookedTransformerConfig,
    FactoredMatrix,
    ActivationCache,
)
import circuitsvis as cv

from optim_hunter.plotly_utils import imshow, hist, plot_comp_scores, plot_logit_attribution, plot_loss_difference, line
from optim_hunter.utils import prepare_prompt, slice_dataset
from optim_hunter.sklearn_regressors import linear_regression, knn_regression, random_forest, baseline_average, baseline_last, baseline_random
from optim_hunter.datasets import get_dataset_friedman_2
import logging

# Configure logging
logging.basicConfig(level=logging.WARNING)
›logger = logging.getLogger(__name__)

# Saves computation time, since we don't need it for the contents of this notebook
t.set_grad_enabled(False)

#device = t.device("cuda:0,1" if t.cuda.is_available() else "cpu")
device = t.device("cuda:0" if t.cuda.is_available() else "cpu")
# device = t.device("cpu")

MAIN = __name__ == "__main__"

In [None]:
# Load directly from model path https://github.com/TransformerLensOrg/TransformerLens/issues/691
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer

MODEL_TYPE = "meta-llama/Meta-Llama-3-8B-Instruct"
MODEL_PATH = "/home/freiza/optim_hunter/.models/Llama-3.1-8B-Instruct/"

if MODEL_PATH:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    hf_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True,
                                                     #quantization_config=BitsAndBytesConfig(load_in_4bit=True), 
                                                     #torch_dtype = t.float32, 
                                                     #device_map = "cuda:0"
                                                     )

    tokenizer.padding_side = 'left'
    tokenizer.pad_token = tokenizer.eos_token

    model = HookedTransformer.from_pretrained(
        MODEL_TYPE,
        hf_model=hf_model,
        device="cuda",
        n_devices=2,
        fold_ln=True,
        # fold_value_biases=False,
        center_writing_weights=True,
        # refactor_factored_attn_matrices=True,
        center_unembed=True,
        # dtype=t.bfloat16,
        dtype=t.float16,
        default_padding_side='left',
        tokenizer=tokenizer
    )

    #model = model.to("cuda" if t.cuda.is_available() else "cpu")
    #model.generate("The capital of Germany is", max_new_tokens=20, temperature=0)

In [None]:
seq_len = 5
x_train, y_train, x_test, y_test = get_dataset_friedman_2(random_state=11)
# x_train, y_train, x_test, y_test =  slice_dataset(x_train, y_train, x_test, y_test, seq_len)
linear_regression_prediction = linear_regression(x_train, x_test, y_train, y_test)['y_predict']
knn_regression_prediction = knn_regression(x_train, x_test, y_train, y_test)['y_predict']
random_forest_prediction = random_forest(x_train, x_test, y_train, y_test)['y_predict']
baseline_avg_prediction = baseline_average(x_train, x_test, y_train, y_test)['y_predict']
baseline_last_prediction = baseline_last(x_train, x_test, y_train, y_test)['y_predict']
baseline_random_prediction = baseline_random(x_train, x_test, y_train, y_test)['y_predict']
gold = y_test.values[0]

print("Linear Regression Prediction:", linear_regression_prediction[0])
print("Random Forest Prediction:", random_forest_prediction[0])
print("KNN Regression Prediction:", knn_regression_prediction[0])
print("Baseline Average Prediction:", baseline_avg_prediction[0])
print("Baseline Last Prediction:", baseline_last_prediction[0]) 
print("Baseline Random Prediction:", baseline_random_prediction[0])
print("Gold:", gold)

prompt = prepare_prompt(x_train, y_train, x_test)
prompt = prompt + ""
# example_answer = f"{baseline_last_prediction[0]}"
# example_answer = f"{random_forest_prediction[0]}"
example_answer = f"{gold}"
utils.test_prompt(prompt, example_answer, model, prepend_bos=True)

In [4]:
def create_comparison_data(model, dataset_func, regressors, random_state=1, seq_len=None):
    """
    Creates a structured comparison dataset for analyzing different regression models against gold values.
    
    Args:
        model (HookedTransformer): The transformer model used for tokenization
        dataset_func (callable): Function that returns (x_train, y_train, x_test, y_test)
        regressors (list): List of regression functions to compare
        random_state (int, optional): Random seed for dataset generation. Defaults to 11
    
    Returns:
        dict: A structured dictionary containing:
        {
            'dataset_name': str,  # Name of the dataset function
            'prompt': str,        # Generated prompt text for the model
            'predictions': {      # Dictionary of predictions from each model
                'gold': float,    # True value
                'model_name1': float,  # Prediction from first model
                'model_name2': float,  # Prediction from second model
                ...
            },
            'comparison_names': [  # List of comparison descriptors
                'model1 vs model2',
                'model1 vs model3',
                ...
            ],
            'token_pairs': tensor  # Shape: [num_comparisons, 1, 2]
                                  # Each pair contains the first tokens of two predictions
                                  # being compared
        }
    
    Note:
        - The function generates unique combinations (not permutations) of comparisons
        - Only the first token of each prediction is stored in token_pairs
        - All possible combinations between gold and regressors are included
        - Token pairs maintain the order specified in comparison_names
    """
    # Get dataset
    x_train, y_train, x_test, y_test = dataset_func(random_state=random_state)
    if seq_len:
        x_train, y_train, x_test, y_test = slice_dataset(x_train, y_train, x_test, y_test, seq_len)
    
    # Get prompt
    prompt = prepare_prompt(x_train, y_train, x_test)
    
    # Get gold value
    gold = y_test.values[0]
    
    # Get predictions from each regressor
    predictions = {}
    predictions['gold'] = gold
    for regressor in regressors:
        result = regressor(x_train, x_test, y_train, y_test)
        predictions[result['model_name']] = result['y_predict'][0]
    
    # Create comparison names and token pairs
    comparison_names = []
    token_pairs = []
    
    # Create list of all predictors (including gold)
    all_predictors = ['gold'] + [reg(x_train, x_test, y_train, y_test)['model_name'] for reg in regressors]
    
    # Generate unique combinations (not permutations)
    for i, pred1 in enumerate(all_predictors):
        for j, pred2 in enumerate(all_predictors[i+1:], i+1):  # Start from i+1 to avoid duplicates
            comparison_name = f"{pred1} vs {pred2}"
            comparison_names.append(comparison_name)
            
            # Tokenize each prediction separately and get their first tokens
            tokens1 = model.to_tokens(str(predictions[pred1]), prepend_bos=False)[0, 0]  # First token of first prediction
            tokens2 = model.to_tokens(str(predictions[pred2]), prepend_bos=False)[0, 0]  # First token of second prediction
            
            # Combine the first tokens into a pair
            first_tokens = t.tensor([tokens1, tokens2], device=tokens1.device).unsqueeze(0)  # Shape: [1, 2]
            token_pairs.append(first_tokens)

    # Verification Step: Ensure that each comparison_name matches the corresponding token_pair
    for idx, (comp_name, token_pair) in enumerate(zip(comparison_names, token_pairs)):
        pred1_name, pred2_name = comp_name.split(' vs ')
        pred1_value = predictions[pred1_name]
        pred2_value = predictions[pred2_name]
        
        # Tokenize the actual prediction values
        actual_tokens1 = model.to_tokens(str(pred1_value), prepend_bos=False)[0, 0].item()
        actual_tokens2 = model.to_tokens(str(pred2_value), prepend_bos=False)[0, 0].item()
        
        # Extract tokens from token_pair
        token1, token2 = token_pair.squeeze(0).tolist()
        
        # Assert that tokens match
        assert token1 == actual_tokens1, f"Mismatch in token1 for comparison '{comp_name}' at index {idx}"
        assert token2 == actual_tokens2, f"Mismatch in token2 for comparison '{comp_name}' at index {idx}"
    
    
    return {
        'dataset_name': dataset_func.__name__,
        'prompt': prompt,
        'predictions': predictions,
        'comparison_names': comparison_names,
        'token_pairs': t.stack(token_pairs),  # Shape: [num_comparisons, 1, 2]
    }

# # Create the data store
# datasets = [ get_dataset_friedman_2 ]
# regressors = [ linear_regression, knn_regression, random_forest, baseline_average, baseline_last, baseline_random ]
# data_store = {}
# for dataset_func in datasets:
#     data_store[dataset_func.__name__] = create_comparison_data(model, dataset_func, regressors)
# # Print out the token pairs and comparison names
# for dataset_name, dataset_info in data_store.items():
#     print(f"\nDataset: {dataset_name}")
#     print("Token pairs:")
#     print(dataset_info['token_pairs'])
#     print("\nComparison names:")
#     print(dataset_info['comparison_names'])

# # Get the first token pair from the first dataset
# first_dataset_name = next(iter(data_store))
# first_token_pair = data_store[first_dataset_name]['token_pairs'][0]  # Shape: [1, 2]
# print("\nFirst token pair:")
# print(first_token_pair)

In [5]:
regressors = [ linear_regression, knn_regression, random_forest, baseline_average, baseline_last, baseline_random ]

def generate_linreg_tokens(
    model: HookedTransformer,
    dataset,
    seq_len = 5,
    batch: int = 1
) -> Int[Tensor, "batch full_seq_len"]:
    '''
    Generates a sequence of linear regression ICL tokens

    Outputs are:
        linreg_tokens: [batch, 1+linreg]
    '''
    prefix = (t.ones(batch, 1) * model.tokenizer.bos_token_id).long().to(device)
    zero_token = model.to_tokens('0', truncate=True)[0][-1]
    
    # Create list to store tokens for each batch
    batch_tokens = []
    data_store = []

    dataset_func = get_dataset_friedman_2
    
    # Generate tokens for each batch with different random seeds
    for i in range(batch):
        data = create_comparison_data(model, dataset_func, regressors, random_state=i, seq_len=seq_len)
        tokens = model.to_tokens(data['prompt'], truncate=True)
        batch_tokens.append(tokens[0])
        data_store.append(data)
    
    # Find the longest sequence length
    max_len = max(len(tokens) for tokens in batch_tokens)
    
    # Pad shorter sequences with token 0 at position -4
    for i in range(len(batch_tokens)):
        while len(batch_tokens[i]) < max_len:
            # Insert 0 at position -4 from the end
            print(f"Found mismatch in token length for batch {i}!\nLargest length: {max_len}\nBatch {i} length: {len(batch_tokens[i])}\nApplying padding...")
            print(f"\nBefore Zero Token Padding:\n#####\n{model.to_string(batch_tokens[i][-50:])}\n#####")
            batch_tokens[i] = t.cat([
                batch_tokens[i][:len(batch_tokens[i])-3],  
                zero_token.unsqueeze(0), # Add unsqueeze to make zero_token 1-dimensional
                batch_tokens[i][len(batch_tokens[i])-3:]
            ])
            print(f"\nAfter Zero Token Padding:\n#####\n{model.to_string(batch_tokens[i][-50:])}\n#####")

    
    # Stack all batches together 
    linreg_tokens = t.stack(batch_tokens).to(device)
    
    # Add prefix to each batch
    linreg_tokens = t.cat([prefix, linreg_tokens], dim=-1).to(device)
    return linreg_tokens, data_store

def run_and_cache_model_linreg_tokens(model: HookedTransformer, seq_len: int, batch: int = 1) -> tuple[Tensor, Tensor, ActivationCache]:
    '''
    Generates a sequence of linear regression ICL tokens, and runs the model on it, returning (tokens, logits, cache)

    Should use the `generate_linreg_tokens` function above

    Outputs are:
        linreg_tokens: [batch, 1+linreg]
        linreg_logits: [batch, 1+linreg, d_vocab]
        linreg_cache: The cache of the model run on linreg_tokens
    '''
    linreg_tokens, linreg_data_store = generate_linreg_tokens(model, get_dataset_friedman_2, seq_len, batch)
    linreg_logits, linreg_cache = model.run_with_cache(linreg_tokens)
    return linreg_tokens, linreg_logits, linreg_cache, linreg_data_store

In [None]:
model.clear_contexts()

seq_len = 10
batch = 4
(linreg_tokens, linreg_logits, linreg_cache, linreg_data_store) = run_and_cache_model_linreg_tokens(model, seq_len, batch)

In [None]:
model.clear_contexts()
linreg_tokens = linreg_tokens.to('cpu')
linreg_logits = linreg_logits.to('cpu')
linreg_cache = linreg_cache.to('cpu')

# Verify all datasets have the same comparison names
base_comparison_names = linreg_data_store[0]["comparison_names"]
all_match = all(dataset["comparison_names"] == base_comparison_names for dataset in linreg_data_store[1:])
assert all_match, "Mismatch in comparison names across datasets."

# Extract comparison names from the first dataset
token_pairs_names = base_comparison_names.copy()

# Extract token pairs across all datasets for each comparison
token_pairs = [
    t.stack([dataset["token_pairs"][i] for dataset in linreg_data_store])[0]
    for i in range(len(token_pairs_names))
]

logger.info(f"Number of comparisons: {len(token_pairs_names)}")
logger.info(f"Number of token_pairs: {len(token_pairs)}")

# Iterate over token pairs and generate plots
for i, token_pair in enumerate(token_pairs):
    logger.info(f"Processing comparison {i}: {token_pairs_names[i]}")
    token_pair = token_pair.to('cpu')

    def logits_to_ave_logit_diff(
        logits: Float[Tensor, "batch seq d_vocab"],
        answer_tokens: Float[Tensor, "batch 2"] = token_pair,
        per_prompt: bool = False
    ) -> Float[Tensor, "*batch"]:
        '''
        Returns logit difference between the correct and incorrect answer.

        If per_prompt=True, return the array of differences rather than the average.
        '''
        # Extract token IDs for correct and incorrect answers
        correct = answer_tokens[:, 0]  # Correct token IDs
        incorrect = answer_tokens[:, 1]  # Incorrect token IDs

        # Extract logits for the final token in the sequence
        final_logits = logits[:, -1, :]  # Shape: (batch, d_vocab)

        # Get logits for the correct and incorrect answers
        correct_logits = final_logits[t.arange(final_logits.size(0)), correct]  # Shape: (batch,)
        incorrect_logits = final_logits[t.arange(final_logits.size(0)), incorrect]  # Shape: (batch,)

        # Calculate logit difference
        logit_diff = correct_logits - incorrect_logits  # Shape: (batch,)

        if per_prompt:
            return logit_diff  # Return per-prompt logit differences
        else:
            return logit_diff.mean()  # Return mean logit difference over the batch

    original_per_prompt_diff = logits_to_ave_logit_diff(linreg_logits, token_pair, per_prompt=True)
    logger.debug(f"Per prompt logit difference for comparison '{token_pairs_names[i]}': {original_per_prompt_diff}")
    original_average_logit_diff = logits_to_ave_logit_diff(linreg_logits, token_pair)
    logger.debug(f"Average logit difference for comparison '{token_pairs_names[i]}': {original_average_logit_diff}")

    # Retrieve final residual stream
    final_residual_stream: Float[Tensor, "batch seq d_model"] = linreg_cache["resid_post", -1]
    logger.debug(f"Final residual stream shape: {final_residual_stream.shape}")
    final_token_residual_stream: Float[Tensor, "batch d_model"] = final_residual_stream[:, -1, :]

    # Compute residual directions
    pair_residual_directions = model.tokens_to_residual_directions(token_pair.to('cpu'))  # [batch 2 d_model]
    logger.debug(f"Answer residual directions shape: {pair_residual_directions.shape}")

    correct_residual_directions, incorrect_residual_directions = pair_residual_directions.unbind(dim=1)
    logit_diff_directions = correct_residual_directions - incorrect_residual_directions  # [batch d_model]
    logger.debug(f"Logit difference directions shape: {logit_diff_directions.shape}")

    def residual_stack_to_logit_diff(
        residual_stack: Float[Tensor, "... batch d_model"],
        cache: ActivationCache,
        logit_diff_directions: Float[Tensor, "batch d_model"] = logit_diff_directions,
    ) -> Float[Tensor, "..."]:
        '''
        Gets the avg logit difference between the correct and incorrect answer for a given
        stack of components in the residual stream.
        '''
        # Apply LayerNorm scaling (to just the final sequence position)
        scaled_residual_stream = cache.apply_ln_to_stack(residual_stack, layer=-1, pos_slice=-1)

        logit_diff_directions = logit_diff_directions.to(dtype=scaled_residual_stream.dtype)

        logit_diff_directions = logit_diff_directions.to('cpu')

        logger.debug(f"Scaled residual stream shape: {scaled_residual_stream.shape}")
        logger.debug(f"Logit diff directions shape: {logit_diff_directions.shape}")

        # Projection
        batch_size = residual_stack.size(-2)
        avg_logit_diff = einops.einsum(
            scaled_residual_stream,
            logit_diff_directions,
            "... batch d_model, batch d_model -> ..."
        ) / batch_size
        return avg_logit_diff

    # Verify residual stack computation
    t.testing.assert_close(
        residual_stack_to_logit_diff(final_token_residual_stream.to(t.float32), linreg_cache.to(t.float32)),
        original_average_logit_diff.to(t.float32),
        rtol=5e-3,  # Increased tolerance
        atol=5e-3
    )

    # Accumulate residuals
    accumulated_residual, labels = linreg_cache.accumulated_resid(layer=-1, incl_mid=True, pos_slice=-1, return_labels=True)
    # accumulated_residual has shape (component, batch, d_model)

    logit_lens_logit_diffs: Float[Tensor, "component"] = residual_stack_to_logit_diff(accumulated_residual, linreg_cache)
    # Convert to half precision
    logit_lens_logit_diffs = logit_lens_logit_diffs.half()

    # # Generate plot
    line(
        logit_lens_logit_diffs,
        hovermode="x unified",
        title=f"Logit Difference From Accumulated Residual Stream for {token_pairs_names[i]}",
        labels={"x": "Layer", "y": "Logit Diff"},
        xaxis_tickvals=labels,
        width=800
    )
    # Save the logit lens plot
    fig = line(
        logit_lens_logit_diffs,
        hovermode="x unified", 
        title=f"Logit Difference From Accumulated Residual Stream for {token_pairs_names[i]}",
        labels={"x": "Layer", "y": "Logit Diff"},
        xaxis_tickvals=labels,
        width=800,
        return_fig=True
    )
    fig.write_html(f"../public/logit_lens_{token_pairs_names[i].replace(' ', '_')}.html")
