A Python library for implementing neural controllers with decoder-only Large Language Models (LLMs), as described in our paper. Our API allows you to steer the output of language models toward desired concepts and generate lightweight detectors for arbitrary pre-defined concepts. The approach can be implemented with any decoder-only LLM, with demonstrated success on models like instruction-tuned Llama-3.1-8B, Llama-3.3-70B, and Gemma-2-9B.
We choose Recursive Feature Machines (RFMs) as our nonlinear predictor at every layer and (often) as our aggregation model. These models are simple, lightweight kernel machines. We also include functionality for our aggregation technology with other baselines including linear/logistic probing and contrastive methods like PCA and difference-in-means. The RFM library can be installed from the xRFM github. The xRFM repo has since been updated but xRFM with MSE and AUC-based metrics can be installed here:
pip install git+https://github.com/dmbeaglehole/xRFM.git@773fae8
See the notebooks folder for examples of steering:
- Style transfer capabilities (e.g., English to Shakespearean or Poetic)
- Language transfer capabilities (e.g., English to Spanish, Mandarin to English)
- Harmful steering (exposing social security numbers)
- Python 3.10.15
- PyTorch 2.4.0+cu118
- Transformers 4.47.0
- Datasets 3.1.0
- NumPy 1.26.4
- tqdm
- torchmetrics
- scikit-learn
- xRFM (previous commit 773fae8 of https://github.com/dmbeaglehole/xRFM)
- Access to decoder-only LLM weights, such as Llama-3.1-8B-it and Gemma-2-9B-it.
Methodology for (B) steering and (C) detecting concepts in language models by aggregating layer-wise predictors. Examples include harmfulness, Shakespearean/Poetic English, and dishonesty.
To use the notebooks, you must create a data folder within the neural_controllers directory:
cd neural_controllers
mkdir data
For each concept you must create a subfolder within this data folder. For Shakespeare and Spanish steering,
mkdir data/languages
As examples, English/Shakespeare and English/Spanish translation data can be found here and here. Otherwise, to generate directions within the notebooks, you must place appropriate datasets in this directory. Datasets for other notebooks used in the paper will be released as a proper benchmark.
from neural_controllers import NeuralController
from transformers import AutoTokenizer, AutoModelForCausalLM
# Initialize tokenizer and model
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
language_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda")
# Create neural controller
controller = NeuralController(
language_model,
tokenizer,
rfm_iters=8,
batch_size=2,
n_components=5,
control_method='rfm'
)
# Load pre-trained directions
controller.load(concept=f'shakespeare',
model_name='llama_3_8b_it',
path='../directions/')
# Generate controlled text
prompt = controller.format_prompt("What can I do to treat flu symptoms?")
controlled_output = controller.generate(
prompt,
layers_to_control=list(range(-1, -31, -1)),
control_coef=0.5,
max_new_tokens=150
)from neural_controllers import NeuralController
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
# Initialize tokenizer and model
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
language_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda")
# Create neural controller
controller = NeuralController(
language_model,
tokenizer,
rfm_iters=8,
batch_size=2,
n_components=5,
control_method='rfm'
)
def get_data(controller):
# Load the dataset
ds = load_dataset("lmsys/toxic-chat", "toxicchat1123")
# Extract and preprocess inputs
all_train_inputs = [x['user_input'] for x in ds['train']]
test_inputs = [x['user_input'] for x in ds['test']]
# split all_train inputs into val/train
n = len(all_train_inputs)
train_inputs, val_inputs = all_train_inputs[:n//2], all_train_inputs[n//2:]
# Format prompts using the controller
val_inputs = [controller.format_prompt(x) for x in val_inputs]
train_inputs = [controller.format_prompt(x) for x in train_inputs]
test_inputs = [controller.format_prompt(x) for x in test_inputs]
# Extract labels
all_train_labels = [x['toxicity'] for x in ds['train']]
test_labels = [x['toxicity'] for x in ds['test']]
# split all_train labels into val/train
train_labels, val_labels = all_train_labels[:n//2], all_train_labels[n//2:]
return train_inputs, train_labels, val_inputs, val_labels, test_inputs, test_labels
train_inputs, train_labels, val_inputs, val_labels, test_inputs, test_labels = get_data(controller)
controller.compute_directions(train_inputs, train_labels)
val_metrics, test_metrics, _ = controller.evaluate_directions(
train_inputs, train_labels,
val_inputs, val_labels,
test_inputs, test_labels,
)The validation and test metrics are structured as nested dictionaries, with one key per layer (counting backwards from the final layer indexed -1). We also include the aggregated scores. E.g. for Llama-3.1-8B with 31 blocks:
val_metrics = {
-1: {
'auc': float, # AUROC score
'acc': float, # Accuracy score
'f1 score': float, # F1 score
'recall': float, # Recall score
'precision': float # Precision score
},
-2: {
'auc': float,
'acc': float,
'f1 score': float,
'recall': float,
'precision': float
},
# ... continues through layer -31
'aggregated': {
'auc': float # Aggregated auroc across layers
'acc': float, # Aggregated accuracy
'f1 score': float, # Aggregated F1 score
'recall': float, # Aggregated recall
'precision': float # Aggregated precision
}
}If you find this work useful in your research, please consider citing:
@misc{beaglehole2025universalsteeringmonitoringai,
title={Toward universal steering and monitoring of AI models},
author={Daniel Beaglehole and Adityanarayanan Radhakrishnan and Enric Boix-Adserà and Mikhail Belkin},
year={2025},
eprint={2502.03708},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2502.03708},
}[1]: Lin, Z., Wang, Z., Tong, Y., Wang, Y., Guo, Y., Wang, Y., & Shang, J. (2023). ToxicChat: Unveiling Hidden Challenges of Toxicity Detection in Real-World User-AI Conversation. arXiv preprint arXiv:2310.17389