Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add A2C #183

Closed
wants to merge 13 commits into from
Closed

Add A2C #183

wants to merge 13 commits into from

Conversation

shermansiu
Copy link
Contributor

Given how A2C is a special case of PPO (Huang et al., 2022), adding A2C to trlX becomes a matter of adding a few dedicated A2C configurations, as opposed to implementing A2C from scratch.

Some minor refactoring was necessary to get everything to work (and comply with flake8's standards), but overall the changes are quite minimal.

[1] S. Huang, A. Kanervisto, A. Raffin, W. Wang, S. Ontañón, and R. F. J. Dossa, A2C is a special case of PPO. 2022. Retrieved from https://arxiv.org/pdf/2205.09123.pdf

@shermansiu shermansiu mentioned this pull request Jan 11, 2023
@shermansiu
Copy link
Contributor Author

I added a2c_sentiments.py, as requested by Louis Castricato. Unfortunately, there is a lot of overlap with ppo_sentiments.py but adding relative imports in the examples folder might increase confusion.

@shermansiu
Copy link
Contributor Author

Related issue: #16.

@shermansiu
Copy link
Contributor Author

@Dahoas On the Discord channel, I think you indicated that you were interested in reviewing code for #16?

@LouisCastricato
Copy link
Contributor

I added a2c_sentiments.py, as requested by Louis Castricato. Unfortunately, there is a lot of overlap with ppo_sentiments.py but adding relative imports in the examples folder might increase confusion.

That's fine.

@LouisCastricato
Copy link
Contributor

LouisCastricato commented Jan 11, 2023

Do you have compute access yet? Or colab? Can you run a2c sentiment and post the wandb here?

I can merge after we verify it runs.

SGD: str = "sgd"


def get_optimizer_class(name: OptimizerName):
torch_optimizers: Dict[str, type] = dict(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jon-tow I remember you being somewhat against this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fine. We don't really care about much of the torch optimizers outside of the current ones. I'm not even sure we should supply RMSProp as it's never really used for optimizing transformer language models.

@LouisCastricato
Copy link
Contributor

Do NOT merge until we verify on wandb runs.

@shermansiu
Copy link
Contributor Author

I don't have access to the compute cluster yet. I can run the script after I get home from the University: I have a research computer at home that has a GPU with more VRAM.

Copy link
Collaborator

@jon-tow jon-tow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @shermansiu! I've made some minor change requests. Let's get some reports on this before doing a final review for merging. If you're unable to get resources soon-ish we can run things for you later in the week 👍

Comment on lines 18 to 20
TRLX_PATH = pathlib.Path(__file__).resolve().parent.parent
with TRLX_PATH.joinpath("configs/ilql_config.yml").open() as f:
default_config = yaml.safe_load(f)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to make this change, we have to do it for every example (not just sentiments) for consistency. Revert otherwise.

@@ -0,0 +1,55 @@
train:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove the a2c_gptj.yml unless tested thoroughly. We don't want folks wasting their compute resources on a large-ish tune with untested hparams.

@jon-tow
Copy link
Collaborator

jon-tow commented Jan 12, 2023

I don't have access to the compute cluster yet. I can run the script after I get home from the University: I have a research computer at home that has a GPU with more VRAM.

You should be able to get a reasonable signal + sanity checks from the gpt2-imdb model which is rather small. Let us know it goes!

@LouisCastricato
Copy link
Contributor

LouisCastricato commented Jan 12, 2023

If you can get us runs in the next 12 hours we can merge this for 0.4

The RMSProp optimizer is used in the original A3C/A2C paper (Mnih et
al., 2016). As suggested by Huang et al. (2022), we switch to using the
RMSProp optimizer to "implement" A2C using our existing implementation
of PPO.
We follow the steps taken by Huang et al. (2022) in implementing A2C as
a special case of PPO.

Because 'scale_reward' is set to False within the
existing configurations, we don't need to remove advantage
normalization, as it's already disabled. Moreover, entropy
regularization is not implemented in TRLX's implementation of PPO, so we
don't need to manually set the entropy coefficient to 0.
Flake8 was giving the C901 complaint, so I refactored the `get_optimizer_class` function. Moreover, now, adding a default PyTorch or a `bitsandbytes` optimizer is as simple as a one-liner.
@shermansiu
Copy link
Contributor Author

I made a run of A2C: unfortunately, the mean reward is quite volatile compared to that of PPO. I no longer think A2C is a good candidate algorithm for RLHF.

https://api.wandb.ai/report/shermansiu/lydosf6p

@shermansiu
Copy link
Contributor Author

This is even when I slash the learning rate from 1e-5 to 1e-7... Both are equally volatile, but 1e-7 has a lower mean reward.

@shermansiu shermansiu mentioned this pull request Jan 13, 2023
@RobertKirk
Copy link
Contributor

I no longer think A2C is a good candidate algorithm for RLHF.

It is the algorithm that DeepMind is using for all their RLHF work, and it seems to work well there. Probably there's some other hyperparameters/etc. that make it work there. They also use the KL penalty as an auxiliary loss rather than reward, which makes the RL reward simpler, which is maybe why A2C works for them.

@shermansiu
Copy link
Contributor Author

This is the relevant paper. Muesli is a higher priority item, but I could look at this later.
https://arxiv.org/pdf/1706.03741.pdf

@conceptofmind
Copy link

This is even when I slash the learning rate from 1e-5 to 1e-7... Both are equally volatile, but 1e-7 has a lower mean reward.

This is what was used in the Sparrow paper for A2C:

"We extend the RL scheme of Menick et al. (2022); Perez et al. (2022), training a 70B A2C policy using Adafactor (Shazeer and Stern, 2018), a learning rate of 2 × 10−6, an effective batch size of 16, and 𝑙2-norm gradient clipping to a max norm of 1.0. Instead of the typical entropy term, we regularise by adding the KL divergence between the RL policy and the initial language model (SFT or Chinchilla) to the loss, with a weight 0.2. To reduce memory usage, we freeze the first 80% of the weights (64/80 transformer layers) to the pretrained values, share parameters between policy and value functions, and train with reduced precision using bfloat16 as in Rae et al. (2021) and stochastic rounding (Gupta et al., 2015). The value function predicts the final reward (without discounting) at each token. We implement the value function as an MLP with two hidden layers of size 2048, which takes as input the final transformer representation at each time step. We shard the models across 64 TPU v3 machines (Shoeybi et al., 2019)".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants