Skip to content

Commit

Permalink
Fixed DBN unittest (Rnd gen init)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan authored and Jan committed Jan 26, 2018
1 parent 270f23a commit 2d1689c
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions pydeep/testunits/test_rbm/test_dbn.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def test_forward_propagate(self):
assert numx.sum(numx.abs(
TestDBNModel.stack.forward_propagate(numx.array([[1, 0], [0, 1]])) - forward_target)) < 0.000001

forward_target = numx.array([[0, 1], [1, 1]])
numx.random.seed(42)
forward_target = numx.array([[1, 1], [1, 0]])
assert numx.sum(
numx.abs(TestDBNModel.stack.forward_propagate(numx.array([[1, 0], [0, 1]]),
True) - forward_target)) < 0.000001
Expand All @@ -76,7 +77,8 @@ def test_backward_propagate(self):
numx.abs(
TestDBNModel.stack.backward_propagate(numx.array([[1, 0], [0, 1]])) - backward_target)) < 0.000001

backward_target = numx.array([[0, 0], [1, 1]])
numx.random.seed(42)
backward_target = numx.array([[0, 0], [1, 0]])
assert numx.sum(
numx.abs(TestDBNModel.stack.backward_propagate(numx.array([[1, 0], [0, 1]]),
True) - backward_target)) < 0.000001
Expand All @@ -91,7 +93,8 @@ def test_reconstruct(self):
assert numx.sum(
numx.abs(TestDBNModel.stack.reconstruct(numx.array([[1, 0], [0, 1]])) - rec_target)) < 0.000001

rec_target = numx.array([[0, 1], [1, 0]])
numx.random.seed(42)
rec_target = numx.array([[1, 0], [0, 1]])
assert numx.sum(
numx.abs(TestDBNModel.stack.reconstruct(numx.array([[1, 0], [0, 1]]), True) - rec_target)) < 0.000001
print(' successfully passed!')
Expand Down

0 comments on commit 2d1689c

Please sign in to comment.