-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Backward sampling for continuous torus (ctorus) #193
Conversation
…implement it with step_random(); Remove get_parents test from continuous.
…step() which is called by both forward and backward step().
… add states_from and is_backward args needed for continuous envs backward sampling 3) Rename mask_invalid_actions -> mask 4) Add docstring.
… tests and use instead of get_uniform_... for the tetris env.
# Conflicts: # config/logger/base.yaml # gflownet/envs/tree.py
Co-authored-by: Michał Koziarski <michal.koziarski@gmail.com>
…ezgarcia/gflownet into backward-sampling-continuous
…ntinuous-mk Changes to continuous backward sampling
…ezgarcia/gflownet into backward-sampling-continuous
@@ -143,54 +173,106 @@ def get_parents( | |||
parents = [state] | |||
return parents, [action] | |||
|
|||
def sample_actions( | |||
def action2representative(self, action: Tuple) -> Tuple: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if this method returns self.representative_action
regardless the action
, why is it action2representative
with an argument action
, not just get_representative_action
(w/o any arguments)?
if self.done: | ||
assert action == self.eos | ||
self.done = False | ||
self.n_actions += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why self.n_actions
is incremented in the backward step, not decremented? I thought it should be the same as the last dimension of the state
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it is not the same. self.n_actions
counts the number of (valid) actions in a trajectory, regardless of whether it is forward of backward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't get all the details but in overall it looks good to me
The main goal of this PR is to enable backward sampling in the Continuous Torus (
ctorus
) environment. This is needed to sample from the replay buffer and to compute evaluation metrics related to the likelihood of test data.#Enabling backward sampling in the CTorus has required changing, among other things, the arguments of (formerly called)
sample_actions()
method of the environments. Specifically, it needs to knowTherefore, these has required changes wherever this method is used.
I have taken the chance to make other changes in this method:
sample_actions_batch()
More things:
set_state(state)
now copies the state before setting it toself.state
. This was the source of painful hidden errors.ctorus
.ctorus.py
statebatch2policy()
of the tori callsstatetorch2proxy
instead of doing numpy-based transformations.get_uniform_terminating_states()
of Tetris since it is outdated and it can use the base env'sget_random_terminating_states()
.get_uniform_terminating_states()
. (See important note about this below)Tried to execute action (2, 0.5769401788711548) not present in action space.
I think this may be fixed in another PR.Sanity checks:
wandb sanity check runs