Skip to content

Commit

Permalink
test dtypes of tt classes
Browse files Browse the repository at this point in the history
  • Loading branch information
Bihaqo committed Oct 28, 2018
1 parent a8da2cb commit d4343cb
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 15 deletions.
22 changes: 18 additions & 4 deletions t3f/tensor_train_batch_test.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import numpy as np
import tensorflow as tf

from t3f import initializers
from t3f import ops


class TensorTrainBatchTest(tf.test.TestCase):
class _TensorTrainBatchTest():

def testTensorIndexing(self):
tens = initializers.random_tensor_batch((3, 3, 4), batch_size=3)
tens = initializers.random_tensor_batch((3, 3, 4), batch_size=3,
dtype=self.tf_dtype)
with self.test_session() as sess:
desired = ops.full(tens)[:, :, :, :]
actual = ops.full(tens[:, :, :, :])
Expand Down Expand Up @@ -47,7 +49,8 @@ def testTensorIndexing(self):
tens[1, 1]

def testPlaceholderTensorIndexing(self):
tens = initializers.random_tensor_batch((3, 3, 4), batch_size=3)
tens = initializers.random_tensor_batch((3, 3, 4), batch_size=3,
dtype=self.tf_dtype)
with self.test_session() as sess:
start = tf.placeholder(tf.int32)
end = tf.placeholder(tf.int32)
Expand All @@ -74,10 +77,21 @@ def testPlaceholderTensorIndexing(self):

def testShapeOverflow(self):
large_shape = [10] * 20
tensor = initializers.random_matrix_batch([large_shape, large_shape], batch_size=5)
tensor = initializers.random_matrix_batch([large_shape, large_shape],
batch_size=5, dtype=self.tf_dtype)
shape = tensor.get_shape()
self.assertEqual([5, 10 ** 20, 10 ** 20], shape)


class TensorTrainBatchTestFloat32(tf.test.TestCase, _TensorTrainBatchTest):
np_dtype = np.float32
tf_dtype = tf.float32


class TensorTrainBatchTestFloat64(tf.test.TestCase, _TensorTrainBatchTest):
np_dtype = np.float64
tf_dtype = tf.float64


if __name__ == "__main__":
tf.test.main()
35 changes: 24 additions & 11 deletions t3f/tensor_train_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import numpy as np
import tensorflow as tf

from t3f import tensor_train
from t3f import initializers
from t3f import ops


class TensorTrainTest(tf.test.TestCase):
class _TensorTrainTest():

def testValidateTTCores2d(self):
schedule = (((1, 1, 1, 1), (1, 1, 1), True),
Expand All @@ -26,8 +27,8 @@ def testValidateTTCores2d(self):
((1, 2, 1, 1), (1, 2, 1), False))

for tt_ranks, claimed_tt_ranks, desired in schedule:
a = tf.random_normal((tt_ranks[0], 10, tt_ranks[1]))
b = tf.random_normal((tt_ranks[2], 9, tt_ranks[3]))
a = tf.random_normal((tt_ranks[0], 10, tt_ranks[1]), dtype=self.tf_dtype)
b = tf.random_normal((tt_ranks[2], 9, tt_ranks[3]), dtype=self.tf_dtype)
with self.test_session():
actual = tensor_train._are_tt_cores_valid((a, b), (10, 9),
claimed_tt_ranks)
Expand All @@ -41,7 +42,7 @@ def testValidateTTCores2d(self):
tensor_train.TensorTrain((a, b), (10, 9), claimed_tt_ranks)

# Make dtypes inconsistent.
b_new = tf.cast(b, tf.float64)
b_new = tf.cast(b, tf.float16)
actual = tensor_train._are_tt_cores_valid((a, b_new), (10, 9),
claimed_tt_ranks)
self.assertEqual(False, actual)
Expand Down Expand Up @@ -71,9 +72,9 @@ def testValidateTTCores3d(self):
((1, 2, 2, 3, 3, 1), None, True))

for tt_ranks, claimed_tt_ranks, desired in schedule:
a = tf.random_normal((tt_ranks[0], 10, tt_ranks[1]))
b = tf.random_normal((tt_ranks[2], 1, tt_ranks[3]))
c = tf.random_normal((tt_ranks[4], 2, tt_ranks[5]))
a = tf.random_normal((tt_ranks[0], 10, tt_ranks[1]), dtype=self.tf_dtype)
b = tf.random_normal((tt_ranks[2], 1, tt_ranks[3]), dtype=self.tf_dtype)
c = tf.random_normal((tt_ranks[4], 2, tt_ranks[5]), dtype=self.tf_dtype)
with self.test_session():
actual = tensor_train._are_tt_cores_valid((a, b, c), (10, 1, 2),
claimed_tt_ranks)
Expand All @@ -87,15 +88,15 @@ def testValidateTTCores3d(self):
tensor_train.TensorTrain((a, b, c), (10, 1, 2), claimed_tt_ranks)

# Make dtypes inconsistent.
b_new = tf.cast(b, tf.float64)
b_new = tf.cast(b, tf.float16)
actual = tensor_train._are_tt_cores_valid((a, b_new, c), (10, 1, 2),
claimed_tt_ranks)
self.assertEqual(False, actual)
with self.assertRaises(ValueError):
tensor_train.TensorTrain((a, b_new, c), (10, 1, 2), claimed_tt_ranks)

def testTensorIndexing(self):
tens = initializers.random_tensor((3, 3, 4))
tens = initializers.random_tensor((3, 3, 4), dtype=self.tf_dtype)
with self.test_session() as sess:
desired = ops.full(tens)[:, :, :]
actual = ops.full(tens[:, :, :])
Expand Down Expand Up @@ -125,7 +126,7 @@ def testTensorIndexing(self):
tens[1, 1]

def testPlaceholderTensorIndexing(self):
tens = initializers.random_tensor((3, 3, 4))
tens = initializers.random_tensor((3, 3, 4), dtype=self.tf_dtype)
with self.test_session() as sess:
start = tf.placeholder(tf.int32)
end = tf.placeholder(tf.int32)
Expand All @@ -136,9 +137,21 @@ def testPlaceholderTensorIndexing(self):

def testShapeOverflow(self):
large_shape = [10] * 20
matrix = initializers.matrix_zeros([large_shape, large_shape])
matrix = initializers.matrix_zeros([large_shape, large_shape],
dtype=self.tf_dtype)
shape = matrix.get_shape()
self.assertEqual([10 ** 20, 10 ** 20], shape)


class TensorTrainTestFloat32(tf.test.TestCase, _TensorTrainTest):
np_dtype = np.float32
tf_dtype = tf.float32


class TensorTrainTestFloat64(tf.test.TestCase, _TensorTrainTest):
np_dtype = np.float64
tf_dtype = tf.float64


if __name__ == "__main__":
tf.test.main()

0 comments on commit d4343cb

Please sign in to comment.