In [None]:
import pandas as pd
from datasets import load_dataset
import datasets
from tqdm import tqdm
from typing import Optional
import torch
import einops

from transformers import AutoTokenizer, AutoModelForCausalLM
from circuitsvis.activations import text_neuron_activations

import mypkg.whitebox_infra.dictionaries.batch_topk_sae as batch_topk_sae

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_name = "mistralai/Ministral-8B-Instruct-2410"
dtype = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=dtype, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

def get_submodule(model: AutoModelForCausalLM, layer: int):
    """Gets the residual stream submodule"""
    model_name = model.config._name_or_path

    if "pythia" in model_name:
        return model.gpt_neox.layers[layer]
    elif "gemma" in model_name or "mistral" in model_name:
        return model.model.layers[layer]
    else:
        raise ValueError(f"Please add submodule for model {model_name}")


chosen_layers = [18]
sae_repo = "adamkarvonen/ministral_saes"
sae_path = "mistralai_Ministral-8B-Instruct-2410_batch_top_k/resid_post_layer_18/trainer_1/ae.pt"

sae = batch_topk_sae.load_dictionary_learning_batch_topk_sae(
    repo_id=sae_repo,
    filename=sae_path,
    model_name=model_name,
    device=device,
    dtype=dtype,
    layer=chosen_layers[0],
    local_dir="downloaded_saes",
)

submodules = [get_submodule(model, chosen_layers[0])]

In [None]:
class EarlyStopException(Exception):
    """Custom exception for stopping model forward pass early."""
    pass


@torch.no_grad()
def collect_activations(model, submodule, inputs_BL):
    """
    Registers a forward hook on the submodule to capture the residual (or hidden)
    activations. We then raise an EarlyStopException to skip unneeded computations.
    """
    activations_BLD = None

    def gather_target_act_hook(module, inputs, outputs):
        nonlocal activations_BLD
        # For many models, the submodule outputs are a tuple or a single tensor:
        # If "outputs" is a tuple, pick the relevant item:
        #   e.g. if your layer returns (hidden, something_else), you'd do outputs[0]
        # Otherwise just do outputs
        if isinstance(outputs, tuple):
            activations_BLD = outputs[0]
        else:
            activations_BLD = outputs

        raise EarlyStopException("Early stopping after capturing activations")

    handle = submodule.register_forward_hook(gather_target_act_hook)

    try:
        _ = model(input_ids=inputs_BL.to(model.device))
    except EarlyStopException:
        pass
    except Exception as e:
        print(f"Unexpected error during forward pass: {str(e)}")
        raise
    finally:
        handle.remove()

    return activations_BLD

In [None]:
# mistral has some weird tokenization shit going on, not sure if we need to deal with it

# # Import needed packages:
# from mistral_common.protocol.instruct.messages import (
#     UserMessage,
# )
# from mistral_common.protocol.instruct.request import ChatCompletionRequest
# from mistral_common.protocol.instruct.tool_calls import (
#     Function,
#     Tool,
# )
# from mistral_common.tokens.tokenizers.mistral import MistralTokenizer

# # Load Mistral tokenizer

# model_name = "open-mixtral-8x22b"

# mistral_tokenizer = MistralTokenizer.from_model(model_name, strict=True)

# # Tokenize a list of messages
# tokenized = mistral_tokenizer.encode_chat_completion(
#     ChatCompletionRequest(
#         messages=[
#             UserMessage(content="How can I center a div?"),
#         ],
#         model=model_name,
#     )
# )
# tokens, text = tokenized.tokens, tokenized.text

# # Count the number of tokens
# print(len(tokens))
# print(type(tokens))
# tokens = torch.tensor(tokens).unsqueeze(0)
# print(tokens.shape)

# print(tokens, text)

# activations_BLD = collect_activations(model, submodules[0], tokens)


In [None]:
test_input = "Can you continue this sentence? The scientist named the population, after their distinctive horn, Ovid’s Unicorn. These four-horned, silver-white unicorns were previously unknown to science"

test_input = "[INST]Can you continue this story? The scientist named the population, after their distinctive horn, Ovid’s Unicorn. These four-horned, silver-white unicorns were previously unknown to science[/INST]assistant"


# test_input = "[INST]How can I center a div?[/INST]assistant"

tokens = tokenizer(test_input, return_tensors="pt", add_special_tokens=True).to(device)["input_ids"]
# tokens = tokenizer(test_input, return_tensors="pt", add_special_tokens=False).to(device)

print(tokens.shape)

activations_BLD = collect_activations(model, submodules[0], tokens)

print(tokens)
    

In [None]:
norms_BL = activations_BLD.norm(dim=-1)
print(norms_BL)
print(norms_BL.mean())


In [None]:
sae.use_threshold = True
encoded_BLF = sae.encode(activations_BLD)
decoded_BLD = sae.decode(encoded_BLF)

torch.set_printoptions(precision=8, sci_mode=False)

nonzero_BL = einops.reduce((encoded_BLF > 0).float(), "b l f -> b l", "sum")
print(nonzero_BL)
mean_nonzero = nonzero_BL.mean()
print(mean_nonzero, "\n\n")

MSE_BL = (activations_BLD - decoded_BLD).pow(2).mean(dim=-1)
print(MSE_BL)
mean_MSE = MSE_BL.mean()
print(mean_MSE)

In [None]:

@torch.no_grad()
def reconstruct_activations(model, submodule, sae, inputs_BL):

    def gather_target_act_hook(module, inputs, outputs):
        # For many models, the submodule outputs are a tuple or a single tensor:
        # If "outputs" is a tuple, pick the relevant item:
        #   e.g. if your layer returns (hidden, something_else), you'd do outputs[0]
        # Otherwise just do outputs
        if isinstance(outputs, tuple):
            activations_BLD = outputs[0]
        else:
            activations_BLD = outputs

        encoded_BLF = sae.encode(activations_BLD)
        decoded_BLD = sae.decode(encoded_BLF)

        outputs = (decoded_BLD,) + outputs[1:]

        return outputs

    handle = submodule.register_forward_hook(gather_target_act_hook)

    try:
        outputs = model(input_ids=inputs_BL.to(model.device), labels=inputs_BL.to(model.device))
    except Exception as e:
        print(f"Unexpected error during forward pass: {str(e)}")
        raise
    finally:
        handle.remove()

    return outputs

original_loss = model(input_ids=tokens.to(model.device), labels=tokens.to(model.device)).loss

outputs = reconstruct_activations(model, submodules[0], sae, tokens)

print(outputs.loss)
print(original_loss)
ratio = outputs.loss / original_loss
print(ratio)