Skip to content

Multi-agent Social Simulation + Efficient, Effective, and Stable alternative of RLHF. Code for the paper "Training Socially Aligned Language Models in Simulated Human Society".

License

Notifications You must be signed in to change notification settings

agi-templar/Stable-Alignment

Repository files navigation

Stable Alignment

Stable Alignment - Alignment Learning in Social Games

lint License: Apache 2.0

This is the official repo for the Stable Alignment project. We aim to provide a RLHF alternative which is superior in alignment performance, highly-efficient in data learning, and easy to deploy in scaled-up settings. Instead of training an extra reward model that can be gamed during optimization, we directly train on the recorded interaction data in simulated social games. We find high-quality data + reliable algorithm is the secret recipe for stable alignment learning.

The repo contains:

Life is a game. Play by your rules!

Overview of Stable Alignment

Sandbox Simulation

Installation

# install development environment
pip install -r requirements.txt
# install dependencies for package re-distribution
pip install -e .

Simulation Setup

  • Initial data is already stored at assets/hh-rlhf/labeled_prior.jsonl (with Git LFS).
  • After a round of simulation, the simulated interaction data and metrics will be saved at data/cache/world_<world_id>/.
  • Place your OpenAI API key in .env inside the project root folder.

Run Simulation

Navigate to the project root folder and run simulation with customized settings:

python stable_alignment/simulation.py \
    -model_type 'text-davinci-002' \
    -obs_model_type 'gpt-3.5-turbo' \
    -world_id 1 \
    -init_setting 'all_bad' \
    -n_round '2' \
    -size '4' \
    -dataset_name 'hh-rlhf'

We present an example simulation result in assets/sample_world. It is simulated with 100 text-davinci-003 based social agents and ChatGPT based observer agents. The simulation is run for 50 rounds of interactions.

Alignment Data Release

Back Scatter in SandBox

The alignment data used for training has been already included in the path assets/sandbox_v1.json and assets/sandbox_v2.json. Note that they are sampled from the full set of interaction data by a ratio of 5:1:1 for Alignment Imitation, Self-Critic, and Realignment data respectively. The full set of interaction data is available upon request.

The Statistics of Alignment Data (Full Set)
  • sandbox_v1.json
Data / Social Agent Type text-davinci-002 text-davinci-003 ChatGPT Total
Alignment Imitation 9.8k 10k 10k 29.8k
Self-Critic 17k 20k 20k 57k
Realignment 3.3k 3k 0.7k 7k
Total 30.1k 33k 30.7k 93.8k
  • sandbox_v2.json
Data / Social Agent Type text-davinci-002 text-davinci-003 GPT4 Total
Alignment Imitation 18.2k 10.4k 20.2k 48.8k
Self-Critic 36.3k 18.3k 40k 94.6k
Realignment 18.2k 3.4k 4.0k 25.6k
Total 72.7k 32.1k 64.2k 169k

Training with Stable Alignment

torchrun --nproc_per_node=4 --master_port=36646 train_alignment.py \
      --model_name_or_path "/workspace/hhh_sft" \  # path to your SFT model
      --data_path "./assets/sandbox_v1.json" \ # path to the alignment data
      --bf16 True \
      --output_dir "/workspace/<your_output_lm_name>" \
      --num_train_epochs 7 \
      --per_device_train_batch_size 1 \  # batch size has to be 1 for alignment training
      --per_device_eval_batch_size 1 \
      --gradient_accumulation_steps 8 \
      --evaluation_strategy "no" \
      --save_strategy "steps" \
      --save_steps 200 \
      --save_total_limit 1 \
      --learning_rate 2e-5 \
      --weight_decay 0. \
      --warmup_ratio 0.03 \
      --lr_scheduler_type "cosine" \
      --logging_steps 1 \
      --fsdp "shard_grad_op auto_wrap" \  # change to "full_shard auto_wrap" if OOM
      --fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
      --tf32 True \
      --model_max_length 360 \  # change to shorter length if OOM
      --rating_scale 7 \  # the scale of the ratings. 7 for 1-7, 10 for 1-10, etc.
      --margin 10 \  # constant, see the paper
      --max_flow False \  # mean or max for the penalty
      --ratio 0.2 \  # control the ratio of the penalty
      --num_comp 3

So(cially)-Good Language Model

Model Release Model Release

We have released our models on huggingface! 🤗

Released models include:

  1. better-base, base model trained on LLaMA with AlpacaDataCleaned which is the fixed Alpaca instruction tuning dataset, and codealpaca which is the code pretraining dataset.

  2. hh-rlhf-sft, supervised fine-tuned model on better-base with the socially aligned demonstrations in Anthropic HH-RLHF dataset (the accepted samples in the dataset).

  3. socially-good-lm, socially aligned language model trained on hh-rlhf-sft with the stable alignment method.

After you download the model, you can run inference with the following command:

python stable_alignment/run_inference.py \
    --model_path './models/socially-good-lm' \
    --device 'cuda:0'

Citation

Please cite our paper if you use the data or code in this repo:

@misc{liu2023sociallyaligned,
      title={Training Socially Aligned Language Models in Simulated Human Society},
      author={Ruibo Liu and Ruixin Yang and Chenyan Jia and Ge Zhang and Denny Zhou and Andrew M. Dai and Diyi Yang and Soroush Vosoughi},
      year={2023},
      eprint={2305.16960},
      archivePrefix={arXiv},
      primaryClass={cs.CL}
}

About

Multi-agent Social Simulation + Efficient, Effective, and Stable alternative of RLHF. Code for the paper "Training Socially Aligned Language Models in Simulated Human Society".

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published