From 5cffe8bb7988c798992c0a9053779eeb025a7547 Mon Sep 17 00:00:00 2001 From: lixy9474 Date: Wed, 29 Dec 2021 16:04:13 +0800 Subject: [PATCH] [Embedding] Fix the feature filter UT in feature_column_v2_test. (#19) --- .../feature_column/feature_column_v2_test.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py index 36ff0250158..d206a9d1578 100644 --- a/tensorflow/python/feature_column/feature_column_v2_test.py +++ b/tensorflow/python/feature_column/feature_column_v2_test.py @@ -7577,13 +7577,17 @@ def testEmbeddingVariableForFeatureFilter(self): sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS)) sess.run([init]) emb1, top, l = sess.run([emb, train_op, loss]) + for val1 in emb1.tolist(): + for val in val1: + self.assertEqual(val, 1.0) emb1, top, l = sess.run([emb, train_op, loss]) - emb1, top, l = sess.run([emb, train_op, loss]) - for val in emb1.tolist()[0]: - self.assertEqual(val, 1.0) - emb1, top, l = sess.run([emb, train_op, loss]) - for val in emb1.tolist()[0]: - self.assertNotEqual(val, 1.0) + for index, val1 in enumerate(emb1.tolist()): + if index < 7: + for val in val1: + self.assertNotEqual(val, 1.0) + else: + for val in val1: + self.assertEqual(val, 1.0) @test_util.run_deprecated_v1 def testEmbeddingVariableForAdaptiveEmbedding(self):