Skip to content

Commit

Permalink
shapes: test with different dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
Bihaqo committed Oct 30, 2018
1 parent 6f204dc commit 4cf2355
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions t3f/shapes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,22 @@
from t3f import shapes


class ShapesTest(tf.test.TestCase):
class _ShapesTest():

def testLazyShapeOverflow(self):
large_shape = [10] * 20
tensor = initializers.random_matrix_batch([large_shape, large_shape], batch_size=5)
tensor = initializers.random_matrix_batch([large_shape, large_shape],
batch_size=5, dtype=self.dtype)
self.assertAllEqual([5, 10 ** 20, 10 ** 20], shapes.lazy_shape(tensor))


class ShapesTestFloat32(tf.test.TestCase, _ShapesTest):
dtype = tf.float32


class ShapesTestFloat64(tf.test.TestCase, _ShapesTest):
dtype = tf.float64


if __name__ == "__main__":
tf.test.main()

0 comments on commit 4cf2355

Please sign in to comment.