Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update parameter to be savable #1518

Merged
merged 22 commits into from
Aug 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions doc/source/notebooks/theory/FITCvsVFE.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ def repeatMinimization(model, Xtest, Ytest):

def trainSparseModel(Xtrain, Ytrain, exact_model, isFITC, Xtest, Ytest):
sparse_model = getSparseModel(Xtrain, Ytrain, isFITC)
sparse_model.likelihood.variance = exact_model.likelihood.variance.read_value().copy()
sparse_model.kern.lengthscales = exact_model.kern.lengthscales.read_value().copy()
sparse_model.kern.variance = exact_model.kern.variance.read_value().copy()
sparse_model.likelihood.variance = exact_model.likelihood.variance.numpy()
sparse_model.kern.lengthscales = exact_model.kern.lengthscales.numpy()
sparse_model.kern.variance = exact_model.kern.variance.numpy()
return sparse_model, repeatMinimization(sparse_model, Xtest, Ytest)


Expand Down
7 changes: 4 additions & 3 deletions doc/source/notebooks/understanding/models.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,11 @@
p.transform.inverse(p)

# %% [markdown]
# You can also change the `transform` attribute in place:
# To replace the `transform` of a parameter you need to recreate the parameter with updated transform:

# %%
m.kernel.kernels[0].variance.transform = tfp.bijectors.Exp()
vdutor marked this conversation as resolved.
Show resolved Hide resolved
p = m.kernel.kernels[0].variance
m.kernel.kernels[0].variance = gpflow.Parameter(p.numpy(), transform=tfp.bijectors.Exp())

# %%
print_summary(m, fmt="notebook")
Expand Down Expand Up @@ -281,7 +282,7 @@ def maximum_log_likelihood_objective(self):
# %%
xx, yy = np.mgrid[-4:4:200j, -4:4:200j]
X_test = np.vstack([xx.flatten(), yy.flatten()]).T
f_test = np.dot(X_test, m.W.read_value()) + m.b.read_value()
f_test = np.dot(X_test, m.W.numpy()) + m.b.numpy()
p_test = np.exp(f_test)
p_test /= p_test.sum(1)[:, None]

Expand Down
191 changes: 27 additions & 164 deletions gpflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class PriorOn(Enum):
UNCONSTRAINED = "unconstrained"


class Parameter(tf.Module):
class Parameter(tfp.util.TransformedVariable):
def __init__(
self,
value: TensorData,
Expand All @@ -103,35 +103,31 @@ def __init__(
therefore we need a positive constraint and it is natural to use constrained values.
A prior can be imposed either on the constrained version (default) or on the unconstrained version of the parameter.
"""
super().__init__()
if transform is None:
transform = tfp.bijectors.Identity()

value = _cast_to_dtype(value, dtype)
_validate_unconstrained_value(value, transform, dtype)
super().__init__(value, transform, dtype=value.dtype, trainable=trainable, name=name)

self._transform = transform
self.prior = prior
self.prior_on = prior_on # type: ignore # see https://github.com/python/mypy/issues/3004

if isinstance(value, tf.Variable):
self._unconstrained = value
else:
unconstrained_value = self.validate_unconstrained_value(value, dtype)
self._unconstrained = tf.Variable(
unconstrained_value, dtype=dtype, name=name, trainable=trainable
)

def log_prior_density(self) -> tf.Tensor:
""" Log of the prior probability density of the constrained variable. """

if self.prior is None:
return tf.convert_to_tensor(0.0, dtype=self.dtype)

y = self.read_value()
y = self

if self.prior_on == PriorOn.CONSTRAINED:
# evaluation is in same space as prior
return tf.reduce_sum(self.prior.log_prob(y))

else:
# prior on unconstrained, but evaluating log-prior in constrained space
x = self._unconstrained
x = self.unconstrained_variable
log_p = tf.reduce_sum(self.prior.log_prob(x))

if self.transform is not None:
Expand All @@ -149,31 +145,13 @@ def prior_on(self) -> PriorOn:
def prior_on(self, value: Union[str, PriorOn]) -> None:
self._prior_on = PriorOn(value)

def value(self) -> tf.Tensor:
return _to_constrained(self._unconstrained.value(), self.transform) # type: ignore # assumes _to_constrained returns a tf.Tensor

def read_value(self) -> tf.Tensor:
return _to_constrained(self._unconstrained.read_value(), self.transform) # type: ignore # assumes _to_constrained returns a tf.Tensor

def experimental_ref(self) -> "Parameter":
return self

def deref(self) -> "Parameter":
return self

@property
def unconstrained_variable(self) -> tf.Variable:
return self._unconstrained
return self._pretransformed_input

@property
def transform(self) -> Optional[Transform]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand this is mainly because of legacy, but why are we still using the term transform? We could make the interface cleaner if we followed TFP's naming and used the word bijector instead. Having two words for the same object makes the code harder to read and understand. This PR seems like the right place and time to update this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like Bijector name, because not every bijector that TFP has is actually a bijection. Transform name has been used since the start of gpflow and I think keeping that name indeed reflects meaning better. And another reason for transform - keep the commit diff minimal.

return self._transform

@transform.setter
def transform(self, new_transform: Optional[Transform]) -> None:
constrained_value = self.read_value()
self._transform = new_transform
self.assign(constrained_value)
return self.bijector

@property
def trainable(self) -> bool:
Expand All @@ -182,21 +160,7 @@ def trainable(self) -> bool:

This attribute cannot be set directly. Use :func:`gpflow.set_trainable`.
"""
return self._unconstrained.trainable

@property
def initial_value(self) -> tf.Tensor:
return self._unconstrained.initial_value

def validate_unconstrained_value(self, value: TensorData, dtype: DType) -> tf.Tensor:
value = _cast_to_dtype(value, dtype)
unconstrained_value = _to_unconstrained(value, self.transform)
message = (
"gpflow.Parameter: the value to be assigned is incompatible with this parameter's "
"transform (the corresponding unconstrained value has NaN or Inf) and hence cannot be "
"assigned."
)
return tf.debugging.assert_all_finite(unconstrained_value, message=message)
return self.unconstrained_variable.trainable

def assign(
self,
Expand Down Expand Up @@ -225,125 +189,11 @@ def assign(
:param read_value: if True, will return something which evaluates to the new
value of the variable; if False will return the assign op.
"""
unconstrained_value = self.validate_unconstrained_value(value, self.dtype)
return self._unconstrained.assign(
unconstrained_value = _validate_unconstrained_value(value, self.transform, self.dtype)
return self.unconstrained_variable.assign(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The super class already has this implemented. Making use of it makes it more clear what is happening:

  1. validate that the unconstrained value exists.
  2. assign it.
Suggested change
return self.unconstrained_variable.assign(
return super().assign(value, use_locking, name, read_value)

Copy link
Contributor

@sam-willis sam-willis Aug 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're comment is right, the suggestion is wrong (I guess you know that though) - If you're going to do that, you can simply remove the method, it'll inherit it from the parent.

I suppose what you're suggesting is

unconstrained_value = super().assign(value, use_locking, name, read_value)
message = (
        "gpflow.Parameter: the value to be assigned is incompatible with this parameter's "
        "transform (the corresponding unconstrained value has NaN or Inf) and hence cannot be "
        "assigned."
    )
return tf.debugging.assert_all_finite(unconstrained_value, message=message)

Which makes sense to me.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, the transformation will be performed 2 times. We redefine the method so that we could do extra steps before assigning value without sacrificing performance.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand what you mean. The snippet @sam-willis posted seems correct to me?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it will only be performed twice if you call _validate_unconstrained_value, which is why I suggested just using tf.debugging.assert_all_finite. You could make the message a constant, if you don't want code duplication.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can remove _validate_unconstrained_value, _unconstrained_value and _constrained_value, and just check that after either assignment or initialisation the unconstrained value is finite.

Copy link
Member Author

@awav awav Aug 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't apply Sam's suggestion because it changes the value first and then checks correctness, whereas it should be vice versa.

Also, I don't understand what you mean by: I think you can remove _validate_unconstrained_value, _unconstrained_value and _constrained_value

Copy link
Contributor

@sam-willis sam-willis Aug 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I do see your point, maybe it's better to not change the value if it's invalid, although an exception is thrown, so I'm not sure how important it is. I guess if you catch the exception your current code makes recovering easier.

Copy link
Contributor

@sam-willis sam-willis Aug 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_constrained_value is unused, and technically you should always have the identity for the transform now, so you don't need the if in _unconstrained_value, but I think those are very minor.

unconstrained_value, use_locking=use_locking, name=name, read_value=read_value
)

@property
def is_tensor_like(self) -> bool:
"""
This method means that TensorFlow's `tensor_util.is_tensor` function
will return `True`
"""
return True

@property
def name(self) -> str:
return self._unconstrained.name

@property
def initializer(self): # type unknown
return self._unconstrained.initializer

@property
def device(self) -> Optional[str]:
return self._unconstrained.device

@property
def dtype(self) -> tf.DType:
return self._unconstrained.dtype

@property
def op(self) -> tf.Operation:
return self._unconstrained.op

@property
def shape(self) -> tf.TensorShape:
if self.transform is not None:
return self.transform.forward_event_shape(self._unconstrained.shape)
return self._unconstrained.shape

def numpy(self) -> np.ndarray:
return self.read_value().numpy()

def get_shape(self) -> tf.TensorShape:
return self.shape

def _should_act_as_resource_variable(self): # type unknown
# needed so that Parameters are correctly identified by TensorFlow's
# is_resource_variable() in resource_variable_ops.py
pass # only checked by TensorFlow using hasattr()

@property
def handle(self): # type unknown
return self._unconstrained.handle

def __repr__(self) -> str:
unconstrained = self.unconstrained_variable
constrained = self.read_value()
if tf.executing_eagerly():
info = (
f"unconstrained-shape={unconstrained.shape} "
f"unconstrained-value={unconstrained.numpy()} "
f"constrained-shape={constrained.shape} "
f"constrained-value={constrained.numpy()}"
)
else:
if unconstrained.shape == constrained.shape:
info = f"shape={constrained.shape}"
else:
info = (
f"unconstrained-shape={unconstrained.shape} "
f"constrained-shape={constrained.shape}"
)

return f"<gpflow.Parameter {self.name!r} dtype={self.dtype.name} {info}>"

# Below
# TensorFlow copy-paste code to make variable-like object to work

@classmethod
def _OverloadAllOperators(cls): # pylint: disable=invalid-name
"""Register overloads for all operators."""
for operator in tf.Tensor.OVERLOADABLE_OPERATORS:
cls._OverloadOperator(operator)
# For slicing, bind getitem differently than a tensor (use SliceHelperVar
# instead)
# pylint: disable=protected-access
setattr(cls, "__getitem__", array_ops._SliceHelperVar)

@classmethod
def _OverloadOperator(cls, operator): # pylint: disable=invalid-name
"""Defer an operator overload to `ops.Tensor`.

We pull the operator out of ops.Tensor dynamically to avoid ordering issues.

Args:
operator: string. The operator name.
"""
tensor_oper = getattr(tf.Tensor, operator)

def _run_op(a, *args, **kwargs):
# pylint: disable=protected-access
return tensor_oper(a.read_value(), *args, **kwargs)

functools.update_wrapper(_run_op, tensor_oper)
setattr(cls, operator, _run_op)

# NOTE(mrry): This enables the Variable's overloaded "right" binary
# operators to run when the left operand is an ndarray, because it
# accords the Variable class higher priority than an ndarray, or a
# numpy matrix.
# TODO(mrry): Convert this to using numpy's __numpy_ufunc__
# mechanism, which allows more control over how Variables interact
# with ndarrays.
__array_priority__ = 100


Parameter._OverloadAllOperators()
tf.register_tensor_conversion_function(Parameter, lambda x, *args, **kwds: x.read_value())


def _cast_to_dtype(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't TF2.2 resolve the issue with cast - can we simply use tf.case now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we still support 2.1

value: TensorData, dtype: Optional[DType] = None
Expand All @@ -360,6 +210,19 @@ def _cast_to_dtype(
return tf.convert_to_tensor(value, dtype=dtype)


def _validate_unconstrained_value(
value: TensorData, transform: tfp.bijectors.Bijector, dtype: DType
) -> tf.Tensor:
value = _cast_to_dtype(value, dtype)
unconstrained_value = _to_unconstrained(value, transform)
message = (
"gpflow.Parameter: the value to be assigned is incompatible with this parameter's "
"transform (the corresponding unconstrained value has NaN or Inf) and hence cannot be "
"assigned."
)
return tf.debugging.assert_all_finite(unconstrained_value, message=message)


def _to_constrained(value: TensorType, transform: Optional[Transform]) -> TensorType:
if transform is not None:
return transform.forward(value)
Expand Down
9 changes: 4 additions & 5 deletions gpflow/optimizers/natgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,15 +195,14 @@ def _natgrad_steps(
:param parameters: List of tuples (q_mu, q_sqrt, xi_transform)
"""
q_mus, q_sqrts, xis = zip(*parameters)
unconstrained_variables = [
p.unconstrained_variable for params in (q_mus, q_sqrts) for p in params
]
q_mu_vars = [p.unconstrained_variable for p in q_mus]
q_sqrt_vars = [p.unconstrained_variable for p in q_sqrts]

with tf.GradientTape(watch_accessed_variables=False) as tape:
tape.watch(unconstrained_variables)
tape.watch(q_mu_vars + q_sqrt_vars)
loss = loss_fn()

q_mu_grads, q_sqrt_grads = tape.gradient(loss, [q_mus, q_sqrts])
q_mu_grads, q_sqrt_grads = tape.gradient(loss, [q_mu_vars, q_sqrt_vars])
# NOTE that these are the gradients in *unconstrained* space

with tf.name_scope(f"{self._name}/natural_gradient_steps"):
Expand Down
13 changes: 9 additions & 4 deletions gpflow/utilities/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,17 +190,22 @@ def leaf_components(input: tf.Module):
def _merge_leaf_components(
input: Dict[str, Union[tf.Variable, tf.Tensor, Parameter]]
) -> Dict[str, Union[tf.Variable, tf.Tensor, Parameter]]:
input_values = set([value.experimental_ref() for value in input.values()])

ref_fn = lambda x: (x if isinstance(x, Parameter) else x.ref())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not add a ref / deref method to Parameter? Seems like this is just missing from TFP's TransformedVariable API.

def ref(self):
    return self.unconstrained.ref()

Copy link
Member Author

@awav awav Aug 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter is not a tensor object and not a trait to the tensor class, it simply a wrapper to a tensor. You cannot use the same reference for two different objects. That was one of the reasons why saving didn't work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it not be cleaner to have

def ref(self):
    return self

in order to get rid of ref_fn and deref_fn? It seems to me that this referencing functionality could be handy elsewhere as well at some point.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ref() belongs to tensor API, and ref() returns a different type. There must be a distinction between these two for clarity. I would prefer not to introduce this method to the parameter class because of a printing issue.

deref_fn = lambda x: (x if isinstance(x, Parameter) else x.deref())

input_values = set([ref_fn(value) for value in input.values()])
if len(input_values) == len(input):
return input

tmp_dict = dict() # Type: Dict[ref, str]
for key, variable in input.items():
ref = variable.experimental_ref()
for key, value in input.items():
ref = ref_fn(value)
if ref in tmp_dict:
tmp_dict[ref] = f"{tmp_dict[ref]}\n{key}"
else:
tmp_dict[ref] = key
return {key: ref.deref() for ref, key in tmp_dict.items()}
return {key: deref_fn(ref) for ref, key in tmp_dict.items()}


def _get_leaf_components(input_module: tf.Module):
Expand Down
2 changes: 1 addition & 1 deletion tests/gpflow/expectations/test_expectations.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def test_RBF_eKzxKxz_gradient_notNaN():

with tf.GradientTape() as tape:
ekz = expectation(p, (kernel, z), (kernel, z))
grad = tape.gradient(ekz, kernel.lengthscales)
grad = tape.gradient(ekz, kernel.lengthscales.unconstrained_variable)
assert grad is not None and not np.isnan(grad)


Expand Down
Loading