In [7]:
import torch
import torch.nn.functional as F
from sae_lens import SAE
from transformer_lens import HookedTransformer
from bidict import bidict
import numpy as np
import time
import os
from datasets import load_dataset
import lightgbm as lgb
from sklearn.model_selection import train_test_split

def get_saes(device, total_layers=26):
    """Loads a series of Sparse Autoencoders (SAEs) for specified layers"""
    saes = []
    print(f"Loading {total_layers} SAEs...")
    for layer in range(total_layers):
        sae, _, _ = SAE.from_pretrained(
            release="gemma-scope-2b-pt-res-canonical",
            sae_id=f"layer_{layer}/width_16k/canonical",
            device=device
        )
        saes.append(sae)
        print(f"Layer {layer} loaded.")
    return saes

def get_all_mmlu_questions(dataset):
    """Returns all questions from the MMLU dataset"""
    all_questions = []
    for i in range(len(dataset)):
        data_point = dataset[i]
        question = data_point['question']
        choices = data_point['choices']
        instruction = "The following is a multiple choice question. Output only a single token corresponding to the right answer (ie: A) \n"
        formatted_question = instruction + f" Question: {question}\n"
        choice_labels = ['A', 'B', 'C', 'D']
        for j, choice in enumerate(choices):
            formatted_question += f"{choice_labels[j]}) {choice}\n"
        formatted_question += "Answer: "
        all_questions.append({
            'text': formatted_question,
            'subject': data_point['subject'],
            'answer': data_point['answer'],
            'choices': data_point['choices'],
            'raw_question': question
        })
    return all_questions
        
def process_question(model, saes, question, feature_bidict, sample_idx=None):
    """
    Performs a forward pass, extracts features and loss for a single MMLU question
    """
    question_text = question['text']
    correct_answer_idx = question['answer']
    choice_labels = ['A', 'B', 'C', 'D']
    total_features = len(feature_bidict)
    feature_vector = np.zeros(total_features, dtype=np.byte)

    with torch.no_grad():
        input_tokens = model.to_tokens(question_text, prepend_bos=True).to(model.cfg.device)
        logits, cache = model.run_with_cache(input_tokens)
        all_layer_activations = []
        for layer_idx, sae in enumerate(saes):
            raw_activations = cache[sae.cfg.hook_name][0, -1, :]
            all_layer_activations.append(raw_activations.cpu().numpy())
        neuron_activations = np.concatenate(all_layer_activations)
        last_token_logits = logits[0, -1, :]
        choice_token_ids = [model.to_tokens(label, prepend_bos=False)[0, 0].item() for label in choice_labels]
        choice_logits = last_token_logits[choice_token_ids]
        output_logits = choice_logits.cpu().numpy()
        loss = F.cross_entropy(choice_logits.unsqueeze(0), torch.tensor([correct_answer_idx]).to(model.cfg.device)).item()
        predicted_choice_idx = torch.argmax(choice_logits).item()
        is_correct = predicted_choice_idx == correct_answer_idx

        print(f"\n--- Sample {sample_idx} ---")
        print(f"Subject: {question['subject']}")
        print(f"Question: {question['raw_question']}")
        print("Choices:")
        for i, choice in enumerate(question['choices']):
            print(f"  {choice_labels[i]}) {choice}")
        print(f"Correct Answer: {choice_labels[correct_answer_idx]} ({question['choices'][correct_answer_idx]})")
        print(f"Model's Answer: {choice_labels[predicted_choice_idx]} ({question['choices'][predicted_choice_idx]})")
        print(f"Result: {'✓ CORRECT' if is_correct else '✗ INCORRECT'}")
        print(f"Loss: {loss:.4f}")

        for layer_idx, sae in enumerate(saes):
            final_token_activations = cache[sae.cfg.hook_name][0, -1, :].unsqueeze(0)
            feature_acts = sae.encode(final_token_activations)
            active_indices = torch.where(feature_acts > 0)[1].cpu().tolist()
            for feature_idx in active_indices:
                global_feature_idx = feature_bidict.get((layer_idx, feature_idx))
                if global_feature_idx is not None:
                    feature_vector[global_feature_idx] = 1

    return neuron_activations, feature_vector, is_correct, loss, output_logits
        
def extract_features_and_correctness(model, saes, questions, feature_bidict, output_dir=None):
    """
    Processes questions to get features and correctness, with an option to save data
    """
    all_neuron_activations = []
    all_feature_vectors = []
    all_correctness_labels = []
    all_losses = []
    all_output_logits = []
    for i, question in enumerate(questions):
        neuron_activations, feature_vector, is_correct, loss, output_logits = process_question(
            model, saes, question, feature_bidict, sample_idx=i+1
        )
        all_neuron_activations.append(neuron_activations)
        all_feature_vectors.append(feature_vector)
        all_correctness_labels.append(is_correct)
        all_losses.append(loss)
        all_output_logits.append(output_logits)

    if output_dir:
        neuron_activations_np = np.array(all_neuron_activations, dtype=np.byte)
        features_np = np.array(all_feature_vectors, dtype=np.byte)
        correctness_np = np.array(all_correctness_labels)
        losses_np = np.array(all_losses)
        output_logits_np =  np.array(all_output_logits)
        neuron_activations_path =  os.path.join(output_dir, f"neuron_activations.npy")
        feature_path = os.path.join(output_dir, f"features.npy")
        correctness_path = os.path.join(output_dir, f"correctness.npy")
        loss_path = os.path.join(output_dir, f"losses.npy")
        output_logits_path = os.path.join(output_dir, f"output_logits.npy")
        np.save(neuron_activations_path, neuron_activations_np)
        np.save(feature_path, features_np)
        np.save(correctness_path, correctness_np)
        np.save(loss_path, losses_np)
        np.save(output_logits_path, output_logits_np)

    return np.array(all_neuron_activations, dtype=np.byte), np.array(all_feature_vectors, dtype=np.byte), np.array(all_correctness_labels), np.array(all_losses), np.array(all_output_logits)

def load_or_create_data(model, saes, feature_bidict, mmlu_dataset, output_dir):
    feature_path = os.path.join(output_dir, "features.npy")
    correctness_path = os.path.join(output_dir, "correctness.npy")
    loss_path = os.path.join(output_dir, "losses.npy")
    output_logits_path = os.path.join(output_dir, "output_logits.npy")
    neurons_activations_path = os.path.join(output_dir, "neuron_activations.npy")
    if os.path.exists(feature_path) and os.path.exists(correctness_path) and os.path.exists(loss_path): #and os.path.exists(output_logits_path):
        print(f"Found cached data. Loading from '{output_dir}'...")
        feature_vectors = np.load(feature_path)
        correctness_labels = np.load(correctness_path)
        losses = np.load(loss_path)
        try: 
            output_logits = np.load(output_logits_path)
            neurons_activations =  np.load(neurons_activations_path)
        except:
            print("output_logits/neuron_activations set to None")
            output_logits = None
            neurons_activations = None
    else:
        print(f"No cached data found in '{output_dir}'. Generating data from scratch...")
        questions = get_all_mmlu_questions(mmlu_dataset)
        neurons_activations, feature_vectors, correctness_labels, losses, output_logits = extract_features_and_correctness(
            model, saes, questions, feature_bidict,
            output_dir=output_dir)
    return neurons_activations, feature_vectors, correctness_labels, losses, output_logits


# Init all variables and load SAEs for all layers
# device = "cuda" if torch.cuda.is_available() else "cpu"
# TEST_OUTPUT_DIR = "cached_test_data_losses"#"cached_test_data_correctness_losses_logits"
# os.makedirs(TEST_OUTPUT_DIR, exist_ok=True)
# model = HookedTransformer.from_pretrained("gemma-2-2b", device=device)
# saes = get_saes(device, total_layers=26)
# feature_bidict = bidict()
# global_idx = 0
# for layer_idx, sae in enumerate(saes):
#     for feature_idx in range(sae.cfg.d_sae):
#         feature_bidict[(layer_idx, feature_idx)] = global_idx
#         global_idx += 1
# total_features = len(feature_bidict)
# print(f"Created mapping for {total_features} features across {len(saes)} layers")
# TEST_OUTPUT_DIR = "cached_test_data_losses"

# Build tree
MMLU_test_split = load_dataset("cais/mmlu", "all", split='test')
test_feature_vectors, test_correctness_labels, test_losses, test_output_logits = load_or_create_data(model, saes, feature_bidict, MMLU_test_split, TEST_OUTPUT_DIR)
X_test, X_stop_test, y_test, y_stop_test = train_test_split(
    test_feature_vectors, 
    test_losses, 
    test_size=0.1,
    random_state=42
)
test_data = lgb.Dataset(X_test, label=y_test)
stop_test_data = lgb.Dataset(X_stop_test, label=y_stop_test, reference=test_data)
print(f"\nExtracted {len(test_feature_vectors)} samples for tree construction. Building NAP tree...")
start_time = time.time()
params = {
    'objective': 'regression',
    'boosting_type': 'gbdt',
    'metric': 'rmse',
    'max_depth': -1,
    'num_leaves': 2048,
    'min_data_in_leaf': 3,
    'min_sum_hessian_in_leaf': 1e-4,
    'feature_fraction': 1.0,
    'feature_fraction_bynode': 1.0,
    'bagging_fraction': 1.0,
    'learning_rate': 0.01,
    'lambda_l1': 0.05,
    'lambda_l2': 0.05,
    'num_threads': -1,
    'force_row_wise': True,
    'verbosity': 1,
    'seed': 42,
    'deterministic': True
}
model_lgb = lgb.train(
    params, 
    test_data, 
    num_boost_round=2000,
    valid_sets=[stop_test_data],
    valid_names=['stop_test_data'],
    callbacks=[
        lgb.early_stopping(stopping_rounds=100),
        lgb.log_evaluation(period=200)
    ]
)
print(f"LightGBM model built in {time.time() - start_time:.2f} seconds")

def compute_sae_reconstruction_loss(neuron_activations, saes, device):
    """
    Compute reconstruction loss by passing neuron activations through SAEs
    
    Args:
        neuron_activations: numpy array of shape (n_samples, total_neurons)
        saes: list of SAE models for each layer
        device: torch device
    
    Returns:
        numpy array of reconstruction losses for each sample
    """
    n_samples = neuron_activations.shape[0]
    reconstruction_losses = np.zeros(n_samples)
    
    # Calculate neurons per layer (assuming equal distribution)
    neurons_per_layer = neuron_activations.shape[1] // len(saes)
    
    with torch.no_grad():
        for sample_idx in range(n_samples):
            sample_activations = neuron_activations[sample_idx]
            total_loss = 0.0
            
            # Process each layer's activations through its corresponding SAE
            for layer_idx, sae in enumerate(saes):
                start_idx = layer_idx * neurons_per_layer
                end_idx = (layer_idx + 1) * neurons_per_layer
                layer_activations = sample_activations[start_idx:end_idx]
                
                # Convert to tensor and add batch dimension
                layer_tensor = torch.tensor(layer_activations, dtype=torch.float32, device=device).unsqueeze(0)
                
                # Pass through SAE encoder-decoder
                encoded = sae.encode(layer_tensor)
                reconstructed = sae.decode(encoded)
                
                # Calculate L2 reconstruction loss for this layer
                layer_loss = F.mse_loss(reconstructed, layer_tensor).item()
                total_loss += layer_loss
            
            reconstruction_losses[sample_idx] = total_loss
    
    return reconstruction_losses


# Baseline: test using SAE loss
VALIDATION_OUTPUT_DIR = "cached_validation_data_activations_correctness_losses_logits"
os.makedirs(VALIDATION_OUTPUT_DIR, exist_ok=True)
MMLU_validation_split = load_dataset("cais/mmlu", "all", split='validation')
validation_neuron_activations, validation_feature_vectors, validation_correctness_labels, validation_losses, output_logits = load_or_create_data(model, saes, feature_bidict, MMLU_validation_split, VALIDATION_OUTPUT_DIR)
validation_loss_predictions = compute_sae_reconstruction_loss(
    validation_neuron_activations, saes, device
)
low_loss_threshold = 0.651
confident_mask = validation_loss_predictions < low_loss_threshold
confident_correctness_labels = validation_correctness_labels[confident_mask]
confident_accuracy = np.mean(confident_correctness_labels)

print(f"Coverage of confident predictions: {np.mean(confident_mask):.3f} ({np.sum(confident_mask)}/{len(validation_feature_vectors)} validation samples)")
print(f"Accuracy on confident predictions: {np.mean(confident_correctness_labels):.4f}")
print(f"Model's true validation accuracy: {np.sum(validation_correctness_labels) / len(validation_correctness_labels):.4f}")

Found cached data. Loading from 'cached_test_data_losses'...


KeyboardInterrupt: 

In [9]:
def compute_sae_reconstruction_loss(neuron_activations, saes, device):
    """
    Compute reconstruction loss by passing neuron activations through SAEs
    
    Args:
        neuron_activations: numpy array of shape (n_samples, total_neurons)
        saes: list of SAE models for each layer
        device: torch device
    
    Returns:
        numpy array of reconstruction losses for each sample
    """
    n_samples = neuron_activations.shape[0]
    reconstruction_losses = np.zeros(n_samples)
    
    # Calculate neurons per layer (assuming equal distribution)
    neurons_per_layer = neuron_activations.shape[1] // len(saes)
    
    with torch.no_grad():
        for sample_idx in range(n_samples):
            print(sample_idx)
            sample_activations = neuron_activations[sample_idx]
            total_loss = 0.0
            
            # Process each layer's activations through its corresponding SAE
            for layer_idx, sae in enumerate(saes):
                start_idx = layer_idx * neurons_per_layer
                end_idx = (layer_idx + 1) * neurons_per_layer
                layer_activations = sample_activations[start_idx:end_idx]
                
                # Convert to tensor and add batch dimension
                layer_tensor = torch.tensor(layer_activations, dtype=torch.float32, device=device).unsqueeze(0)
                
                # Pass through SAE encoder-decoder
                encoded = sae.encode(layer_tensor)
                reconstructed = sae.decode(encoded)
                
                # Calculate L2 reconstruction loss for this layer
                layer_loss = F.mse_loss(reconstructed, layer_tensor).item()
                total_loss += layer_loss
            
            reconstruction_losses[sample_idx] = total_loss
    
    return reconstruction_losses

VALIDATION_OUTPUT_DIR = "cached_validation_data_activations_correctness_losses_logits"
os.makedirs(VALIDATION_OUTPUT_DIR, exist_ok=True)
MMLU_validation_split = load_dataset("cais/mmlu", "all", split='validation')
validation_neuron_activations, validation_feature_vectors, validation_correctness_labels, validation_losses, output_logits = load_or_create_data(model, saes, feature_bidict, MMLU_validation_split, VALIDATION_OUTPUT_DIR)
validation_loss_predictions = compute_sae_reconstruction_loss(
    validation_neuron_activations, saes, device
)
low_loss_threshold = 0.651
confident_mask = validation_loss_predictions < low_loss_threshold
confident_correctness_labels = validation_correctness_labels[confident_mask]
confident_accuracy = np.mean(confident_correctness_labels)

print(f"Coverage of confident predictions: {np.mean(confident_mask):.3f} ({np.sum(confident_mask)}/{len(validation_feature_vectors)} validation samples)")
print(f"Accuracy on confident predictions: {np.mean(confident_correctness_labels):.4f}")
print(f"Model's true validation accuracy: {np.sum(validation_correctness_labels) / len(validation_correctness_labels):.4f}")

Found cached data. Loading from 'cached_validation_data_activations_correctness_losses_logits'...
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


In [17]:
low_loss_threshold = 90
confident_mask = validation_loss_predictions < low_loss_threshold
confident_correctness_labels = validation_correctness_labels[confident_mask]
confident_accuracy = np.mean(confident_correctness_labels)

print(f"Coverage of confident predictions: {np.mean(confident_mask):.3f} ({np.sum(confident_mask)}/{len(validation_feature_vectors)} validation samples)")
print(f"Accuracy on confident predictions: {np.mean(confident_correctness_labels):.4f}")
print(f"Model's true validation accuracy: {np.sum(validation_correctness_labels) / len(validation_correctness_labels):.4f}")

Coverage of confident predictions: 0.007 (10/1531 validation samples)
Accuracy on confident predictions: 0.2000
Model's true validation accuracy: 0.4161
