Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

minor bugfix for forgetting recurrent networks

Signed-off-by: Tom Schaul <schaul@gmail.com>
  • Loading branch information...
commit 2fcd4a53dd525a7495b53e205e9ea1efeed1870b 1 parent a4af9ce
@schaul schaul authored
View
4 .gitignore
@@ -8,4 +8,6 @@
build
dist
docs/sphinx/.build
-.DS_Store
+.DS_Store
+/pybrain/rl/environments/cartpole/fast_version/cartpolewrap.cpp
+/pybrain/rl/environments/cartpole/fast_version/cartpolewrap.pyd
View
6 pybrain/structure/networks/recurrent.py
@@ -146,6 +146,10 @@ class RecurrentNetwork(RecurrentNetworkComponent, Network):
bufferlist = Network.bufferlist
- def __init__(self, forget=False, *args, **kwargs):
+ def __init__(self, *args, **kwargs):
Network.__init__(self, *args, **kwargs)
+ if 'forget' in kwargs:
+ forget = kwargs['forget']
+ else:
+ forget = False
RecurrentNetworkComponent.__init__(self, forget, *args, **kwargs)
View
4 pybrain/tests/unittests/test_peephole_mdlstm.py
@@ -44,8 +44,8 @@
True
List all the states again, explicitly (buffer size is 8 by now).
- >>> fListToString(N['mdlstm'].outputbuffer[:,1], 3)
- '[0.4 , 0.4 , 0.814 , 0.407 , -0.152, -0.152, 0 , 0 ]'
+ >>> fListToString(N['mdlstm'].outputbuffer[:,1], 2)
+ '[0.4 , 0.4 , 0.81 , 0.41 , -0.15, -0.15, 0 , 0 ]'
"""
Please sign in to comment.
Something went wrong with that request. Please try again.