-
Notifications
You must be signed in to change notification settings - Fork 435
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
predict_f and predict_y as NamedTuple #1657
base: develop
Are you sure you want to change the base?
Changes from 3 commits
05873f7
b318676
35e3617
c737f20
9a0ad7e
1bafd34
415bad7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -13,7 +13,7 @@ | |||||
# limitations under the License. | ||||||
|
||||||
import abc | ||||||
from typing import Optional, Tuple | ||||||
from typing import NamedTuple, Optional | ||||||
|
||||||
import tensorflow as tf | ||||||
|
||||||
|
@@ -25,7 +25,12 @@ | |||||
from ..utilities import to_default_float | ||||||
from .training_mixins import InputData, RegressionData | ||||||
|
||||||
MeanAndVariance = Tuple[tf.Tensor, tf.Tensor] | ||||||
|
||||||
class MeanAndVariance(NamedTuple): | ||||||
""" NamedTuple to access mean- and variance-function separately """ | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
mean: tf.Tensor | ||||||
variance: tf.Tensor | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More of a general question at other maintainers (@vdutor @markvdw @awav) & anyone else ...: what should the name of this field be? In code we often abbreviate the variance to "var", but also it sometimes represents the covariance... or maybe that should be a different type? MeanAndVariance and MeanAndCov (and return a different one depending on full_cov etc)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i like the idea of the two types. Haven't thought about it in detail There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you always know for sure if you've got the cov or var? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you pass in full_cov=False and full_output_cov=False, you get the marginals back. If one of full_cov or full_output_cov is True, you get the covariance over inputs or outputs, respectively. If both are True, you should get the N P x N P covariance matrix (though this combination isn't actually implemented in several cases, I believe). So the output type is solely determined by the full_cov and full_output_cov arguments. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. btw you can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've created a small example, which should work as wanted: But I' not comfortable with the if-else statement :/ from typing import Literal, NamedTuple, overload
class MeanAndVariance(NamedTuple):
mean: int
variance: int
class MeanAndCovariance(NamedTuple):
mean: int
covariance: int
@overload
def predict_f(auto_cov: Literal[False]) -> MeanAndVariance:
...
@overload
def predict_f(auto_cov: Literal[True]) -> MeanAndCovariance:
...
def predict_f(auto_cov: bool = False) -> (MeanAndVariance | MeanAndCovariance):
# calculations
return MeanAndCovariance(1, 2) if auto_cov else MeanAndVariance(1, 2) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think there's a good way around the if-else, and I think it's better to be explicit than the ambiguity of having to remember whether it's a [N, Q] or [Q, N, N] tensor ...:) I'd be happy with this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As it pointed out, typing.Literal is only available from Python 3.8 and up :/ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @antonykamp the typing_extensions module provides backports for older versions of Python, it does seem to include Literal. :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added this pattern with Also, I wanted to ask if the parameters of the model constructor should be listed in the parametrization? |
||||||
|
||||||
|
||||||
class BayesianModel(Module, metaclass=abc.ABCMeta): | ||||||
|
@@ -218,7 +223,7 @@ def predict_y( | |||||
) | ||||||
|
||||||
f_mean, f_var = self.predict_f(Xnew, full_cov=full_cov, full_output_cov=full_output_cov) | ||||||
return self.likelihood.predict_mean_and_var(f_mean, f_var) | ||||||
return MeanAndVariance(*self.likelihood.predict_mean_and_var(f_mean, f_var)) | ||||||
|
||||||
def predict_log_density( | ||||||
self, data: RegressionData, full_cov: bool = False, full_output_cov: bool = False | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(note- if we rename the variance field we'll have to update this)