In [35]:
import torch
import torch.nn.functional as F
from sae_lens import SAE
from transformer_lens import HookedTransformer
from bidict import bidict
import numpy as np
import os
from datasets import load_dataset
import tqdm 

def get_saes(device, total_layers=26):
    """Loads a series of Sparse Autoencoders (SAEs) for specified layers"""
    saes = []
    print(f"Loading {total_layers} SAEs...")
    for layer in range(total_layers):
        sae, _, _ = SAE.from_pretrained(
            release="gemma-scope-2b-pt-res-canonical",
            sae_id=f"layer_{layer}/width_16k/canonical",
            device=device
        )
        saes.append(sae)
        print(f"Layer {layer} loaded.")
    return saes

def get_all_mmlu_questions(dataset):
    """Returns all questions from the MMLU dataset"""
    all_questions = []
    for i in range(len(dataset)):
        data_point = dataset[i]
        question = data_point['question']
        choices = data_point['choices']
        instruction = "The following is a multiple choice question. Output only a single token corresponding to the right answer (ie: A) \n"
        formatted_question = instruction + f" Question: {question}\n"
        choice_labels = ['A', 'B', 'C', 'D']
        for j, choice in enumerate(choices):
            formatted_question += f"{choice_labels[j]}) {choice}\n"
        formatted_question += "Answer: "
        all_questions.append({
            'text': formatted_question,
            'subject': data_point['subject'],
            'answer': data_point['answer'],
            'choices': data_point['choices'],
            'raw_question': question
        })
    return all_questions
        
def process_question(model, saes, question, feature_bidict, sample_idx=None):
    """
    Performs a forward pass, extracts features and loss for a single MMLU question
    """
    question_text = question['text']
    correct_answer_idx = question['answer']
    choice_labels = ['A', 'B', 'C', 'D']
    total_features = len(feature_bidict)
    feature_vector = np.zeros(total_features, dtype=np.byte)

    with torch.no_grad():
        input_tokens = model.to_tokens(question_text, prepend_bos=True).to(model.cfg.device)
        logits, cache = model.run_with_cache(input_tokens)
        last_token_logits = logits[0, -1, :]
        choice_token_ids = [model.to_tokens(label, prepend_bos=False)[0, 0].item() for label in choice_labels]
        choice_logits = last_token_logits[choice_token_ids]
        output_logits = choice_logits.cpu().numpy() 
        loss = F.cross_entropy(choice_logits.unsqueeze(0), torch.tensor([correct_answer_idx]).to(model.cfg.device)).item()
        predicted_choice_idx = torch.argmax(choice_logits).item()
        is_correct = predicted_choice_idx == correct_answer_idx

        print(f"\n--- Sample {sample_idx} ---")
        print(f"Subject: {question['subject']}")
        print(f"Question: {question['raw_question']}")
        print("Choices:")
        for i, choice in enumerate(question['choices']):
            print(f"  {choice_labels[i]}) {choice}")
        print(f"Correct Answer: {choice_labels[correct_answer_idx]} ({question['choices'][correct_answer_idx]})")
        print(f"Model's Answer: {choice_labels[predicted_choice_idx]} ({question['choices'][predicted_choice_idx]})")
        print(f"Result: {'✓ CORRECT' if is_correct else '✗ INCORRECT'}")
        print(f"Loss: {loss:.4f}")

        for layer_idx, sae in enumerate(saes):
            final_token_activations = cache[sae.cfg.hook_name][0, -1, :].unsqueeze(0)
            feature_acts = sae.encode(final_token_activations)
            active_indices = torch.where(feature_acts > 0)[1].cpu().tolist()
            for feature_idx in active_indices:
                global_feature_idx = feature_bidict.get((layer_idx, feature_idx))
                if global_feature_idx is not None:
                    feature_vector[global_feature_idx] = 1

    return feature_vector, is_correct, output_logits
        
def extract_features_and_correctness(model, saes, questions, feature_bidict, output_dir=None):
    """
    Processes questions to get features and correctness, with an option to save data
    """
    all_feature_vectors = []
    all_correctness_labels = []
    all_output_logits = []
    for i, question in enumerate(questions):
        feature_vector, is_correct, output_logits = process_question(
            model, saes, question, feature_bidict, sample_idx=i+1
        )
        all_feature_vectors.append(feature_vector)
        all_correctness_labels.append(is_correct)
        all_output_logits.append(output_logits)

    if output_dir:
        features_np = np.array(all_feature_vectors, dtype=np.byte)
        correctness_np = np.array(all_correctness_labels)
        feature_path = os.path.join(output_dir, f"features.npy")
        correctness_path = os.path.join(output_dir, f"correctness.npy")
        np.save(feature_path, features_np)
        np.save(correctness_path, correctness_np)

    return np.array(all_feature_vectors, dtype=np.byte), np.array(all_correctness_labels), np.array(all_output_logits, dtype=np.byte)

def load_or_create_data(model, saes, feature_bidict, mmlu_dataset, output_dir):
    feature_path = os.path.join(output_dir, "features.npy")
    correctness_path = os.path.join(output_dir, "correctness.npy")
    output_logits_path = os.path.join(output_dir, "output_logits.npy")
    if os.path.exists(feature_path) and os.path.exists(correctness_path) and os.path.exists(output_logits_path):
        print(f"Found cached data. Loading from '{output_dir}'...")
        feature_vectors = np.load(feature_path)
        correctness_labels = np.load(correctness_path)
        output_logits = np.load(output_logits_path)
    else:
        print(f"No cached data found in '{output_dir}'. Generating data from scratch...")
        questions = get_all_mmlu_questions(mmlu_dataset)
        feature_vectors, correctness_labels, output_logits = extract_features_and_correctness(
            model, saes, questions, feature_bidict,
            output_dir=output_dir)
    return feature_vectors, correctness_labels, output_logits

def compute_input_nap(activations):
    return (activations > 0).int()

def compute_baseline_nap(model, test_feature_vectors, test_class, delta):
    all_class_naps = []
    model.eval()
    with torch.no_grad():
        for test_feature_vector in test_feature_vectors:
            # Convert numpy array to torch tensor
            test_feature_tensor = torch.from_numpy(test_feature_vector).to(device)
            binary_naps = compute_input_nap(test_feature_tensor)
            # Add batch dimension to maintain 2D structure
            all_class_naps.append(binary_naps.unsqueeze(0).cpu())
            
    all_class_naps = torch.cat(all_class_naps, dim=0)
    num_samples, num_neurons = all_class_naps.shape
    
    # Apply statistical abstraction (A_tilde) as per equation (6)
    activation_counts = torch.sum(all_class_naps, dim=0)
    deactivation_counts = num_samples - activation_counts
    
    baseline_nap = torch.full((num_neurons,), 2, dtype=torch.int)  # 2 represents '*'
    
    # Apply delta threshold
    for i in range(num_neurons):
        activation_freq = activation_counts[i].item() / num_samples
        deactivation_freq = deactivation_counts[i].item() / num_samples
        
        if activation_freq >= delta:
            baseline_nap[i] = 1  # State '1'
        # elif deactivation_freq >= delta:
        #     baseline_nap[i] = 2  # State '0'
        # else:
        #     baseline_nap[i] = 
        # Otherwise remains '*' (state 2)
            
    num_binary_neurons = (baseline_nap != 2).sum().item()
    print(f"Baseline NAP computed. Size (number of binary neurons): {num_binary_neurons}/{num_neurons}")
    return baseline_nap.to(device)

# Init all variables and load SAEs for all layers
device = "cuda" if torch.cuda.is_available() else "cpu"
TEST_OUTPUT_DIR = "cached_test_data_correctness_losses_logits_2"
os.makedirs(TEST_OUTPUT_DIR, exist_ok=True)
model = HookedTransformer.from_pretrained("gemma-2-2b", device=device)
saes = get_saes(device, total_layers=26)
feature_bidict = bidict()
global_idx = 0
for layer_idx, sae in enumerate(saes):
    for feature_idx in range(sae.cfg.d_sae):
        feature_bidict[(layer_idx, feature_idx)] = global_idx
        global_idx += 1
total_features = len(feature_bidict)
print(f"Created mapping for {total_features} features across {len(saes)} layers")

# # Build tree
DELTA = 0.9
naps = []
MMLU_test_split = load_dataset("cais/mmlu", "all", split='test')
test_feature_vectors, test_correctness_labels, _ = load_or_create_data(model, saes, feature_bidict, MMLU_test_split, TEST_OUTPUT_DIR)
test_class_categories = [q['subject'] for q in get_all_mmlu_questions(MMLU_test_split)]
test_features_by_class = {}
for feature_vector, correctness, subject in zip(test_feature_vectors, test_correctness_labels, test_class_categories):
    if subject not in test_features_by_class:
        test_features_by_class[subject] = []
    if correctness:  # Only include correctly answered questions
        test_features_by_class[subject].append(feature_vector)

for test_class, correct_test_feature_vectors in test_features_by_class.items():
    print(f"\n--- Processing Class '{test_class}' ---")
    nap = compute_baseline_nap(model, correct_test_feature_vectors, test_class, delta=DELTA)
    print(nap)
    naps.append(nap)

Found cached data. Loading from 'cached_test_data_correctness_losses_logits_2'...

--- Processing Class 'abstract_algebra' ---
Baseline NAP computed. Size (number of binary neurons): 378/425984
tensor([2, 2, 2,  ..., 2, 2, 2], dtype=torch.int32)

--- Processing Class 'anatomy' ---
Baseline NAP computed. Size (number of binary neurons): 345/425984
tensor([2, 2, 2,  ..., 2, 2, 2], dtype=torch.int32)

--- Processing Class 'astronomy' ---
Baseline NAP computed. Size (number of binary neurons): 326/425984
tensor([2, 2, 2,  ..., 2, 2, 2], dtype=torch.int32)

--- Processing Class 'business_ethics' ---
Baseline NAP computed. Size (number of binary neurons): 248/425984
tensor([2, 2, 2,  ..., 2, 2, 2], dtype=torch.int32)

--- Processing Class 'clinical_knowledge' ---
Baseline NAP computed. Size (number of binary neurons): 284/425984
tensor([2, 2, 2,  ..., 2, 2, 2], dtype=torch.int32)

--- Processing Class 'college_biology' ---
Baseline NAP computed. Size (number of binary neurons): 332/425984
te

In [36]:
# Create NAP mapping by class name
naps_by_class = {}
for (test_class, _), nap in zip(test_features_by_class.items(), naps):
    naps_by_class[test_class] = nap

VALIDATION_OUTPUT_DIR = "cached_validation_data_correctness_losses_logits"
os.makedirs(VALIDATION_OUTPUT_DIR, exist_ok=True)
MMLU_validation_split = load_dataset("cais/mmlu", "all", split='validation')
validation_feature_vectors, validation_correctness_labels, _ = load_or_create_data(model, saes, feature_bidict, MMLU_validation_split, VALIDATION_OUTPUT_DIR)

# Get validation class categories (not test!)
validation_class_categories = [q['subject'] for q in get_all_mmlu_questions(MMLU_validation_split)]
results = {class_name: [0, 0, 0, 0] for class_name in set(validation_class_categories)}

model.eval()
with torch.no_grad():
    validation_features_by_class = {}
    for feature_vector, correctness, subject in zip(validation_feature_vectors, validation_correctness_labels, validation_class_categories):
        if subject not in validation_features_by_class:
            validation_features_by_class[subject] = []
        validation_features_by_class[subject].append([feature_vector, correctness])

    for validation_class, validation_data in validation_features_by_class.items():
        if validation_class not in naps_by_class:
            print(f"Warning: No NAP found for class {validation_class}, skipping...")
            continue
            
        validation_feature_vectors_class, correctness_labels = zip(*validation_data)
        print(f"\n--- Processing Class '{validation_class}' ---")
        print(f"Number of validation samples: {len(validation_feature_vectors_class)}")
        
        input_naps = compute_input_nap(torch.tensor(validation_feature_vectors_class))
        class_nap = naps_by_class[validation_class]
        
        for input_nap, is_correct in zip(input_naps, correctness_labels):
            is_covered = nap_subsumes(class_nap, input_nap)
            if is_correct:
                results[validation_class][0] += 1  # correct_total
                if is_covered:
                    results[validation_class][1] += 1  # correct_covered
            else:
                results[validation_class][2] += 1  # incorrect_total
                if is_covered:
                    results[validation_class][3] += 1  # incorrect_covered

print("\n--- Evaluation Results ---")
for class_name, data in results.items():
    correct_total, correct_covered, incorrect_total, incorrect_covered = data
    
    # Coverage over correct examples
    coverage_correct = (correct_covered / correct_total) * 100 if correct_total > 0 else 0
    # Coverage over incorrect examples
    coverage_incorrect = (incorrect_covered / incorrect_total) * 100 if incorrect_total > 0 else 0
    
    print(f"\n--- Class '{class_name}' ---")
    print(f"Coverage over correctly classified examples: {coverage_correct:.2f}% ({correct_covered}/{correct_total})")
    print(f"Coverage over misclassified examples: {coverage_incorrect:.2f}% ({incorrect_covered}/{incorrect_total})")

Found cached data. Loading from 'cached_validation_data_correctness_losses_logits'...

--- Processing Class 'abstract_algebra' ---
Number of validation samples: 11

--- Processing Class 'anatomy' ---
Number of validation samples: 14

--- Processing Class 'astronomy' ---
Number of validation samples: 16

--- Processing Class 'business_ethics' ---
Number of validation samples: 11

--- Processing Class 'clinical_knowledge' ---
Number of validation samples: 29

--- Processing Class 'college_biology' ---
Number of validation samples: 16

--- Processing Class 'college_chemistry' ---
Number of validation samples: 8

--- Processing Class 'college_computer_science' ---
Number of validation samples: 11

--- Processing Class 'college_mathematics' ---
Number of validation samples: 11

--- Processing Class 'college_medicine' ---
Number of validation samples: 22

--- Processing Class 'college_physics' ---
Number of validation samples: 11

--- Processing Class 'computer_security' ---
Number of valida

In [20]:
# def nap_subsumes(class_nap, input_nap):
#     return torch.all((class_nap == 2) | (class_nap == input_nap))

# def compute_input_nap(activations):
#     return (activations > 0).int()

# #TODO: fix the issue of the weird output below 
# VALIDATION_OUTPUT_DIR = "cached_validation_data_correctness_losses_logits"
# os.makedirs(VALIDATION_OUTPUT_DIR, exist_ok=True)
# MMLU_validation_split = load_dataset("cais/mmlu", "all", split='validation')
# validation_feature_vectors, validation_correctness_labels, _ = load_or_create_data(model, saes, feature_bidict, MMLU_validation_split, VALIDATION_OUTPUT_DIR)
# input_naps = compute_input_nap(torch.from_numpy(validation_feature_vectors).to(device))
# validation_predictions = nap_subsumes(nap, input_naps).cpu().numpy()
# confidence_threshold = 0.85
# confident_mask = (validation_predictions > confidence_threshold) | (validation_predictions < (1 - confidence_threshold))
# confident_validation_labels = validation_correctness_labels[confident_mask]
# confident_predictions = (validation_predictions[confident_mask] > 0.5).astype(int)
# confident_accuracy = np.mean(confident_predictions == confident_validation_labels)

# print(f"Coverage of confident predictions: {np.mean(confident_mask):.3f} ({np.sum(confident_mask)}/{len(validation_feature_vectors)} validation samples)")
# print(f"Accuracy on confident predictions: {np.mean(confident_validation_labels):.4f}")
# print(f"Model's true validation accuracy: {np.sum(validation_correctness_labels) / len(validation_correctness_labels):.4f}")

def nap_subsumes(class_nap, input_nap):
    return torch.all((class_nap == 2) | (class_nap == input_nap))


VALIDATION_OUTPUT_DIR = "cached_validation_data_correctness_losses_logits"
os.makedirs(VALIDATION_OUTPUT_DIR, exist_ok=True)
MMLU_validation_split = load_dataset("cais/mmlu", "all", split='validation')
validation_feature_vectors, validation_correctness_labels, _ = load_or_create_data(model, saes, feature_bidict, MMLU_validation_split, VALIDATION_OUTPUT_DIR)
results = {i: [0, 0, 0, 0] for i in range(len(test_class_categories))}

model.eval()
with torch.no_grad():
    validation_features_by_class = {}
    for feature_vector, correctness, subject in zip(validation_feature_vectors, validation_correctness_labels, test_class_categories):
        if subject not in validation_features_by_class:
            validation_features_by_class[subject] = []
        validation_features_by_class[subject].append([feature_vector, correctness])

    i = 0
    for validation_class, validation_feature_vectors in validation_features_by_class.items():
        validation_feature_vectors, is_correct = zip(*validation_feature_vectors)
        print(len(validation_feature_vectors))
        print(f"\n--- Processing Class '{validation_class}' ---")
        input_naps = compute_input_nap(torch.tensor(validation_feature_vectors))
        class_nap = naps[i]
        for input_nap in input_naps:
            is_covered = nap_subsumes(class_nap, input_nap)
            if is_correct:
                results[i][0] += 1
                if is_covered:
                    results[i][1] += 1
            else:
                results[i][2] += 1
                if is_covered:
                    results[i][3] += 1
        i += 1
    

print("\n--- Evaluation Results ---")
for class_id, data in results.items():
    correct_total, correct_covered, incorrect_total, incorrect_covered = data
    
    # Coverage over correct examples
    coverage_correct = (correct_covered / correct_total) * 100 if correct_total > 0 else 0
    # Coverage over incorrect examples
    coverage_incorrect = (incorrect_covered / incorrect_total) * 100 if incorrect_total > 0 else 0
    
    print(f"\n--- Class {class_id} ---")
    print(f"Coverage over correctly classified examples: {coverage_correct:.2f}% ({correct_covered}/{correct_total})")
    print(f"Coverage over misclassified examples: {coverage_incorrect:.2f}% ({incorrect_covered}/{incorrect_total})")

print("\n--------------------------")

Found cached data. Loading from 'cached_validation_data_correctness_losses_logits'...
100

--- Processing Class 'abstract_algebra' ---
135

--- Processing Class 'anatomy' ---
152

--- Processing Class 'astronomy' ---
100

--- Processing Class 'business_ethics' ---
265

--- Processing Class 'clinical_knowledge' ---
144

--- Processing Class 'college_biology' ---
100

--- Processing Class 'college_chemistry' ---
100

--- Processing Class 'college_computer_science' ---
100

--- Processing Class 'college_mathematics' ---
173

--- Processing Class 'college_medicine' ---
102

--- Processing Class 'college_physics' ---
60

--- Processing Class 'computer_security' ---

--- Evaluation Results ---

--- Class 0 ---
Coverage over correctly classified examples: 0.00% (0/100)
Coverage over misclassified examples: 0.00% (0/0)

--- Class 1 ---
Coverage over correctly classified examples: 0.00% (0/135)
Coverage over misclassified examples: 0.00% (0/0)

--- Class 2 ---
Coverage over correctly classified