Skip to content

Commit

Permalink
variables: test dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
Bihaqo committed Oct 28, 2018
1 parent cd3e85b commit 57d00ae
Showing 1 changed file with 24 additions and 9 deletions.
33 changes: 24 additions & 9 deletions t3f/variables_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from t3f import ops
from t3f import initializers

class VariablesTest(tf.test.TestCase):

class _VariablesTest():

def testGetExistingVariable(self):
init = initializers.random_tensor([2, 3, 2], tt_rank=2)
init = initializers.random_tensor([2, 3, 2], tt_rank=2, dtype=self.tf_dtype)
tt_1 = variables.get_variable('tt_1', initializer=init)
with tf.variable_scope('test'):
tt_2 = variables.get_variable('tt_2', initializer=init)
Expand All @@ -23,13 +24,14 @@ def testGetExistingVariable(self):
variables.get_variable('tt_3')

with tf.variable_scope('', reuse=True):
tt_1_copy = variables.get_variable('tt_1')
tt_1_copy = variables.get_variable('tt_1', dtype=self.tf_dtype)
self.assertAllClose(ops.full(tt_1).eval(), ops.full(tt_1_copy).eval())

with tf.variable_scope('', reuse=True):
# Again try to retrieve an existing variable, but pass an initializer
# and check that it still works.
tt_1_copy = variables.get_variable('tt_1', initializer=0 * init)
tt_1_copy = variables.get_variable('tt_1', initializer=0 * init,
dtype=self.tf_dtype)
self.assertAllClose(ops.full(tt_1).eval(), ops.full(tt_1_copy).eval())

with self.assertRaises(ValueError):
Expand All @@ -43,15 +45,16 @@ def testGetExistingVariable(self):
variables.get_variable('tt_2')

with tf.variable_scope('test', reuse=True):
tt_2_copy = variables.get_variable('tt_2')
tt_2_copy = variables.get_variable('tt_2', dtype=self.tf_dtype)
self.assertAllClose(ops.full(tt_2).eval(), ops.full(tt_2_copy).eval())

def testAttributes(self):
# Test that after converting an initializer into a variable all the
# attributes stays the same.
tens = initializers.random_tensor([2, 3, 2], tt_rank=2)
tens = initializers.random_tensor([2, 3, 2], tt_rank=2, dtype=self.tf_dtype)
tens_v = variables.get_variable('tt_tens', initializer=tens)
mat = initializers.random_matrix([[3, 2, 2], [3, 3, 3]], tt_rank=3)
mat = initializers.random_matrix([[3, 2, 2], [3, 3, 3]], tt_rank=3,
dtype=self.tf_dtype)
mat_v = variables.get_variable('tt_mat', initializer=mat)
for (init, var) in [[tens, tens_v], [mat, mat_v]]:
self.assertEqual(init.get_shape(), var.get_shape())
Expand All @@ -61,9 +64,11 @@ def testAttributes(self):
self.assertEqual(init.is_tt_matrix(), var.is_tt_matrix())

def testAssign(self):
old_init = initializers.random_tensor([2, 3, 2], tt_rank=2)
old_init = initializers.random_tensor([2, 3, 2], tt_rank=2,
dtype=self.tf_dtype)
tt = variables.get_variable('tt', initializer=old_init)
new_init = initializers.random_tensor([2, 3, 2], tt_rank=2)
new_init = initializers.random_tensor([2, 3, 2], tt_rank=2,
dtype=self.tf_dtype)
assigner = variables.assign(tt, new_init)
with self.test_session():
tf.global_variables_initializer().run()
Expand All @@ -78,5 +83,15 @@ def testAssign(self):
self.assertGreater(rel_diff, 0.2)


class VariablesTestFloat32(tf.test.TestCase, _VariablesTest):
np_dtype = np.float32
tf_dtype = tf.float32


class VariablesTestFloat64(tf.test.TestCase, _VariablesTest):
np_dtype = np.float64
tf_dtype = tf.float64


if __name__ == "__main__":
tf.test.main()

0 comments on commit 57d00ae

Please sign in to comment.