Skip to content

Commit

Permalink
decompositions: test for different dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
Bihaqo committed Oct 28, 2018
1 parent ac7a2ee commit 77022dc
Showing 1 changed file with 40 additions and 16 deletions.
56 changes: 40 additions & 16 deletions t3f/decompositions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from t3f import initializers


class DecompositionsTest(tf.test.TestCase):
class _DecompositionsTest():

def testTTTensor(self):
shape = (2, 1, 4, 3)
np.random.seed(1)
tens = np.random.rand(*shape).astype(np.float32)
tens = np.random.rand(*shape).astype(self.np_dtype)
tf_tens = tf.constant(tens)
tt_tens = decompositions.to_tt_tensor(tf_tens, max_tt_rank=3)
with self.test_session():
Expand All @@ -22,7 +22,7 @@ def testTTTensor(self):
self.assertAllEqual(dynamic_tt_ranks, static_tt_ranks)

# Try to decompose the same tensor with unknown shape.
tf_tens_pl = tf.placeholder(tf.float32, (None, None, 4, None))
tf_tens_pl = tf.placeholder(self.tf_dtype, (None, None, 4, None))
tt_tens = decompositions.to_tt_tensor(tf_tens_pl, max_tt_rank=3)
tt_val = ops.full(tt_tens).eval({tf_tens_pl: tens})
self.assertAllClose(tens, tt_val)
Expand All @@ -33,8 +33,8 @@ def testTTTensorSimple(self):
# Test that a tensor of ones and of zeros can be converted into TT with
# TT-rank 1.
shape = (2, 1, 4, 3)
tens_arr = (np.zeros(shape).astype(np.float32),
np.ones(shape).astype(np.float32))
tens_arr = (np.zeros(shape).astype(self.np_dtype),
np.ones(shape).astype(self.np_dtype))
for tens in tens_arr:
tf_tens = tf.constant(tens)
tt_tens = decompositions.to_tt_tensor(tf_tens, max_tt_rank=1)
Expand All @@ -45,7 +45,7 @@ def testTTTensorSimple(self):
self.assertAllEqual(dynamic_tt_ranks, static_tt_ranks)

# Try to decompose the same tensor with unknown shape.
tf_tens_pl = tf.placeholder(tf.float32, (None, None, None, None))
tf_tens_pl = tf.placeholder(self.tf_dtype, (None, None, None, None))
tt_tens = decompositions.to_tt_tensor(tf_tens_pl, max_tt_rank=1)
tt_val = ops.full(tt_tens).eval({tf_tens_pl: tens})
self.assertAllClose(tens, tt_val)
Expand All @@ -56,7 +56,7 @@ def testTTVector(self):
vec_shape = (2, 1, 4, 3)
np.random.seed(1)
rows = np.prod(vec_shape)
vec = np.random.rand(rows, 1).astype(np.float32)
vec = np.random.rand(rows, 1).astype(self.np_dtype)
tf_vec = tf.constant(vec)
tt_vec = decompositions.to_tt_matrix(tf_vec, (vec_shape, None))
with self.test_session():
Expand All @@ -66,7 +66,7 @@ def testTTCompositeRankTensor(self):
# Test if a composite rank (list of ranks) can be used for decomposition
# for tensor.
np.random.seed(1)
np_tensor = np.random.rand(2, 3, 3, 1)
np_tensor = np.random.rand(2, 3, 3, 1).astype(self.np_dtype)
tf_tensor = tf.constant(np_tensor)

tt_ranks = [1, 2, 3, 3, 1]
Expand All @@ -81,7 +81,7 @@ def testTTCompositeRankMatrix(self):
out_shape = (1, 2, 2, 1)
np.random.seed(1)
mat = np.random.rand(np.prod(out_shape), np.prod(inp_shape))
mat = mat.astype(np.float32)
mat = mat.astype(self.np_dtype)
tf_mat = tf.constant(mat)
tt_ranks = [10, 20, 30, 40, 30]
tt_mat = decompositions.to_tt_matrix(tf_mat, (out_shape, inp_shape),
Expand All @@ -96,7 +96,7 @@ def testTTMatrix(self):
out_shape = (3, 3, 2, 3)
np.random.seed(1)
mat = np.random.rand(np.prod(out_shape), np.prod(inp_shape))
mat = mat.astype(np.float32)
mat = mat.astype(self.np_dtype)
tf_mat = tf.constant(mat)
tt_mat = decompositions.to_tt_matrix(tf_mat, (out_shape, inp_shape),
max_tt_rank=90)
Expand All @@ -107,7 +107,8 @@ def testTTMatrix(self):
def testRoundTensor(self):
shape = (2, 1, 4, 3, 3)
np.random.seed(1)
tens = initializers.random_tensor(shape, tt_rank=15)
tens = initializers.random_tensor(shape, tt_rank=15,
dtype=self.tf_dtype)
rounded_tens = decompositions.round(tens, max_tt_rank=9)
with self.test_session() as sess:
vars = [ops.full(tens), ops.full(rounded_tens)]
Expand All @@ -121,7 +122,8 @@ def testOrthogonalizeLeftToRight(self):
shape = (2, 4, 3, 3)
tt_ranks = (1, 5, 2, 17, 1)
updated_tt_ranks = (1, 2, 2, 6, 1)
tens = initializers.random_tensor(shape, tt_rank=tt_ranks)
tens = initializers.random_tensor(shape, tt_rank=tt_ranks,
dtype=self.tf_dtype)
orthogonal = decompositions.orthogonalize_tt_cores(tens)
with self.test_session() as sess:
tens_val, orthogonal_val = sess.run([ops.full(tens), ops.full(orthogonal)])
Expand All @@ -142,7 +144,8 @@ def testOrthogonalizeRightToLeft(self):
shape = (2, 4, 3, 3)
tt_ranks = (1, 5, 2, 17, 1)
updated_tt_ranks = (1, 5, 2, 3, 1)
tens = initializers.random_tensor(shape, tt_rank=tt_ranks)
tens = initializers.random_tensor(shape, tt_rank=tt_ranks,
dtype=self.tf_dtype)
orthogonal = decompositions.orthogonalize_tt_cores(tens, left_to_right=False)
with self.test_session() as sess:
tens_val, orthogonal_val = sess.run([ops.full(tens), ops.full(orthogonal)])
Expand All @@ -160,14 +163,14 @@ def testOrthogonalizeRightToLeft(self):
should_be_eye_val)


class DecompositionsBatchTest(tf.test.TestCase):
class _DecompositionsBatchTest():

def testOrthogonalizeLeftToRight(self):
shape = (2, 4, 3, 3)
tt_ranks = (1, 5, 2, 17, 1)
updated_tt_ranks = (1, 2, 2, 6, 1)
tens = initializers.random_tensor_batch(shape, tt_rank=tt_ranks,
batch_size=2)
batch_size=2, dtype=self.tf_dtype)
orthogonal = decompositions.orthogonalize_tt_cores(tens)
with self.test_session() as sess:
tens_val, orthogonal_val = sess.run([ops.full(tens), ops.full(orthogonal)])
Expand All @@ -187,7 +190,8 @@ def testOrthogonalizeLeftToRight(self):

def testRoundTensor(self):
shape = (2, 1, 4, 3, 3)
tens = initializers.random_tensor_batch(shape, tt_rank=15, batch_size=3)
tens = initializers.random_tensor_batch(shape, tt_rank=15, batch_size=3,
dtype=self.tf_dtype)
rounded_tens = decompositions.round(tens, max_tt_rank=9)
with self.test_session() as sess:
vars = [ops.full(tens), ops.full(rounded_tens)]
Expand All @@ -199,5 +203,25 @@ def testRoundTensor(self):
self.assertAllEqual([1, 2, 2, 8, 3, 1], dynamic_tt_ranks)


class DecompositionsTestFloat32(tf.test.TestCase, _DecompositionsTest):
np_dtype = np.float32
tf_dtype = tf.float32


class DecompositionsTestFloat64(tf.test.TestCase, _DecompositionsTest):
np_dtype = np.float64
tf_dtype = tf.float64


class DecompositionsBatchTestFloat32(tf.test.TestCase, _DecompositionsBatchTest):
np_dtype = np.float32
tf_dtype = tf.float32


class DecompositionsBatchTestFloat64(tf.test.TestCase, _DecompositionsBatchTest):
np_dtype = np.float64
tf_dtype = tf.float64


if __name__ == "__main__":
tf.test.main()

0 comments on commit 77022dc

Please sign in to comment.