diff --git a/mushroom_rl/algorithms/value/td/double_q_learning.py b/mushroom_rl/algorithms/value/td/double_q_learning.py index 360be7e3..0073524d 100644 --- a/mushroom_rl/algorithms/value/td/double_q_learning.py +++ b/mushroom_rl/algorithms/value/td/double_q_learning.py @@ -12,11 +12,9 @@ class DoubleQLearning(TD): """ def __init__(self, mdp_info, policy, learning_rate): - self.Q = EnsembleTable(2, mdp_info.size) + Q = EnsembleTable(2, mdp_info.size) - self._add_save_attr(Q='pickle', alpha='pickle') - - super().__init__(mdp_info, policy, self.Q, learning_rate) + super().__init__(mdp_info, policy, Q, learning_rate) self.alpha = [deepcopy(self.alpha), deepcopy(self.alpha)] diff --git a/mushroom_rl/algorithms/value/td/expected_sarsa.py b/mushroom_rl/algorithms/value/td/expected_sarsa.py index 412a54b3..66187608 100644 --- a/mushroom_rl/algorithms/value/td/expected_sarsa.py +++ b/mushroom_rl/algorithms/value/td/expected_sarsa.py @@ -10,10 +10,9 @@ class ExpectedSARSA(TD): """ def __init__(self, mdp_info, policy, learning_rate): - self.Q = Table(mdp_info.size) - self._add_save_attr(Q='mushroom') + Q = Table(mdp_info.size) - super().__init__(mdp_info, policy, self.Q, learning_rate) + super().__init__(mdp_info, policy, Q, learning_rate) def _update(self, state, action, reward, next_state, absorbing): q_current = self.Q[state, action] diff --git a/mushroom_rl/algorithms/value/td/q_learning.py b/mushroom_rl/algorithms/value/td/q_learning.py index 65b14349..2f762b74 100644 --- a/mushroom_rl/algorithms/value/td/q_learning.py +++ b/mushroom_rl/algorithms/value/td/q_learning.py @@ -11,11 +11,9 @@ class QLearning(TD): """ def __init__(self, mdp_info, policy, learning_rate): - self.Q = Table(mdp_info.size) + Q = Table(mdp_info.size) - self._add_save_attr(Q='mushroom') - - super().__init__(mdp_info, policy, self.Q, learning_rate) + super().__init__(mdp_info, policy, Q, learning_rate) def _update(self, state, action, reward, next_state, absorbing): q_current = self.Q[state, action] diff --git a/mushroom_rl/algorithms/value/td/r_learning.py b/mushroom_rl/algorithms/value/td/r_learning.py index 18e52367..e4730b9c 100644 --- a/mushroom_rl/algorithms/value/td/r_learning.py +++ b/mushroom_rl/algorithms/value/td/r_learning.py @@ -19,13 +19,13 @@ def __init__(self, mdp_info, policy, learning_rate, beta): beta (Parameter): beta coefficient. """ - self.Q = Table(mdp_info.size) + Q = Table(mdp_info.size) self._rho = 0. self.beta = beta - self._add_save_attr(Q='mushroom', _rho='primitive', beta='pickle') + self._add_save_attr(_rho='primitive', beta='pickle') - super().__init__(mdp_info, policy, self.Q, learning_rate) + super().__init__(mdp_info, policy, Q, learning_rate) def _update(self, state, action, reward, next_state, absorbing): q_current = self.Q[state, action] diff --git a/mushroom_rl/algorithms/value/td/rq_learning.py b/mushroom_rl/algorithms/value/td/rq_learning.py index b19e02b0..a3d18870 100644 --- a/mushroom_rl/algorithms/value/td/rq_learning.py +++ b/mushroom_rl/algorithms/value/td/rq_learning.py @@ -33,20 +33,19 @@ def __init__(self, mdp_info, policy, learning_rate, off_policy=False, else: raise ValueError('delta or beta parameters needed.') - self.Q = Table(mdp_info.size) + Q = Table(mdp_info.size) self.Q_tilde = Table(mdp_info.size) self.R_tilde = Table(mdp_info.size) self._add_save_attr( - off_policy='pickle', + off_policy='primitive', delta='pickle', beta='pickle', - Q='mushroom', Q_tilde='mushroom', R_tilde='mushroom' ) - super().__init__(mdp_info, policy, self.Q, learning_rate) + super().__init__(mdp_info, policy, Q, learning_rate) def _update(self, state, action, reward, next_state, absorbing): alpha = self.alpha(state, action, target=reward) diff --git a/mushroom_rl/algorithms/value/td/sarsa.py b/mushroom_rl/algorithms/value/td/sarsa.py index 42397842..673b8b06 100644 --- a/mushroom_rl/algorithms/value/td/sarsa.py +++ b/mushroom_rl/algorithms/value/td/sarsa.py @@ -8,10 +8,9 @@ class SARSA(TD): """ def __init__(self, mdp_info, policy, learning_rate): - self.Q = Table(mdp_info.size) - self._add_save_attr(Q='mushroom') + Q = Table(mdp_info.size) - super().__init__(mdp_info, policy, self.Q, learning_rate) + super().__init__(mdp_info, policy, Q, learning_rate) def _update(self, state, action, reward, next_state, absorbing): q_current = self.Q[state, action] diff --git a/mushroom_rl/algorithms/value/td/sarsa_lambda.py b/mushroom_rl/algorithms/value/td/sarsa_lambda.py index 2574d2b0..ef29356d 100644 --- a/mushroom_rl/algorithms/value/td/sarsa_lambda.py +++ b/mushroom_rl/algorithms/value/td/sarsa_lambda.py @@ -18,17 +18,16 @@ def __init__(self, mdp_info, policy, learning_rate, lambda_coeff, trace (str, 'replacing'): type of eligibility trace to use. """ - self.Q = Table(mdp_info.size) + Q = Table(mdp_info.size) self._lambda = lambda_coeff - self.e = EligibilityTrace(self.Q.shape, trace) + self.e = EligibilityTrace(Q.shape, trace) self._add_save_attr( - Q='mushroom', _lambda='primitive', e='pickle' ) - super().__init__(mdp_info, policy, self.Q, learning_rate) + super().__init__(mdp_info, policy, Q, learning_rate) def _update(self, state, action, reward, next_state, absorbing): q_current = self.Q[state, action] diff --git a/mushroom_rl/algorithms/value/td/sarsa_lambda_continuous.py b/mushroom_rl/algorithms/value/td/sarsa_lambda_continuous.py index d2120611..c0122af0 100644 --- a/mushroom_rl/algorithms/value/td/sarsa_lambda_continuous.py +++ b/mushroom_rl/algorithms/value/td/sarsa_lambda_continuous.py @@ -18,21 +18,19 @@ def __init__(self, mdp_info, policy, approximator, learning_rate, lambda_coeff (float): eligibility trace coefficient. """ - self._approximator_params = dict() if approximator_params is None else \ + approximator_params = dict() if approximator_params is None else \ approximator_params - self.Q = Regressor(approximator, **self._approximator_params) - self.e = np.zeros(self.Q.weights_size) + Q = Regressor(approximator, **approximator_params) + self.e = np.zeros(Q.weights_size) self._lambda = lambda_coeff self._add_save_attr( - _approximator_params='pickle', - Q='pickle', _lambda='primitive', e='numpy' ) - super().__init__(mdp_info, policy, self.Q, learning_rate, features) + super().__init__(mdp_info, policy, Q, learning_rate, features) def _update(self, state, action, reward, next_state, absorbing): phi_state = self.phi(state) diff --git a/mushroom_rl/algorithms/value/td/speedy_q_learning.py b/mushroom_rl/algorithms/value/td/speedy_q_learning.py index fee592de..d5ca4fda 100644 --- a/mushroom_rl/algorithms/value/td/speedy_q_learning.py +++ b/mushroom_rl/algorithms/value/td/speedy_q_learning.py @@ -12,12 +12,12 @@ class SpeedyQLearning(TD): """ def __init__(self, mdp_info, policy, learning_rate): - self.Q = Table(mdp_info.size) - self.old_q = deepcopy(self.Q) + Q = Table(mdp_info.size) + self.old_q = deepcopy(Q) - self._add_save_attr(Q='mushroom', old_q='mushroom') + self._add_save_attr(old_q='mushroom') - super().__init__(mdp_info, policy, self.Q, learning_rate) + super().__init__(mdp_info, policy, Q, learning_rate) def _update(self, state, action, reward, next_state, absorbing): old_q = deepcopy(self.Q) diff --git a/mushroom_rl/algorithms/value/td/td.py b/mushroom_rl/algorithms/value/td/td.py index c296b49d..08bcc2dd 100644 --- a/mushroom_rl/algorithms/value/td/td.py +++ b/mushroom_rl/algorithms/value/td/td.py @@ -22,9 +22,9 @@ def __init__(self, mdp_info, policy, approximator, learning_rate, self.alpha = learning_rate policy.set_q(approximator) - self.approximator = approximator + self.Q = approximator - self._add_save_attr(alpha='pickle', approximator='pickle') + self._add_save_attr(alpha='pickle', Q='mushroom') super().__init__(mdp_info, policy, features) diff --git a/mushroom_rl/algorithms/value/td/true_online_sarsa_lambda.py b/mushroom_rl/algorithms/value/td/true_online_sarsa_lambda.py index 8c114354..ad50b921 100644 --- a/mushroom_rl/algorithms/value/td/true_online_sarsa_lambda.py +++ b/mushroom_rl/algorithms/value/td/true_online_sarsa_lambda.py @@ -21,23 +21,21 @@ def __init__(self, mdp_info, policy, learning_rate, lambda_coeff, lambda_coeff (float): eligibility trace coefficient. """ - self._approximator_params = dict() if approximator_params is None else \ + approximator_params = dict() if approximator_params is None else \ approximator_params - self.Q = Regressor(LinearApproximator, **self._approximator_params) - self.e = np.zeros(self.Q.weights_size) + Q = Regressor(LinearApproximator, **approximator_params) + self.e = np.zeros(Q.weights_size) self._lambda = lambda_coeff self._q_old = None self._add_save_attr( - _approximator_params='pickle', - Q='pickle', - _q_old='pickle', + _q_old='numpy', _lambda='primitive', e='numpy' ) - super().__init__(mdp_info, policy, self.Q, learning_rate, features) + super().__init__(mdp_info, policy, Q, learning_rate, features) def _update(self, state, action, reward, next_state, absorbing): phi_state = self.phi(state) diff --git a/mushroom_rl/algorithms/value/td/weighted_q_learning.py b/mushroom_rl/algorithms/value/td/weighted_q_learning.py index c9a7833b..51d5e4dc 100644 --- a/mushroom_rl/algorithms/value/td/weighted_q_learning.py +++ b/mushroom_rl/algorithms/value/td/weighted_q_learning.py @@ -23,23 +23,22 @@ def __init__(self, mdp_info, policy, learning_rate, sampling=True, version. """ - self.Q = Table(mdp_info.size) + Q = Table(mdp_info.size) self._sampling = sampling self._precision = precision self._add_save_attr( - Q='mushroom', _sampling='primitive', _precision='primitive', - _n_updates='pickle', - _sigma='pickle', + _n_updates='mushroom', + _sigma='mushroom', _Q2='mushroom', _w='primitive', _w1='mushroom', _w2='mushroom' ) - super().__init__(mdp_info, policy, self.Q, learning_rate) + super().__init__(mdp_info, policy, Q, learning_rate) self._n_updates = Table(mdp_info.size) self._sigma = Table(mdp_info.size, initial_value=1e10) diff --git a/mushroom_rl/approximators/_implementations/action_regressor.py b/mushroom_rl/approximators/_implementations/action_regressor.py index 60899fa0..6accb164 100644 --- a/mushroom_rl/approximators/_implementations/action_regressor.py +++ b/mushroom_rl/approximators/_implementations/action_regressor.py @@ -1,7 +1,9 @@ import numpy as np +from mushroom_rl.core import Serializable -class ActionRegressor: + +class ActionRegressor(Serializable): """ This class is used to approximate the Q-function with a different approximator of the provided class for each action. It is often used in MDPs @@ -14,7 +16,7 @@ def __init__(self, approximator, n_actions, **params): Constructor. Args: - approximator (object): the model class to approximate the + approximator (class): the model class to approximate the Q-function of each action; n_actions (int): number of different actions of the problem. It determines the number of different regressors in the action @@ -28,6 +30,11 @@ def __init__(self, approximator, n_actions, **params): for i in range(self._n_actions): self.model.append(approximator(**params)) + self._add_save_attr( + _n_actions='primitive', + model=self._get_serialization_method(approximator) + ) + def fit(self, state, action, q, **fit_params): """ Fit the model. diff --git a/mushroom_rl/approximators/_implementations/ensemble.py b/mushroom_rl/approximators/_implementations/ensemble.py index 543b8805..6c4171e1 100644 --- a/mushroom_rl/approximators/_implementations/ensemble.py +++ b/mushroom_rl/approximators/_implementations/ensemble.py @@ -1,8 +1,10 @@ import numpy as np from sklearn.exceptions import NotFittedError +from mushroom_rl.core import Serializable -class Ensemble(object): + +class Ensemble(Serializable): """ This class is used to create an ensemble of regressors. @@ -12,7 +14,7 @@ def __init__(self, model, n_models, **params): Constructor. Args: - approximator (object): the model class to approximate the + approximator (class): the model class to approximate the Q-function. n_models (int): number of regressors in the ensemble; **params: parameters dictionary to create each regressor. @@ -23,6 +25,10 @@ def __init__(self, model, n_models, **params): for _ in range(n_models): self._model.append(model(**params)) + self._add_save_attr( + _model=self._get_serialization_method(model) + ) + def fit(self, *z, idx=None, **fit_params): """ Fit the ``idx``-th model of the ensemble if ``idx`` is provided, every diff --git a/mushroom_rl/approximators/_implementations/generic_regressor.py b/mushroom_rl/approximators/_implementations/generic_regressor.py index a5da4416..84441f96 100644 --- a/mushroom_rl/approximators/_implementations/generic_regressor.py +++ b/mushroom_rl/approximators/_implementations/generic_regressor.py @@ -1,4 +1,7 @@ -class GenericRegressor: +from mushroom_rl.core import Serializable + + +class GenericRegressor(Serializable): """ This class is used to create a regressor that approximates a generic function. An arbitrary number of inputs and outputs is supported. @@ -9,7 +12,7 @@ def __init__(self, approximator, n_inputs, **params): Constructor. Args: - approximator (object): the model class to approximate the + approximator (class): the model class to approximate the a generic function; n_inputs (int): number of inputs of the regressor; **params: parameters dictionary to the regressor; @@ -18,6 +21,11 @@ def __init__(self, approximator, n_inputs, **params): self._n_inputs = n_inputs self.model = approximator(**params) + self._add_save_attr( + _n_inputs='primitive', + model=self._get_serialization_method(approximator) + ) + def fit(self, *z, **fit_params): """ Fit the model. diff --git a/mushroom_rl/approximators/_implementations/q_regressor.py b/mushroom_rl/approximators/_implementations/q_regressor.py index 71b24773..494d30f1 100644 --- a/mushroom_rl/approximators/_implementations/q_regressor.py +++ b/mushroom_rl/approximators/_implementations/q_regressor.py @@ -1,7 +1,8 @@ import numpy as np +from mushroom_rl.core import Serializable -class QRegressor: +class QRegressor(Serializable): """ This class is used to create a regressor that approximates the Q-function using a multi-dimensional output where each output corresponds to the @@ -14,13 +15,17 @@ def __init__(self, approximator, **params): Constructor. Args: - approximator (object): the model class to approximate the + approximator (class): the model class to approximate the Q-function; **params: parameters dictionary to the regressor. """ self.model = approximator(**params) + self._add_save_attr( + model=self._get_serialization_method(approximator) + ) + def fit(self, state, action, q, **fit_params): """ Fit the model. diff --git a/mushroom_rl/approximators/parametric/linear.py b/mushroom_rl/approximators/parametric/linear.py index 01d18552..942a8c5b 100644 --- a/mushroom_rl/approximators/parametric/linear.py +++ b/mushroom_rl/approximators/parametric/linear.py @@ -1,7 +1,9 @@ import numpy as np +from mushroom_rl.core import Serializable -class LinearApproximator: + +class LinearApproximator(Serializable): """ This class implements a linear approximator. @@ -34,6 +36,8 @@ def __init__(self, weights=None, input_shape=None, output_shape=(1,), raise ValueError('You should specify the initial parameter vector' ' or the input dimension') + self._add_save_attr(_w='numpy') + def fit(self, x, y, **fit_params): """ Fit the model. diff --git a/mushroom_rl/approximators/parametric/torch_approximator.py b/mushroom_rl/approximators/parametric/torch_approximator.py index 5177aa52..f008ebc7 100644 --- a/mushroom_rl/approximators/parametric/torch_approximator.py +++ b/mushroom_rl/approximators/parametric/torch_approximator.py @@ -2,11 +2,12 @@ import numpy as np from tqdm import trange, tqdm +from mushroom_rl.core import Serializable from mushroom_rl.utils.minibatches import minibatch_generator from mushroom_rl.utils.torch import get_weights, set_weights, zero_grad -class TorchApproximator: +class TorchApproximator(Serializable): """ Class to interface a pytorch model to the mushroom Regressor interface. This class implements all is needed to use a generic pytorch model and train @@ -64,6 +65,18 @@ def __init__(self, input_shape, output_shape, network, optimizer=None, **optimizer['params']) self._loss = loss + self._add_save_attr( + _batch_size='primitive', + _reinitialize='primitive', + _use_cuda='primitive', + _dropout='primitive', + _quiet='primitive', + _n_fit_targets='primitive', + network='torch', + _optimizer='torch', + _loss='pickle' + ) + def predict(self, *args, output_tensor=False, **kwargs): """ Predict. diff --git a/mushroom_rl/approximators/regressor.py b/mushroom_rl/approximators/regressor.py index 6b9f5c59..4df38bf7 100644 --- a/mushroom_rl/approximators/regressor.py +++ b/mushroom_rl/approximators/regressor.py @@ -1,12 +1,13 @@ import numpy as np +from mushroom_rl.core import Serializable from ._implementations.q_regressor import QRegressor from ._implementations.action_regressor import ActionRegressor from ._implementations.ensemble import Ensemble from ._implementations.generic_regressor import GenericRegressor -class Regressor: +class Regressor(Serializable): """ This class implements the function to manage a function approximator. This class selects the appropriate kind of regressor to implement according to @@ -29,7 +30,7 @@ def __init__(self, approximator, input_shape, output_shape=(1,), Constructor. Args: - approximator (object): the approximator class to use to create + approximator (class): the approximator class to use to create the model; input_shape (tuple): the shape of the input of the model; output_shape (tuple, (1,)): the shape of the output of the model; @@ -67,6 +68,14 @@ def __init__(self, approximator, input_shape, output_shape=(1,), len(self.input_shape), **params) + self._add_save_attr( + _input_shape='primitive', + _output_shape='primitive', + n_actions='primitive', + _n_models='primitive', + _impl='mushroom' + ) + def __call__(self, *z, **predict_params): return self.predict(*z, **predict_params) diff --git a/mushroom_rl/core/serialization.py b/mushroom_rl/core/serialization.py index b7765727..71ba2bb7 100644 --- a/mushroom_rl/core/serialization.py +++ b/mushroom_rl/core/serialization.py @@ -126,6 +126,7 @@ def load_zip(cls, zip_file, folder=''): 'implemented'.format(method)) att_val = load_method(zip_file, file_name) setattr(loaded_object, att, att_val) + else: setattr(loaded_object, att, None) @@ -248,3 +249,10 @@ def _save_mushroom(zip_file, name, obj, folder, full_save): else: obj.save_zip(zip_file, full_save=full_save, folder=new_folder) + @staticmethod + def _get_serialization_method(class_name): + if issubclass(class_name, Serializable): + return 'mushroom' + else: + return 'pickle' +