Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

minor adjustment to LSPI, additional balance-task

  • Loading branch information...
commit 29fe692b73f6d87590907b71da59f6f38ff5f221 1 parent ddae5b1
Tom Schaul schaul authored
29 pybrain/rl/environments/cartpole/balancetask.py
View
@@ -1,12 +1,14 @@
-from pybrain.utilities import crossproduct
+from pybrain.rl.environments.cartpole.doublepole import DoublePoleEnvironment
__author__ = 'Thomas Rueckstiess and Tom Schaul'
-from scipy import pi, dot, array
+from scipy import pi, dot, array, ones, exp
+from scipy.linalg import norm
from pybrain.rl.environments.cartpole.nonmarkovpole import NonMarkovPoleEnvironment
from pybrain.rl.environments import EpisodicTask
from cartpole import CartPoleEnvironment
-
+from pybrain.utilities import crossproduct
+
class BalanceTask(EpisodicTask):
""" The task of balancing some pole(s) on a cart """
@@ -230,14 +232,23 @@ def isFinished(self):
return False
def getObservation(self):
- from scipy import ones, exp
- from scipy.linalg import norm
- res = ones(10)
- sensors = self.env.getSensors()[:2]
+ res = ones(1+len(self.CENTERS))
+ sensors = self.env.getSensors()[:-2]
res[1:] = exp(-array(map(norm, self.CENTERS-sensors))**2/2)
return res
@property
def outdim(self):
- return 10
-
+ return 1+len(self.CENTERS)
+
+
+class DiscreteDoubleBalanceTaskRBF(DiscreteBalanceTaskRBF):
+ """ Same idea, but two poles. """
+
+ CENTERS = array(crossproduct([[-pi/4, 0, pi/4], [1, 0, -1]]*2))
+
+ def __init__(self, env=None, maxsteps=1000):
+ if env == None:
+ env = DoublePoleEnvironment()
+ DiscreteBalanceTask.__init__(self, env, maxsteps)
+
24 pybrain/rl/learners/valuebased/linearfa.py
View
@@ -167,11 +167,14 @@ class LSPI(LinearFALearner):
passNextAction = True
+ lazyInversions = 20
+
def _additionalInit(self):
phi_size = self.num_actions * self.num_features
self._A = zeros((phi_size, phi_size))
self._b = zeros(phi_size)
self._untouched = ones(phi_size, dtype=bool)
+ self._count = 0
def _updateWeights(self, state, action, reward, next_state, next_action):
phi = zeros((self.num_actions, self.num_features))
@@ -184,14 +187,19 @@ def _updateWeights(self, state, action, reward, next_state, next_action):
self._A += outer(phi, phi - self.rewardDiscount * phi_n)
self._b += reward * phi
- if self.exploring:
- # add something to all the entries that are untouched
- self._untouched &= (phi == 0)
- res = dot(pinv2(self._A), self._b + self.explorationReward * self._untouched)
- else:
- res = dot(pinv2(self._A), self._b)
- self._theta = res.reshape(self.num_actions, self.num_features)
-
+
+
+ if self.lazyInversions is None or self._count % self.lazyInversions == 0:
+ if self.exploring:
+ # add something to all the entries that are untouched
+ self._untouched &= (phi == 0)
+ res = dot(pinv2(self._A), self._b + self.explorationReward * self._untouched)
+ else:
+ res = dot(pinv2(self._A), self._b)
+ self._theta = res.reshape(self.num_actions, self.num_features)
+
+ self._count += 1
+
class GQLambda(QLambda_LinFA):
""" From Maei/Sutton 2010, with additional info from Adam White. """
Please sign in to comment.
Something went wrong with that request. Please try again.