Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an option to run SQIL with various off-policy algorithms #778

Merged
merged 11 commits into from Sep 8, 2023

Conversation

michalzajac-ml
Copy link
Contributor

Description

This PR adds a possibility to combine SQIL with off-policy algorithms other than DQN, such as SAC, TD3, DDPG, as requested in #767.
A tutorial with SQIL+SAC training on HalfCheetah env is also provided. Random policy gets < 0, expert demonstrations are at ~3400. SQIL+SAC reaches 1400.7 +/- 254.1 after 300K steps (mean +/- std from 5 runs).

Testing

pytest tests/algorithms/test_sqil.py -- adapted relevant tests to work with new base algorithms.
Also one can run the provided tutorial.

Base automatically changed from dependency_fixes to master September 7, 2023 22:57
Copy link
Member

@AdamGleave AdamGleave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the implementation! Overall looks strong, just a few relatively minor changes.

cache = pytestconfig.cache
assert cache is not None
return expert_trajectories.make_expert_transition_loader(
cache_dir=cache.mkdir("experts"),
cache_dir=cache.mkdir(env_name.replace("/", "_")),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need environment name in the cache directory? Should already be included in the environment path in https://github.com/HumanCompatibleAI/imitation/blob/master/src/imitation/testing/expert_trajectories.py#L74

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, indeed, I was confused about the implementation of this function that uses cache and was not sure if I need to make it unique or not. Now I see that this root cache dir can be shared.

tests/algorithms/test_sqil.py Outdated Show resolved Hide resolved
docs/tutorials/8a_train_sqil_sac.ipynb Outdated Show resolved Hide resolved
"cell_type": "markdown",
"metadata": {},
"source": [
"After we collected our expert trajectories, it's time to set up our behavior cloning algorithm."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this was just copied from the original tutorial but I find the reference to behavior cloning potentially ambiguous: it usually refers to supervised learning on expert trajectories (and we have a BC class that does exactly), SQIL is doing something conceptually similar but quite different in the details (RL rather than supervised learning).

Would suggest rephrasing this (and the original tutorial), could just call it an imitation algorithm rather than supervised learning algorithm.

"cell_type": "markdown",
"metadata": {},
"source": [
"After training, we can observe that agent is quite improved (> 1000), although it does not reach the expert performance in this case."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have time to do more tuning, great, but not a priority; this is enough to illustrate the algorithm.

Copy link
Member

@AdamGleave AdamGleave left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@AdamGleave AdamGleave merged commit 5c85ebf into master Sep 8, 2023
7 of 9 checks passed
@AdamGleave AdamGleave deleted the 767-sqil-other-algos branch September 8, 2023 16:26
lukasberglund pushed a commit to lukasberglund/imitation that referenced this pull request Sep 12, 2023
…mpatibleAI#778)

* Pin huggingface_sb3 version.

* Properly specify the compatible seals version so it does not auto-upgrade to 0.2.

* Make random_mdp test deterministic by seeding the environment.

* Add an option to run SQIL with various off-policy algorithms

* Add 8a_train_sqil_sac to toctree

* Fix performance tests for SQIL

* fix

* Update docs/tutorials/8a_train_sqil_sac.ipynb

Co-authored-by: Adam Gleave <adam@gleave.me>

* minor fixes

* Bring back performance tests for SQIL

---------

Co-authored-by: Maximilian Ernestus <maximilian@ernestus.de>
Co-authored-by: Adam Gleave <adam@gleave.me>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants