In [22]:
import torch
import torch.nn.functional as F
from sae_lens import SAE
from transformer_lens import HookedTransformer
from bidict import bidict
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import os
from datasets import load_dataset
import torch.nn as nn
from sklearn.feature_selection import SelectKBest, f_regression

def get_saes(device, total_layers=26):
    """Loads a series of Sparse Autoencoders (SAEs) for specified layers"""
    saes = []
    print(f"Loading {total_layers} SAEs...")
    for layer in range(total_layers):
        sae, _, _ = SAE.from_pretrained(
            release="gemma-scope-2b-pt-res-canonical",
            sae_id=f"layer_{layer}/width_16k/canonical",
            device=device
        )
        saes.append(sae)
        print(f"Layer {layer} loaded.")
    return saes

def get_all_mmlu_questions(dataset):
    """Returns all questions from the MMLU dataset"""
    all_questions = []
    for i in range(len(dataset)):
        data_point = dataset[i]
        question = data_point['question']
        choices = data_point['choices']
        instruction = "The following is a multiple choice question. Output only a single token corresponding to the right answer (ie: A) \n"
        formatted_question = instruction + f" Question: {question}\n"
        choice_labels = ['A', 'B', 'C', 'D']
        for j, choice in enumerate(choices):
            formatted_question += f"{choice_labels[j]}) {choice}\n"
        formatted_question += "Answer: "
        all_questions.append({
            'text': formatted_question,
            'subject': data_point['subject'],
            'answer': data_point['answer'],
            'choices': data_point['choices'],
            'raw_question': question
        })
    return all_questions
        
def process_question(model, saes, question, feature_bidict, sample_idx=None):
    """
    Performs a forward pass, extracts features and loss for a single MMLU question
    """
    question_text = question['text']
    correct_answer_idx = question['answer']
    choice_labels = ['A', 'B', 'C', 'D']
    total_features = len(feature_bidict)
    feature_vector = np.zeros(total_features, dtype=np.byte)

    with torch.no_grad():
        input_tokens = model.to_tokens(question_text, prepend_bos=True).to(model.cfg.device)
        logits, cache = model.run_with_cache(input_tokens)
        last_token_logits = logits[0, -1, :]
        choice_token_ids = [model.to_tokens(label, prepend_bos=False)[0, 0].item() for label in choice_labels]
        choice_logits = last_token_logits[choice_token_ids]
        output_logits = choice_logits.cpu().numpy()
        loss = F.cross_entropy(choice_logits.unsqueeze(0), torch.tensor([correct_answer_idx]).to(model.cfg.device)).item()
        predicted_choice_idx = torch.argmax(choice_logits).item()
        is_correct = predicted_choice_idx == correct_answer_idx

        print(f"\n--- Sample {sample_idx} ---")
        print(f"Subject: {question['subject']}")
        print(f"Question: {question['raw_question']}")
        print("Choices:")
        for i, choice in enumerate(question['choices']):
            print(f"  {choice_labels[i]}) {choice}")
        print(f"Correct Answer: {choice_labels[correct_answer_idx]} ({question['choices'][correct_answer_idx]})")
        print(f"Model's Answer: {choice_labels[predicted_choice_idx]} ({question['choices'][predicted_choice_idx]})")
        print(f"Result: {'✓ CORRECT' if is_correct else '✗ INCORRECT'}")
        print(f"Loss: {loss:.4f}")

        for layer_idx, sae in enumerate(saes):
            final_token_activations = cache[sae.cfg.hook_name][0, -1, :].unsqueeze(0)
            feature_acts = sae.encode(final_token_activations)
            active_indices = torch.where(feature_acts > 0)[1].cpu().tolist()
            for feature_idx in active_indices:
                global_feature_idx = feature_bidict.get((layer_idx, feature_idx))
                if global_feature_idx is not None:
                    feature_vector[global_feature_idx] = 1

    return feature_vector, is_correct, loss, output_logits
        
def extract_features_and_correctness(model, saes, questions, feature_bidict, output_dir=None):
    """
    Processes questions to get features and correctness, with an option to save data
    """
    all_feature_vectors = []
    all_correctness_labels = []
    all_losses = []
    all_output_logits = []
    for i, question in enumerate(questions):
        feature_vector, is_correct, loss, output_logits = process_question(
            model, saes, question, feature_bidict, sample_idx=i+1
        )
        all_feature_vectors.append(feature_vector)
        all_correctness_labels.append(is_correct)
        all_losses.append(loss)
        all_output_logits.append(output_logits)

    if output_dir:
        features_np = np.array(all_feature_vectors, dtype=np.byte)
        correctness_np = np.array(all_correctness_labels)
        losses_np = np.array(all_losses)
        output_logits_np =  np.array(all_output_logits)
        feature_path = os.path.join(output_dir, f"features.npy")
        correctness_path = os.path.join(output_dir, f"correctness.npy")
        loss_path = os.path.join(output_dir, f"losses.npy")
        output_logits_path = os.path.join(output_dir, f"output_logits.npy")
        np.save(feature_path, features_np)
        np.save(correctness_path, correctness_np)
        np.save(loss_path, losses_np)
        np.save(output_logits_path, output_logits_np)

    return np.array(all_feature_vectors, dtype=np.byte), np.array(all_correctness_labels), np.array(all_losses), np.array(all_output_logits)

def load_or_create_data(model, saes, feature_bidict, mmlu_dataset, output_dir):
    feature_path = os.path.join(output_dir, "features.npy")
    correctness_path = os.path.join(output_dir, "correctness.npy")
    loss_path = os.path.join(output_dir, "losses.npy")
    output_logits_path = os.path.join(output_dir, "output_logits.npy")
    if os.path.exists(feature_path) and os.path.exists(correctness_path) and os.path.exists(loss_path): #and os.path.exists(output_logits_path):
        print(f"Found cached data. Loading from '{output_dir}'...")
        feature_vectors = np.load(feature_path)
        correctness_labels = np.load(correctness_path)
        losses = np.load(loss_path)
        try: 
            output_logits = np.load(output_logits_path)
        except:
            print("output_logits set to None")
            output_logits = None
    else:
        print(f"No cached data found in '{output_dir}'. Generating data from scratch...")
        questions = get_all_mmlu_questions(mmlu_dataset)
        feature_vectors, correctness_labels, losses, output_logits = extract_features_and_correctness(
            model, saes, questions, feature_bidict,
            output_dir=output_dir)
    return feature_vectors, correctness_labels, losses, output_logits

class SupervisorNN(nn.Module):
    def __init__(self, activation_len):
        super(SupervisorNN, self).__init__()
        self.activation_len = activation_len
        self.fc1 = nn.Linear(activation_len, activation_len)
        self.fc2 = nn.Linear(activation_len, activation_len)
        self.fc3 = nn.Linear(activation_len, 1)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
def train_supervisor(supervisor, test_feature_vectors, test_losses, epochs=10, batch_size=256):
    X = torch.FloatTensor(test_feature_vectors)
    y = torch.FloatTensor(test_losses).unsqueeze(1)  # Add dimension for regression

    dataset = TensorDataset(X, y)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    supervisor.to(device)
    supervisor.train()
    optimizer = torch.optim.Adam(supervisor.parameters(), lr=0.001)
    criterion = nn.MSELoss()  # Use MSE loss for regression
    
    for epoch in range(epochs):
        total_loss = 0
        for batch_x, batch_y in data_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            optimizer.zero_grad()
            output = supervisor(batch_x)
            loss = criterion(output, batch_y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f'Epoch {epoch+1}/{epochs}, Average Loss: {total_loss/len(data_loader):.4f}')


# Init all variables and load SAEs for all layers
# device = "cuda" if torch.cuda.is_available() else "cpu"
# TEST_OUTPUT_DIR = "cached_test_data_losses"#"cached_test_data_correctness_losses_logits"
# os.makedirs(TEST_OUTPUT_DIR, exist_ok=True)
# model = HookedTransformer.from_pretrained("gemma-2-2b", device=device)
# saes = get_saes(device, total_layers=26)
# feature_bidict = bidict()
# global_idx = 0
# for layer_idx, sae in enumerate(saes):
#     for feature_idx in range(sae.cfg.d_sae):
#         feature_bidict[(layer_idx, feature_idx)] = global_idx
#         global_idx += 1
# total_features = len(feature_bidict)
# print(f"Created mapping for {total_features} features across {len(saes)} layers")

# Cut down to 50000 top features and train NN 
MMLU_test_split = load_dataset("cais/mmlu", "all", split='test')
test_feature_vectors, test_correctness_labels, test_losses, test_output_logits = load_or_create_data(model, saes, feature_bidict, MMLU_test_split, TEST_OUTPUT_DIR)
selector = SelectKBest(score_func=f_regression, k=15000)
test_feature_vectors = selector.fit_transform(test_feature_vectors, test_losses)
supervisor = SupervisorNN(test_feature_vectors.shape[1])
train_supervisor(supervisor, test_feature_vectors, test_losses, epochs=100)

Found cached data. Loading from 'cached_test_data_losses'...
output_logits set to None
Epoch 1/100, Average Loss: 640.7592
Epoch 2/100, Average Loss: 1.3137
Epoch 3/100, Average Loss: 1.1803
Epoch 4/100, Average Loss: 1.0212
Epoch 5/100, Average Loss: 0.8593
Epoch 6/100, Average Loss: 0.7336
Epoch 7/100, Average Loss: 0.6015
Epoch 8/100, Average Loss: 0.5211
Epoch 9/100, Average Loss: 0.5133
Epoch 10/100, Average Loss: 0.3972
Epoch 11/100, Average Loss: 0.3186
Epoch 12/100, Average Loss: 0.2774
Epoch 13/100, Average Loss: 0.2565
Epoch 14/100, Average Loss: 0.2548
Epoch 15/100, Average Loss: 0.2061
Epoch 16/100, Average Loss: 0.1709
Epoch 17/100, Average Loss: 0.1527
Epoch 18/100, Average Loss: 0.1399
Epoch 19/100, Average Loss: 0.1258
Epoch 20/100, Average Loss: 0.1187
Epoch 21/100, Average Loss: 0.1160
Epoch 22/100, Average Loss: 0.1039


KeyboardInterrupt: 

In [None]:
VALIDATION_OUTPUT_DIR = "cached_validation_data_correctness_losses_logits"
os.makedirs(VALIDATION_OUTPUT_DIR, exist_ok=True)
MMLU_validation_split = load_dataset("cais/mmlu", "all", split='validation')
validation_feature_vectors, validation_correctness_labels, validation_losses, output_logits = load_or_create_data(model, saes, feature_bidict, MMLU_validation_split, VALIDATION_OUTPUT_DIR)
validation_feature_vectors = selector.transform(validation_feature_vectors)

supervisor.eval()
with torch.no_grad():
    validation_features_tensor = torch.FloatTensor(validation_feature_vectors).to(device)
    validation_loss_predictions = supervisor(validation_features_tensor).cpu().numpy().flatten()

low_loss_threshold = 0.
confident_mask = validation_loss_predictions < low_loss_threshold
confident_correctness_labels = validation_correctness_labels[confident_mask]
confident_accuracy = np.mean(confident_correctness_labels)

print(f"Coverage of confident predictions: {np.mean(confident_mask):.3f} ({np.sum(confident_mask)}/{len(validation_feature_vectors)} validation samples)")
print(f"Accuracy on confident predictions: {np.mean(confident_correctness_labels):.4f}")
print(f"Model's true validation accuracy: {np.sum(validation_correctness_labels) / len(validation_correctness_labels):.4f}")

Using the latest cached version of the dataset since cais/mmlu couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'all' at /Users/noahschwartz/.cache/huggingface/datasets/cais___mmlu/all/0.0.0/c30699e8356da336a370243923dbaf21066bb9fe (last modified on Mon Jul 14 09:52:53 2025).


Found cached data. Loading from 'cached_validation_data_correctness_losses_logits'...
Coverage of confident predictions: 0.036 (55/1531 validation samples)
Accuracy on confident predictions: 0.8182
Model's true validation accuracy: 0.4161


In [38]:
low_loss_threshold = 0.0000001
confident_mask = validation_loss_predictions < low_loss_threshold
confident_correctness_labels = validation_correctness_labels[confident_mask]
confident_accuracy = np.mean(confident_correctness_labels)

print(f"Coverage of confident predictions: {np.mean(confident_mask):.3f} ({np.sum(confident_mask)}/{len(validation_feature_vectors)} validation samples)")
print(f"Accuracy on confident predictions: {np.mean(confident_correctness_labels):.4f}")
print(f"Model's true validation accuracy: {np.sum(validation_correctness_labels) / len(validation_correctness_labels):.4f}")

Coverage of confident predictions: 0.005 (7/1531 validation samples)
Accuracy on confident predictions: 0.8571
Model's true validation accuracy: 0.4161


In [24]:
def softmax_max_probability(logits):
    probs = F.softmax(torch.tensor(logits), dim=-1)
    return torch.max(probs).item()

VALIDATION_OUTPUT_DIR = "cached_validation_data_correctness_losses_logits"
os.makedirs(VALIDATION_OUTPUT_DIR, exist_ok=True)
MMLU_validation_split = load_dataset("cais/mmlu", "all", split='validation')
validation_feature_vectors, validation_correctness_labels, validation_losses, output_logits = load_or_create_data(model, saes, feature_bidict, MMLU_validation_split, VALIDATION_OUTPUT_DIR)
confident_validation_outputs_mask = np.array([softmax_max_probability(output_logit) > 0. for output_logit in output_logits])
confident_outputs_validation_labels = validation_correctness_labels[confident_validation_outputs_mask]
validation_outputs_coverage = np.mean(confident_validation_outputs_mask)
print(f"Baseline coverage of confident NAPs: {validation_outputs_coverage:.3f}")
print(f"Baseline Accuracy on confident predictions (how often model gets answer right): {np.sum(confident_outputs_validation_labels) / len(confident_outputs_validation_labels):.4f}")
print(f"Baseline Model's true test accuracy (how often model gets answer right): {np.sum(validation_correctness_labels) / len(validation_correctness_labels):.4f}")

Found cached data. Loading from 'cached_validation_data_correctness_losses_logits'...
Baseline coverage of confident NAPs: 1.000
Baseline Accuracy on confident predictions (how often model gets answer right): 0.4161
Baseline Model's true test accuracy (how often model gets answer right): 0.4161
