# Probe Training and Evaluation

In [None]:
import probe_gen.probes as probes
from sklearn.metrics import classification_report
from probe_gen.config import ConfigDict

# If 'torch' not in name, we are using sklearn probes which dont use a val set to fit
# I suggest not using mean_torch as it is not implemented fully?
probe_type = ["mean", "attention_torch", "mean_torch"][1]
# synthetic training set
dataset_name = "deception_rp_mistral_7b_3k"

# Can set hyperparameters or load best ones we found if we did
# cfg = ConfigDict(layer=12, use_bias=True, normalize=True, c=0.001)
cfg = ConfigDict.from_json(probe_type, dataset_name)

# Create train, val, and test datasets
activations_tensor, attention_mask, labels_tensor = probes.load_hf_activations_and_labels_at_layer(dataset_name, layer=cfg.layer, verbose=True)
if "mean" in probe_type:
    activations_tensor = probes.MeanAggregation()(activations_tensor, attention_mask)
train_dataset, val_dataset, test_dataset = probes.create_activation_datasets(activations_tensor, labels_tensor, splits=[2500, 500, 0], verbose=True)

# Initialise and fit a probe with the datasets
if probe_type == "mean":
    probe = probes.SklearnLogisticProbe(cfg)
elif probe_type == "mean_torch":
    probe = probes.TorchLinearProbe(cfg)
elif probe_type == "attention_torch":
    probe = probes.TorchAttentionProbe(cfg)
probe.fit(train_dataset, val_dataset)

# Print val results
eval_dict, y_pred, y_pred_proba = probe.eval(val_dataset)
print('\nroc_auc:', eval_dict['roc_auc'])

In [None]:
# Eval against synthetic test set and real test set. The goal is to reduce the gap and make real ("deception_rp_llama_3b_500") performance better
for new_dataset_name in ["deception_rp_mistral_7b_500", "deception_rp_llama_3b_500"]:
    print(f"\nEvaluating on {new_dataset_name}")
    activations_tensor, attention_mask, labels_tensor = probes.load_hf_activations_and_labels_at_layer(new_dataset_name, layer=cfg.layer, verbose=True)
    if "mean" in probe_type:
        activations_tensor = probes.MeanAggregation()(activations_tensor, attention_mask)
    _, _, test_dataset = probes.create_activation_datasets(activations_tensor, labels_tensor, splits=[0, 0, 1000], verbose=True)

    # Evaluate the model
    eval_dict, y_pred, y_pred_proba = probe.eval(test_dataset)
    print(eval_dict)
    print(classification_report(test_dataset['y'], y_pred))

In [None]:
# Nice visualisation to see how the probe splits the two classes
from probe_gen.standard_experiments.experiment_plotting import plot_per_class_prediction_distributions

plot_per_class_prediction_distributions(test_dataset['y'], y_pred_proba)

## Hyperparameter Search

In [None]:
from probe_gen.standard_experiments.hyperparameter_search import run_full_hyp_search_on_layers
from probe_gen.standard_experiments.hyperparameter_search import load_best_params_from_search
from probe_gen.config import ConfigDict

probe_type = "attention_torch"
dataset_name = "deception_rp_mistral_7b_3k"
model_name = "llama_3b"

# You might not be able to run all layers at once, so can do them in batches like below
# Or just layer 12 for faster results
layer_list = [6,9,12,15,18,21]
run_full_hyp_search_on_layers(
    probe_type, dataset_name, model_name, layer_list
)