In [None]:
import torch
import torch.nn.functional as F
from sae_lens import SAE
from transformer_lens import HookedTransformer
from bidict import bidict
import numpy as np
import time
import os
from datasets import load_dataset
import lightgbm as lgb
from sklearn.model_selection import train_test_split

def get_saes(device, total_layers=26):
    """Loads a series of Sparse Autoencoders (SAEs) for specified layers"""
    saes = []
    print(f"Loading {total_layers} SAEs...")
    for layer in range(total_layers):
        sae, _, _ = SAE.from_pretrained(
            release="gemma-scope-2b-pt-res-canonical",
            sae_id=f"layer_{layer}/width_16k/canonical",
            device=device
        )
        saes.append(sae)
        print(f"Layer {layer} loaded.")
    return saes

def get_all_mmlu_questions(dataset):
    """Returns all questions from the MMLU dataset"""
    all_questions = []
    for i in range(len(dataset)):
        data_point = dataset[i]
        question = data_point['question']
        choices = data_point['choices']
        instruction = "The following is a multiple choice question. Output only a single token corresponding to the right answer (ie: A) \n"
        formatted_question = instruction + f" Question: {question}\n"
        choice_labels = ['A', 'B', 'C', 'D']
        for j, choice in enumerate(choices):
            formatted_question += f"{choice_labels[j]}) {choice}\n"
        formatted_question += "Answer: "
        all_questions.append({
            'text': formatted_question,
            'subject': data_point['subject'],
            'answer': data_point['answer'],
            'choices': data_point['choices'],
            'raw_question': question
        })
    return all_questions
        
def process_question(model, saes, question, feature_bidict, sample_idx=None):
    """
    Performs a forward pass, extracts features and loss for a single MMLU question
    """
    question_text = question['text']
    correct_answer_idx = question['answer']
    choice_labels = ['A', 'B', 'C', 'D']
    total_features = len(feature_bidict)
    feature_vector = np.zeros(total_features, dtype=np.byte)

    with torch.no_grad():
        input_tokens = model.to_tokens(question_text, prepend_bos=True).to(model.cfg.device)
        logits, cache = model.run_with_cache(input_tokens)
        last_token_logits = logits[0, -1, :]
        choice_token_ids = [model.to_tokens(label, prepend_bos=False)[0, 0].item() for label in choice_labels]
        choice_logits = last_token_logits[choice_token_ids]
        loss = F.cross_entropy(choice_logits.unsqueeze(0), torch.tensor([correct_answer_idx]).to(model.cfg.device)).item()
        predicted_choice_idx = torch.argmax(choice_logits).item()
        is_correct = predicted_choice_idx == correct_answer_idx

        print(f"\n--- Sample {sample_idx} ---")
        print(f"Subject: {question['subject']}")
        print(f"Question: {question['raw_question']}")
        print("Choices:")
        for i, choice in enumerate(question['choices']):
            print(f"  {choice_labels[i]}) {choice}")
        print(f"Correct Answer: {choice_labels[correct_answer_idx]} ({question['choices'][correct_answer_idx]})")
        print(f"Model's Answer: {choice_labels[predicted_choice_idx]} ({question['choices'][predicted_choice_idx]})")
        print(f"Result: {'✓ CORRECT' if is_correct else '✗ INCORRECT'}")
        print(f"Loss: {loss:.4f}")

        for layer_idx, sae in enumerate(saes):
            final_token_activations = cache[sae.cfg.hook_name][0, -1, :].unsqueeze(0)
            feature_acts = sae.encode(final_token_activations)
            active_indices = torch.where(feature_acts > 0)[1].cpu().tolist()
            for feature_idx in active_indices:
                global_feature_idx = feature_bidict.get((layer_idx, feature_idx))
                if global_feature_idx is not None:
                    feature_vector[global_feature_idx] = 1

    return feature_vector, is_correct
        
def extract_features_and_correctness(model, saes, questions, feature_bidict, output_dir=None):
    """
    Processes questions to get features and correctness, with an option to save data
    """
    all_feature_vectors = []
    all_correctness_labels = []
    for i, question in enumerate(questions):
        feature_vector, is_correct = process_question(
            model, saes, question, feature_bidict, sample_idx=i+1
        )
        all_feature_vectors.append(feature_vector)
        all_correctness_labels.append(is_correct)

    if output_dir:
        features_np = np.array(all_feature_vectors, dtype=np.byte)
        correctness_np = np.array(all_correctness_labels)
        feature_path = os.path.join(output_dir, f"features.npy")
        correctness_path = os.path.join(output_dir, f"correctness.npy")
        np.save(feature_path, features_np)
        np.save(correctness_path, correctness_np)

    return np.array(all_feature_vectors, dtype=np.byte), np.array(all_correctness_labels)

def load_or_create_data(model, saes, feature_bidict, mmlu_dataset, output_dir):
    feature_path = os.path.join(output_dir, "features.npy")
    correctness_path = os.path.join(output_dir, "correctness.npy")
    if os.path.exists(feature_path) and os.path.exists(correctness_path):
        print(f"Found cached data. Loading from '{output_dir}'...")
        feature_vectors = np.load(feature_path)
        correctness_labels = np.load(correctness_path)
    else:
        print(f"No cached data found in '{output_dir}'. Generating data from scratch...")
        questions = get_all_mmlu_questions(mmlu_dataset)
        feature_vectors, correctness_labels = extract_features_and_correctness(
            model, saes, questions, feature_bidict,
            output_dir=output_dir)
    return feature_vectors, correctness_labels


# Init all variables and load SAEs for all layers
device = "cuda" if torch.cuda.is_available() else "cpu"
TEST_OUTPUT_DIR = "cached_test_data"
os.makedirs(TEST_OUTPUT_DIR, exist_ok=True)
model = HookedTransformer.from_pretrained("gemma-2-2b", device=device)
saes = get_saes(device, total_layers=26)
feature_bidict = bidict()
global_idx = 0
for layer_idx, sae in enumerate(saes):
    for feature_idx in range(sae.cfg.d_sae):
        feature_bidict[(layer_idx, feature_idx)] = global_idx
        global_idx += 1
total_features = len(feature_bidict)
print(f"Created mapping for {total_features} features across {len(saes)} layers")

# Build tree
MMLU_test_split = load_dataset("cais/mmlu", "all", split='test')
test_feature_vectors, test_correctness_labels = load_or_create_data(model, saes, feature_bidict, MMLU_test_split, TEST_OUTPUT_DIR)