Skip to content

Commit

Permalink
Merge 9035216 into 4ae0aee
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexander Novikov committed Oct 30, 2018
2 parents 4ae0aee + 9035216 commit 15f987f
Show file tree
Hide file tree
Showing 19 changed files with 636 additions and 345 deletions.
3 changes: 0 additions & 3 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@ python:
env:
matrix:
- TF_VERSION=1.0
- TF_VERSION=1.7
- TF_VERSION=1.8
- TF_VERSION=1.9
- TF_VERSION=1.10
- TF_VERSION=1.11
# command to install dependencies
Expand Down
2 changes: 1 addition & 1 deletion t3f/approximate.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def reduce_sum_batch(tt_batch, max_tt_rank, coef=None):

is_batch_output = False
if coef is not None:
coef = tf.convert_to_tensor(coef)
coef = tf.convert_to_tensor(coef, dtype=tt_batch.dtype)
if len(coef.get_shape()) == 1:
tt_batch = batch_ops.multiply_along_batch_dim(tt_batch, coef)
elif len(coef.get_shape()) == 2:
Expand Down
29 changes: 20 additions & 9 deletions t3f/approximate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from t3f import initializers


class ApproximateTest(tf.test.TestCase):
class _ApproximateTest():

def testAddN(self):
# Sum a bunch of TT-matrices.
tt_a = initializers.random_matrix(((2, 1, 4), (2, 2, 2)), tt_rank=2)
tt_a = initializers.random_matrix(((2, 1, 4), (2, 2, 2)), tt_rank=2,
dtype=self.dtype)
tt_b = initializers.random_matrix(((2, 1, 4), (2, 2, 2)),
tt_rank=[1, 2, 4, 1])
tt_rank=[1, 2, 4, 1], dtype=self.dtype)

def desired(tt_objects):
res = tt_objects[0]
Expand Down Expand Up @@ -48,7 +49,8 @@ def desired(tt_batch):
with self.test_session() as sess:
tt_batch = initializers.random_tensor_batch((4, 3, 5),
tt_rank=2,
batch_size=batch_size)
batch_size=batch_size,
dtype=self.dtype)
res_actual = ops.full(approximate.reduce_sum_batch(tt_batch, 10))
res_desired = ops.full(desired(tt_batch))
res_desired_val, res_actual_val = sess.run([res_desired, res_actual])
Expand All @@ -65,7 +67,8 @@ def desired(tt_batch, coef):
with self.test_session() as sess:
tt_batch = initializers.random_tensor_batch((4, 3, 5),
tt_rank=3,
batch_size=3)
batch_size=3,
dtype=self.dtype)
res_actual = ops.full(approximate.reduce_sum_batch(tt_batch, 9,
[1.2, -0.2, 1]))
res_desired = ops.full(desired(tt_batch, [1.2, -0.2, 1]))
Expand All @@ -81,13 +84,13 @@ def desired(tt_batch, coef):
res += coef[i] * tt_batch[i]
return res
with self.test_session() as sess:
tt_batch = initializers.random_tensor_batch((4, 3, 5),
tt_rank=2,
batch_size=3)
tt_batch = initializers.random_tensor_batch((4, 3, 5), tt_rank=2,
batch_size=3,
dtype=self.dtype)
coef = [[1., 0.1],
[0.9, -0.2],
[0.3, 0.3]]
coef = np.array(coef).astype(np.float32)
coef = np.array(coef)
res_actual = ops.full(approximate.reduce_sum_batch(tt_batch, 6,
coef))
res_desired_1 = ops.full(desired(tt_batch, coef[:, 0]))
Expand All @@ -97,5 +100,13 @@ def desired(tt_batch, coef):
self.assertAllClose(res_desired_val, res_actual_val, atol=1e-5, rtol=1e-5)


class ApproximateTestFloat32(tf.test.TestCase, _ApproximateTest):
dtype = tf.float32


class ApproximateTestFloat64(tf.test.TestCase, _ApproximateTest):
dtype = tf.float64


if __name__ == "__main__":
tf.test.main()
2 changes: 1 addition & 1 deletion t3f/batch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def multiply_along_batch_dim(batch_tt, weights):
Returns:
TensorTrainBatch
"""
weights = tf.convert_to_tensor(weights)
weights = tf.convert_to_tensor(weights, dtype=batch_tt.dtype)
tt_cores = list(batch_tt.tt_cores)
if batch_tt.is_tt_matrix():
weights = weights[:, tf.newaxis, tf.newaxis, tf.newaxis, tf.newaxis]
Expand Down
59 changes: 42 additions & 17 deletions t3f/batch_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
from t3f import initializers


class BatchOpsTest(tf.test.TestCase):
class _BatchOpsTest():

def testConcatMatrix(self):
# Test concating TTMatrix batches along batch dimension.
first = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=1)
second = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=4)
third = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=3)
first = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=1,
dtype=self.dtype)
second = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=4,
dtype=self.dtype)
third = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=3,
dtype=self.dtype)
first_res = batch_ops.concat_along_batch_dim((first))
first_res = ops.full(first_res)
first_second_res = batch_ops.concat_along_batch_dim((first, second))
Expand Down Expand Up @@ -45,7 +48,8 @@ def testConcatMatrix(self):
def testConcatTensorPlaceholders(self):
# Test concating TTTensors of unknown batch sizes along batch dimension.
number_of_objects = tf.placeholder(tf.int32)
all = initializers.random_tensor_batch((2, 3), batch_size=5)
all = initializers.random_tensor_batch((2, 3), batch_size=5,
dtype=self.dtype)
actual = batch_ops.concat_along_batch_dim((all[:number_of_objects],
all[number_of_objects:]))
with self.test_session() as sess:
Expand All @@ -56,7 +60,8 @@ def testConcatTensorPlaceholders(self):
def testConcatMatrixPlaceholders(self):
# Test concating TTMatrices of unknown batch sizes along batch dimension.
number_of_objects = tf.placeholder(tf.int32)
all = initializers.random_matrix_batch(((2, 3), (2, 3)), batch_size=5)
all = initializers.random_matrix_batch(((2, 3), (2, 3)), batch_size=5,
dtype=self.dtype)
actual = batch_ops.concat_along_batch_dim((all[:number_of_objects],
all[number_of_objects:]))
with self.test_session() as sess:
Expand All @@ -66,7 +71,8 @@ def testConcatMatrixPlaceholders(self):

def testBatchMultiply(self):
# Test multiplying batch of TTMatrices by individual numbers.
tt = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=3)
tt = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=3,
dtype=self.dtype)
weights = [0.1, 0, -10]
actual = batch_ops.multiply_along_batch_dim(tt, weights)
individual_desired = [weights[i] * tt[i:i+1] for i in range(3)]
Expand All @@ -77,7 +83,8 @@ def testBatchMultiply(self):

def testGramMatrix(self):
# Test Gram Matrix of a batch of TT vectors.
tt_vectors = initializers.random_matrix_batch(((2, 3), None), batch_size=5)
tt_vectors = initializers.random_matrix_batch(((2, 3), None), batch_size=5,
dtype=self.dtype)
res_actual = batch_ops.gram_matrix(tt_vectors)
full_vectors = tf.reshape(ops.full(tt_vectors), (5, 6))
res_desired = tf.matmul(full_vectors, tf.transpose(full_vectors))
Expand All @@ -90,8 +97,9 @@ def testGramMatrixWithMatrix(self):
# Test Gram Matrix of a batch of TT vectors with providing a matrix, so we
# should compute
# res[i, j] = tt_vectors[i] ^ T * matrix * tt_vectors[j]
tt_vectors = initializers.random_matrix_batch(((2, 3), None), batch_size=4)
matrix = initializers.random_matrix(((2, 3), (2, 3)))
tt_vectors = initializers.random_matrix_batch(((2, 3), None), batch_size=4,
dtype=self.dtype)
matrix = initializers.random_matrix(((2, 3), (2, 3)), dtype=self.dtype)
res_actual = batch_ops.gram_matrix(tt_vectors, matrix)
full_vectors = tf.reshape(ops.full(tt_vectors), (4, 6))
with self.test_session() as sess:
Expand All @@ -107,8 +115,10 @@ def testGramMatrixWithMatrix(self):

def testPairwiseFlatInnerTensor(self):
# Test pairwise_flat_inner of a batch of TT tensors.
tt_tensors_1 = initializers.random_tensor_batch((2, 3, 2), batch_size=5)
tt_tensors_2 = initializers.random_tensor_batch((2, 3, 2), batch_size=5)
tt_tensors_1 = initializers.random_tensor_batch((2, 3, 2), batch_size=5,
dtype=self.dtype)
tt_tensors_2 = initializers.random_tensor_batch((2, 3, 2), batch_size=5,
dtype=self.dtype)
res_actual = batch_ops.pairwise_flat_inner(tt_tensors_1, tt_tensors_2)
full_tensors_1 = tf.reshape(ops.full(tt_tensors_1), (5, 12))
full_tensors_2 = tf.reshape(ops.full(tt_tensors_2), (5, 12))
Expand All @@ -121,9 +131,11 @@ def testPairwiseFlatInnerTensor(self):
def testPairwiseFlatInnerMatrix(self):
# Test pairwise_flat_inner of a batch of TT matrices.
tt_vectors_1 = initializers.random_matrix_batch(((2, 3), (2, 3)),
batch_size=5)
batch_size=5,
dtype=self.dtype)
tt_vectors_2 = initializers.random_matrix_batch(((2, 3), (2, 3)),
batch_size=5)
batch_size=5,
dtype=self.dtype)
res_actual = batch_ops.pairwise_flat_inner(tt_vectors_1, tt_vectors_2)
full_vectors_1 = tf.reshape(ops.full(tt_vectors_1), (5, 36))
full_vectors_2 = tf.reshape(ops.full(tt_vectors_2), (5, 36))
Expand All @@ -137,9 +149,13 @@ def testPairwiseFlatInnerVectorsWithMatrix(self):
# Test pairwise_flat_inner of a batch of TT vectors with providing a matrix,
# so we should compute
# res[i, j] = tt_vectors[i] ^ T * matrix * tt_vectors[j]
tt_vectors_1 = initializers.random_matrix_batch(((2, 3), None), batch_size=2)
tt_vectors_2 = initializers.random_matrix_batch(((2, 3), None), batch_size=3)
matrix = initializers.random_matrix(((2, 3), (2, 3)))
tt_vectors_1 = initializers.random_matrix_batch(((2, 3), None),
batch_size=2,
dtype=self.dtype)
tt_vectors_2 = initializers.random_matrix_batch(((2, 3), None),
batch_size=3,
dtype=self.dtype)
matrix = initializers.random_matrix(((2, 3), (2, 3)), dtype=self.dtype)
res_actual = batch_ops.pairwise_flat_inner(tt_vectors_1, tt_vectors_2,
matrix)
full_vectors_1 = tf.reshape(ops.full(tt_vectors_1), (2, 6))
Expand All @@ -156,6 +172,15 @@ def testPairwiseFlatInnerVectorsWithMatrix(self):
res_desired_val[i, j] = curr_val
self.assertAllClose(res_desired_val, res_actual_val)


class BatchOpsTestFloat32(tf.test.TestCase, _BatchOpsTest):
dtype = tf.float32


class BatchOpsTestFloat64(tf.test.TestCase, _BatchOpsTest):
dtype = tf.float64


if __name__ == "__main__":
tf.test.main()

52 changes: 36 additions & 16 deletions t3f/decompositions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
from t3f import initializers


class DecompositionsTest(tf.test.TestCase):
class _DecompositionsTest():

def testTTTensor(self):
shape = (2, 1, 4, 3)
np.random.seed(1)
tens = np.random.rand(*shape).astype(np.float32)
tens = np.random.rand(*shape).astype(self.dtype.as_numpy_dtype)
tf_tens = tf.constant(tens)
tt_tens = decompositions.to_tt_tensor(tf_tens, max_tt_rank=3)
with self.test_session():
Expand All @@ -22,7 +22,7 @@ def testTTTensor(self):
self.assertAllEqual(dynamic_tt_ranks, static_tt_ranks)

# Try to decompose the same tensor with unknown shape.
tf_tens_pl = tf.placeholder(tf.float32, (None, None, 4, None))
tf_tens_pl = tf.placeholder(self.dtype, (None, None, 4, None))
tt_tens = decompositions.to_tt_tensor(tf_tens_pl, max_tt_rank=3)
tt_val = ops.full(tt_tens).eval({tf_tens_pl: tens})
self.assertAllClose(tens, tt_val)
Expand All @@ -33,8 +33,8 @@ def testTTTensorSimple(self):
# Test that a tensor of ones and of zeros can be converted into TT with
# TT-rank 1.
shape = (2, 1, 4, 3)
tens_arr = (np.zeros(shape).astype(np.float32),
np.ones(shape).astype(np.float32))
tens_arr = (np.zeros(shape).astype(self.dtype.as_numpy_dtype),
np.ones(shape).astype(self.dtype.as_numpy_dtype))
for tens in tens_arr:
tf_tens = tf.constant(tens)
tt_tens = decompositions.to_tt_tensor(tf_tens, max_tt_rank=1)
Expand All @@ -45,7 +45,7 @@ def testTTTensorSimple(self):
self.assertAllEqual(dynamic_tt_ranks, static_tt_ranks)

# Try to decompose the same tensor with unknown shape.
tf_tens_pl = tf.placeholder(tf.float32, (None, None, None, None))
tf_tens_pl = tf.placeholder(self.dtype, (None, None, None, None))
tt_tens = decompositions.to_tt_tensor(tf_tens_pl, max_tt_rank=1)
tt_val = ops.full(tt_tens).eval({tf_tens_pl: tens})
self.assertAllClose(tens, tt_val)
Expand All @@ -56,7 +56,7 @@ def testTTVector(self):
vec_shape = (2, 1, 4, 3)
np.random.seed(1)
rows = np.prod(vec_shape)
vec = np.random.rand(rows, 1).astype(np.float32)
vec = np.random.rand(rows, 1).astype(self.dtype.as_numpy_dtype)
tf_vec = tf.constant(vec)
tt_vec = decompositions.to_tt_matrix(tf_vec, (vec_shape, None))
with self.test_session():
Expand All @@ -66,7 +66,7 @@ def testTTCompositeRankTensor(self):
# Test if a composite rank (list of ranks) can be used for decomposition
# for tensor.
np.random.seed(1)
np_tensor = np.random.rand(2, 3, 3, 1)
np_tensor = np.random.rand(2, 3, 3, 1).astype(self.dtype.as_numpy_dtype)
tf_tensor = tf.constant(np_tensor)

tt_ranks = [1, 2, 3, 3, 1]
Expand All @@ -81,7 +81,7 @@ def testTTCompositeRankMatrix(self):
out_shape = (1, 2, 2, 1)
np.random.seed(1)
mat = np.random.rand(np.prod(out_shape), np.prod(inp_shape))
mat = mat.astype(np.float32)
mat = mat.astype(self.dtype.as_numpy_dtype)
tf_mat = tf.constant(mat)
tt_ranks = [10, 20, 30, 40, 30]
tt_mat = decompositions.to_tt_matrix(tf_mat, (out_shape, inp_shape),
Expand All @@ -96,7 +96,7 @@ def testTTMatrix(self):
out_shape = (3, 3, 2, 3)
np.random.seed(1)
mat = np.random.rand(np.prod(out_shape), np.prod(inp_shape))
mat = mat.astype(np.float32)
mat = mat.astype(self.dtype.as_numpy_dtype)
tf_mat = tf.constant(mat)
tt_mat = decompositions.to_tt_matrix(tf_mat, (out_shape, inp_shape),
max_tt_rank=90)
Expand All @@ -107,7 +107,8 @@ def testTTMatrix(self):
def testRoundTensor(self):
shape = (2, 1, 4, 3, 3)
np.random.seed(1)
tens = initializers.random_tensor(shape, tt_rank=15)
tens = initializers.random_tensor(shape, tt_rank=15,
dtype=self.dtype)
rounded_tens = decompositions.round(tens, max_tt_rank=9)
with self.test_session() as sess:
vars = [ops.full(tens), ops.full(rounded_tens)]
Expand All @@ -121,7 +122,8 @@ def testOrthogonalizeLeftToRight(self):
shape = (2, 4, 3, 3)
tt_ranks = (1, 5, 2, 17, 1)
updated_tt_ranks = (1, 2, 2, 6, 1)
tens = initializers.random_tensor(shape, tt_rank=tt_ranks)
tens = initializers.random_tensor(shape, tt_rank=tt_ranks,
dtype=self.dtype)
orthogonal = decompositions.orthogonalize_tt_cores(tens)
with self.test_session() as sess:
tens_val, orthogonal_val = sess.run([ops.full(tens), ops.full(orthogonal)])
Expand All @@ -142,7 +144,8 @@ def testOrthogonalizeRightToLeft(self):
shape = (2, 4, 3, 3)
tt_ranks = (1, 5, 2, 17, 1)
updated_tt_ranks = (1, 5, 2, 3, 1)
tens = initializers.random_tensor(shape, tt_rank=tt_ranks)
tens = initializers.random_tensor(shape, tt_rank=tt_ranks,
dtype=self.dtype)
orthogonal = decompositions.orthogonalize_tt_cores(tens, left_to_right=False)
with self.test_session() as sess:
tens_val, orthogonal_val = sess.run([ops.full(tens), ops.full(orthogonal)])
Expand All @@ -160,14 +163,14 @@ def testOrthogonalizeRightToLeft(self):
should_be_eye_val)


class DecompositionsBatchTest(tf.test.TestCase):
class _DecompositionsBatchTest():

def testOrthogonalizeLeftToRight(self):
shape = (2, 4, 3, 3)
tt_ranks = (1, 5, 2, 17, 1)
updated_tt_ranks = (1, 2, 2, 6, 1)
tens = initializers.random_tensor_batch(shape, tt_rank=tt_ranks,
batch_size=2)
batch_size=2, dtype=self.dtype)
orthogonal = decompositions.orthogonalize_tt_cores(tens)
with self.test_session() as sess:
tens_val, orthogonal_val = sess.run([ops.full(tens), ops.full(orthogonal)])
Expand All @@ -187,7 +190,8 @@ def testOrthogonalizeLeftToRight(self):

def testRoundTensor(self):
shape = (2, 1, 4, 3, 3)
tens = initializers.random_tensor_batch(shape, tt_rank=15, batch_size=3)
tens = initializers.random_tensor_batch(shape, tt_rank=15, batch_size=3,
dtype=self.dtype)
rounded_tens = decompositions.round(tens, max_tt_rank=9)
with self.test_session() as sess:
vars = [ops.full(tens), ops.full(rounded_tens)]
Expand All @@ -199,5 +203,21 @@ def testRoundTensor(self):
self.assertAllEqual([1, 2, 2, 8, 3, 1], dynamic_tt_ranks)


class DecompositionsTestFloat32(tf.test.TestCase, _DecompositionsTest):
dtype = tf.float32


class DecompositionsTestFloat64(tf.test.TestCase, _DecompositionsTest):
dtype = tf.float64


class DecompositionsBatchTestFloat32(tf.test.TestCase, _DecompositionsBatchTest):
dtype = tf.float32


class DecompositionsBatchTestFloat64(tf.test.TestCase, _DecompositionsBatchTest):
dtype = tf.float64


if __name__ == "__main__":
tf.test.main()
Loading

0 comments on commit 15f987f

Please sign in to comment.