Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quadrature Refactoring #1505

Merged
merged 47 commits into from
Jul 30, 2020
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
e437e67
WIP: quadrature refactoring
gustavo-delfosim Jun 8, 2020
138b41b
Removing old ndiagquad code
gustavo-delfosim Jun 8, 2020
1d9f5b9
deleted test code
gustavo-delfosim Jun 8, 2020
4e22337
formatting and type-hint
gustavo-delfosim Jun 8, 2020
941548d
merge modules
st-- Jun 8, 2020
f5bc0a6
black formatting
gustavo-delfosim Jun 8, 2020
fdf2bea
Merge branch 'st/quadrature' into gustavocmv/quadrature
gustavo-delfosim Jun 8, 2020
2412309
formatting
gustavo-delfosim Jun 8, 2020
0a64680
solving failing tests
gustavo-delfosim Jun 8, 2020
4592ed8
fixing failing tests
gustavo-delfosim Jun 8, 2020
d1afb32
fixes
gustavo-delfosim Jun 8, 2020
8ab2f4a
adapting tests for new syntax, keeping numerical behavior
gustavo-delfosim Jun 8, 2020
98428db
black formatting
gustavo-delfosim Jun 8, 2020
695d87f
remove printf
gustavo-delfosim Jun 8, 2020
62b79f8
changed code for compiled tf compatibility
gustavo-delfosim Jun 8, 2020
20ad0bc
black
gustavo-delfosim Jun 8, 2020
d4161d8
restored to original version
gustavo-delfosim Jun 8, 2020
029ec7a
undoing changes
gustavo-delfosim Jun 9, 2020
48516ea
renaming
gustavo-delfosim Jun 9, 2020
6ab2ede
renaming
gustavo-delfosim Jun 9, 2020
592bfdc
renaming
gustavo-delfosim Jun 9, 2020
89756c2
reshape kwargs
gustavo-delfosim Jun 9, 2020
0b76b7d
quadrature along axis=-2, simplified broadcasting
gustavo-delfosim Jun 9, 2020
da8b7d2
black
gustavo-delfosim Jun 9, 2020
41eac9b
docs
gustavo-delfosim Jun 9, 2020
77fcfb1
docs
gustavo-delfosim Jun 9, 2020
1b5c22b
helper function
gustavo-delfosim Jun 9, 2020
ed49cf0
docstrings and typing
gustavo-delfosim Jun 11, 2020
7b6df63
Merge branch 'develop' into gustavocmv/quadrature
gustavocmv Jun 11, 2020
dddc5b6
Merge remote-tracking branch 'origin/develop' into gustavocmv/quadrature
gustavo-delfosim Jun 11, 2020
650991d
added new and old quadrature equivalence tests
gustavo-delfosim Jun 13, 2020
d537f67
black
gustavo-delfosim Jun 13, 2020
6aa25a8
Merge branch 'develop' into gustavocmv/quadrature
gustavocmv Jul 14, 2020
bc98286
Removing comments
gustavocmv Jul 20, 2020
5621a9a
Typo
gustavocmv Jul 20, 2020
898f4c0
notation
gustavocmv Jul 20, 2020
72227d6
Merge branch 'develop' into gustavocmv/quadrature
gustavocmv Jul 20, 2020
a69acda
reshape_Z_dZ return docstring fix
gustavo-delfosim Jul 20, 2020
187bf35
FIX: quad_old computed with the ndiagquad_old
gustavocmv Jul 21, 2020
8e13837
more readable implementation
gustavocmv Jul 21, 2020
2524561
tf.ensure_shape added
gustavo-delfosim Jul 21, 2020
7bb0e9f
removed ndiagquad
gustavo-delfosim Jul 21, 2020
8e23524
removed ndiagquad
gustavo-delfosim Jul 21, 2020
d6e7e63
Revert "removed ndiagquad"
gustavo-delfosim Jul 21, 2020
8f7e949
FIX: shape checking of dZ
gustavo-delfosim Jul 21, 2020
d4f313c
Revert "removed ndiagquad"
gustavo-delfosim Jul 21, 2020
04525d8
Merge branch 'develop' into gustavocmv/quadrature
gustavocmv Jul 30, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions gpflow/quadrature/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base import GaussianQuadrature
from .gauss_hermite import NDiagGHQuadrature
from .deprecated import hermgauss, mvhermgauss, mvnquad, ndiagquad, ndiag_mc
93 changes: 93 additions & 0 deletions gpflow/quadrature/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from typing import Callable, Union, Iterable

import abc
import tensorflow as tf

from ..base import TensorType


class GaussianQuadrature:
"""
Abstract class implementing quadrature methods to compute Gaussian Expectations.
Inhering classes must provide the method _build_X_W to create points and weights
gustavocmv marked this conversation as resolved.
Show resolved Hide resolved
to be used for quadrature.
"""

@abc.abstractmethod
def _build_X_W(self, mean: TensorType, var: TensorType):
raise NotImplementedError

def __call__(self, fun, mean, var, *args, **kwargs):
r"""
Compute the Gaussian Expectation of a function f:

X ~ N(mean, var)
E[f(X)] = ∫f(x, *args, **kwargs)p(x)dx

Using the formula:
E[f(X)] = sum_{i=1}^{N_quad_points} f(x_i) * w_i

where x_i, w_i must be provided by the inheriting class through self._build_X_W.
The computations broadcast along batch-dimensions, represented by [b1, b2, ..., bX].

:param fun: Callable or Iterable of Callables that operates elementwise, with
signature f(X, *args, **kwargs). Moreover, if must satisfy the shape-mapping:
X shape: [b1, b2, ..., bX, N_quad_points, d],
usually [N, N_quad_points, d]
f(X) shape: [b1, b2, ...., bf, N_quad_points, d'],
usually [N, N_quad_points, 1] or [N, N_quad_points, d]
In most cases, f should only operate over the last dimension of X
:param mean: Array/Tensor with shape [b1, b2, ..., bX, d], usually [N, d],
representing the mean of a d-Variate Gaussian distribution
:param var: Array/Tensor with shape b1, b2, ..., bX, d], usually [N, d],
representing the variance of a d-Variate Gaussian distribution
:param *args: Passed to fun
:param **kargs: Passed to fun
:return: Array/Tensor with shape [b1, b2, ...., bf, N_quad_points, d'],
usually [N, d] or [N, 1]
"""

X, W = self._build_X_W(mean, var)
if isinstance(fun, Iterable):
# Maybe this can be better implemented by stacking [f1(X), f2(X), ...]
# and sum-reducing all at once
# The problem: there is no garantee that f1(X), f2(X), ...
# have comaptible shapes
gustavocmv marked this conversation as resolved.
Show resolved Hide resolved
return [tf.reduce_sum(f(X, *args, **kwargs) * W, axis=-2) for f in fun]
return tf.reduce_sum(fun(X, *args, **kwargs) * W, axis=-2)

def logspace(self, fun: Union[Callable, Iterable[Callable]], mean, var, *args, **kwargs):
r"""
Compute the Gaussian log-Expectation of a the exponential of a function f:

X ~ N(mean, var)
log E[exp[f(X)]] = log ∫exp[f(x, *args, **kwargs)]p(x)dx

Using the formula:
log E[exp[f(X)]] = log sum_{i=1}^{N_quad_points} exp[f(x_i) + log w_i]

where x_i, w_i must be provided by the inheriting class through self._build_X_W.
The computations broadcast along batch-dimensions, represented by [b1, b2, ..., bX].

:param fun: Callable or Iterable of Callables that operates elementwise, with
signature f(X, *args, **kwargs). Moreover, if must satisfy the shape-mapping:
X shape: [b1, b2, ..., bX, N_quad_points, d],
usually [N, N_quad_points, d]
f(X) shape: [b1, b2, ...., bf, N_quad_points, d'],
usually [N, N_quad_points, 1] or [N, N_quad_points, d]
In most cases, f should only operate over the last dimension of X
:param mean: Array/Tensor with shape [b1, b2, ..., bX, d], usually [N, d],
representing the mean of a d-Variate Gaussian distribution
:param var: Array/Tensor with shape b1, b2, ..., bX, d], usually [N, d],
representing the variance of a d-Variate Gaussian distribution
:param *args: Passed to fun
:param **kargs: Passed to fun
:return: Array/Tensor with shape [b1, b2, ...., bf, N_quad_points, d'],
usually [N, d] or [N, 1]
"""

X, W = self._build_X_W(mean, var)
logW = tf.math.log(W)
if isinstance(fun, Iterable):
return [tf.reduce_logsumexp(f(X, *args, **kwargs) + logW, axis=-2) for f in fun]
return tf.reduce_logsumexp(fun(X, *args, **kwargs) + logW, axis=-2)
71 changes: 35 additions & 36 deletions gpflow/quadrature.py → gpflow/quadrature/deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
import numpy as np
import tensorflow as tf

from .config import default_float
from .utilities import to_default_float
from ..config import default_float
from ..utilities import to_default_float

from .gauss_hermite import NDiagGHQuadrature


def hermgauss(n: int):
Expand Down Expand Up @@ -117,51 +119,48 @@ def ndiagquad(funcs, H: int, Fmu, Fvar, logspace: bool = False, **Ys):
Fmu, Fvar, Ys should all have same shape, with overall size `N`
:return: shape is the same as that of the first Fmu
"""
n_gh = H
if isinstance(Fmu, (tuple, list)):
Din = len(Fmu)

def unify(f_list):
"""Stack a list of means/vars into a full block."""
return tf.reshape(
tensor=tf.concat([tf.reshape(f, shape=(-1, 1)) for f in f_list], axis=1),
shape=(-1, 1, Din),
)

dim = len(Fmu)
shape = tf.shape(Fmu[0])
Fmu, Fvar = map(unify, [Fmu, Fvar]) # both [N, 1, Din]
Fmu = tf.stack(Fmu, axis=-1)
Fvar = tf.stack(Fvar, axis=-1)
else:
Din = 1
dim = 1
shape = tf.shape(Fmu)
Fmu, Fvar = [tf.reshape(f, (-1, 1, 1)) for f in [Fmu, Fvar]]

xn, wn = mvhermgauss(H, Din)
# xn: H**Din x Din, wn: H**Din
Fmu = tf.reshape(Fmu, (-1, dim))
Fvar = tf.reshape(Fvar, (-1, dim))

gh_x = xn.reshape(1, -1, Din) # [1, H]**Din x Din
Xall = gh_x * tf.sqrt(2.0 * Fvar) + Fmu # [N, H]**Din x Din
Xs = [Xall[:, :, i] for i in range(Din)] # [N, H]**Din each
Ys = {Yname: tf.reshape(Y, (-1, 1)) for Yname, Y in Ys.items()}

gh_w = wn * np.pi ** (-0.5 * Din) # H**Din x 1
def wrapper(old_fun):
def new_fun(X, **Ys):
Xs = tf.unstack(X, axis=-1)
fun_eval = old_fun(*Xs, **Ys)
if tf.rank(fun_eval) < tf.rank(X):
fun_eval = tf.expand_dims(fun_eval, axis=-1)
return fun_eval

for name, Y in Ys.items():
Y = tf.reshape(Y, (-1, 1))
Y = tf.tile(Y, [1, H ** Din]) # broadcast Y to match X
# without the tiling, some calls such as tf.where() (in bernoulli) fail
Ys[name] = Y # now [N, H]**Din

def eval_func(f):
feval = f(*Xs, **Ys) # f should be elementwise: return shape [N, H]**Din
if logspace:
log_gh_w = np.log(gh_w.reshape(1, -1))
result = tf.reduce_logsumexp(feval + log_gh_w, axis=1)
else:
result = tf.linalg.matmul(feval, gh_w.reshape(-1, 1))
return tf.reshape(result, shape)
return new_fun

if isinstance(funcs, Iterable):
return [eval_func(f) for f in funcs]
funcs = [wrapper(f) for f in funcs]
else:
funcs = wrapper(funcs)

quadrature = NDiagGHQuadrature(dim, n_gh)
if logspace:
result = quadrature.logspace(funcs, Fmu, Fvar, **Ys)
else:
result = quadrature(funcs, Fmu, Fvar, **Ys)

if isinstance(result, list):
result = [tf.reshape(r, shape) for r in result]
else:
result = tf.reshape(result, shape)

return eval_func(funcs)
return result


def ndiag_mc(funcs, S: int, Fmu, Fvar, logspace: bool = False, epsilon=None, **Ys):
Expand Down
110 changes: 110 additions & 0 deletions gpflow/quadrature/gauss_hermite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from typing import List

import numpy as np
import tensorflow as tf

from .base import GaussianQuadrature
from ..config import default_float

from ..base import TensorType


def gh_points_and_weights(n_gh: int):
r"""
Given the number of Gauss-Hermite points n_gh,
returns the points z and the weights dz to perform the following
uni-dimensional gaussian quadrature:

X ~ N(mean, stddev²)
E[f(X)] = ∫f(x)p(x)dx = sum_{i=1}^{n_gh} f(mean + stddev*z_i)*dz_i

:param n_gh: Number of Gauss-Hermite points, integer
:returns: Points z and weights dz, both tensors with shape [n_gh],
to compute uni-dimensional gaussian expectation
"""
z, dz = np.polynomial.hermite.hermgauss(n_gh)
z = z * np.sqrt(2)
dz = dz / np.sqrt(np.pi)
z, dz = z.astype(default_float()), dz.astype(default_float())
return tf.convert_to_tensor(z), tf.convert_to_tensor(dz)


def list_to_flat_grid(xs: List[TensorType]):
"""
:param xs: List with d rank-1 Tensors, with shapes N1, N2, ..., Nd
:return: Tensor with shape [N1*N2*...*Nd, dim] representing the flattened
D-dimensional grid built from the input tensors xs
gustavocmv marked this conversation as resolved.
Show resolved Hide resolved
"""
return tf.reshape(tf.stack(tf.meshgrid(*xs), axis=-1), (-1, len(xs)))


def reshape_Z_dZ(zs: List[TensorType], dzs: List[TensorType]):
"""
:param zs: List with d rank-1 Tensors, with shapes N1, N2, ..., Nd
gustavocmv marked this conversation as resolved.
Show resolved Hide resolved
:param dzs: List with d rank-1 Tensors, with shapes N1, N2, ..., Nd
:returns: points Z, Tensor with shape [n_gh**dim, dim],
and weights dZ, Tensor with shape [n_gh**dim, 1]
"""
Z = list_to_flat_grid(zs)
dZ = tf.reduce_prod(list_to_flat_grid(dzs), axis=-1, keepdims=True)
return Z, dZ


def repeat_as_list(x: TensorType, n: int):
"""
:param x: Array/Tensor to be repeated
:param n: Integer with the number of repetitions
:return: List of n repetitions of Tensor x
"""
return tf.unstack(tf.repeat(tf.expand_dims(x, axis=0), n, axis=0), axis=0)
gustavocmv marked this conversation as resolved.
Show resolved Hide resolved


def ndgh_points_and_weights(dim: int, n_gh: int):
r"""
:param n_gh: number of Gauss-Hermite points, integer
:param dim: dimension of the multivariate normal, integer
:returns: points Z, Tensor with shape [n_gh**dim, dim],
and weights dZ, Tensor with shape [n_gh**dim, 1]
"""
z, dz = gh_points_and_weights(n_gh)
zs = repeat_as_list(z, dim)
dzs = repeat_as_list(dz, dim)
return reshape_Z_dZ(zs, dzs)


class NDiagGHQuadrature(GaussianQuadrature):
def __init__(self, dim: int, n_gh: int):
"""
:param n_gh: number of Gauss-Hermite points, integer
:param dim: dimension of the multivariate normal, integer
"""
Z, dZ = ndgh_points_and_weights(dim, n_gh)
self.n_gh_total = n_gh ** dim
self.Z = tf.convert_to_tensor(Z)
self.dZ = tf.convert_to_tensor(dZ)
# Z: [n_gh_total, dim]
gustavocmv marked this conversation as resolved.
Show resolved Hide resolved
# dZ: [n_gh_total, 1]

def _build_X_W(self, mean: TensorType, var: TensorType):
"""
:param mean: Array/Tensor with shape [b1, b2, ..., bX, dim], usually [N, dim],
representing the mean of a dim-Variate Gaussian distribution
:param var: Array/Tensor with shape b1, b2, ..., bX, dim], usually [N, dim],
representing the variance of a dim-Variate Gaussian distribution
:return: points X, Tensor with shape [b1, b2, ..., bX, n_gh_total, dim],
usually [N, n_gh_total, dim],
and weights W, a Tensor with shape [b1, b2, ..., bX, n_gh_total, 1],
usually [N, n_gh_total, 1]
"""

# mean, stddev: [b1, b2, ..., bX, dim], usually [N, dim]
mean = tf.expand_dims(mean, -2)
stddev = tf.expand_dims(tf.sqrt(var), -2)
# mean, stddev: [b1, b2, ..., bX, 1, dim], usually [N, 1, dim]

X = mean + stddev * self.Z
W = self.dZ
# X: [b1, b2, ..., bX, n_gh_total, dim], usually [N, n_gh_total, dim]
# W: [b1, b2, ..., bX, n_gh_total, 1], usually [N, n_gh_total, 1]

return X, W
73 changes: 73 additions & 0 deletions tests/gpflow/quadrature/ndiagquad_old.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from collections.abc import Iterable

import numpy as np
import tensorflow as tf

from gpflow.quadrature.deprecated import mvhermgauss


def ndiagquad(funcs, H: int, Fmu, Fvar, logspace: bool = False, **Ys):
"""
Computes N Gaussian expectation integrals of one or more functions
using Gauss-Hermite quadrature. The Gaussians must be independent.
The means and variances of the Gaussians are specified by Fmu and Fvar.
The N-integrals are assumed to be taken wrt the last dimensions of Fmu, Fvar.
:param funcs: the integrand(s):
Callable or Iterable of Callables that operates elementwise
:param H: number of Gauss-Hermite quadrature points
:param Fmu: array/tensor or `Din`-tuple/list thereof
:param Fvar: array/tensor or `Din`-tuple/list thereof
:param logspace: if True, funcs are the log-integrands and this calculates
the log-expectation of exp(funcs)
:param **Ys: arrays/tensors; deterministic arguments to be passed by name
Fmu, Fvar, Ys should all have same shape, with overall size `N`
:return: shape is the same as that of the first Fmu
"""
if isinstance(Fmu, (tuple, list)):
Din = len(Fmu)

def unify(f_list):
"""Stack a list of means/vars into a full block."""
return tf.reshape(
tensor=tf.concat([tf.reshape(f, shape=(-1, 1)) for f in f_list], axis=1),
shape=(-1, 1, Din),
)

shape = tf.shape(Fmu[0])
Fmu, Fvar = map(unify, [Fmu, Fvar]) # both [N, 1, Din]

print(Fmu)
print(Fvar)
else:
Din = 1
shape = tf.shape(Fmu)
Fmu, Fvar = [tf.reshape(f, (-1, 1, 1)) for f in [Fmu, Fvar]]

xn, wn = mvhermgauss(H, Din)
# xn: H**Din x Din, wn: H**Din

gh_x = xn.reshape(1, -1, Din) # [1, H]**Din x Din
Xall = gh_x * tf.sqrt(2.0 * Fvar) + Fmu # [N, H]**Din x Din
Xs = [Xall[:, :, i] for i in range(Din)] # [N, H]**Din each

gh_w = wn * np.pi ** (-0.5 * Din) # H**Din x 1

for name, Y in Ys.items():
Y = tf.reshape(Y, (-1, 1))
Y = tf.tile(Y, [1, H ** Din]) # broadcast Y to match X
# without the tiling, some calls such as tf.where() (in bernoulli) fail
Ys[name] = Y # now [N, H]**Din

def eval_func(f):
feval = f(*Xs, **Ys) # f should be elementwise: return shape [N, H]**Din
if logspace:
log_gh_w = np.log(gh_w.reshape(1, -1))
result = tf.reduce_logsumexp(feval + log_gh_w, axis=1)
else:
result = tf.linalg.matmul(feval, gh_w.reshape(-1, 1))
return tf.reshape(result, shape)

if isinstance(funcs, Iterable):
return [eval_func(f) for f in funcs]

return eval_func(funcs)