Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
st-- committed Nov 24, 2017
1 parent 1a886fc commit 6a5e6a0
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
6 changes: 3 additions & 3 deletions gpflow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,12 +275,12 @@ def forward(self, x):

def backward(self, y):
# Return diagonals of matrices
if not (y.shape[1] == y.shape[2] == self.dim) and (len(y.shape) == 3):
if len(y.shape) not in (2, 3) or not (y.shape[-1] == y.shape[-2] == self.dim):
raise ValueError("shape of input does not match this transform")
return y.diagonal(offset=0, axis1=1, axis2=2).flatten()
return y.reshape((-1, self.dim, self.dim)).diagonal(offset=0, axis1=1, axis2=2).flatten()

def forward_tensor(self, x):
# create diagonal; matrices
# create diagonal matrices
return tf.matrix_diag(tf.reshape(x, (-1, self.dim)))

def log_jacobian_tensor(self, x):
Expand Down
15 changes: 12 additions & 3 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,25 @@ def testDivByZero(self):
self.assertFalse(np.any(np.isnan(y)))


class TestLowerTriTransform(GPflowTestCase):
class TestMatrixTransforms(GPflowTestCase):
"""
Some extra tests for the LowerTriangle transformation.
Some extra tests for the matrix transformations.
"""
def testErrors(self):
def test_LowerTriangular(self):
t = gpflow.transforms.LowerTriangular(1, 3)
t.forward(np.ones(3 * 6))
with self.assertRaises(ValueError):
t.forward(np.ones(3 * 7))

def test_DiagMatrix(self):
t = gpflow.transforms.DiagMatrix(3)
t.backward(np.eye(3))
t.backward(np.eye(3)[None, :, :])
t.backward(np.eye(3)[None, :, :] * np.array([1, 2])[:, None, None])
with self.assertRaises(ValueError):
t.backward(np.eye(4))
t.backward(np.eye(2)[None, :, :] * np.array([1, 2, 3])[:, None, None])


class TestDiagMatrixTransform(GPflowTestCase):
def setUp(self):
Expand Down

0 comments on commit 6a5e6a0

Please sign in to comment.