-
Notifications
You must be signed in to change notification settings - Fork 465
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add A2C #183
Add A2C #183
Conversation
I added a2c_sentiments.py, as requested by Louis Castricato. Unfortunately, there is a lot of overlap with ppo_sentiments.py but adding relative imports in the examples folder might increase confusion. |
Related issue: #16. |
That's fine. |
Do you have compute access yet? Or colab? Can you run a2c sentiment and post the wandb here? I can merge after we verify it runs. |
SGD: str = "sgd" | ||
|
||
|
||
def get_optimizer_class(name: OptimizerName): | ||
torch_optimizers: Dict[str, type] = dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jon-tow I remember you being somewhat against this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be fine. We don't really care about much of the torch optimizers outside of the current ones. I'm not even sure we should supply RMSProp as it's never really used for optimizing transformer language models.
Do NOT merge until we verify on wandb runs. |
I don't have access to the compute cluster yet. I can run the script after I get home from the University: I have a research computer at home that has a GPU with more VRAM. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @shermansiu! I've made some minor change requests. Let's get some reports on this before doing a final review for merging. If you're unable to get resources soon-ish we can run things for you later in the week 👍
examples/ilql_sentiments.py
Outdated
TRLX_PATH = pathlib.Path(__file__).resolve().parent.parent | ||
with TRLX_PATH.joinpath("configs/ilql_config.yml").open() as f: | ||
default_config = yaml.safe_load(f) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want to make this change, we have to do it for every example (not just sentiments) for consistency. Revert otherwise.
configs/a2c_gptj.yml
Outdated
@@ -0,0 +1,55 @@ | |||
train: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the a2c_gptj.yml
unless tested thoroughly. We don't want folks wasting their compute resources on a large-ish tune with untested hparams.
You should be able to get a reasonable signal + sanity checks from the |
If you can get us runs in the next 12 hours we can merge this for 0.4 |
The RMSProp optimizer is used in the original A3C/A2C paper (Mnih et al., 2016). As suggested by Huang et al. (2022), we switch to using the RMSProp optimizer to "implement" A2C using our existing implementation of PPO.
We follow the steps taken by Huang et al. (2022) in implementing A2C as a special case of PPO. Because 'scale_reward' is set to False within the existing configurations, we don't need to remove advantage normalization, as it's already disabled. Moreover, entropy regularization is not implemented in TRLX's implementation of PPO, so we don't need to manually set the entropy coefficient to 0.
Flake8 was giving the C901 complaint, so I refactored the `get_optimizer_class` function. Moreover, now, adding a default PyTorch or a `bitsandbytes` optimizer is as simple as a one-liner.
I made a run of A2C: unfortunately, the mean reward is quite volatile compared to that of PPO. I no longer think A2C is a good candidate algorithm for RLHF. |
This is even when I slash the learning rate from 1e-5 to 1e-7... Both are equally volatile, but 1e-7 has a lower mean reward. |
It is the algorithm that DeepMind is using for all their RLHF work, and it seems to work well there. Probably there's some other hyperparameters/etc. that make it work there. They also use the KL penalty as an auxiliary loss rather than reward, which makes the RL reward simpler, which is maybe why A2C works for them. |
This is the relevant paper. Muesli is a higher priority item, but I could look at this later. |
This is what was used in the Sparrow paper for A2C: "We extend the RL scheme of Menick et al. (2022); Perez et al. (2022), training a 70B A2C policy using Adafactor (Shazeer and Stern, 2018), a learning rate of 2 × 10−6, an effective batch size of 16, and 𝑙2-norm gradient clipping to a max norm of 1.0. Instead of the typical entropy term, we regularise by adding the KL divergence between the RL policy and the initial language model (SFT or Chinchilla) to the loss, with a weight 0.2. To reduce memory usage, we freeze the first 80% of the weights (64/80 transformer layers) to the pretrained values, share parameters between policy and value functions, and train with reduced precision using bfloat16 as in Rae et al. (2021) and stochastic rounding (Gupta et al., 2015). The value function predicts the final reward (without discounting) at each token. We implement the value function as an MLP with two hidden layers of size 2048, which takes as input the final transformer representation at each time step. We shard the models across 64 TPU v3 machines (Shoeybi et al., 2019)". |
Given how A2C is a special case of PPO (Huang et al., 2022), adding A2C to trlX becomes a matter of adding a few dedicated A2C configurations, as opposed to implementing A2C from scratch.
Some minor refactoring was necessary to get everything to work (and comply with flake8's standards), but overall the changes are quite minimal.
[1] S. Huang, A. Kanervisto, A. Raffin, W. Wang, S. Ontañón, and R. F. J. Dossa, A2C is a special case of PPO. 2022. Retrieved from https://arxiv.org/pdf/2205.09123.pdf