Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backward sampling for continuous torus (ctorus) #193

Merged
merged 38 commits into from
Sep 7, 2023

Conversation

alexhernandezgarcia
Copy link
Owner

@alexhernandezgarcia alexhernandezgarcia commented Sep 6, 2023

The main goal of this PR is to enable backward sampling in the Continuous Torus (ctorus) environment. This is needed to sample from the replay buffer and to compute evaluation metrics related to the likelihood of test data.#

Enabling backward sampling in the CTorus has required changing, among other things, the arguments of (formerly called) sample_actions() method of the environments. Specifically, it needs to know

  • The originating states
  • Whether the actions are forward or backward

Therefore, these has required changes wherever this method is used.

I have taken the chance to make other changes in this method:

  • Add a good amount of docstrings
  • Rename the method to sample_actions_batch()
  • Other refactoring changes

More things:

  • Important fix: set_state(state) now copies the state before setting it to self.state. This was the source of painful hidden errors.
  • Extended tests for the ctorus.
  • Documentation here and there, especially in ctorus.py
  • statebatch2policy() of the tori calls statetorch2proxy instead of doing numpy-based transformations.
  • Remove get_uniform_terminating_states() of Tetris since it is outdated and it can use the base env's get_random_terminating_states().
  • Enable test of backward sampling for all environments through the use of get_uniform_terminating_states(). (See important note about this below)

⚠️ After having enabled tests of backward sampling for some environments that did not include these tests, they fail. This is the case of:

  • Crystal: it gets stuck in no valid actions available, so we need to look into it.
  • Tree: Tried to execute action (2, 0.5769401788711548) not present in action space. I think this may be fixed in another PR.

Sanity checks:

wandb sanity check runs

…implement it with step_random(); Remove get_parents test from continuous.
…step() which is called by both forward and backward step().
… add states_from and is_backward args needed for continuous envs backward sampling 3) Rename mask_invalid_actions -> mask 4) Add docstring.
… tests and use instead of get_uniform_... for the tetris env.
@alexhernandezgarcia alexhernandezgarcia marked this pull request as ready for review September 6, 2023 19:41
gflownet/envs/base.py Show resolved Hide resolved
gflownet/envs/base.py Show resolved Hide resolved
gflownet/envs/base.py Show resolved Hide resolved
gflownet/envs/ctorus.py Outdated Show resolved Hide resolved
gflownet/envs/ctorus.py Show resolved Hide resolved
gflownet/envs/ctorus.py Outdated Show resolved Hide resolved
gflownet/envs/ctorus.py Outdated Show resolved Hide resolved
gflownet/envs/ctorus.py Outdated Show resolved Hide resolved
gflownet/envs/tree.py Outdated Show resolved Hide resolved
gflownet/envs/tree.py Outdated Show resolved Hide resolved
Co-authored-by: Michał Koziarski <michal.koziarski@gmail.com>
…ezgarcia/gflownet into backward-sampling-continuous
…ntinuous-mk

Changes to continuous backward sampling
@@ -143,54 +173,106 @@ def get_parents(
parents = [state]
return parents, [action]

def sample_actions(
def action2representative(self, action: Tuple) -> Tuple:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this method returns self.representative_action regardless the action, why is it action2representative with an argument action, not just get_representative_action (w/o any arguments)?

if self.done:
assert action == self.eos
self.done = False
self.n_actions += 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why self.n_actions is incremented in the backward step, not decremented? I thought it should be the same as the last dimension of the state

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it is not the same. self.n_actions counts the number of (valid) actions in a trajectory, regardless of whether it is forward of backward.

Copy link
Collaborator

@AlexandraVolokhova AlexandraVolokhova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't get all the details but in overall it looks good to me

@alexhernandezgarcia alexhernandezgarcia merged commit e73b143 into main Sep 7, 2023
1 check passed
@josephdviviano josephdviviano deleted the backward-sampling-continuous branch January 31, 2024 21:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants