Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Log likelihood estimation #167

Merged
merged 55 commits into from
Sep 6, 2023
Merged

Conversation

alexhernandezgarcia
Copy link
Owner

@alexhernandezgarcia alexhernandezgarcia commented Jul 29, 2023

Summary

Code to estimate the log likelihood of sampling data points with the current GFLowNet policy, which can be used to compute potentially insightful metrics. This PR includes the calculation of:

  • The correlation between the estimated log-likelihood of test data and the rewards of the data.
  • The mean negative log likelihood of the test data.

You can see one example on wandb.

Estimation of the log-likelihood

The log-likelihood is estimated by sampling backward trajectories according to the backward policy, then calculating the log probability of a sample with importance sampling, where the weights are the backward transition probabilities of the trajectories. In particular:

$\log p_T(x) = \int_{x \in \tau} P_F(\tau)d\tau$

$= \log \mathbb{E}_{P_B(\tau|x)} \frac{P_F(x)}{P_B(\tau|x)}$

$\approx \log \frac{1}{N} \sum_{i=1}^{N} \frac{P_F(x_i)}{P_B(\tau|x_i)}, x_i \sim P_B(\tau|x_i)$

Other notes

  • For convenience, I have implemented a new method GFlowNetAgent.compute_logprobs_trajectories(), which is used now too by the trajectory balance loss method.
  • The code of estimate_logprobs_data()may be simplified and modularised a bit, but it's not a big deal anyway.

The core of this PR is ready to review but note that the code to calculate the metrics in test() has become even messier and there may be issues before/after merging. A future PR should organise the evaluation code.

Things in this PR not directly related to the goal

  • set_state() for the Tetris environment, to catch cases where the state and done are incompatible.
  • I had to implement a couple of methods in the Batch to make the trajectory indices consecutive.

alexhernandezgarcia and others added 30 commits June 22, 2023 09:52
Copy link
Collaborator

@michalkoziarski michalkoziarski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reviewed it, it seems ok to me (minus some minor comments). I would love to get one more review from someone else (at least for the estimate_logprobs_data function, the rest should be fine), but I guess that will be unlikely to get at the moment in a reasonable time.

One question: did you confirm that the results with this are comparable to previous versions on your test environments (since technically you change the TB loss)?

gflownet/utils/batch.py Show resolved Hide resolved
gflownet/utils/batch.py Outdated Show resolved Hide resolved
gflownet/gflownet.py Outdated Show resolved Hide resolved
gflownet/gflownet.py Outdated Show resolved Hide resolved
gflownet/gflownet.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@AlexandraVolokhova AlexandraVolokhova left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this great job! I left a couple of comments for simplifying / making more readable the code, but in overall, it looks good to me.

gflownet/gflownet.py Outdated Show resolved Hide resolved
gflownet/gflownet.py Outdated Show resolved Hide resolved
gflownet/gflownet.py Show resolved Hide resolved
@AlexandraVolokhova
Copy link
Collaborator

One more thing: I'd add computing and tracking two variances of the log probs:

  1. variance over samples of logprobs_estimates (to understand better the behaviour of the correlation coefficient over the training)
  2. median over samples of the variances of the logprobs_estimates over trajectories for each sample (to get a sense of how noisy the estimation is). The math is a bit tricky here as we use log mean as an estimation, not just the mean. But there're some work around: https://stats.stackexchange.com/questions/418313/variance-of-x-and-variance-of-logx-how-to-relate-them
    But in any case, we will need to compute empirical var(P_F(tau) / P_B (tau)) / n_traj for each sample and then play around a bit with it to get variance for the log mean estimation.

@alexhernandezgarcia
Copy link
Owner Author

I did compare the results with previous versions, at least with the Grid, and I have later run more instances. See this report on the sanity checks project on wandb.

@alexhernandezgarcia
Copy link
Owner Author

One more thing: I'd add computing and tracking two variances of the log probs:

1. variance over samples of logprobs_estimates (to understand better the behaviour of the correlation coefficient over the training)

2. median over samples of the variances of the logprobs_estimates over trajectories for each sample (to get a sense of how noisy the estimation is). The math is a bit tricky here as we use log mean as an estimation, not just the mean. But there're some work around: https://stats.stackexchange.com/questions/418313/variance-of-x-and-variance-of-logx-how-to-relate-them
   But in any case, we will need to compute empirical var(P_F(tau) / P_B (tau)) / n_traj for each sample and then play around a bit with it to get variance for the log mean estimation.

I agree. I will not do this in this PR so as to merge asap. I have opened an issue instead: #192

Copy link
Collaborator

@michalkoziarski michalkoziarski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the changes and the tests - looks good to me!

@alexhernandezgarcia alexhernandezgarcia merged commit ca157c1 into main Sep 6, 2023
1 check passed
@josephdviviano josephdviviano deleted the log-likelihood-estimation branch January 31, 2024 21:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants