This repository contains the implementation of the paper "One Head to Rule Them All: Amplifying LVLM Safety through a Single Critical Attention Head" which proposes a novel approach to enhancing the safety of Large Vision-Language Models (LVLMs) by identifying and leveraging critical attention heads that are essential for safety.
LVLMs have demonstrated impressive capabilities in multimodal understanding tasks, but they often exhibit degraded safety alignment compared to text-only LLMs. This project introduces a method to amplify LVLM safety by:
- Investigating internal multi-head attention mechanisms
- Identifying critical "safety" attention heads
- Measuring deflection angles of hidden states to efficiently discriminate between safe and unsafe inputs
- Implementing a defense mechanism that achieves near-perfect detection of harmful inputs while maintaining low false positive rates
- PyTorch
- Transformers
- LVLMs (LLaVA-v1.5-7B, Qwen2-VL-7B-Instruct, Aya-Vision-8B, Phi-3.5-Vision, etc.)
# Install required dependencies
pip install -r requirements.txtThe code works with several datasets:
- VLSafe: Contains harmful image-text pairs
- LLaVA-Instruct-80K: Used for safe dataset
- ShareGPT4V: Used for safe dataset testing
- JailbreakV-28K: Contains various jailbreak attack scenarios
Prepare your datasets in the following directory structure:
data/
├── JailBreakV_28K/
│ └── JailBreakV_28K.csv
├── ShareGPT4V/
│ └── sharegpt4v.csv
├── VLSafe/
│ └── vlsafe.csv
├── LLaVA-Instruct-80K/
│ └── safe.csv
The repository is organized into several modules:
attack/: Code for testing LVLMs with unsafe inputseval/: Evaluation scripts using LLaMAGuard and Attack Success Rate calculationdetect/: Implementation of the detection mechanismdefense/: Implementation of the defense mechanismhead/: Code for identifying and analyzing safety-critical attention headsutils/: Utility scripts for threshold determination and safe head selection
Test how LVLMs respond to unsafe inputs:
python attack/attack_vlm.py --model_path /path/to/models/llava-v1.5-7b \
--image_path /path/to/train2017 \
--csv_path data/VLSafe/vlsafe.csvEvaluate responses using LLaMAGuard:
python eval/eval.py --data_path ./results/attack/VLSafe/LLaVA/vlsafe.csvCalculate Attack Success Rate:
python eval/asr.pyIdentify attention heads that are critical for safety:
python head/head_llava.py \
--model_path /path/to/models/llava-v1.5-7b \
--image_path /path/to/train2017 \
--csv_path data/VLSafe/vlsafe.csv
python head/head_llava.py \
--model_path /path/to/models/llava-v1.5-7b \
--image_path /path/to/train2017 \
--csv_path data/LLaVA-Instruct-80K/safe.csvSearch for the optimal safety attention heads:
python utils/search_safe_head.pyCalculate deflection angles for different datasets:
# Calculate deflection angles for unsafe dataset
python detect/detect_llava.py \
--model_path /path/to/models/llava-v1.5-7b \
--image_path /path/to/train2017 \
--csv_path data/VLSafe/vlsafe.csv \
--hidden_layer -1
--safe_heads [[8,2]] \
# Calculate deflection angles for safe dataset
python detect/detect_llava.py \
--model_path /path/to/models/llava-v1.5-7b \
--image_path /path/to/train2017 \
--csv_path data/LLaVA-Instruct-80K/safe.csv \
--hidden_layer -1
--safe_heads [[8,2]] \Determine the optimal threshold:
python utils/threshold.py \
--file1 results/detect/llava-v1.5-7b/vlsafe_layer-1/defense_results.csv \
--file2 results/detect/llava-v1.5-7b/safe_layer-1/defense_results.csvDetect potentially harmful inputs using identified safety heads:
python detect/detect_llava.py \
--model_path /path/to/models/llava-v1.5-7b \
--image_path /path/to/train2017 \
--csv_path data/VLSafe/vlsafe.csv \
--hidden_layer -1 \
--safe_heads [[8,2]] \
--threshold 2.16Implement the defense mechanism to prevent harmful outputs:
python defense/defense_llava.py \
--model_path /path/to/models/llava-v1.5-7b \
--image_path /path/to/train2017 \
--csv_path data/VLSafe/vlsafe.csv \
--hidden_layer -1 \
--safe_heads [[8,2]] \
--threshold 2.16