Skip to content

Commit

Permalink
[Embedding] Fix the feature filter UT in feature_column_v2_test. (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
lixy9474 committed Dec 29, 2021
1 parent 7dac910 commit 5cffe8b
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions tensorflow/python/feature_column/feature_column_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7577,13 +7577,17 @@ def testEmbeddingVariableForFeatureFilter(self):
sess.run(ops.get_collection(ops.GraphKeys.EV_INIT_SLOT_OPS))
sess.run([init])
emb1, top, l = sess.run([emb, train_op, loss])
for val1 in emb1.tolist():
for val in val1:
self.assertEqual(val, 1.0)
emb1, top, l = sess.run([emb, train_op, loss])
emb1, top, l = sess.run([emb, train_op, loss])
for val in emb1.tolist()[0]:
self.assertEqual(val, 1.0)
emb1, top, l = sess.run([emb, train_op, loss])
for val in emb1.tolist()[0]:
self.assertNotEqual(val, 1.0)
for index, val1 in enumerate(emb1.tolist()):
if index < 7:
for val in val1:
self.assertNotEqual(val, 1.0)
else:
for val in val1:
self.assertEqual(val, 1.0)

@test_util.run_deprecated_v1
def testEmbeddingVariableForAdaptiveEmbedding(self):
Expand Down

0 comments on commit 5cffe8b

Please sign in to comment.