diff --git a/gpflow/conditionals/util.py b/gpflow/conditionals/util.py index 7d54739d7..0a6ee4ecf 100644 --- a/gpflow/conditionals/util.py +++ b/gpflow/conditionals/util.py @@ -38,12 +38,34 @@ def base_conditional( :param white: bool :return: [N, R] or [R, N, N] """ + Lm = tf.linalg.cholesky(Kmm) + return base_conditional_with_lm( + Kmn=Kmn, Lm=Lm, Knn=Knn, f=f, full_cov=full_cov, q_sqrt=q_sqrt, white=white + ) + + +def base_conditional_with_lm( + Kmn: tf.Tensor, + Lm: tf.Tensor, + Knn: tf.Tensor, + f: tf.Tensor, + *, + full_cov=False, + q_sqrt: Optional[tf.Tensor] = None, + white=False, +): + r""" + Has the same functionality as the `base_conditional` function, except that instead of + `Kmm` this function accepts `Lm`, which is the Cholesky decomposition of `Kmm`. + + This allows `Lm` to be precomputed, which can improve performance. + """ # compute kernel stuff num_func = tf.shape(f)[-1] # R N = tf.shape(Kmn)[-1] M = tf.shape(f)[-2] - # get the leadings dims in Kmn to the front of the tensor + # get the leading dims in Kmn to the front of the tensor # if Kmn has rank two, i.e. [M, N], this is the identity op. K = tf.rank(Kmn) perm = tf.concat( @@ -58,7 +80,7 @@ def base_conditional( shape_constraints = [ (Kmn, [..., "M", "N"]), - (Kmm, ["M", "M"]), + (Lm, ["M", "M"]), (Knn, [..., "N", "N"] if full_cov else [..., "N"]), (f, ["M", "R"]), ] @@ -75,7 +97,6 @@ def base_conditional( ) leading_dims = tf.shape(Kmn)[:-2] - Lm = tf.linalg.cholesky(Kmm) # [M, M] # Compute the projection matrix A Lm = tf.broadcast_to(Lm, tf.concat([leading_dims, tf.shape(Lm)], 0)) # [..., M, M]