-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Log likelihood estimation #167
Conversation
…ckward, depending on argument.
…ity-checked on 10x10 Grid)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reviewed it, it seems ok to me (minus some minor comments). I would love to get one more review from someone else (at least for the estimate_logprobs_data
function, the rest should be fine), but I guess that will be unlikely to get at the moment in a reasonable time.
One question: did you confirm that the results with this are comparable to previous versions on your test environments (since technically you change the TB loss)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this great job! I left a couple of comments for simplifying / making more readable the code, but in overall, it looks good to me.
One more thing: I'd add computing and tracking two variances of the log probs:
|
I did compare the results with previous versions, at least with the Grid, and I have later run more instances. See this report on the sanity checks project on wandb. |
…d of each test data point to the test config.
…points have n_trajectories.
…ent Batch.make_indices_consecutive().
…cutive traj. indices (without changing the batch).
I agree. I will not do this in this PR so as to merge asap. I have opened an issue instead: #192 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the changes and the tests - looks good to me!
Summary
Code to estimate the log likelihood of sampling data points with the current GFLowNet policy, which can be used to compute potentially insightful metrics. This PR includes the calculation of:
You can see one example on wandb.
Estimation of the log-likelihood
The log-likelihood is estimated by sampling backward trajectories according to the backward policy, then calculating the log probability of a sample with importance sampling, where the weights are the backward transition probabilities of the trajectories. In particular:
Other notes
GFlowNetAgent.compute_logprobs_trajectories()
, which is used now too by the trajectory balance loss method.estimate_logprobs_data()
may be simplified and modularised a bit, but it's not a big deal anyway.The core of this PR is ready to review but note that the code to calculate the metrics in
test()
has become even messier and there may be issues before/after merging. A future PR should organise the evaluation code.Things in this PR not directly related to the goal
set_state()
for the Tetris environment, to catch cases where the state and done are incompatible.