Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

minor bugfix for forgetting recurrent networks

Signed-off-by: Tom Schaul <schaul@gmail.com>
  • Loading branch information...
commit 2fcd4a53dd525a7495b53e205e9ea1efeed1870b 1 parent a4af9ce
Tom Schaul schaul authored
4 .gitignore
@@ -8,4 +8,6 @@
8 8 build
9 9 dist
10 10 docs/sphinx/.build
11   -.DS_Store
  11 +.DS_Store
  12 +/pybrain/rl/environments/cartpole/fast_version/cartpolewrap.cpp
  13 +/pybrain/rl/environments/cartpole/fast_version/cartpolewrap.pyd
6 pybrain/structure/networks/recurrent.py
@@ -146,6 +146,10 @@ class RecurrentNetwork(RecurrentNetworkComponent, Network):
146 146
147 147 bufferlist = Network.bufferlist
148 148
149   - def __init__(self, forget=False, *args, **kwargs):
  149 + def __init__(self, *args, **kwargs):
150 150 Network.__init__(self, *args, **kwargs)
  151 + if 'forget' in kwargs:
  152 + forget = kwargs['forget']
  153 + else:
  154 + forget = False
151 155 RecurrentNetworkComponent.__init__(self, forget, *args, **kwargs)
4 pybrain/tests/unittests/test_peephole_mdlstm.py
@@ -44,8 +44,8 @@
44 44 True
45 45
46 46 List all the states again, explicitly (buffer size is 8 by now).
47   - >>> fListToString(N['mdlstm'].outputbuffer[:,1], 3)
48   - '[0.4 , 0.4 , 0.814 , 0.407 , -0.152, -0.152, 0 , 0 ]'
  47 + >>> fListToString(N['mdlstm'].outputbuffer[:,1], 2)
  48 + '[0.4 , 0.4 , 0.81 , 0.41 , -0.15, -0.15, 0 , 0 ]'
49 49
50 50 """
51 51

0 comments on commit 2fcd4a5

Please sign in to comment.
Something went wrong with that request. Please try again.