Skip to content

Commit

Permalink
test with different dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
Bihaqo committed Oct 28, 2018
1 parent 784cba3 commit 21ccf41
Showing 1 changed file with 22 additions and 9 deletions.
31 changes: 22 additions & 9 deletions t3f/approximate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from t3f import initializers


class ApproximateTest(tf.test.TestCase):
class _ApproximateTest():

def testAddN(self):
# Sum a bunch of TT-matrices.
tt_a = initializers.random_matrix(((2, 1, 4), (2, 2, 2)), tt_rank=2)
tt_a = initializers.random_matrix(((2, 1, 4), (2, 2, 2)), tt_rank=2,
dtype=self.tf_dtype)
tt_b = initializers.random_matrix(((2, 1, 4), (2, 2, 2)),
tt_rank=[1, 2, 4, 1])
tt_rank=[1, 2, 4, 1], dtype=self.tf_dtype)

def desired(tt_objects):
res = tt_objects[0]
Expand Down Expand Up @@ -48,7 +49,8 @@ def desired(tt_batch):
with self.test_session() as sess:
tt_batch = initializers.random_tensor_batch((4, 3, 5),
tt_rank=2,
batch_size=batch_size)
batch_size=batch_size,
dtype=self.tf_dtype)
res_actual = ops.full(approximate.reduce_sum_batch(tt_batch, 10))
res_desired = ops.full(desired(tt_batch))
res_desired_val, res_actual_val = sess.run([res_desired, res_actual])
Expand All @@ -65,7 +67,8 @@ def desired(tt_batch, coef):
with self.test_session() as sess:
tt_batch = initializers.random_tensor_batch((4, 3, 5),
tt_rank=3,
batch_size=3)
batch_size=3,
dtype=self.tf_dtype)
res_actual = ops.full(approximate.reduce_sum_batch(tt_batch, 9,
[1.2, -0.2, 1]))
res_desired = ops.full(desired(tt_batch, [1.2, -0.2, 1]))
Expand All @@ -81,13 +84,13 @@ def desired(tt_batch, coef):
res += coef[i] * tt_batch[i]
return res
with self.test_session() as sess:
tt_batch = initializers.random_tensor_batch((4, 3, 5),
tt_rank=2,
batch_size=3)
tt_batch = initializers.random_tensor_batch((4, 3, 5), tt_rank=2,
batch_size=3,
dtype=self.tf_dtype)
coef = [[1., 0.1],
[0.9, -0.2],
[0.3, 0.3]]
coef = np.array(coef).astype(np.float32)
coef = np.array(coef).astype(self.np_dtype)
res_actual = ops.full(approximate.reduce_sum_batch(tt_batch, 6,
coef))
res_desired_1 = ops.full(desired(tt_batch, coef[:, 0]))
Expand All @@ -97,5 +100,15 @@ def desired(tt_batch, coef):
self.assertAllClose(res_desired_val, res_actual_val, atol=1e-5, rtol=1e-5)


class ApproximateTestFloat32(tf.test.TestCase, _ApproximateTest):
np_dtype = np.float32
tf_dtype = tf.float32


class ApproximateTestFloat64(tf.test.TestCase, _ApproximateTest):
np_dtype = np.float64
tf_dtype = tf.float64


if __name__ == "__main__":
tf.test.main()

0 comments on commit 21ccf41

Please sign in to comment.