Skip to content

Commit

Permalink
test batch_ops with different dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
Bihaqo committed Oct 28, 2018
1 parent 1bd9395 commit ae6c28e
Showing 1 changed file with 44 additions and 17 deletions.
61 changes: 44 additions & 17 deletions t3f/batch_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@
from t3f import initializers


class BatchOpsTest(tf.test.TestCase):
class _BatchOpsTest():

def testConcatMatrix(self):
# Test concating TTMatrix batches along batch dimension.
first = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=1)
second = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=4)
third = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=3)
first = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=1,
dtype=self.tf_dtype)
second = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=4,
dtype=self.tf_dtype)
third = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=3,
dtype=self.tf_dtype)
first_res = batch_ops.concat_along_batch_dim((first))
first_res = ops.full(first_res)
first_second_res = batch_ops.concat_along_batch_dim((first, second))
Expand Down Expand Up @@ -45,7 +48,8 @@ def testConcatMatrix(self):
def testConcatTensorPlaceholders(self):
# Test concating TTTensors of unknown batch sizes along batch dimension.
number_of_objects = tf.placeholder(tf.int32)
all = initializers.random_tensor_batch((2, 3), batch_size=5)
all = initializers.random_tensor_batch((2, 3), batch_size=5,
dtype=self.tf_dtype)
actual = batch_ops.concat_along_batch_dim((all[:number_of_objects],
all[number_of_objects:]))
with self.test_session() as sess:
Expand All @@ -56,7 +60,8 @@ def testConcatTensorPlaceholders(self):
def testConcatMatrixPlaceholders(self):
# Test concating TTMatrices of unknown batch sizes along batch dimension.
number_of_objects = tf.placeholder(tf.int32)
all = initializers.random_matrix_batch(((2, 3), (2, 3)), batch_size=5)
all = initializers.random_matrix_batch(((2, 3), (2, 3)), batch_size=5,
dtype=self.tf_dtype)
actual = batch_ops.concat_along_batch_dim((all[:number_of_objects],
all[number_of_objects:]))
with self.test_session() as sess:
Expand All @@ -66,7 +71,8 @@ def testConcatMatrixPlaceholders(self):

def testBatchMultiply(self):
# Test multiplying batch of TTMatrices by individual numbers.
tt = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=3)
tt = initializers.random_matrix_batch(((2, 3), (3, 3)), batch_size=3,
dtype=self.tf_dtype)
weights = [0.1, 0, -10]
actual = batch_ops.multiply_along_batch_dim(tt, weights)
individual_desired = [weights[i] * tt[i:i+1] for i in range(3)]
Expand All @@ -77,7 +83,8 @@ def testBatchMultiply(self):

def testGramMatrix(self):
# Test Gram Matrix of a batch of TT vectors.
tt_vectors = initializers.random_matrix_batch(((2, 3), None), batch_size=5)
tt_vectors = initializers.random_matrix_batch(((2, 3), None), batch_size=5,
dtype=self.tf_dtype)
res_actual = batch_ops.gram_matrix(tt_vectors)
full_vectors = tf.reshape(ops.full(tt_vectors), (5, 6))
res_desired = tf.matmul(full_vectors, tf.transpose(full_vectors))
Expand All @@ -90,8 +97,9 @@ def testGramMatrixWithMatrix(self):
# Test Gram Matrix of a batch of TT vectors with providing a matrix, so we
# should compute
# res[i, j] = tt_vectors[i] ^ T * matrix * tt_vectors[j]
tt_vectors = initializers.random_matrix_batch(((2, 3), None), batch_size=4)
matrix = initializers.random_matrix(((2, 3), (2, 3)))
tt_vectors = initializers.random_matrix_batch(((2, 3), None), batch_size=4,
dtype=self.tf_dtype)
matrix = initializers.random_matrix(((2, 3), (2, 3)), dtype=self.tf_dtype)
res_actual = batch_ops.gram_matrix(tt_vectors, matrix)
full_vectors = tf.reshape(ops.full(tt_vectors), (4, 6))
with self.test_session() as sess:
Expand All @@ -107,8 +115,10 @@ def testGramMatrixWithMatrix(self):

def testPairwiseFlatInnerTensor(self):
# Test pairwise_flat_inner of a batch of TT tensors.
tt_tensors_1 = initializers.random_tensor_batch((2, 3, 2), batch_size=5)
tt_tensors_2 = initializers.random_tensor_batch((2, 3, 2), batch_size=5)
tt_tensors_1 = initializers.random_tensor_batch((2, 3, 2), batch_size=5,
dtype=self.tf_dtype)
tt_tensors_2 = initializers.random_tensor_batch((2, 3, 2), batch_size=5,
dtype=self.tf_dtype)
res_actual = batch_ops.pairwise_flat_inner(tt_tensors_1, tt_tensors_2)
full_tensors_1 = tf.reshape(ops.full(tt_tensors_1), (5, 12))
full_tensors_2 = tf.reshape(ops.full(tt_tensors_2), (5, 12))
Expand All @@ -121,9 +131,11 @@ def testPairwiseFlatInnerTensor(self):
def testPairwiseFlatInnerMatrix(self):
# Test pairwise_flat_inner of a batch of TT matrices.
tt_vectors_1 = initializers.random_matrix_batch(((2, 3), (2, 3)),
batch_size=5)
batch_size=5,
dtype=self.tf_dtype)
tt_vectors_2 = initializers.random_matrix_batch(((2, 3), (2, 3)),
batch_size=5)
batch_size=5,
dtype=self.tf_dtype)
res_actual = batch_ops.pairwise_flat_inner(tt_vectors_1, tt_vectors_2)
full_vectors_1 = tf.reshape(ops.full(tt_vectors_1), (5, 36))
full_vectors_2 = tf.reshape(ops.full(tt_vectors_2), (5, 36))
Expand All @@ -137,9 +149,13 @@ def testPairwiseFlatInnerVectorsWithMatrix(self):
# Test pairwise_flat_inner of a batch of TT vectors with providing a matrix,
# so we should compute
# res[i, j] = tt_vectors[i] ^ T * matrix * tt_vectors[j]
tt_vectors_1 = initializers.random_matrix_batch(((2, 3), None), batch_size=2)
tt_vectors_2 = initializers.random_matrix_batch(((2, 3), None), batch_size=3)
matrix = initializers.random_matrix(((2, 3), (2, 3)))
tt_vectors_1 = initializers.random_matrix_batch(((2, 3), None),
batch_size=2,
dtype=self.tf_dtype)
tt_vectors_2 = initializers.random_matrix_batch(((2, 3), None),
batch_size=3,
dtype=self.tf_dtype)
matrix = initializers.random_matrix(((2, 3), (2, 3)), dtype=self.tf_dtype)
res_actual = batch_ops.pairwise_flat_inner(tt_vectors_1, tt_vectors_2,
matrix)
full_vectors_1 = tf.reshape(ops.full(tt_vectors_1), (2, 6))
Expand All @@ -156,6 +172,17 @@ def testPairwiseFlatInnerVectorsWithMatrix(self):
res_desired_val[i, j] = curr_val
self.assertAllClose(res_desired_val, res_actual_val)


class BatchOpsTestFloat32(tf.test.TestCase, _BatchOpsTest):
np_dtype = np.float32
tf_dtype = tf.float32


class BatchOpsTestFloat64(tf.test.TestCase, _BatchOpsTest):
np_dtype = np.float64
tf_dtype = tf.float64


if __name__ == "__main__":
tf.test.main()

0 comments on commit ae6c28e

Please sign in to comment.