Skip to content

Commit

Permalink
New dtype style testing
Browse files Browse the repository at this point in the history
  • Loading branch information
Bihaqo committed Oct 30, 2018
1 parent fea76f6 commit b690957
Showing 1 changed file with 17 additions and 9 deletions.
26 changes: 17 additions & 9 deletions t3f/autodiff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
from t3f import autodiff


class AutodiffTest(tf.test.TestCase):
class _AutodiffTest():

def testGradients(self):
w = initializers.random_matrix(([5] * 3, None))
A = initializers.random_matrix(([5] * 3, [5] * 3))
x = initializers.random_matrix(([5] * 3, None))
z = initializers.random_matrix(([5] * 3, None))
w = initializers.random_matrix(([5] * 3, None), dtype=self.dtype)
A = initializers.random_matrix(([5] * 3, [5] * 3), dtype=self.dtype)
x = initializers.random_matrix(([5] * 3, None), dtype=self.dtype)
z = initializers.random_matrix(([5] * 3, None), dtype=self.dtype)

def func1(x):
return 0.5 * ops.flat_inner(x, w) ** 2
Expand All @@ -36,10 +36,10 @@ def func2(x):
np.testing.assert_allclose(actual_v2, desired_v2, rtol=1e-4)

def testHessianVectorProduct(self):
w = initializers.random_matrix(([5] * 3, None))
A = initializers.random_matrix(([5] * 3, [5] * 3))
x = initializers.random_matrix(([5] * 3, None))
z = initializers.random_matrix(([5] * 3, None))
w = initializers.random_matrix(([5] * 3, None), dtype=self.dtype)
A = initializers.random_matrix(([5] * 3, [5] * 3), dtype=self.dtype)
x = initializers.random_matrix(([5] * 3, None), dtype=self.dtype)
z = initializers.random_matrix(([5] * 3, None), dtype=self.dtype)
projected_vector = ops.full(riemannian.project(z, x))

def func1(x):
Expand Down Expand Up @@ -68,6 +68,14 @@ def func2(x):
np.testing.assert_allclose(actual2_v, desired2_v, rtol=1e-3)


class AutodiffTestFloat32(tf.test.TestCase, _AutodiffTest):
dtype = tf.float32


class AutodiffTestFloat64(tf.test.TestCase, _AutodiffTest):
dtype = tf.float64


if __name__ == "__main__":
tf.test.main()

0 comments on commit b690957

Please sign in to comment.