Skip to content

Commit

Permalink
Transfer likelihood tests
Browse files Browse the repository at this point in the history
  • Loading branch information
awav committed Mar 12, 2019
1 parent 1ee9146 commit b393722
Show file tree
Hide file tree
Showing 17 changed files with 453 additions and 589 deletions.
2 changes: 1 addition & 1 deletion gpflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def op(self):
@property
def shape(self):
if self.transform is not None:
return self.transform.forward_event_shape(self._unconstrained)
return self.transform.forward_event_shape(self._unconstrained.shape)
return self._unconstrained.shape

def get_shape(self):
Expand Down
4 changes: 2 additions & 2 deletions gpflow/conditionals/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def sample_mvn(mean, cov, cov_structure):
- "full": cov holds the full covariance matrix (without jitter)
:return: sample from the MVN of shape N x D
"""
eps = tf.random_normal(mean.shape, dtype=mean.dtype) # N x P
eps = tf.random.normal(mean.shape, dtype=mean.dtype) # N x P
if cov_structure == "diag":
sample = mean + tf.sqrt(cov) * eps # N x P
elif cov_structure == "full":
Expand Down Expand Up @@ -313,4 +313,4 @@ def fully_correlated_conditional_repeat(Kmn, Kmm, Knn, f, *, full_cov=False, ful
elif not full_cov and not full_output_cov:
addvar = tf.reshape(tf.reduce_sum(LTA ** 2, axis=1), (R, N, K)) # [R, N, K]
fvar = fvar[None, ...] + addvar # [R, N, K]
return fmean, fvar
return fmean, fvar
3 changes: 2 additions & 1 deletion gpflow/covariances/kuus.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import tensorflow as tf

from ..features import InducingPoints, Multiscale
from ..kernels import Kernel, RBF
from .dispatch import Kuu
Expand All @@ -13,7 +14,7 @@ def _Kuu(feat: InducingPoints, kern: Kernel, *, jitter=0.0):

@Kuu.register(Multiscale, RBF)
def _Kuu(feat: Multiscale, kern: RBF, *, jitter=0.0):
Zmu, Zlen = kern.slice(feat.Z, feat.scales())
Zmu, Zlen = kern.slice(feat.Z, feat.scales)
idlengthscales2 = tf.square(kern.lengthscales + Zlen)
sc = tf.sqrt(idlengthscales2[None, ...] + idlengthscales2[:, None, ...]
- kern.lengthscales ** 2)
Expand Down
5 changes: 2 additions & 3 deletions gpflow/features/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .features import (InducingFeature, InducingPoints, InducingPointsBase,
Multiscale)
from .features import InducingFeature, InducingPoints, Multiscale, InducingPointsBase
from .mo_features import (MixedKernelSharedMof, Mof, SeparateIndependentMof,
SharedIndependentMof)
SharedIndependentMof)
1 change: 1 addition & 0 deletions gpflow/kernels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def slice(self, X: tf.Tensor, Y: Optional[tf.Tensor] = None):
X = X[..., dims]
Y = Y[..., dims] if Y is not None else X
elif dims is not None:
# TODO(@awav): Convert when TF2.0 whill support proper slicing.
X = tf.gather(X, dims, axis=-1)
Y = tf.gather(Y, dims, axis=-1) if Y is not None else X
return X, Y
Expand Down
2 changes: 1 addition & 1 deletion gpflow/kullback_leiblers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def gauss_kl(q_mu, q_sqrt, K=None):
Lq = Lq_diag = q_sqrt
Lq_full = tf.linalg.diag(tf.transpose(q_sqrt)) # [B, M, M]
else:
Lq = Lq_full = tf.matrix_band_part(q_sqrt, -1, 0) # force lower triangle # [B, M, M]
Lq = Lq_full = tf.linalg.band_part(q_sqrt, -1, 0) # force lower triangle # [B, M, M]
Lq_diag = tf.linalg.diag_part(Lq) # M x B

# Mahalanobis term: μqᵀ Σp⁻¹ μq
Expand Down
2 changes: 1 addition & 1 deletion gpflow/likelihoods/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .likelihoods import (Bernoulli, Beta, Exponential, Gamma, Gaussian,
GaussianMC, Likelihood, MonteCarloLikelihood,
MultiClass, Ordinal, Poisson, Softmax,
SwitchedLikelihood)
SwitchedLikelihood, StudentT, RobustMax)
26 changes: 13 additions & 13 deletions gpflow/likelihoods/likelihoods.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

def inv_probit(x):
jitter = 1e-3 # ensures output is strictly between 0 and 1
return 0.5 * (1.0 + tf.erf(x / np.sqrt(2.0))) * (1 - 2 * jitter) + jitter
return 0.5 * (1.0 + tf.math.erf(x / np.sqrt(2.0))) * (1 - 2 * jitter) + jitter



Expand Down Expand Up @@ -172,7 +172,7 @@ def conditional_mean(self, F):
def variational_expectations(self, Fmu, Fvar, Y):
if self.invlink is tf.exp:
return Y * Fmu - tf.exp(Fmu + Fvar / 2) * self.binsize \
- tf.lgamma(Y + 1) + Y * tf.math.log(self.binsize)
- tf.math.lgamma(Y + 1) + Y * tf.math.log(self.binsize)
return super(Poisson, self).variational_expectations(Fmu, Fvar, Y)


Expand Down Expand Up @@ -267,7 +267,8 @@ def conditional_variance(self, F):

def variational_expectations(self, Fmu, Fvar, Y):
if self.invlink is tf.exp:
return -self.shape * Fmu - tf.lgamma(self.shape) + (self.shape - 1.) * tf.math.log(Y) - Y * tf.exp(-Fmu + Fvar / 2.)
return -self.shape * Fmu - tf.math.lgamma(self.shape) + (self.shape - 1.) * tf.math.log(
Y) - Y * tf.exp(-Fmu + Fvar / 2.)
else:
return super().variational_expectations(Fmu, Fvar, Y)

Expand Down Expand Up @@ -328,15 +329,15 @@ def __init__(self, num_classes, invlink=None, **kwargs):

def logp(self, F, Y):
hits = tf.equal(tf.expand_dims(tf.argmax(F, 1), 1), tf.cast(Y, tf.int64))
yes = tf.ones(Y.shape, dtype=default_float()) - self.invlink.epsilon()
no = tf.zeros(Y.shape, dtype=default_float()) + self.invlink._eps_K1
yes = tf.ones(Y.shape, dtype=default_float()) - self.invlink.epsilon
no = tf.zeros(Y.shape, dtype=default_float()) + self.invlink.eps_k1
p = tf.where(hits, yes, no)
return tf.math.log(p)

def variational_expectations(self, Fmu, Fvar, Y):
gh_x, gh_w = hermgauss(self.num_gauss_hermite_points)
p = self.invlink.prob_is_largest(Y, Fmu, Fvar, gh_x, gh_w)
ve = p * tf.math.log(1. - self.invlink.epsilon()) + (1. - p) * tf.math.log(self.invlink._eps_K1)
ve = p * tf.math.log(1. - self.invlink.epsilon) + (1. - p) * tf.math.log(self.invlink.eps_k1)
return ve

def predict_mean_and_var(self, Fmu, Fvar):
Expand All @@ -352,7 +353,7 @@ def predict_density(self, Fmu, Fvar, Y):
def _predict_non_logged_density(self, Fmu, Fvar, Y):
gh_x, gh_w = hermgauss(self.num_gauss_hermite_points)
p = self.invlink.prob_is_largest(Y, Fmu, Fvar, gh_x, gh_w)
den = p * (1. - self.invlink.epsilon()) + (1. - p) * (self.invlink._eps_K1)
den = p * (1. - self.invlink.epsilon) + (1. - p) * (self.invlink.eps_k1)
return den

def conditional_mean(self, F):
Expand Down Expand Up @@ -394,8 +395,7 @@ def _partition_and_stitch(self, args, func_name):
args = zip(*[tf.dynamic_partition(X, ind, len(self.likelihoods)) for X in args])

# apply the likelihood-function to each section of the data
with params_as_tensors_for(self, convert=False):
funcs = [getattr(lik, func_name) for lik in self.likelihoods]
funcs = [getattr(lik, func_name) for lik in self.likelihoods]
results = [f(*args_i) for f, args_i in zip(funcs, args)]

# stitch the results back together
Expand Down Expand Up @@ -478,10 +478,10 @@ def _make_phi(self, F):
Note that a matrix of F values is flattened.
"""
scaled_bins_left = tf.concat([self.bin_edges / self.sigma(), np.array([np.inf])], 0)
scaled_bins_right = tf.concat([np.array([-np.inf]), self.bin_edges / self.sigma()], 0)
return inv_probit(scaled_bins_left - tf.reshape(F, (-1, 1)) / self.sigma()) \
- inv_probit(scaled_bins_right - tf.reshape(F, (-1, 1)) / self.sigma())
scaled_bins_left = tf.concat([self.bin_edges / self.sigma, np.array([np.inf])], 0)
scaled_bins_right = tf.concat([np.array([-np.inf]), self.bin_edges / self.sigma], 0)
return inv_probit(scaled_bins_left - tf.reshape(F, (-1, 1)) / self.sigma) \
- inv_probit(scaled_bins_right - tf.reshape(F, (-1, 1)) / self.sigma)

def conditional_mean(self, F):
phi = self._make_phi(F)
Expand Down
12 changes: 6 additions & 6 deletions gpflow/likelihoods/robustmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ class RobustMax(tf.Module):

def __init__(self, num_classes, epsilon=1e-3, **kwargs):
super().__init__(**kwargs)
transform = tfp.bijectors.Logistic()
transform = tfp.bijectors.Sigmoid()
prior = tfp.distributions.Beta(0.2, 5.)
self.epsilon = Parameter(epsilon, transform=transform, prior=prior, trainable=False)
self.num_classes = num_classes

def __call__(self, F):
i = tf.argmax(F, 1)
return tf.one_hot(i, self.num_classes, tf.squeeze(1. - self.epsilon()), tf.squeeze(self._eps_K1))
return tf.one_hot(i, self.num_classes, tf.squeeze(1. - self.epsilon), tf.squeeze(self.eps_k1))

@property
def _eps_K1(self):
return self.epsilon() / (self.num_classes - 1.)
def eps_k1(self):
return self.epsilon / (self.num_classes - 1.)

def prob_is_largest(self, Y, mu, var, gh_x, gh_w):
Y = tf.cast(Y, default_int())
Expand All @@ -48,7 +48,7 @@ def prob_is_largest(self, Y, mu, var, gh_x, gh_w):
# compute the CDF of the Gaussian between the latent functions and the grid (including the selected function)
dist = (tf.expand_dims(X, 1) - tf.expand_dims(mu, 2)) / tf.expand_dims(
tf.sqrt(tf.clip_by_value(var, 1e-10, np.inf)), 2)
cdfs = 0.5 * (1.0 + tf.erf(dist / np.sqrt(2.0)))
cdfs = 0.5 * (1.0 + tf.math.erf(dist / np.sqrt(2.0)))

cdfs = cdfs * (1 - 2e-4) + 1e-4

Expand All @@ -57,4 +57,4 @@ def prob_is_largest(self, Y, mu, var, gh_x, gh_w):
cdfs = cdfs * tf.expand_dims(oh_off, 2) + tf.expand_dims(oh_on, 2)

# take the product over the latent functions, and the sum over the GH grid.
return tf.reduce_prod(cdfs, reduction_indices=[1]) @ tf.reshape(gh_w / np.sqrt(np.pi), (-1, 1))
return tf.reduce_prod(cdfs, axis=[1]) @ tf.reshape(gh_w / np.sqrt(np.pi), (-1, 1))
12 changes: 6 additions & 6 deletions gpflow/logdensities.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,21 @@ def bernoulli(x, p):


def poisson(x, lam):
return x * tf.math.log(lam) - lam - tf.lgamma(x + 1.)
return x * tf.math.log(lam) - lam - tf.math.lgamma(x + 1.)


def exponential(x, scale):
return - x/scale - tf.math.log(scale)


def gamma(x, shape, scale):
return -shape * tf.math.log(scale) - tf.lgamma(shape) \
return -shape * tf.math.log(scale) - tf.math.lgamma(shape) \
+ (shape - 1.) * tf.math.log(x) - x / scale


def student_t(x, mean, scale, df):
df = tf.cast(df, default_float())
const = tf.lgamma((df + 1.) * 0.5) - tf.lgamma(df * 0.5) \
const = tf.math.lgamma((df + 1.) * 0.5) - tf.math.lgamma(df * 0.5) \
- 0.5 * (tf.math.log(tf.square(scale)) + tf.math.log(df) + np.log(np.pi))
const = tf.cast(const, default_float())
return const - 0.5 * (df + 1.) * \
Expand All @@ -60,9 +60,9 @@ def beta(x, alpha, beta):
# need to clip x, since log of 0 is nan...
x = tf.clip_by_value(x, 1e-6, 1-1e-6)
return (alpha - 1.) * tf.math.log(x) + (beta - 1.) * tf.math.log(1. - x) \
+ tf.lgamma(alpha + beta)\
- tf.lgamma(alpha)\
- tf.lgamma(beta)
+ tf.math.lgamma(alpha + beta)\
- tf.math.lgamma(alpha)\
- tf.math.lgamma(beta)


def laplace(x, mu, sigma):
Expand Down
2 changes: 1 addition & 1 deletion gpflow/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def predict_f_samples(self, X, num_samples, full_cov=False, full_output_cov=Fals
for i in range(self.num_latent):
L = tf.linalg.cholesky(var[i, :, :] + jitter)
shape = tf.stack([L.shape[0], num_samples])
V = tf.random_normal(shape, dtype=L.dtype, seed=self.seed)
V = tf.random.normal(shape, dtype=L.dtype, seed=self.seed)
samples[i] = mu[:, i:(i+1)] + L @ V
return tf.linalg.transpose(tf.stack(samples))

Expand Down
6 changes: 3 additions & 3 deletions gpflow/models/svgp.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def prior_kl(self):
K = None
if not self.whiten:
K = Kuu(self.feature, self.kernel, jitter=default_jitter()) # [P, M, M] or [M, M]
return kullback_leiblers.gauss_kl(self.q_mu(), self.q_sqrt(), K)
return kullback_leiblers.gauss_kl(self.q_mu, self.q_sqrt, K)

def log_likelihood(self, X: tf.Tensor, Y: tf.Tensor) -> tf.Tensor:
"""
Expand All @@ -152,8 +152,8 @@ def elbo(self, X: tf.Tensor, Y: tf.Tensor) -> tf.Tensor:
return self.neg_log_marginal_likelihood(X, Y)

def predict_f(self, Xnew, full_cov=False, full_output_cov=False) -> tf.Tensor:
q_mu = self.q_mu()
q_sqrt = self.q_sqrt()
q_mu = self.q_mu
q_sqrt = self.q_sqrt
mu, var = conditional(Xnew, self.feature, self.kernel, q_mu, q_sqrt=q_sqrt,
full_cov=full_cov, white=self.whiten,
full_output_cov=full_output_cov)
Expand Down
4 changes: 2 additions & 2 deletions gpflow/models/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,5 @@ def inducingpoint_wrapper(feature):
for the methods.
"""
if isinstance(feature, np.ndarray):
feat = InducingPoints(feature)
return feat
feature = InducingPoints(feature)
return feature
2 changes: 1 addition & 1 deletion gpflow/optimizers/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def _copy_variables(variables):


def _init_ps(xs):
return _map(lambda x: tf.random_normal(x.shape, dtype=x.dtype.as_numpy_dtype), xs)
return _map(lambda x: tf.random.normal(x.shape, dtype=x.dtype.as_numpy_dtype), xs)


def _update_ps(ps, grads, epsilon, coeff=1):
Expand Down
2 changes: 1 addition & 1 deletion gpflow/quadrature.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def ndiag_mc(funcs, S: int, Fmu, Fvar, logspace: bool=False, epsilon=None, **Ys)
N, D = Fmu.shape[0], Fvar.shape[1]

if epsilon is None:
epsilon = tf.random_normal((S, N, D), dtype=default_float())
epsilon = tf.random.normal((S, N, D), dtype=default_float())

mc_x = Fmu[None, :, :] + tf.sqrt(Fvar[None, :, :]) * epsilon
mc_Xr = tf.reshape(mc_x, (S * N, D))
Expand Down

0 comments on commit b393722

Please sign in to comment.