# Do probe training

Load the activations and labels from HF, aggregate, and construct datasets to train the probe on (note sklearn doesn't require a validation dataset).

In [1]:
import probe_gen.probes as probes
from sklearn.metrics import classification_report
from probe_gen.config import ConfigDict

layer = 12
probe_type = ["mean_torch", "attention_torch"][1]

# Create train, val, and test datasets
activations_tensor, attention_mask, labels_tensor = probes.load_hf_activations_and_labels_at_layer("refusal_llama_3b_5k", layer=layer, verbose=True)
if probe_type == "mean_torch":
    activations_tensor = probes.MeanAggregation()(activations_tensor, attention_mask)
train_dataset, val_dataset, test_dataset = probes.create_activation_datasets(activations_tensor, labels_tensor, val_size=0.1, test_size=0.2, balance=True, verbose=True)

loaded labels
loaded activations with shape torch.Size([5000, 337, 3072])
calculated attention mask with shape torch.Size([5000, 337])
Train: 3500 samples, 1750.0 positives
Val:   500 samples, 250.0 positives
Test:  1000 samples, 500.0 positives


Create a probe and fit it.

In [2]:
# Initialise and fit a probe with the datasets
cfg = ConfigDict(use_bias=True, normalize=True, lr=0.0001, weight_decay=0.0)
if probe_type == "mean_torch":
    probe = probes.TorchLinearProbe(cfg)
elif probe_type == "attention_torch":
    probe = probes.TorchAttentionProbe(cfg)
probe.fit(train_dataset, val_dataset)

# Print val results
eval_dict, y_pred, y_pred_proba = probe.eval(val_dataset)
print('\nroc_auc:', eval_dict['roc_auc'])


Epoch 10/100, Train Loss: 0.3058, Val Loss: 0.3383
Epoch 20/100, Train Loss: 0.2619, Val Loss: 0.3300

roc_auc: 0.9354560000000001


Evaluate the probe on test dataset.

In [3]:
# Evaluate the model
eval_dict, y_pred, y_pred_proba = probe.eval(test_dataset)
print(eval_dict)
print(classification_report(test_dataset['y'], y_pred))

{'accuracy': 0.876, 'roc_auc': 0.934068, 'tpr_at_1_fpr': np.float64(0.3)}
              precision    recall  f1-score   support

         0.0       0.88      0.87      0.88       500
         1.0       0.87      0.88      0.88       500

    accuracy                           0.88      1000
   macro avg       0.88      0.88      0.88      1000
weighted avg       0.88      0.88      0.88      1000



In [4]:
# Load a seperate test dataset
activations_tensor, attention_mask, labels_tensor = probes.load_hf_activations_and_labels_at_layer("refusal_llama_3b_1k", layer=layer, verbose=True)
if probe_type == "mean_torch":
    activations_tensor = probes.MeanAggregation()(activations_tensor, attention_mask)
_, _, test_dataset = probes.create_activation_datasets(activations_tensor, labels_tensor, val_size=0.0, test_size=1.0, balance=True, verbose=True)

# Evaluate the model
eval_dict, y_pred, y_pred_proba = probe.eval(test_dataset)
print(eval_dict)
print(classification_report(test_dataset['y'], y_pred))

loaded labels


loaded activations with shape torch.Size([1000, 249, 3072])
calculated attention mask with shape torch.Size([1000, 249])
Train: 0 samples, 0.0 positives
Val:   0 samples, 0.0 positives
Test:  1000 samples, 500.0 positives
{'accuracy': 0.858, 'roc_auc': 0.91186, 'tpr_at_1_fpr': np.float64(0.104)}
              precision    recall  f1-score   support

         0.0       0.88      0.83      0.85       500
         1.0       0.84      0.88      0.86       500

    accuracy                           0.86      1000
   macro avg       0.86      0.86      0.86      1000
weighted avg       0.86      0.86      0.86      1000



In [5]:
# Load a seperate test dataset
activations_tensor, attention_mask, labels_tensor = probes.load_hf_activations_and_labels_at_layer("refusal_llama_3b_prompted_1k", layer=layer, verbose=True)
if probe_type == "mean_torch":
    activations_tensor = probes.MeanAggregation()(activations_tensor, attention_mask)
_, _, test_dataset = probes.create_activation_datasets(activations_tensor, labels_tensor, val_size=0.0, test_size=1.0, balance=True, verbose=True)

# Evaluate the model
eval_dict, y_pred, y_pred_proba = probe.eval(test_dataset)
print(eval_dict)
print(classification_report(test_dataset['y'], y_pred))

loaded labels


loaded activations with shape torch.Size([1000, 249, 3072])
calculated attention mask with shape torch.Size([1000, 249])
Train: 0 samples, 0.0 positives
Val:   0 samples, 0.0 positives
Test:  1000 samples, 500.0 positives
{'accuracy': 0.857, 'roc_auc': 0.90416, 'tpr_at_1_fpr': np.float64(0.058)}
              precision    recall  f1-score   support

         0.0       0.90      0.80      0.85       500
         1.0       0.82      0.91      0.86       500

    accuracy                           0.86      1000
   macro avg       0.86      0.86      0.86      1000
weighted avg       0.86      0.86      0.86      1000



In [6]:
# Load a seperate test dataset
activations_tensor, attention_mask, labels_tensor = probes.load_hf_activations_and_labels_at_layer("refusal_ministral_8b_1k", layer=layer, verbose=True)
if probe_type == "mean_torch":
    activations_tensor = probes.MeanAggregation()(activations_tensor, attention_mask)
_, _, test_dataset = probes.create_activation_datasets(activations_tensor, labels_tensor, val_size=0.0, test_size=1.0, balance=True, verbose=True)

# Evaluate the model
eval_dict, y_pred, y_pred_proba = probe.eval(test_dataset)
print(eval_dict)
print(classification_report(test_dataset['y'], y_pred))

loaded labels
loaded activations with shape torch.Size([1000, 248, 3072])
calculated attention mask with shape torch.Size([1000, 248])
Train: 0 samples, 0.0 positives
Val:   0 samples, 0.0 positives
Test:  1000 samples, 500.0 positives
{'accuracy': 0.851, 'roc_auc': 0.912976, 'tpr_at_1_fpr': np.float64(0.112)}
              precision    recall  f1-score   support

         0.0       0.87      0.83      0.85       500
         1.0       0.83      0.88      0.85       500

    accuracy                           0.85      1000
   macro avg       0.85      0.85      0.85      1000
weighted avg       0.85      0.85      0.85      1000



# Hyperparameter Search

In [1]:
from probe_gen.standard_experiments.hyperparameter_search import run_full_hyp_search_on_layers

# You might not be able to run all layers at once, so can do them in batches like below
run_full_hyp_search_on_layers(
    'attention_torch', 'refusal_llama_3b_5k', 'llama_3b', [6,9,12,15,18,21]
)


######################### Evaluating layer 6 #############################

Epoch 10/100, Train Loss: 0.3491, Val Loss: 0.3851
Epoch 20/100, Train Loss: 0.3133, Val Loss: 0.3706
Epoch 30/100, Train Loss: 0.2890, Val Loss: 0.3668
Epoch 40/100, Train Loss: 0.2738, Val Loss: 0.3644
Epoch 50/100, Train Loss: 0.2606, Val Loss: 0.3648


[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.



Epoch 10/100, Train Loss: 0.3491, Val Loss: 0.3851
Epoch 20/100, Train Loss: 0.3133, Val Loss: 0.3706
Epoch 30/100, Train Loss: 0.2890, Val Loss: 0.3668
Epoch 40/100, Train Loss: 0.2738, Val Loss: 0.3644
Epoch 50/100, Train Loss: 0.2606, Val Loss: 0.3648

Epoch 10/100, Train Loss: 0.3491, Val Loss: 0.3851
Epoch 20/100, Train Loss: 0.3133, Val Loss: 0.3706
Epoch 30/100, Train Loss: 0.2890, Val Loss: 0.3668
Epoch 40/100, Train Loss: 0.2738, Val Loss: 0.3644
Epoch 50/100, Train Loss: 0.2606, Val Loss: 0.3648

Epoch 10/100, Train Loss: 0.3491, Val Loss: 0.3851
Epoch 20/100, Train Loss: 0.3133, Val Loss: 0.3706
Epoch 30/100, Train Loss: 0.2890, Val Loss: 0.3668
Epoch 40/100, Train Loss: 0.2739, Val Loss: 0.3644
Epoch 50/100, Train Loss: 0.2607, Val Loss: 0.3648

Epoch 10/100, Train Loss: 0.3491, Val Loss: 0.3851
Epoch 20/100, Train Loss: 0.3136, Val Loss: 0.3706
Epoch 30/100, Train Loss: 0.2895, Val Loss: 0.3668
Epoch 40/100, Train Loss: 0.2746, Val Loss: 0.3644
Epoch 50/100, Train Loss: 0

In [3]:
from probe_gen.standard_experiments.hyperparameter_search import load_best_params_from_search

# Can load the best params from the search at any time
load_best_params_from_search(
    'attention_torch', 'refusal_llama_3b_5k', 'llama_3b', [6,9,12,15,18,21]
)

Best roc_auc: 0.928752
Best params: {'layer': 12, 'lr': 0.0001, 'use_bias': True, 'normalize': True, 'weight_decay': 0.0}
