In [10]:
import torch
import torch.nn.functional as F
from sae_lens import SAE
from transformer_lens import HookedTransformer
from bidict import bidict
import numpy as np
import time
import os
from datasets import load_dataset
import lightgbm as lgb
from sklearn.model_selection import train_test_split

def get_saes(device, total_layers=26):
    """Loads a series of Sparse Autoencoders (SAEs) for specified layers"""
    saes = []
    print(f"Loading {total_layers} SAEs...")
    for layer in range(total_layers):
        sae, _, _ = SAE.from_pretrained(
            release="gemma-scope-2b-pt-res-canonical",
            sae_id=f"layer_{layer}/width_16k/canonical",
            device=device
        )
        saes.append(sae)
        print(f"Layer {layer} loaded.")
    return saes

def get_all_mmlu_questions(dataset):
    """Returns all questions from the MMLU dataset"""
    all_questions = []
    for i in range(len(dataset)):
        data_point = dataset[i]
        question = data_point['question']
        choices = data_point['choices']
        instruction = "The following is a multiple choice question. Think through it very briefly, then output only a single token corresponding to the right answer (A, B, C, or D).\n"
        formatted_question = instruction + f" Question: {question}\n"
        choice_labels = ['A', 'B', 'C', 'D']
        for j, choice in enumerate(choices):
            formatted_question += f"{choice_labels[j]}) {choice}\n"
        formatted_question += "Answer: "
        all_questions.append({
            'text': formatted_question,
            'subject': data_point['subject'],
            'answer': data_point['answer'],
            'choices': data_point['choices'],
            'raw_question': question
        })
    return all_questions
        
def process_question(model, saes, question, feature_bidict, sample_idx=None):
    """
    Performs a forward pass, extracts features and loss for a single MMLU question
    """
    question_text = question['text']
    correct_answer_idx = question['answer']
    choice_labels = ['A', 'B', 'C', 'D']
    total_features = len(feature_bidict)
    feature_vector = np.zeros(total_features, dtype=np.byte)

    with torch.no_grad():
        # Stage 1: Generate reasoning
        reasoning_prompt = question_text.replace("Answer: ", "Let me think through this very briefly (you have 50 tokens to think):\n")
        reasoning_tokens = model.generate(
            model.to_tokens(reasoning_prompt, prepend_bos=True),
            max_new_tokens=50,
            do_sample=False
        )

        # Stage 2: Generate final answer
        reasoning_text = model.to_string(reasoning_tokens[0])
        final_prompt = reasoning_text + "\n\n I will answer with either: A, B, C, or D. Based on my brief reasoning above, the answer is:"
        final_tokens = model.generate(
            model.to_tokens(final_prompt, prepend_bos=True),
            max_new_tokens=3,  # Just enough for the answer
            do_sample=False
        )
        final_output = model.to_string(final_tokens[0])
        # Run with cache on the FINAL tokens (not empty!)
        logits, cache = model.run_with_cache(final_tokens)

        choice_tokens = [model.to_tokens(label, prepend_bos=False)[0, 0].item() for label in ['A', 'B', 'C', 'D']]
        answer_position = -1  # Default to last token
        import re
        reasoning_text = model.to_string(final_tokens[0])

        # Just look for the last occurrence of A), B), C), or D)
        last_answer = None
        for choice in ['A)', 'B)', 'C)', 'D)']:
            if choice in reasoning_text:
                last_pos = reasoning_text.rfind(choice)  # Find LAST occurrence
                if last_answer is None or last_pos > last_answer[1]:
                    last_answer = (choice[0], last_pos)

        if last_answer:
            predicted_choice_idx = ['A', 'B', 'C', 'D'].index(last_answer[0])
        else:
            predicted_choice_idx = torch.argmax(last_token_logits).item()

        last_token_logits = logits[0, answer_position, :]

        choice_token_ids = [model.to_tokens(label, prepend_bos=False)[0, 0].item() for label in choice_labels]
        choice_logits = last_token_logits[choice_token_ids]
        output_logits = choice_logits.cpu().numpy()
        loss = F.cross_entropy(choice_logits.unsqueeze(0), torch.tensor([correct_answer_idx]).to(model.cfg.device)).item()
        #predicted_choice_idx = torch.argmax(choice_logits).item()
        is_correct = predicted_choice_idx == correct_answer_idx

        print(f"\n--- Sample {sample_idx} ---")
        print(f"Subject: {question['subject']}")
        print(f"Question: {question['raw_question']}")
        print("Choices:")
        for i, choice in enumerate(question['choices']):
            print(f"  {choice_labels[i]}) {choice}")
        print(f"Correct Answer: {choice_labels[correct_answer_idx]} ({question['choices'][correct_answer_idx]})")
        print(f"Model's Answer: {choice_labels[predicted_choice_idx]} ({question['choices'][predicted_choice_idx]})")
        print(f"Result: {'✓ CORRECT' if is_correct else '✗ INCORRECT'}")
        print(f"Loss: {loss:.4f}")
        print(f"Model's Reasoning:\n{final_output}")

        for layer_idx, sae in enumerate(saes):
            final_token_activations = cache[sae.cfg.hook_name][0, -1, :].unsqueeze(0)
            feature_acts = sae.encode(final_token_activations)
            active_indices = torch.where(feature_acts > 0)[1].cpu().tolist()
            for feature_idx in active_indices:
                global_feature_idx = feature_bidict.get((layer_idx, feature_idx))
                if global_feature_idx is not None:
                    feature_vector[global_feature_idx] = 1

    return feature_vector, is_correct, loss, output_logits
        
def extract_features_and_correctness(model, saes, questions, feature_bidict, output_dir=None):
    """
    Processes questions to get features and correctness, with an option to save data
    """
    all_feature_vectors = []
    all_correctness_labels = []
    all_losses = []
    all_output_logits = []
    for i, question in enumerate(questions):
        feature_vector, is_correct, loss, output_logits = process_question(
            model, saes, question, feature_bidict, sample_idx=i+1
        )
        all_feature_vectors.append(feature_vector)
        all_correctness_labels.append(is_correct)
        all_losses.append(loss)
        all_output_logits.append(output_logits)

    if output_dir:
        features_np = np.array(all_feature_vectors, dtype=np.byte)
        correctness_np = np.array(all_correctness_labels)
        losses_np = np.array(all_losses)
        output_logits_np =  np.array(all_output_logits)
        feature_path = os.path.join(output_dir, f"features.npy")
        correctness_path = os.path.join(output_dir, f"correctness.npy")
        loss_path = os.path.join(output_dir, f"losses.npy")
        output_logits_path = os.path.join(output_dir, f"output_logits.npy")
        np.save(feature_path, features_np)
        np.save(correctness_path, correctness_np)
        np.save(loss_path, losses_np)
        np.save(output_logits_path, output_logits_np)

    return np.array(all_feature_vectors, dtype=np.byte), np.array(all_correctness_labels), np.array(all_losses), np.array(all_output_logits)

def load_or_create_data(model, saes, feature_bidict, mmlu_dataset, output_dir):
    feature_path = os.path.join(output_dir, "features.npy")
    correctness_path = os.path.join(output_dir, "correctness.npy")
    loss_path = os.path.join(output_dir, "losses.npy")
    output_logits_path = os.path.join(output_dir, "output_logits.npy")
    if os.path.exists(feature_path) and os.path.exists(correctness_path) and os.path.exists(loss_path): #and os.path.exists(output_logits_path):
        print(f"Found cached data. Loading from '{output_dir}'...")
        feature_vectors = np.load(feature_path)
        correctness_labels = np.load(correctness_path)
        losses = np.load(loss_path)
        try: 
            output_logits = np.load(output_logits_path)
        except:
            print("output_logits set to None")
            output_logits = None
    else:
        print(f"No cached data found in '{output_dir}'. Generating data from scratch...")
        questions = get_all_mmlu_questions(mmlu_dataset)
        feature_vectors, correctness_labels, losses, output_logits = extract_features_and_correctness(
            model, saes, questions, feature_bidict,
            output_dir=output_dir)
    return feature_vectors, correctness_labels, losses, output_logits


# Init all variables and load SAEs for all layers
# device = "cuda" if torch.cuda.is_available() else "cpu"
# TEST_OUTPUT_DIR = "cached_test_data_thinking_correctness_losses_logits"
# os.makedirs(TEST_OUTPUT_DIR, exist_ok=True)
# model = HookedTransformer.from_pretrained("gemma-2-2b", device=device)
# saes = get_saes(device, total_layers=26)
# feature_bidict = bidict()
# global_idx = 0
# for layer_idx, sae in enumerate(saes):
#     for feature_idx in range(sae.cfg.d_sae):
#         feature_bidict[(layer_idx, feature_idx)] = global_idx
#         global_idx += 1
# total_features = len(feature_bidict)
# print(f"Created mapping for {total_features} features across {len(saes)} layers")

# Build tree
MMLU_test_split = load_dataset("cais/mmlu", "all", split='test')
test_feature_vectors, test_correctness_labels, test_losses, test_output_logits = load_or_create_data(model, saes, feature_bidict, MMLU_test_split, TEST_OUTPUT_DIR)
X_test, X_stop_test, y_test, y_stop_test = train_test_split(
    test_feature_vectors, 
    test_losses, 
    test_size=0.1,
    random_state=42
)
test_data = lgb.Dataset(X_test, label=y_test)
stop_test_data = lgb.Dataset(X_stop_test, label=y_stop_test, reference=test_data)
print(f"\nExtracted {len(test_feature_vectors)} samples for tree construction. Building NAP tree...")
start_time = time.time()
params = {
    'objective': 'regression',
    'boosting_type': 'gbdt',
    'metric': 'rmse',
    'max_depth': -1,
    'num_leaves': 2048,
    'min_data_in_leaf': 3,
    'min_sum_hessian_in_leaf': 1e-4,
    'feature_fraction': 1.0,
    'feature_fraction_bynode': 1.0,
    'bagging_fraction': 1.0,
    'learning_rate': 0.01,
    'lambda_l1': 0.05,
    'lambda_l2': 0.05,
    'num_threads': -1,
    'force_row_wise': True,
    'verbosity': 1,
    'seed': 42,
    'deterministic': True
}
model_lgb = lgb.train(
    params, 
    test_data, 
    num_boost_round=2000,
    valid_sets=[stop_test_data],
    valid_names=['stop_test_data'],
    callbacks=[
        lgb.early_stopping(stopping_rounds=100),
        lgb.log_evaluation(period=200)
    ]
)
print(f"LightGBM model built in {time.time() - start_time:.2f} seconds")

No cached data found in 'cached_test_data_thinking_correctness_losses_logits'. Generating data from scratch...


100%|██████████| 50/50 [00:32<00:00,  1.55it/s]
100%|██████████| 3/3 [00:03<00:00,  1.10s/it]



--- Sample 1 ---
Subject: abstract_algebra
Question: Find the degree for the given field extension Q(sqrt(2), sqrt(3), sqrt(18)) over Q.
Choices:
  A) 0
  B) 4
  C) 2
  D) 6
Correct Answer: B (4)
Model's Answer: A (0)
Result: ✗ INCORRECT
Loss: 2.5464
Model's Reasoning:
<bos><bos>The following is a multiple choice question. Think through it very briefly, then output only a single token corresponding to the right answer (A, B, C, or D).
 Question: Find the degree for the given field extension Q(sqrt(2), sqrt(3), sqrt(18)) over Q.
A) 0
B) 4
C) 2
D) 6
Let me think through this very briefly (you have 50 tokens to think):
1. The degree of the extension is the degree of the minimal polynomial of the element.
2. The minimal polynomial of sqrt(2) is x^2 - 2.
3. The minimal polynomial of sqrt(3) is x

 I will answer with either: A, B, C, or D. Based on my brief reasoning above, the answer is:
A)


100%|██████████| 50/50 [00:31<00:00,  1.60it/s]
100%|██████████| 3/3 [00:02<00:00,  1.15it/s]



--- Sample 2 ---
Subject: abstract_algebra
Question: Let p = (1, 2, 5, 4)(2, 3) in S_5 . Find the index of <p> in S_5.
Choices:
  A) 8
  B) 2
  C) 24
  D) 120
Correct Answer: C (24)
Model's Answer: A (8)
Result: ✗ INCORRECT
Loss: 2.3323
Model's Reasoning:
<bos><bos>The following is a multiple choice question. Think through it very briefly, then output only a single token corresponding to the right answer (A, B, C, or D).
 Question: Let p = (1, 2, 5, 4)(2, 3) in S_5 . Find the index of <p> in S_5.
A) 8
B) 2
C) 24
D) 120
Let me think through this very briefly (you have 50 tokens to think):
1. p is a permutation of S_5.
2. p is a product of disjoint cycles.
3. p is a product of disjoint cycles of length 2.
4. p is a product of disjoint cycles of length 

 I will answer with either: A, B, C, or D. Based on my brief reasoning above, the answer is:
A)


100%|██████████| 50/50 [00:29<00:00,  1.72it/s]
100%|██████████| 3/3 [00:02<00:00,  1.18it/s]



--- Sample 3 ---
Subject: abstract_algebra
Question: Find all zeros in the indicated finite field of the given polynomial with coefficients in that field. x^5 + 3x^3 + x^2 + 2x in Z_5
Choices:
  A) 0
  B) 1
  C) 0,1
  D) 0,4
Correct Answer: D (0,4)
Model's Answer: A (0)
Result: ✗ INCORRECT
Loss: 0.5742
Model's Reasoning:
<bos><bos>The following is a multiple choice question. Think through it very briefly, then output only a single token corresponding to the right answer (A, B, C, or D).
 Question: Find all zeros in the indicated finite field of the given polynomial with coefficients in that field. x^5 + 3x^3 + x^2 + 2x in Z_5
A) 0
B) 1
C) 0,1
D) 0,4
Let me think through this very briefly (you have 50 tokens to think):
1. The polynomial is x^5 + 3x^3 + x^2 + 2x.
2. The coefficients are in Z_5.
3. The polynomial has 5 zeros.
4. The polynomial has

 I will answer with either: A, B, C, or D. Based on my brief reasoning above, the answer is:
A)


 90%|█████████ | 45/50 [00:26<00:02,  1.69it/s]


KeyboardInterrupt: 

In [None]:
VALIDATION_OUTPUT_DIR = "cached_validation_data_thinking_correctness_losses_logits"
os.makedirs(VALIDATION_OUTPUT_DIR, exist_ok=True)
MMLU_validation_split = load_dataset("cais/mmlu", "all", split='validation')
validation_feature_vectors, validation_correctness_labels, validation_losses, output_logits = load_or_create_data(model, saes, feature_bidict, MMLU_validation_split, VALIDATION_OUTPUT_DIR)
validation_loss_predictions = model_lgb.predict(validation_feature_vectors)
low_loss_threshold = 0.651
confident_mask = validation_loss_predictions < low_loss_threshold
confident_correctness_labels = validation_correctness_labels[confident_mask]
confident_accuracy = np.mean(confident_correctness_labels)

print(f"Coverage of confident predictions: {np.mean(confident_mask):.3f} ({np.sum(confident_mask)}/{len(validation_feature_vectors)} validation samples)")
print(f"Accuracy on confident predictions: {np.mean(confident_correctness_labels):.4f}")
print(f"Model's true validation accuracy: {np.sum(validation_correctness_labels) / len(validation_correctness_labels):.4f}")