An easy-to-use framework for training large language models with reinforcement learning in PyTorch
, accelerated with Lightning Fabric. The framework is designed to be modular and extensible. It supports different RL algorithms, different tasks, and different datasets.
git clone https://github.com/Eclectic-Sheep/sheeprlhf.git && cd sheeprlhf
pip install -e "."
# Launch SFT training for summarization task using OPT model using Lightning Fabric on GPU
python sheeprlhf.py train data=summarization model=opt task=sft fabric=auto_cuda
# Optionally
pip install -e ".[eval]"
python sheeprlhf.py eval task=perplexity experiment_dir=<path_to_sft_experiment>
This will train a model on the summarization dataset using the OPT model. The training will first download the dataset and the model, and then start training. The training will be accelerated with Lightning Fabric, and all metrics will be logged locally using TensorBoard.
Here is the available configurations out-of-the-box for the framework:
Dataset Name | Config Name |
---|---|
CarperAI/openai_summarize_comparisons | summarization |
Dahoas/full-hh-rlhf | helpful_harmless |
Model Name | Config Name |
---|---|
OPT | opt |
GPT2 | gpt2 |
Phi | phi |
Train Task Name | Config Name |
---|---|
Supervised Fine-Tuning | sft |
Reward Modeling | rm |
Proximal Policy Optimization | ppo |
Direct Policy Optimization | dpo |
Evaluation Task Name | Config Name |
---|---|
Perplexity | perplexity |
ROUGE | rouge |
We want to have a framework for RL algorithms for LLMs starting from common RLHF algorithms that is at the same time simple and scalable thanks to Lightning Fabric framework. Single framework for different type of tasks and algorithms, should allow developers to easily experiment with different configurations.
Reinforcement Learning with Human Feedback (RLHF) is a technique that combines traditional reinforcement learning (RL) with human decisions to train more effective and safe policies. Instead of solely relying on reward signals obtained from the environment, RLHF integrates feedback from humans to guide the learning process. With RLHF, we can have approximated reward signals that are not crafted manually, but rather learned from human judgments. Moreover, we have implemented Direct Policy Optimization for aligning models to human preferences without training a reward model.
SheepRLHF
is designed to be modular and extensible. The framework provides two entry points: train
and eval
. The train
entry point is used to train a model, while the eval
entry point is used to evaluate a model. After selecting the entry point, the user can select the task, the model, and the data to use. All other configurations can be changed by passing them as command line arguments.
The repository is structured as follows:
agent
: Contains the implementation of the agents for RL algorithms.config
: Contains the default configurations for entry points or experiments.data
: Contains the implementation of the data processors that can be extended to support new datasets. It also includes dataset and data collator implementations.loss
: Contains the implementation of the loss functions for available tasks.model
: Contains the implementation of wrapper model classes for LLMs.structure
: This folder has all configurations for the framework, including the default configurations. The user can add new settings to the framework by adding new configurations to this folder.data.py
: Contains the configuration for each dataset available.fabric.py
: Configurations for Lightning Fabric instance.generation.py
: Contains parameters for generation configuration for text generation.model.py
: Contains the configuration for each model available.optim.py
: Optimization configuration.run.py
: Entry point configurations for training and evaluation.task.py
: Contains the configuration for each task available such as SFT, DPO, and PPO etc.
task
: In this folder, we have implementations for each task that the framework supports.train
: Contains the implementation of the training algorithms such as SFT, DPO, and PPO.eval
: Contains the implementation of the evaluation algorithms such as perplexity and and ROUGE.
utils
: Contains utilities and helper functions.cli.py
: Contains the entry points for the framework.
All models are defined as configuration dataclasses
under sheeprlhf/structure/model.py
file.To add a new model available on Huggingface
, one can add a new configuration to the file. For example, to add OPT 350M model, one can add the following code:
@dataclass
class OPTConfig(ModelConfig):
"""Configurations for OPT based models."""
config_name: str = "opt"
repo_name: str = "facebook/opt-350m"
embedding_dim_name: Optional[str] = "word_embed_proj_dim"
lora_cfg: Optional[LORAConfig] = LORAConfig(targets="('q_proj','v_proj')")
SheepRLHF
supports LoRA out of the box, which helps reducing memory requirements while only updating the subset of parameters. To enable LoRA, one can add the following code to the configuration of the algorithm:
python sheeprlhf.py train task=sft model=opt data=summarization model.finetune_mode=LORA model.lora_cfg.rank=16
The best way to contribute is by opening an issue to discuss a new feature or a bug, or by opening a PR to fix a bug or to add a new feature. For development, it is required to install the pre-commit hooks and have development dependencies installed. To do so, run the following commands:
pip install ".[dev]"
pre-commit install
This work and the code developed for the task is a long educational and experimental journey. Please ask us about anything you need or not clear on GitHub. It will be even more then welcomed if you like to contribute. We would like to thank the following works for their contributions to the field and inspiring us to develop this work.
- StackLLaMa
- Implementing RLHF: Learning to Summarize with trlX
- RLHF: Reinforcement Learning from Human Feedback
- Fine-tune Llama 2 with DPO
- Learning to summarize from human feedback
- Training language models to follow instructions
- DeepSpeed-Chat: Easy, Fast and Affordable RLHF Training of ChatGPT-like Models at All Scales
- Secrets of RLHF in Large Language Models Part I: PPO
- LLAMA 2: Open Foundation and Fine-Tuned Chat Models
- Training a Helpful and Harmless Assistant with Reinforcement Learning from Human Feedback
You can contact us for any further questions or discussions:
- Refik Can Malli: refikcan.malli@orobix.com
- Federico Belotti: federico.belotti@orobix.com
- Davide Angioni: davide.angioni@orobix.com
- Michele Milesi: michele.milesi@orobix.com
This project is licensed under the terms of the Apache License 2.0. Please see the LICENSE file for details. Be aware that the project also may use other third-party libraries or models available online, which may be licensed under different licenses.