Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tune hyperparameters in tutorials for GAIL and AIRL #772

Merged
merged 7 commits into from Sep 7, 2023

Conversation

michalzajac-ml
Copy link
Contributor

Description

This PR tunes hyperparameters for the GAIL and AIRL tutorials.
For GAIL, the expert performance is reached (~500 on CartPole) with 800K PPO steps (~2 min run time on MacBook Air M1). For AIRL, the default is the "fast" version which improves over random but does not reach the expert performance (800K steps, ~2 min run time); if we switch off "fast" then the expert performance is reached (2M steps, ~5 min run time).

The hyperparameters were inspired by configs for half-cheetah from the benchmarking directory + a bit of manual tuning. Also, for GAIL I needed to change from BasicShapedRewardNet to BasicRewardNet to make it work (not exactly sure why but it affected performance a lot!).

Testing

Just ran the notebooks, and also tested with a few different seeds to make sure results are stable.

@michalzajac-ml michalzajac-ml added the docs Documentation missing, incorrect or unclear label Sep 5, 2023
@michalzajac-ml michalzajac-ml linked an issue Sep 5, 2023 that may be closed by this pull request
8 tasks
Copy link
Collaborator

@ernestum ernestum left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! I meant to do this for quite some time!
The changes itsel LGTM.
I think the pipeline fails because we get the newest seals version (0.2) which is made for gymnasium. If we change our seals version specifier in setup.py to seals~=0.1.5, this should be fixed.

@@ -84,7 +88,7 @@ Detailed example notebook: :doc:`../tutorials/4_train_airl`
learner_rewards_before_training, _ = evaluate_policy(
learner, env, 100, return_episode_rewards=True,
)
airl_trainer.train(20000)
airl_trainer.train(20000) # Train for 2_000_000 steps to match expert.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 million timesteps is a lot of timesteps for something as simple as CartPole, I expect we can do better but this seems fine for the purpose of this PR, at least the environment runs quickly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd keep it for now (it's already an improvement) and possibly revisit in another PR.

"print(\"mean reward after training:\", np.mean(learner_rewards_after_training))\n",
"print(\"mean reward before training:\", np.mean(learner_rewards_before_training))\n",
"\n",
"plt.hist(\n",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you removing histogram (here and in AIRL)? Fine to remove if it's not informative. But perhaps we should report the SD as well as the means?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the reason was I thought it was not super informative (especially in case we reach expert perf). Good suggestion with SD though, will add!

Copy link
Collaborator

@ernestum ernestum Sep 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shameless plug: this would be a nice application for my newly release data-samples-printer:

import data_samples_printer as dsp
dsp.pprint(
    before_training=learner_rewards_before_training,
    after_training=learner_rewards_after_training
)

prints something like:

▁  ▁      ▁▄  ▄▄▄█▇▄▄▇▄▇█▄█▃▃▇▄▇ ▇▁▃▄▁▃ ▄▃▁ ▁▁   ▁ -0.00 ±1.08 before_training
                      ▂▃▇█▄▄▂▁                     -0.01 ±0.20 after_training

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ernestum , thanks for this, the lib looks quite cool! I'll remember about it in the future. For this PR I decided to not introduce additional dependency though.

@michalzajac-ml michalzajac-ml changed the base branch from master to dependency_fixes September 6, 2023 08:03
Copy link
Member

@AdamGleave AdamGleave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Base automatically changed from dependency_fixes to master September 7, 2023 22:56
@AdamGleave AdamGleave merged commit 74b63ff into master Sep 7, 2023
7 of 9 checks passed
@AdamGleave AdamGleave deleted the 763-tune-gail-airl branch September 7, 2023 23:33
lukasberglund pushed a commit to lukasberglund/imitation that referenced this pull request Sep 12, 2023
…I#772)

* Pin huggingface_sb3 version.

* Properly specify the compatible seals version so it does not auto-upgrade to 0.2.

* Make random_mdp test deterministic by seeding the environment.

* Tune hyperparameters in tutorials for GAIL and AIRL

* Modify .rst docs for GAIL and AIRL to match tutorials

* GAIL and AIRL tutorials: report also std in results

---------

Co-authored-by: Maximilian Ernestus <maximilian@ernestus.de>
Co-authored-by: Adam Gleave <adam@gleave.me>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
docs Documentation missing, incorrect or unclear
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Ensure all tutorials work as expected
3 participants