Skip to content

Commit

Permalink
fix feature map (#326)
Browse files Browse the repository at this point in the history
  • Loading branch information
frazane committed Jun 24, 2023
1 parent b7cc06c commit e98050e
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions gpjax/kernels/computations/basis_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def cross_covariance(
"""
z1 = self.compute_features(x)
z2 = self.compute_features(y)
z1 /= self.kernel.num_basis_fns
return self.kernel.base_kernel.variance * jnp.matmul(z1, z2.T)
return self.scaling * jnp.matmul(z1, z2.T)

def gram(self, inputs: Float[Array, "N D"]) -> DenseLinearOperator:
r"""Compute an approximate Gram matrix.
Expand All @@ -47,9 +46,7 @@ def gram(self, inputs: Float[Array, "N D"]) -> DenseLinearOperator:
$`N \times N`$ Gram matrix.
"""
z1 = self.compute_features(inputs)
matrix = jnp.matmul(z1, z1.T) # shape: (n_samples, n_samples)
matrix /= self.kernel.num_basis_fns
return DenseLinearOperator(self.kernel.base_kernel.variance * matrix)
return DenseLinearOperator(self.scaling * jnp.matmul(z1, z1.T))

def compute_features(self, x: Float[Array, "N D"]) -> Float[Array, "N L"]:
r"""Compute the features for the inputs.
Expand All @@ -66,3 +63,7 @@ def compute_features(self, x: Float[Array, "N D"]) -> Float[Array, "N L"]:
z = jnp.matmul(x, (frequencies / scaling_factor).T)
z = jnp.concatenate([jnp.cos(z), jnp.sin(z)], axis=-1)
return z

@property
def scaling(self):
return self.kernel.base_kernel.variance / self.kernel.num_basis_fns

0 comments on commit e98050e

Please sign in to comment.