Skip to content

Commit

Permalink
Improved serialization of msuhroom objects
Browse files Browse the repository at this point in the history
- Serializing now approximators and Regressors
- Fixed TD implementation to use upper class Q function
- easier serialization for TD algorithms
  • Loading branch information
boris-il-forte committed Apr 24, 2020
1 parent e680a30 commit 33b7841
Show file tree
Hide file tree
Showing 20 changed files with 108 additions and 61 deletions.
6 changes: 2 additions & 4 deletions mushroom_rl/algorithms/value/td/double_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ class DoubleQLearning(TD):
"""
def __init__(self, mdp_info, policy, learning_rate):
self.Q = EnsembleTable(2, mdp_info.size)
Q = EnsembleTable(2, mdp_info.size)

self._add_save_attr(Q='pickle', alpha='pickle')

super().__init__(mdp_info, policy, self.Q, learning_rate)
super().__init__(mdp_info, policy, Q, learning_rate)

self.alpha = [deepcopy(self.alpha), deepcopy(self.alpha)]

Expand Down
5 changes: 2 additions & 3 deletions mushroom_rl/algorithms/value/td/expected_sarsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@ class ExpectedSARSA(TD):
"""
def __init__(self, mdp_info, policy, learning_rate):
self.Q = Table(mdp_info.size)
self._add_save_attr(Q='mushroom')
Q = Table(mdp_info.size)

super().__init__(mdp_info, policy, self.Q, learning_rate)
super().__init__(mdp_info, policy, Q, learning_rate)

def _update(self, state, action, reward, next_state, absorbing):
q_current = self.Q[state, action]
Expand Down
6 changes: 2 additions & 4 deletions mushroom_rl/algorithms/value/td/q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@ class QLearning(TD):
"""
def __init__(self, mdp_info, policy, learning_rate):
self.Q = Table(mdp_info.size)
Q = Table(mdp_info.size)

self._add_save_attr(Q='mushroom')

super().__init__(mdp_info, policy, self.Q, learning_rate)
super().__init__(mdp_info, policy, Q, learning_rate)

def _update(self, state, action, reward, next_state, absorbing):
q_current = self.Q[state, action]
Expand Down
6 changes: 3 additions & 3 deletions mushroom_rl/algorithms/value/td/r_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ def __init__(self, mdp_info, policy, learning_rate, beta):
beta (Parameter): beta coefficient.
"""
self.Q = Table(mdp_info.size)
Q = Table(mdp_info.size)
self._rho = 0.
self.beta = beta

self._add_save_attr(Q='mushroom', _rho='primitive', beta='pickle')
self._add_save_attr(_rho='primitive', beta='pickle')

super().__init__(mdp_info, policy, self.Q, learning_rate)
super().__init__(mdp_info, policy, Q, learning_rate)

def _update(self, state, action, reward, next_state, absorbing):
q_current = self.Q[state, action]
Expand Down
7 changes: 3 additions & 4 deletions mushroom_rl/algorithms/value/td/rq_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,19 @@ def __init__(self, mdp_info, policy, learning_rate, off_policy=False,
else:
raise ValueError('delta or beta parameters needed.')

self.Q = Table(mdp_info.size)
Q = Table(mdp_info.size)
self.Q_tilde = Table(mdp_info.size)
self.R_tilde = Table(mdp_info.size)

self._add_save_attr(
off_policy='pickle',
off_policy='primitive',
delta='pickle',
beta='pickle',
Q='mushroom',
Q_tilde='mushroom',
R_tilde='mushroom'
)

super().__init__(mdp_info, policy, self.Q, learning_rate)
super().__init__(mdp_info, policy, Q, learning_rate)

def _update(self, state, action, reward, next_state, absorbing):
alpha = self.alpha(state, action, target=reward)
Expand Down
5 changes: 2 additions & 3 deletions mushroom_rl/algorithms/value/td/sarsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ class SARSA(TD):
"""
def __init__(self, mdp_info, policy, learning_rate):
self.Q = Table(mdp_info.size)
self._add_save_attr(Q='mushroom')
Q = Table(mdp_info.size)

super().__init__(mdp_info, policy, self.Q, learning_rate)
super().__init__(mdp_info, policy, Q, learning_rate)

def _update(self, state, action, reward, next_state, absorbing):
q_current = self.Q[state, action]
Expand Down
7 changes: 3 additions & 4 deletions mushroom_rl/algorithms/value/td/sarsa_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,16 @@ def __init__(self, mdp_info, policy, learning_rate, lambda_coeff,
trace (str, 'replacing'): type of eligibility trace to use.
"""
self.Q = Table(mdp_info.size)
Q = Table(mdp_info.size)
self._lambda = lambda_coeff

self.e = EligibilityTrace(self.Q.shape, trace)
self.e = EligibilityTrace(Q.shape, trace)
self._add_save_attr(
Q='mushroom',
_lambda='primitive',
e='pickle'
)

super().__init__(mdp_info, policy, self.Q, learning_rate)
super().__init__(mdp_info, policy, Q, learning_rate)

def _update(self, state, action, reward, next_state, absorbing):
q_current = self.Q[state, action]
Expand Down
10 changes: 4 additions & 6 deletions mushroom_rl/algorithms/value/td/sarsa_lambda_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,19 @@ def __init__(self, mdp_info, policy, approximator, learning_rate,
lambda_coeff (float): eligibility trace coefficient.
"""
self._approximator_params = dict() if approximator_params is None else \
approximator_params = dict() if approximator_params is None else \
approximator_params

self.Q = Regressor(approximator, **self._approximator_params)
self.e = np.zeros(self.Q.weights_size)
Q = Regressor(approximator, **approximator_params)
self.e = np.zeros(Q.weights_size)
self._lambda = lambda_coeff

self._add_save_attr(
_approximator_params='pickle',
Q='pickle',
_lambda='primitive',
e='numpy'
)

super().__init__(mdp_info, policy, self.Q, learning_rate, features)
super().__init__(mdp_info, policy, Q, learning_rate, features)

def _update(self, state, action, reward, next_state, absorbing):
phi_state = self.phi(state)
Expand Down
8 changes: 4 additions & 4 deletions mushroom_rl/algorithms/value/td/speedy_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ class SpeedyQLearning(TD):
"""
def __init__(self, mdp_info, policy, learning_rate):
self.Q = Table(mdp_info.size)
self.old_q = deepcopy(self.Q)
Q = Table(mdp_info.size)
self.old_q = deepcopy(Q)

self._add_save_attr(Q='mushroom', old_q='mushroom')
self._add_save_attr(old_q='mushroom')

super().__init__(mdp_info, policy, self.Q, learning_rate)
super().__init__(mdp_info, policy, Q, learning_rate)

def _update(self, state, action, reward, next_state, absorbing):
old_q = deepcopy(self.Q)
Expand Down
4 changes: 2 additions & 2 deletions mushroom_rl/algorithms/value/td/td.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def __init__(self, mdp_info, policy, approximator, learning_rate,
self.alpha = learning_rate

policy.set_q(approximator)
self.approximator = approximator
self.Q = approximator

self._add_save_attr(alpha='pickle', approximator='pickle')
self._add_save_attr(alpha='pickle', Q='mushroom')

super().__init__(mdp_info, policy, features)

Expand Down
12 changes: 5 additions & 7 deletions mushroom_rl/algorithms/value/td/true_online_sarsa_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,21 @@ def __init__(self, mdp_info, policy, learning_rate, lambda_coeff,
lambda_coeff (float): eligibility trace coefficient.
"""
self._approximator_params = dict() if approximator_params is None else \
approximator_params = dict() if approximator_params is None else \
approximator_params

self.Q = Regressor(LinearApproximator, **self._approximator_params)
self.e = np.zeros(self.Q.weights_size)
Q = Regressor(LinearApproximator, **approximator_params)
self.e = np.zeros(Q.weights_size)
self._lambda = lambda_coeff
self._q_old = None

self._add_save_attr(
_approximator_params='pickle',
Q='pickle',
_q_old='pickle',
_q_old='numpy',
_lambda='primitive',
e='numpy'
)

super().__init__(mdp_info, policy, self.Q, learning_rate, features)
super().__init__(mdp_info, policy, Q, learning_rate, features)

def _update(self, state, action, reward, next_state, absorbing):
phi_state = self.phi(state)
Expand Down
9 changes: 4 additions & 5 deletions mushroom_rl/algorithms/value/td/weighted_q_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,22 @@ def __init__(self, mdp_info, policy, learning_rate, sampling=True,
version.
"""
self.Q = Table(mdp_info.size)
Q = Table(mdp_info.size)
self._sampling = sampling
self._precision = precision

self._add_save_attr(
Q='mushroom',
_sampling='primitive',
_precision='primitive',
_n_updates='pickle',
_sigma='pickle',
_n_updates='mushroom',
_sigma='mushroom',
_Q2='mushroom',
_w='primitive',
_w1='mushroom',
_w2='mushroom'
)

super().__init__(mdp_info, policy, self.Q, learning_rate)
super().__init__(mdp_info, policy, Q, learning_rate)

self._n_updates = Table(mdp_info.size)
self._sigma = Table(mdp_info.size, initial_value=1e10)
Expand Down
11 changes: 9 additions & 2 deletions mushroom_rl/approximators/_implementations/action_regressor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np

from mushroom_rl.core import Serializable

class ActionRegressor:

class ActionRegressor(Serializable):
"""
This class is used to approximate the Q-function with a different
approximator of the provided class for each action. It is often used in MDPs
Expand All @@ -14,7 +16,7 @@ def __init__(self, approximator, n_actions, **params):
Constructor.
Args:
approximator (object): the model class to approximate the
approximator (class): the model class to approximate the
Q-function of each action;
n_actions (int): number of different actions of the problem. It
determines the number of different regressors in the action
Expand All @@ -28,6 +30,11 @@ def __init__(self, approximator, n_actions, **params):
for i in range(self._n_actions):
self.model.append(approximator(**params))

self._add_save_attr(
_n_actions='primitive',
model=self._get_serialization_method(approximator)
)

def fit(self, state, action, q, **fit_params):
"""
Fit the model.
Expand Down
10 changes: 8 additions & 2 deletions mushroom_rl/approximators/_implementations/ensemble.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import numpy as np
from sklearn.exceptions import NotFittedError

from mushroom_rl.core import Serializable

class Ensemble(object):

class Ensemble(Serializable):
"""
This class is used to create an ensemble of regressors.
Expand All @@ -12,7 +14,7 @@ def __init__(self, model, n_models, **params):
Constructor.
Args:
approximator (object): the model class to approximate the
approximator (class): the model class to approximate the
Q-function.
n_models (int): number of regressors in the ensemble;
**params: parameters dictionary to create each regressor.
Expand All @@ -23,6 +25,10 @@ def __init__(self, model, n_models, **params):
for _ in range(n_models):
self._model.append(model(**params))

self._add_save_attr(
_model=self._get_serialization_method(model)
)

def fit(self, *z, idx=None, **fit_params):
"""
Fit the ``idx``-th model of the ensemble if ``idx`` is provided, every
Expand Down
12 changes: 10 additions & 2 deletions mushroom_rl/approximators/_implementations/generic_regressor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
class GenericRegressor:
from mushroom_rl.core import Serializable


class GenericRegressor(Serializable):
"""
This class is used to create a regressor that approximates a generic
function. An arbitrary number of inputs and outputs is supported.
Expand All @@ -9,7 +12,7 @@ def __init__(self, approximator, n_inputs, **params):
Constructor.
Args:
approximator (object): the model class to approximate the
approximator (class): the model class to approximate the
a generic function;
n_inputs (int): number of inputs of the regressor;
**params: parameters dictionary to the regressor;
Expand All @@ -18,6 +21,11 @@ def __init__(self, approximator, n_inputs, **params):
self._n_inputs = n_inputs
self.model = approximator(**params)

self._add_save_attr(
_n_inputs='primitive',
model=self._get_serialization_method(approximator)
)

def fit(self, *z, **fit_params):
"""
Fit the model.
Expand Down
9 changes: 7 additions & 2 deletions mushroom_rl/approximators/_implementations/q_regressor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
from mushroom_rl.core import Serializable


class QRegressor:
class QRegressor(Serializable):
"""
This class is used to create a regressor that approximates the Q-function
using a multi-dimensional output where each output corresponds to the
Expand All @@ -14,13 +15,17 @@ def __init__(self, approximator, **params):
Constructor.
Args:
approximator (object): the model class to approximate the
approximator (class): the model class to approximate the
Q-function;
**params: parameters dictionary to the regressor.
"""
self.model = approximator(**params)

self._add_save_attr(
model=self._get_serialization_method(approximator)
)

def fit(self, state, action, q, **fit_params):
"""
Fit the model.
Expand Down
6 changes: 5 additions & 1 deletion mushroom_rl/approximators/parametric/linear.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import numpy as np

from mushroom_rl.core import Serializable

class LinearApproximator:

class LinearApproximator(Serializable):
"""
This class implements a linear approximator.
Expand Down Expand Up @@ -34,6 +36,8 @@ def __init__(self, weights=None, input_shape=None, output_shape=(1,),
raise ValueError('You should specify the initial parameter vector'
' or the input dimension')

self._add_save_attr(_w='numpy')

def fit(self, x, y, **fit_params):
"""
Fit the model.
Expand Down

0 comments on commit 33b7841

Please sign in to comment.