Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Sorry for the nasty commit history on this one.
This is an example code that shows how to handle
torchgfn
in the case that actions can be sampled from two distributions, depending on the state. In this example, whent=0
, we sample from one distribution, and whent>0
, we sample from a second distribution. Therefore, each batch element of theStates
tensor must sample from one of two distributions.DistributionWrapper
exposes a.sample()
method which sequentially samples fromQuarterDisk
andQuarterCircleWithExit
. These two output spaces are zero-padded so that they are cross-compatible.BoxPFNeuralNet
similarly computes both theS_t=0
parameters, and theS=t>0
parameters, for all batch elements, and then replaces theS=t>0
outputs with theS_t=0
parameters where nessicary.Note that the number of parameters used for
S_t=0
andS=t>0
can differ. We therefore also mask theBoxPFNeuralNet
outputs additionally to account for when the smaller number of parameters is utilized.TODO (In a Follow Up PR):
BoxPFEstimator
andBoxPFNeuralNet
to be two distinct classes. It would be easier if all the logical existed in theBoxPFEstimator
. But we can debate this.Tabular
case works as intended.Extras
1.1.314
.pyproject.toml
.torch.pi /2
and2 / torch.pi
expressed as globals inbox_utils.py
.