Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Commit

Permalink
Fix tests for eager execution
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG committed Aug 29, 2019
1 parent 6cfd7b9 commit 7f9d179
Showing 1 changed file with 9 additions and 14 deletions.
23 changes: 9 additions & 14 deletions tests/test_trig_pos_embd.py
@@ -1,8 +1,9 @@
import unittest
import os
import tempfile
import unittest

import numpy as np
import tensorflow as tf

from keras_pos_embd.backend import keras
from keras_pos_embd import TrigPosEmbedding

Expand Down Expand Up @@ -39,15 +40,13 @@ def test_brute(self):
model = keras.models.load_model(model_path, custom_objects={'TrigPosEmbedding': TrigPosEmbedding})
model.summary()
predicts = model.predict(indices)[0].tolist()
sess = tf.Session()
for i in range(seq_len):
for j in range(embd_dim):
actual = predicts[i][j]
if j % 2 == 0:
expect = tf.sin(i / tf.pow(10000.0, float(j) / embd_dim))
expect = np.sin(i / 10000.0 ** (float(j) / embd_dim))
else:
expect = tf.cos(i / tf.pow(10000.0, (j - 1.0) / embd_dim))
expect = expect.eval(session=sess)
expect = np.cos(i / 10000.0 ** ((j - 1.0) / embd_dim))
self.assertAlmostEqual(expect, actual, places=6, msg=(embd_dim, i, j, expect, actual))

def test_add(self):
Expand All @@ -66,15 +65,13 @@ def test_add(self):
model = keras.models.load_model(model_path, custom_objects={'TrigPosEmbedding': TrigPosEmbedding})
model.summary()
predicts = model.predict(inputs)[0].tolist()
sess = tf.Session()
for i in range(seq_len):
for j in range(embed_dim):
actual = predicts[i][j]
if j % 2 == 0:
expect = 1.0 + tf.sin(i / tf.pow(10000.0, float(j) / embed_dim))
expect = 1.0 + np.sin(i / 10000.0 ** (float(j) / embed_dim))
else:
expect = 1.0 + tf.cos(i / tf.pow(10000.0, (j - 1.0) / embed_dim))
expect = expect.eval(session=sess)
expect = 1.0 + np.cos(i / 10000.0 ** ((j - 1.0) / embed_dim))
self.assertAlmostEqual(expect, actual, places=6, msg=(embed_dim, i, j, expect, actual))

def test_concat(self):
Expand All @@ -95,13 +92,11 @@ def test_concat(self):
model = keras.models.load_model(model_path, custom_objects={'TrigPosEmbedding': TrigPosEmbedding})
model.summary()
predicts = model.predict(inputs)[0].tolist()
sess = tf.Session()
for i in range(seq_len):
for j in range(embed_dim):
actual = predicts[i][feature_dim + j]
if j % 2 == 0:
expect = tf.sin(i / tf.pow(10000.0, float(j) / embed_dim))
expect = np.sin(i / 10000.0 ** (float(j) / embed_dim))
else:
expect = tf.cos(i / tf.pow(10000.0, (j - 1.0) / embed_dim))
expect = expect.eval(session=sess)
expect = np.cos(i / 10000.0 ** ((j - 1.0) / embed_dim))
self.assertAlmostEqual(expect, actual, places=6, msg=(embed_dim, i, j, expect, actual))

0 comments on commit 7f9d179

Please sign in to comment.