Skip to content
Browse files

Merge branch 'master' of github.com:pybrain/pybrain into rl

  • Loading branch information...
2 parents aecfa12 + 2fcd4a5 commit 311cf07546f08bbbd8e95b20f32e9617d94f5ca4 @schaul schaul committed Nov 9, 2010
View
23 .gitignore
@@ -1,12 +1,13 @@
-*.pyc
-*.egg-info
-*.xml
-*.o
-*.so
-*.log
-.settings
-build
-dist
-docs/sphinx/.build
+*.pyc
+*.egg-info
+*.xml
+*.o
+*.so
+*.log
+.settings
+build
+dist
+docs/sphinx/.build
.DS_Store
-/pybrain/rl/environments/cartpole/fast_version/cartpolewrap.cpp
+/pybrain/rl/environments/cartpole/fast_version/cartpolewrap.cpp
+/pybrain/rl/environments/cartpole/fast_version/cartpolewrap.pyd
View
17 pybrain/structure/modules/module.py
@@ -1,6 +1,6 @@
__author__ = 'Daan Wierstra and Tom Schaul'
-from scipy import zeros
+from scipy import append, zeros
from pybrain.utilities import abstractMethod, Named
@@ -86,6 +86,21 @@ def reset(self):
buf = getattr(self, buffername)
buf[:] = zeros(l)
+ def shift(self, items):
+ """Shift all buffers up or down a defined number of items on offset axis.
+ Negative values indicate backward shift."""
+ if items == 0:
+ return
+ self.offset += items
+ for buffername, l in self.bufferlist:
+ buf = getattr(self, buffername)
+ assert abs(items) <= len(buf), "Cannot shift further than length of buffer."
+ fill = zeros((abs(items), len(buf[0])))
+ if items < 0:
+ buf[:] = append(buf[-items:], fill, 0)
+ else:
+ buf[:] = append(fill ,buf[0:-items] , 0)
+
def activateOnDataset(self, dataset):
"""Run the module's forward pass on the given dataset unconditionally
and return the output."""
View
40 pybrain/structure/networks/recurrent.py
@@ -15,13 +15,10 @@ class RecurrentNetworkComponent(object):
sequential = True
- def __init__(self, forget, name=None, *args, **kwargs):
+ def __init__(self, forget=None, name=None, *args, **kwargs):
self.recurrentConns = []
self.maxoffset = 0
- if forget:
- self.increment = 0
- else:
- self.increment = 1
+ self.forget = forget
def __str__(self):
s = super(RecurrentNetworkComponent, self).__str__()
@@ -51,26 +48,29 @@ def activate(self, inpt):
"""Do one transformation of an input and return the result."""
self.inputbuffer[self.offset] = inpt
self.forward()
- return self.outputbuffer[self.offset - self.increment].copy()
+ if self.forget:
+ return self.outputbuffer[self.offset].copy()
+ else:
+ return self.outputbuffer[self.offset - 1].copy()
def backActivate(self, outerr):
"""Do one transformation of an output error outerr backward and return
the error on the input."""
- self.outputerror[self.offset - self.increment] = outerr
+ self.outputerror[self.offset - 1] = outerr
self.backward()
return self.inputerror[self.offset].copy()
def forward(self):
"""Produce the output from the input."""
- if not (self.offset + self.increment < self.inputbuffer.shape[0]):
+ if not (self.offset + 1 < self.inputbuffer.shape[0]):
self._growBuffers()
super(RecurrentNetworkComponent, self).forward()
- self.offset += self.increment
+ self.offset += 1
self.maxoffset = max(self.offset, self.maxoffset)
def backward(self):
"""Produce the input error from the output error."""
- self.offset -= self.increment
+ self.offset -= 1
super(RecurrentNetworkComponent, self).backward()
def _isLastTimestep(self):
@@ -79,6 +79,9 @@ def _isLastTimestep(self):
def _forwardImplementation(self, inbuf, outbuf):
assert self.sorted, ".sortModules() has not been called"
+ if self.forget:
+ self.offset += 1
+
index = 0
offset = self.offset
for m in self.inmodules:
@@ -87,19 +90,26 @@ def _forwardImplementation(self, inbuf, outbuf):
if offset > 0:
for c in self.recurrentConns:
- c.forward(offset - self.increment, offset)
+ c.forward(offset - 1, offset)
for m in self.modulesSorted:
m.forward()
for c in self.connections[m]:
c.forward(offset, offset)
+ if self.forget:
+ for m in self.modules:
+ m.shift(-1)
+ offset -= 1
+ self.offset -= 2
+
index = 0
for m in self.outmodules:
outbuf[index:index + m.outdim] = m.outputbuffer[offset]
index += m.outdim
def _backwardImplementation(self, outerr, inerr, outbuf, inbuf):
+ assert not self.forget, "Cannot back propagate a forgetful network"
assert self.sorted, ".sortModules() has not been called"
index = 0
offset = self.offset
@@ -109,7 +119,7 @@ def _backwardImplementation(self, outerr, inerr, outbuf, inbuf):
if not self._isLastTimestep():
for c in self.recurrentConns:
- c.backward(offset, offset + self.increment)
+ c.backward(offset, offset + 1)
for m in reversed(self.modulesSorted):
for c in self.connections[m]:
@@ -136,6 +146,10 @@ class RecurrentNetwork(RecurrentNetworkComponent, Network):
bufferlist = Network.bufferlist
- def __init__(self, forget=False, *args, **kwargs):
+ def __init__(self, *args, **kwargs):
Network.__init__(self, *args, **kwargs)
+ if 'forget' in kwargs:
+ forget = kwargs['forget']
+ else:
+ forget = False
RecurrentNetworkComponent.__init__(self, forget, *args, **kwargs)
View
4 pybrain/tests/unittests/test_peephole_mdlstm.py
@@ -44,8 +44,8 @@
True
List all the states again, explicitly (buffer size is 8 by now).
- >>> fListToString(N['mdlstm'].outputbuffer[:,1], 3)
- '[0.4 , 0.4 , 0.814 , 0.407 , -0.152, -0.152, 0 , 0 ]'
+ >>> fListToString(N['mdlstm'].outputbuffer[:,1], 2)
+ '[0.4 , 0.4 , 0.81 , 0.41 , -0.15, -0.15, 0 , 0 ]'
"""

0 comments on commit 311cf07

Please sign in to comment.
Something went wrong with that request. Please try again.