From e1d0dc2cbb8c93fa852476e9ef0ee4a0dfd6263b Mon Sep 17 00:00:00 2001 From: Ken Chatfield Date: Fri, 11 Nov 2016 17:17:40 +0000 Subject: [PATCH 01/27] Support for TB hist visualisation when using generator for validation (Fixes #3358) --- keras/callbacks.py | 72 ++++++++++++++++++++++++++++++++-------- keras/engine/training.py | 6 ++++ 2 files changed, 64 insertions(+), 14 deletions(-) diff --git a/keras/callbacks.py b/keras/callbacks.py index b44236b4f1e5..0772393cd2bc 100644 --- a/keras/callbacks.py +++ b/keras/callbacks.py @@ -527,21 +527,65 @@ def _set_model(self, model): def on_epoch_end(self, epoch, logs={}): import tensorflow as tf - if self.model.validation_data and self.histogram_freq: - if epoch % self.histogram_freq == 0: - # TODO: implement batched calls to sess.run - # (current call will likely go OOM on GPU) - if self.model.uses_learning_phase: - cut_v_data = len(self.model.inputs) - val_data = self.model.validation_data[:cut_v_data] + [0] - tensors = self.model.inputs + [K.learning_phase()] - else: - val_data = self.model.validation_data - tensors = self.model.inputs - feed_dict = dict(zip(tensors, val_data)) - result = self.sess.run([self.merged], feed_dict=feed_dict) - summary_str = result[0] + def get_val_summary(validation_data): + if self.model.uses_learning_phase: + cut_v_data = len(self.model.inputs) + val_data = list(validation_data[:cut_v_data]) + [0] + tensors = self.model.inputs + [K.learning_phase()] + else: + val_data = validation_data + tensors = self.model.inputs + feed_dict = dict(zip(tensors, val_data)) + result = self.sess.run([self.merged], feed_dict=feed_dict) + return result[0] + + if self.histogram_freq and epoch % self.histogram_freq == 0: + if self.model.validation_data: + summary_str = get_val_summary(self.model.validation_data) self.writer.add_summary(summary_str, epoch) + elif self.model.validation_gen: + val_gen = self.model.validation_gen['generator'] + nb_val_samples = self.model.validation_gen['nb_samples'] + # process nb_samples from validation data generator + sub_summaries = [] + processed_samples = 0 + while processed_samples < nb_val_samples: + validation_data = next(val_gen) + summary = tf.Summary.FromString(get_val_summary(validation_data)) + sub_summaries.append(summary) + processed_samples += validation_data[0].shape[0] + # convert summaries to dict of lists + sub_summaries_dict = {} + for sub_summary in sub_summaries: + for value in sub_summary.value: + value_field = value.WhichOneof('value') + value_ifo = sub_summaries_dict.setdefault(value.tag, {'value_field': None, 'values': []}) + if not value_ifo['value_field']: + value_ifo['value_field'] = value_field + else: + assert value_ifo['value_field'] == value_field + value_ifo['values'].append(getattr(value, value_field)) + # aggregate summaries + summary = tf.Summary() + for name, value_ifo in sub_summaries_dict.items(): + summary_value = summary.value.add() + summary_value.tag = name + if value_ifo['value_field'] == 'histo': + values = value_ifo['values'] + summary_value.histo.min = min([x.min for x in values]) + summary_value.histo.max = max([x.max for x in values]) + summary_value.histo.num = sum([x.num for x in values]) + summary_value.histo.sum = sum([x.sum for x in values]) + summary_value.histo.sum_squares = sum([x.sum_squares for x in values]) + # for histogram values, just take first batch for now + # TODO: aggregate histograms over batches + for lim in values[0].bucket_limit: + summary_value.histo.bucket_limit.append(lim) + for bucket in values[0].bucket: + summary_value.histo.bucket.append(bucket) + else: + print('Warning: could not aggregate summary of type {}'.format(value_ifo['value_field'])) + self.writer.add_summary(summary, epoch) for name, value in logs.items(): if name in ['batch', 'size']: diff --git a/keras/engine/training.py b/keras/engine/training.py index adc84339683e..938ba0bd0839 100644 --- a/keras/engine/training.py +++ b/keras/engine/training.py @@ -818,6 +818,7 @@ def _fit_loop(self, f, ins, out_labels=[], batch_size=32, callbacks.on_train_begin() callback_model.stop_training = False self.validation_data = val_ins + self.validation_gen = None for epoch in range(nb_epoch): callbacks.on_epoch_begin(epoch) @@ -1412,8 +1413,13 @@ def generate_arrays_from_file(path): 'or (val_x, val_y). Found: ' + str(validation_data)) val_x, val_y, val_sample_weights = self._standardize_user_data(val_x, val_y, val_sample_weight) self.validation_data = val_x + [val_y, val_sample_weights] + self.validation_gen = None else: self.validation_data = None + self.validation_gen = dict( + generator=validation_data, + nb_samples=nb_val_samples + ) # start generator thread storing batches into a queue data_gen_queue, _stop, generator_threads = generator_queue(generator, max_q_size=max_q_size, nb_worker=nb_worker, From 2498100a8bffbaffb6acba932aa623fb110994c0 Mon Sep 17 00:00:00 2001 From: Ken Chatfield Date: Tue, 15 Nov 2016 14:55:27 +0000 Subject: [PATCH 02/27] Added execute_args parameter to compile function * This allows parameters to be passed through to calls to sess.run() in the TensorFlow backend --- keras/backend/tensorflow_backend.py | 6 ++++-- keras/backend/theano_backend.py | 9 ++++++++- keras/engine/training.py | 22 +++++++++++++++------- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index 03dd2e552afd..08bb246da82a 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -1076,8 +1076,10 @@ def __init__(self, inputs, outputs, updates=[]): updates_ops.append(update) self.updates_op = tf.group(*updates_ops) - def __call__(self, inputs): + def __call__(self, inputs, **kwargs): assert type(inputs) in {list, tuple} + unrecognized_kwargs = set(kwargs.keys()) - {'options', 'run_metadata'} + assert len(unrecognized_kwargs) == 0, 'Unrecognised kwargs: {}'.format(unrecognized_kwargs) feed_dict = {} for tensor, value in zip(self.inputs, inputs): if is_sparse(tensor): @@ -1087,7 +1089,7 @@ def __call__(self, inputs): value = (indices, sparse_coo.data, sparse_coo.shape) feed_dict[tensor] = value session = get_session() - updated = session.run(self.outputs + [self.updates_op], feed_dict=feed_dict) + updated = session.run(self.outputs + [self.updates_op], feed_dict=feed_dict, **kwargs) return updated[:len(self.outputs)] diff --git a/keras/backend/theano_backend.py b/keras/backend/theano_backend.py index 2cd3c7a4bde7..5411a8a73449 100644 --- a/keras/backend/theano_backend.py +++ b/keras/backend/theano_backend.py @@ -14,6 +14,7 @@ from theano.sandbox.softsign import softsign as T_softsign import inspect import numpy as np +import warnings from .common import _FLOATX, _EPSILON, image_dim_ordering py_all = all @@ -806,8 +807,14 @@ def __init__(self, inputs, outputs, updates=[], **kwargs): on_unused_input='ignore', **kwargs) - def __call__(self, inputs): + def __call__(self, inputs, **kwargs): assert type(inputs) in {list, tuple} + if len(kwargs) > 0: + msg = [ + 'Expected no kwargs, you passed %s' % len(kwargs), + 'kwargs passed to F() are ignored with Theano backend' + ] + warnings.warn('\n'.join(msg)) return self.function(*inputs) diff --git a/keras/engine/training.py b/keras/engine/training.py index adc84339683e..5d3d8c34a995 100644 --- a/keras/engine/training.py +++ b/keras/engine/training.py @@ -466,7 +466,7 @@ def data_generator_task(): class Model(Container): def compile(self, optimizer, loss, metrics=[], loss_weights=None, - sample_weight_mode=None, **kwargs): + sample_weight_mode=None, execute_kwargs=None, **kwargs): '''Configures the model for training. # Arguments @@ -488,6 +488,9 @@ def compile(self, optimizer, loss, metrics=[], loss_weights=None, If the model has multiple outputs, you can use a different `sample_weight_mode` on each output by passing a dictionary or a list of modes. + execute_args: when using the Tensorflow backend, these arguments + are passed into calls to sess.run(func, feed_dict, **execute_args). + Ignored for Theano backend. kwargs: when using the Theano backend, these arguments are passed into K.function. Ignored for Tensorflow backend. ''' @@ -694,6 +697,11 @@ def append_metric(layer_num, metric_name, metric_tensor): self.total_loss = total_loss self.sample_weights = sample_weights + # these arguments will be passed into calls to sess.run() + # for the TensorFlow backend when executing the functions + # for train, test and predict + self._function_execute_args = execute_kwargs or {} + # functions for train, test and predict will # be compiled lazily when required. # This saves time when the user is not using all functions. @@ -844,7 +852,7 @@ def _fit_loop(self, f, ins, out_labels=[], batch_size=32, batch_logs['batch'] = batch_index batch_logs['size'] = len(batch_ids) callbacks.on_batch_begin(batch_index, batch_logs) - outs = f(ins_batch) + outs = f(ins_batch, **self._function_execute_args) if type(outs) != list: outs = [outs] for l, o in zip(out_labels, outs): @@ -898,7 +906,7 @@ def _predict_loop(self, f, ins, batch_size=32, verbose=0): else: ins_batch = slice_X(ins, batch_ids) - batch_outs = f(ins_batch) + batch_outs = f(ins_batch, **self._function_execute_args) if type(batch_outs) != list: batch_outs = [batch_outs] if batch_index == 0: @@ -943,7 +951,7 @@ def _test_loop(self, f, ins, batch_size=32, verbose=0): else: ins_batch = slice_X(ins, batch_ids) - batch_outs = f(ins_batch) + batch_outs = f(ins_batch, **self._function_execute_args) if type(batch_outs) == list: if batch_index == 0: for batch_out in enumerate(batch_outs): @@ -1241,7 +1249,7 @@ def train_on_batch(self, x, y, else: ins = x + y + sample_weights self._make_train_function() - outputs = self.train_function(ins) + outputs = self.train_function(ins, **self._function_execute_args) if len(outputs) == 1: return outputs[0] return outputs @@ -1279,7 +1287,7 @@ def test_on_batch(self, x, y, sample_weight=None): else: ins = x + y + sample_weights self._make_test_function() - outputs = self.test_function(ins) + outputs = self.test_function(ins, **self._function_execute_args) if len(outputs) == 1: return outputs[0] return outputs @@ -1294,7 +1302,7 @@ def predict_on_batch(self, x): else: ins = x self._make_predict_function() - outputs = self.predict_function(ins) + outputs = self.predict_function(ins, **self._function_execute_args) if len(outputs) == 1: return outputs[0] return outputs From 6afbf75bcb8723ee2f6c5fa42908bca22ca4d722 Mon Sep 17 00:00:00 2001 From: Ken Chatfield Date: Fri, 18 Nov 2016 15:00:06 +0000 Subject: [PATCH 03/27] Make Tensorboard histogram visualisation work for model duplicates * No longer fails for submodels with multiple inputs/outputs * Includes weight histograms from submodels by constructing layer list recursively --- keras/callbacks.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/keras/callbacks.py b/keras/callbacks.py index b44236b4f1e5..f0e14a7594c8 100644 --- a/keras/callbacks.py +++ b/keras/callbacks.py @@ -491,7 +491,16 @@ def _set_model(self, model): self.model = model self.sess = KTF.get_session() if self.histogram_freq and self.merged is None: - for layer in self.model.layers: + def get_layers_flattened(model_layers): + layers = [] + for layer in model_layers: + if layer.__class__.__name__ == 'Model': + layers.extend(get_layers_flattened(layer.layers)) + else: + layers.append(layer) + return layers + layers = get_layers_flattened(self.model.layers) + for layer in layers: for weight in layer.weights: tf.histogram_summary(weight.name, weight) @@ -510,9 +519,14 @@ def _set_model(self, model): tf.image_summary(weight.name, w_img) - if hasattr(layer, 'output'): - tf.histogram_summary('{}_out'.format(layer.name), - layer.output) + if layer in self.model.layers: + try: + if hasattr(layer, 'output'): + tf.histogram_summary('{}_out'.format(layer.name), + layer.output) + except AttributeError: + pass + self.merged = tf.merge_all_summaries() if self.write_graph: if parse_version(tf.__version__) >= parse_version('0.8.0'): From ee032ef8ecf174d5bf0d92878008990e545e7afa Mon Sep 17 00:00:00 2001 From: Ken Chatfield Date: Fri, 18 Nov 2016 15:26:41 +0000 Subject: [PATCH 04/27] Fix to validation_gen set condition --- keras/callbacks.py | 4 ++-- keras/engine/training.py | 39 +++++++++++++++++++++++---------------- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/keras/callbacks.py b/keras/callbacks.py index 0772393cd2bc..755406c9f8a9 100644 --- a/keras/callbacks.py +++ b/keras/callbacks.py @@ -544,8 +544,8 @@ def get_val_summary(validation_data): summary_str = get_val_summary(self.model.validation_data) self.writer.add_summary(summary_str, epoch) elif self.model.validation_gen: - val_gen = self.model.validation_gen['generator'] - nb_val_samples = self.model.validation_gen['nb_samples'] + val_gen = self.model.validation_gen.generator + nb_val_samples = self.model.validation_gen.nb_samples # process nb_samples from validation data generator sub_summaries = [] processed_samples = 0 diff --git a/keras/engine/training.py b/keras/engine/training.py index 938ba0bd0839..7843778548f4 100644 --- a/keras/engine/training.py +++ b/keras/engine/training.py @@ -7,6 +7,7 @@ import numpy as np import multiprocessing import threading +from collections import namedtuple import six @@ -1401,25 +1402,31 @@ def generate_arrays_from_file(path): }) callbacks.on_train_begin() - if do_validation and not val_gen: - if len(validation_data) == 2: - val_x, val_y = validation_data - val_sample_weight = None - elif len(validation_data) == 3: - val_x, val_y, val_sample_weight = validation_data + if do_validation: + if not val_gen: + if len(validation_data) == 2: + val_x, val_y = validation_data + val_sample_weight = None + elif len(validation_data) == 3: + val_x, val_y, val_sample_weight = validation_data + else: + raise Exception('validation_data should be a tuple ' + '(val_x, val_y, val_sample_weight) ' + 'or (val_x, val_y). Found: ' + str(validation_data)) + val_x, val_y, val_sample_weights = self._standardize_user_data(val_x, val_y, val_sample_weight) + self.validation_data = val_x + [val_y, val_sample_weights] + self.validation_gen = None else: - raise Exception('validation_data should be a tuple ' - '(val_x, val_y, val_sample_weight) ' - 'or (val_x, val_y). Found: ' + str(validation_data)) - val_x, val_y, val_sample_weights = self._standardize_user_data(val_x, val_y, val_sample_weight) - self.validation_data = val_x + [val_y, val_sample_weights] - self.validation_gen = None + ValidationGen = namedtuple('ValidationGen', ['generator', 'nb_samples']) + + self.validation_data = None + self.validation_gen = ValidationGen( + generator=validation_data, + nb_samples=nb_val_samples + ) else: self.validation_data = None - self.validation_gen = dict( - generator=validation_data, - nb_samples=nb_val_samples - ) + self.validation_gen = None # start generator thread storing batches into a queue data_gen_queue, _stop, generator_threads = generator_queue(generator, max_q_size=max_q_size, nb_worker=nb_worker, From 7572cbebf385ee2f0ec43131b267758a73ed80fb Mon Sep 17 00:00:00 2001 From: Ken Chatfield Date: Fri, 18 Nov 2016 17:14:51 +0000 Subject: [PATCH 05/27] Added missing validation_gen property to Sequential model --- keras/models.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/keras/models.py b/keras/models.py index 786bc1d94d07..015f1b2d35ff 100644 --- a/keras/models.py +++ b/keras/models.py @@ -517,6 +517,10 @@ def set_weights(self, weights): def validation_data(self): return self.model.validation_data + @property + def validation_gen(self): + return self.model.validation_gen + @property def training_data(self): return self.model.training_data From 46fcd0a23e6afbefdc07e1ceb16af01d58d257bc Mon Sep 17 00:00:00 2001 From: Ken Chatfield Date: Tue, 22 Nov 2016 18:32:34 +0000 Subject: [PATCH 06/27] Added on_val_start callback and batch/val timings to progbar --- keras/callbacks.py | 33 +++++++++++++++++++++++++++++++++ keras/engine/training.py | 4 ++++ keras/utils/generic_utils.py | 3 +++ 3 files changed, 40 insertions(+) diff --git a/keras/callbacks.py b/keras/callbacks.py index b44236b4f1e5..f06969cd6e51 100644 --- a/keras/callbacks.py +++ b/keras/callbacks.py @@ -68,6 +68,14 @@ def on_batch_end(self, batch, logs={}): 'to the batch update (%f). Check your callbacks.' % delta_t_median) + def on_val_begin(self, epoch, logs={}): + for callback in self.callbacks: + callback.on_val_begin(epoch, logs) + + def on_val_end(self, epoch, logs={}): + for callback in self.callbacks: + callback.on_val_end(epoch, logs) + def on_train_begin(self, logs={}): for callback in self.callbacks: callback.on_train_begin(logs) @@ -124,6 +132,12 @@ def on_batch_begin(self, batch, logs={}): def on_batch_end(self, batch, logs={}): pass + def on_val_begin(self, epoch, logs={}): + pass + + def on_val_end(self, epoch, logs={}): + pass + def on_train_begin(self, logs={}): pass @@ -172,6 +186,9 @@ def on_epoch_begin(self, epoch, logs={}): self.progbar = Progbar(target=self.params['nb_sample'], verbose=self.verbose) self.seen = 0 + self.epoch_t = None + self.batch_t = None + self.val_t = None def on_batch_begin(self, batch, logs={}): if self.seen < self.params['nb_sample']: @@ -190,10 +207,26 @@ def on_batch_end(self, batch, logs={}): if self.verbose and self.seen < self.params['nb_sample']: self.progbar.update(self.seen, self.log_values) + def on_val_begin(self, epoch, logs={}): + self.epoch_t = self.progbar.elapsed_time() + self.batch_t = self.epoch_t / self.params['nb_sample'] + self.val_start_t = time.time() + + def on_val_end(self, epoch, logs={}): + self.val_t = time.time() - self.val_start_t + def on_epoch_end(self, epoch, logs={}): for k in self.params['metrics']: if k in logs: self.log_values.append((k, logs[k])) + if not self.epoch_t: + self.epoch_t = self.progbar.elapsed_time() + self.batch_t = self.epoch_t / self.params['nb_sample'] + # add timings + self.log_values.append(('epoch_t', self.epoch_t)) + self.log_values.append(('batch_t', self.batch_t)) + if self.val_t: + self.log_values.append(('val_t', self.val_t)) if self.verbose: self.progbar.update(self.seen, self.log_values, force=True) diff --git a/keras/engine/training.py b/keras/engine/training.py index 1458ccb4f384..789a050079e0 100644 --- a/keras/engine/training.py +++ b/keras/engine/training.py @@ -858,6 +858,7 @@ def _fit_loop(self, f, ins, out_labels=[], batch_size=32, if batch_index == len(batches) - 1: # last batch # validation if do_validation: + callbacks.on_val_begin(epoch) # replace with self._evaluate val_outs = self._test_loop(val_f, val_ins, batch_size=batch_size, @@ -867,6 +868,7 @@ def _fit_loop(self, f, ins, out_labels=[], batch_size=32, # same labels assumed for l, o in zip(out_labels, val_outs): epoch_logs['val_' + l] = o + callbacks.on_val_end(epoch, epoch_logs) callbacks.on_epoch_end(epoch, epoch_logs) if callback_model.stop_training: break @@ -1497,6 +1499,7 @@ def generate_arrays_from_file(path): 'Set `samples_per_epoch` correctly ' 'to avoid this warning.') if samples_seen >= samples_per_epoch and do_validation: + callbacks.on_val_begin(epoch) if val_gen: val_outs = self.evaluate_generator(validation_data, nb_val_samples, @@ -1515,6 +1518,7 @@ def generate_arrays_from_file(path): # same labels assumed for l, o in zip(out_labels, val_outs): epoch_logs['val_' + l] = o + callbacks.on_val_end(epoch, epoch_logs) callbacks.on_epoch_end(epoch, epoch_logs) epoch += 1 diff --git a/keras/utils/generic_utils.py b/keras/utils/generic_utils.py index d6eab4729c95..ad8bb104c4b2 100644 --- a/keras/utils/generic_utils.py +++ b/keras/utils/generic_utils.py @@ -178,6 +178,9 @@ def update(self, current, values=[], force=False): def add(self, n, values=[]): self.update(self.seen_so_far + n, values) + def elapsed_time(self): + return time.time() - self.start + def display_table(rows, positions): From d1aa19037413a2599670b0f472343272f07d4bbf Mon Sep 17 00:00:00 2001 From: Ken Chatfield Date: Sun, 30 Apr 2017 19:51:26 +0100 Subject: [PATCH 07/27] Added learning rate multiplier support * Based on https://github.com/fchollet/keras/pull/3004 --- keras/engine/topology.py | 23 ++- keras/engine/training.py | 5 +- keras/layers/convolutional.py | 96 ++++++++++- keras/layers/core.py | 75 ++++++++- keras/optimizers.py | 75 ++++++--- tests/keras/test_learning_rate_multipliers.py | 159 ++++++++++++++++++ 6 files changed, 395 insertions(+), 38 deletions(-) create mode 100644 tests/keras/test_learning_rate_multipliers.py diff --git a/keras/engine/topology.py b/keras/engine/topology.py index 79fb85400b9e..91d9b3c2819d 100644 --- a/keras/engine/topology.py +++ b/keras/engine/topology.py @@ -252,6 +252,7 @@ class Layer(object): weights: The concatenation of the lists trainable_weights and non_trainable_weights (in this order). constraints: Dict mapping weights to constraints. + multipliers: dict mapping weights to learning rates multipliers. # Methods call(x, mask=None): Where the layer's logic lives. @@ -310,6 +311,8 @@ def __init__(self, **kwargs): self.losses = [] if not hasattr(self, 'constraints'): self.constraints = {} # dict {tensor: constraint instance} + if not hasattr(self, 'multipliers'): + self.multipliers = {} # dict {tensor: multiplier value} self.built = False # These properties should be set by the user via keyword arguments. @@ -403,7 +406,8 @@ def create_input_layer(self, batch_input_shape, def add_weight(self, shape, initializer, name=None, trainable=True, regularizer=None, - constraint=None): + constraint=None, + multiplier=None): """Adds a weight variable to the layer. # Arguments @@ -420,6 +424,8 @@ def add_weight(self, shape, initializer, name=None, self.add_loss(regularizer(weight)) if constraint is not None: self.constraints[weight] = constraint + if multiplier is not None: + self.multipliers[weight] = multiplier if trainable: self._trainable_weights.append(weight) else: @@ -1068,6 +1074,7 @@ def __init__(self, input_shape=None, batch_input_shape=None, self.inbound_nodes = [] self.outbound_nodes = [] self.constraints = {} + self.multipliers = {} self.sparse = sparse if not name: @@ -1275,6 +1282,7 @@ def __init__(self, layers=None, mode='sum', concat_axis=-1, self.inbound_nodes = [] self.outbound_nodes = [] self.constraints = {} + self.multipliers = {} self._trainable_weights = [] self._non_trainable_weights = [] self.supports_masking = True @@ -1715,6 +1723,7 @@ class Container(Layer): trainable_weights (list of variables) non_trainable_weights (list of variables) constraints (list of tuples (weight, constraint)) + multipliers (list of tuples (weight, learning_rate_multiplier)) # Methods summary @@ -2031,6 +2040,7 @@ def build_map_of_graph(tensor, seen_nodes=set(), depth=0, self.supports_masking = False # The following are implemented as property functions: # self.constraints + # self.multipliers # self.trainable_weights # self.non_trainable_weights # self.input_spec @@ -2141,6 +2151,17 @@ def constraints(self): cons[key] = value return cons + @property + def multipliers(self): + mults = {} + for layer in self.layers: + for key, value in layer.multipliers.items(): + if key in mults: + raise Exception('Received multiple learning rate multipliers ' + 'for one weight tensor: ' + str(key)) + mults[key] = value + return mults + @property def regularizers(self): warnings.warn('The `regularizers` attribute of layers/models ' diff --git a/keras/engine/training.py b/keras/engine/training.py index a6991f51d41e..9ea12a2aaab4 100644 --- a/keras/engine/training.py +++ b/keras/engine/training.py @@ -765,7 +765,7 @@ def _make_train_function(self): inputs = self.inputs + self.targets + self.sample_weights training_updates = self.optimizer.get_updates(self._collected_trainable_weights, - self.constraints, + self.multipliers, self.constraints, self.total_loss) updates = self.updates + training_updates @@ -954,7 +954,8 @@ def _predict_loop(self, f, ins, batch_size=32, verbose=0): else: ins_batch = slice_X(ins, batch_ids) - batch_outs = f(ins_batch, **self._function_execute_args) + execute_kwargs = getattr(self, '_function_execute_args', {}) + batch_outs = f(ins_batch, **execute_kwargs) if not isinstance(batch_outs, list): batch_outs = [batch_outs] if batch_index == 0: diff --git a/keras/layers/convolutional.py b/keras/layers/convolutional.py index c9f907a5a585..eabea7b081cc 100644 --- a/keras/layers/convolutional.py +++ b/keras/layers/convolutional.py @@ -72,6 +72,10 @@ class Convolution1D(Layer): (eg. maxnorm, nonneg), applied to the main weights matrix. b_constraint: instance of the [constraints](../constraints.md) module, applied to the bias. + W_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the main weights matrix. + b_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the bias. bias: whether to include a bias (i.e. make the layer affine rather than linear). input_dim: Number of channels/dimensions in the input. @@ -96,6 +100,7 @@ def __init__(self, nb_filter, filter_length, W_regularizer=None, b_regularizer=None, activity_regularizer=None, W_constraint=None, b_constraint=None, + W_learning_rate_multiplier=None, b_learning_rate_multiplier=None, bias=True, input_dim=None, input_length=None, **kwargs): if border_mode not in {'valid', 'same', 'full'}: @@ -116,6 +121,11 @@ def __init__(self, nb_filter, filter_length, self.W_constraint = constraints.get(W_constraint) self.b_constraint = constraints.get(b_constraint) + if not bias and b_learning_rate_multiplier is not None: + raise ValueError('b_learning_rate_multiplier provided with no bias.') + self.W_learning_rate_multiplier = W_learning_rate_multiplier + self.b_learning_rate_multiplier = b_learning_rate_multiplier + self.bias = bias self.input_spec = [InputSpec(ndim=3)] self.initial_weights = weights @@ -134,13 +144,15 @@ def build(self, input_shape): dim_ordering='th'), name='{}_W'.format(self.name), regularizer=self.W_regularizer, - constraint=self.W_constraint) + constraint=self.W_constraint, + multiplier=self.W_learning_rate_multiplier) if self.bias: self.b = self.add_weight((self.nb_filter,), initializer='zero', name='{}_b'.format(self.name), regularizer=self.b_regularizer, - constraint=self.b_constraint) + constraint=self.b_constraint, + multiplier=self.b_learning_rate_multiplier) else: self.b = None @@ -179,6 +191,8 @@ def get_config(self): 'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None, 'W_constraint': self.W_constraint.get_config() if self.W_constraint else None, 'b_constraint': self.b_constraint.get_config() if self.b_constraint else None, + 'W_learning_rate_multiplier': self.W_learning_rate_multiplier, + 'b_learning_rate_multiplier': self.b_learning_rate_multiplier, 'bias': self.bias, 'input_dim': self.input_dim, 'input_length': self.input_length} @@ -244,6 +258,10 @@ class AtrousConvolution1D(Convolution1D): (eg. maxnorm, nonneg), applied to the main weights matrix. b_constraint: instance of the [constraints](../constraints.md) module, applied to the bias. + W_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the main weights matrix. + b_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the bias. bias: whether to include a bias (i.e. make the layer affine rather than linear). input_dim: Number of channels/dimensions in the input. @@ -268,6 +286,7 @@ def __init__(self, nb_filter, filter_length, W_regularizer=None, b_regularizer=None, activity_regularizer=None, W_constraint=None, b_constraint=None, + W_learning_rate_multiplier=None, b_learning_rate_multiplier=None, bias=True, **kwargs): if border_mode not in {'valid', 'same', 'full'}: @@ -283,6 +302,7 @@ def __init__(self, nb_filter, filter_length, W_regularizer=W_regularizer, b_regularizer=b_regularizer, activity_regularizer=activity_regularizer, W_constraint=W_constraint, b_constraint=b_constraint, + W_learning_rate_multiplier=W_learning_rate_multiplier, b_learning_rate_multiplier=b_learning_rate_multiplier, bias=bias, **kwargs) def get_output_shape_for(self, input_shape): @@ -363,6 +383,10 @@ class Convolution2D(Layer): (eg. maxnorm, nonneg), applied to the main weights matrix. b_constraint: instance of the [constraints](../constraints.md) module, applied to the bias. + W_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the main weights matrix. + b_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the bias. dim_ordering: 'th' or 'tf'. In 'th' mode, the channels dimension (the depth) is at index 1, in 'tf' mode is it at index 3. It defaults to the `image_dim_ordering` value found in your @@ -391,6 +415,7 @@ def __init__(self, nb_filter, nb_row, nb_col, W_regularizer=None, b_regularizer=None, activity_regularizer=None, W_constraint=None, b_constraint=None, + W_learning_rate_multiplier=None, b_learning_rate_multiplier=None, bias=True, **kwargs): if dim_ordering == 'default': dim_ordering = K.image_dim_ordering() @@ -414,6 +439,11 @@ def __init__(self, nb_filter, nb_row, nb_col, self.W_constraint = constraints.get(W_constraint) self.b_constraint = constraints.get(b_constraint) + if not bias and b_learning_rate_multiplier is not None: + raise ValueError('b_learning_rate_multiplier provided with no bias.') + self.W_learning_rate_multiplier = W_learning_rate_multiplier + self.b_learning_rate_multiplier = b_learning_rate_multiplier + self.bias = bias self.input_spec = [InputSpec(ndim=4)] self.initial_weights = weights @@ -433,13 +463,15 @@ def build(self, input_shape): dim_ordering=self.dim_ordering), name='{}_W'.format(self.name), regularizer=self.W_regularizer, - constraint=self.W_constraint) + constraint=self.W_constraint, + multiplier=self.W_learning_rate_multiplier) if self.bias: self.b = self.add_weight((self.nb_filter,), initializer='zero', name='{}_b'.format(self.name), regularizer=self.b_regularizer, - constraint=self.b_constraint) + constraint=self.b_constraint, + multiplier=self.b_learning_rate_multiplier) else: self.b = None @@ -497,6 +529,8 @@ def get_config(self): 'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None, 'W_constraint': self.W_constraint.get_config() if self.W_constraint else None, 'b_constraint': self.b_constraint.get_config() if self.b_constraint else None, + 'W_learning_rate_multiplier': self.W_learning_rate_multiplier, + 'b_learning_rate_multiplier': self.b_learning_rate_multiplier, 'bias': self.bias} base_config = super(Convolution2D, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -595,6 +629,10 @@ class Deconvolution2D(Convolution2D): (eg. maxnorm, nonneg), applied to the main weights matrix. b_constraint: instance of the [constraints](../constraints.md) module, applied to the bias. + W_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the main weights matrix. + b_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the bias. dim_ordering: 'th' or 'tf'. In 'th' mode, the channels dimension (the depth) is at index 1, in 'tf' mode is it at index 3. It defaults to the `image_dim_ordering` value found in your @@ -628,6 +666,7 @@ def __init__(self, nb_filter, nb_row, nb_col, output_shape, dim_ordering='default', W_regularizer=None, b_regularizer=None, activity_regularizer=None, W_constraint=None, b_constraint=None, + W_learning_rate_multiplier=None, b_learning_rate_multiplier=None, bias=True, **kwargs): if dim_ordering == 'default': dim_ordering = K.image_dim_ordering() @@ -648,6 +687,8 @@ def __init__(self, nb_filter, nb_row, nb_col, output_shape, activity_regularizer=activity_regularizer, W_constraint=W_constraint, b_constraint=b_constraint, + W_learning_rate_multiplier=W_learning_rate_multiplier, + b_learning_rate_multiplier=b_learning_rate_multiplier, bias=bias, **kwargs) @@ -742,6 +783,10 @@ class AtrousConvolution2D(Convolution2D): (eg. maxnorm, nonneg), applied to the main weights matrix. b_constraint: instance of the [constraints](../constraints.md) module, applied to the bias. + W_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the main weights matrix. + b_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the bias. dim_ordering: 'th' or 'tf'. In 'th' mode, the channels dimension (the depth) is at index 1, in 'tf' mode is it at index 3. It defaults to the `image_dim_ordering` value found in your @@ -774,6 +819,7 @@ def __init__(self, nb_filter, nb_row, nb_col, W_regularizer=None, b_regularizer=None, activity_regularizer=None, W_constraint=None, b_constraint=None, + W_learning_rate_multiplier=None, b_learning_rate_multiplier=None, bias=True, **kwargs): if dim_ordering == 'default': dim_ordering = K.image_dim_ordering() @@ -795,6 +841,8 @@ def __init__(self, nb_filter, nb_row, nb_col, activity_regularizer=activity_regularizer, W_constraint=W_constraint, b_constraint=b_constraint, + W_learning_rate_multiplier=W_learning_rate_multiplier, + b_learning_rate_multiplier=b_learning_rate_multiplier, bias=bias, **kwargs) @@ -900,6 +948,12 @@ class SeparableConvolution2D(Layer): (eg. maxnorm, nonneg), applied to the pointwise weights matrix. b_constraint: instance of the [constraints](../constraints.md) module, applied to the bias. + depthwise_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the depthwise weights matrix. + pointwise_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the pointwise weights matrix. + b_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the bias. dim_ordering: 'th' or 'tf'. In 'th' mode, the channels dimension (the depth) is at index 1, in 'tf' mode is it at index 3. It defaults to the `image_dim_ordering` value found in your @@ -930,6 +984,8 @@ def __init__(self, nb_filter, nb_row, nb_col, b_regularizer=None, activity_regularizer=None, depthwise_constraint=None, pointwise_constraint=None, b_constraint=None, + depthwise_learning_rate_multiplier=None, pointwise_learning_rate_multiplier=None, + b_learning_rate_multiplier=None, bias=True, **kwargs): if K.backend() != 'tensorflow': @@ -967,6 +1023,12 @@ def __init__(self, nb_filter, nb_row, nb_col, self.pointwise_constraint = constraints.get(pointwise_constraint) self.b_constraint = constraints.get(b_constraint) + if not bias and b_learning_rate_multiplier is not None: + raise ValueError('b_learning_rate_multiplier provided with no bias.') + self.depthwise_learning_rate_multiplier = depthwise_learning_rate_multiplier + self.pointwise_learning_rate_multiplier = pointwise_learning_rate_multiplier + self.b_learning_rate_multiplier = b_learning_rate_multiplier + self.bias = bias self.input_spec = [InputSpec(ndim=4)] self.initial_weights = weights @@ -989,19 +1051,22 @@ def build(self, input_shape): dim_ordering=self.dim_ordering), regularizer=self.depthwise_regularizer, constraint=self.depthwise_constraint, + multiplier=self.depthwise_learning_rate_multiplier, name='{}_depthwise_kernel'.format(self.name)) self.pointwise_kernel = self.add_weight(pointwise_shape, initializer=functools.partial(self.init, dim_ordering=self.dim_ordering), regularizer=self.pointwise_regularizer, constraint=self.pointwise_constraint, + multiplier=self.pointwise_learning_rate_multiplier, name='{}_pointwise_kernel'.format(self.name)) if self.bias: self.b = self.add_weight((self.nb_filter,), initializer='zero', name='{}_b'.format(self.name), regularizer=self.b_regularizer, - constraint=self.b_constraint) + constraint=self.b_constraint, + multiplier=self.b_learning_rate_multiplier) else: self.b = None @@ -1065,6 +1130,9 @@ def get_config(self): 'depthwise_constraint': self.depthwise_constraint.get_config() if self.depthwise_constraint else None, 'pointwise_constraint': self.pointwise_constraint.get_config() if self.pointwise_constraint else None, 'b_constraint': self.b_constraint.get_config() if self.b_constraint else None, + 'depthwise_learning_rate_multiplier': self.depthwise_learning_rate_multiplier, + 'pointwise_learning_rate_multiplier': self.pointwise_learning_rate_multiplier, + 'b_learning_rate_multiplier': self.b_learning_rate_multiplier, 'bias': self.bias} base_config = super(SeparableConvolution2D, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -1110,6 +1178,10 @@ class Convolution3D(Layer): (eg. maxnorm, nonneg), applied to the main weights matrix. b_constraint: instance of the [constraints](../constraints.md) module, applied to the bias. + W_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the main weights matrix. + b_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the bias. dim_ordering: 'th' or 'tf'. In 'th' mode, the channels dimension (the depth) is at index 1, in 'tf' mode is it at index 4. It defaults to the `image_dim_ordering` value found in your @@ -1137,6 +1209,7 @@ def __init__(self, nb_filter, kernel_dim1, kernel_dim2, kernel_dim3, border_mode='valid', subsample=(1, 1, 1), dim_ordering='default', W_regularizer=None, b_regularizer=None, activity_regularizer=None, W_constraint=None, b_constraint=None, + W_learning_rate_multiplier=None, b_learning_rate_multiplier=None, bias=True, **kwargs): if dim_ordering == 'default': dim_ordering = K.image_dim_ordering() @@ -1162,6 +1235,11 @@ def __init__(self, nb_filter, kernel_dim1, kernel_dim2, kernel_dim3, self.W_constraint = constraints.get(W_constraint) self.b_constraint = constraints.get(b_constraint) + if not bias and b_learning_rate_multiplier is not None: + raise ValueError('b_learning_rate_multiplier provided with no bias.') + self.W_learning_rate_multiplier = W_learning_rate_multiplier + self.b_learning_rate_multiplier = b_learning_rate_multiplier + self.bias = bias self.input_spec = [InputSpec(ndim=5)] self.initial_weights = weights @@ -1186,13 +1264,15 @@ def build(self, input_shape): dim_ordering=self.dim_ordering), name='{}_W'.format(self.name), regularizer=self.W_regularizer, - constraint=self.W_constraint) + constraint=self.W_constraint, + multiplier=self.W_learning_rate_multiplier) if self.bias: self.b = self.add_weight((self.nb_filter,), initializer='zero', name='{}_b'.format(self.name), regularizer=self.b_regularizer, - constraint=self.b_constraint) + constraint=self.b_constraint, + multiplier=self.b_learning_rate_multiplier) else: self.b = None @@ -1257,6 +1337,8 @@ def get_config(self): 'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None, 'W_constraint': self.W_constraint.get_config() if self.W_constraint else None, 'b_constraint': self.b_constraint.get_config() if self.b_constraint else None, + 'W_learning_rate_multiplier': self.W_learning_rate_multiplier, + 'b_learning_rate_multiplier': self.b_learning_rate_multiplier, 'bias': self.bias} base_config = super(Convolution3D, self).get_config() return dict(list(base_config.items()) + list(config.items())) diff --git a/keras/layers/core.py b/keras/layers/core.py index 21157289def4..135618c3244d 100644 --- a/keras/layers/core.py +++ b/keras/layers/core.py @@ -742,6 +742,10 @@ class Dense(Layer): (eg. maxnorm, nonneg), applied to the main weights matrix. b_constraint: instance of the [constraints](../constraints.md) module, applied to the bias. + W_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the main weights matrix. + b_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the bias. bias: whether to include a bias (i.e. make the layer affine rather than linear). input_dim: dimensionality of the input (integer). This argument @@ -763,6 +767,7 @@ def __init__(self, output_dim, init='glorot_uniform', activation=None, weights=None, W_regularizer=None, b_regularizer=None, activity_regularizer=None, W_constraint=None, b_constraint=None, + W_learning_rate_multiplier=None, b_learning_rate_multiplier=None, bias=True, input_dim=None, **kwargs): self.init = initializations.get(init) self.activation = activations.get(activation) @@ -776,6 +781,11 @@ def __init__(self, output_dim, init='glorot_uniform', self.W_constraint = constraints.get(W_constraint) self.b_constraint = constraints.get(b_constraint) + if not bias and b_learning_rate_multiplier is not None: + raise ValueError('b_learning_rate_multiplier provided with no bias.') + self.W_learning_rate_multiplier = W_learning_rate_multiplier + self.b_learning_rate_multiplier = b_learning_rate_multiplier + self.bias = bias self.initial_weights = weights self.input_spec = [InputSpec(ndim='2+')] @@ -795,13 +805,15 @@ def build(self, input_shape): initializer=self.init, name='{}_W'.format(self.name), regularizer=self.W_regularizer, - constraint=self.W_constraint) + constraint=self.W_constraint, + multiplier=self.W_learning_rate_multiplier) if self.bias: self.b = self.add_weight((self.output_dim,), initializer='zero', name='{}_b'.format(self.name), regularizer=self.b_regularizer, - constraint=self.b_constraint) + constraint=self.b_constraint, + multiplier=self.b_learning_rate_multiplier) else: self.b = None @@ -832,6 +844,8 @@ def get_config(self): 'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None, 'W_constraint': self.W_constraint.get_config() if self.W_constraint else None, 'b_constraint': self.b_constraint.get_config() if self.b_constraint else None, + 'W_learning_rate_multiplier': self.W_learning_rate_multiplier, + 'b_learning_rate_multiplier': self.b_learning_rate_multiplier, 'bias': self.bias, 'input_dim': self.input_dim} base_config = super(Dense, self).get_config() @@ -904,6 +918,10 @@ class MaxoutDense(Layer): (eg. maxnorm, nonneg), applied to the main weights matrix. b_constraint: instance of the [constraints](../constraints.md) module, applied to the bias. + W_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the main weights matrix. + b_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the bias. bias: whether to include a bias (i.e. make the layer affine rather than linear). input_dim: dimensionality of the input (integer). This argument @@ -929,6 +947,8 @@ def __init__(self, output_dim, activity_regularizer=None, W_constraint=None, b_constraint=None, + W_learning_rate_multiplier=None, + b_learning_rate_multiplier=None, bias=True, input_dim=None, **kwargs): @@ -943,6 +963,11 @@ def __init__(self, output_dim, self.W_constraint = constraints.get(W_constraint) self.b_constraint = constraints.get(b_constraint) + if not bias and b_learning_rate_multiplier is not None: + raise ValueError('b_learning_rate_multiplier provided with no bias.') + self.W_learning_rate_multiplier = W_learning_rate_multiplier + self.b_learning_rate_multiplier = b_learning_rate_multiplier + self.bias = bias self.initial_weights = weights self.input_spec = [InputSpec(ndim=2)] @@ -961,13 +986,15 @@ def build(self, input_shape): initializer=self.init, name='{}_W'.format(self.name), regularizer=self.W_regularizer, - constraint=self.W_constraint) + constraint=self.W_constraint, + multiplier=self.W_learning_rate_multiplier) if self.bias: self.b = self.add_weight((self.nb_feature, self.output_dim,), initializer='zero', name='{}_b'.format(self.name), regularizer=self.b_regularizer, - constraint=self.b_constraint) + constraint=self.b_constraint, + multiplier=self.b_learning_rate_multiplier) else: self.b = None @@ -997,6 +1024,8 @@ def get_config(self): 'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None, 'W_constraint': self.W_constraint.get_config() if self.W_constraint else None, 'b_constraint': self.b_constraint.get_config() if self.b_constraint else None, + 'W_learning_rate_multiplier': self.W_learning_rate_multiplier, + 'b_learning_rate_multiplier': self.b_learning_rate_multiplier, 'bias': self.bias, 'input_dim': self.input_dim} base_config = super(MaxoutDense, self).get_config() @@ -1032,6 +1061,10 @@ class Highway(Layer): (eg. maxnorm, nonneg), applied to the main weights matrix. b_constraint: instance of the [constraints](../constraints.md) module, applied to the bias. + W_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the main weights matrix. + b_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the bias. bias: whether to include a bias (i.e. make the layer affine rather than linear). input_dim: dimensionality of the input (integer). This argument @@ -1057,6 +1090,8 @@ def __init__(self, activity_regularizer=None, W_constraint=None, b_constraint=None, + W_learning_rate_multiplier=None, + b_learning_rate_multiplier=None, bias=True, input_dim=None, **kwargs): @@ -1074,6 +1109,11 @@ def __init__(self, self.W_constraint = constraints.get(W_constraint) self.b_constraint = constraints.get(b_constraint) + if not bias and b_learning_rate_multiplier is not None: + raise ValueError('b_learning_rate_multiplier provided with no bias.') + self.W_learning_rate_multiplier = W_learning_rate_multiplier + self.b_learning_rate_multiplier = b_learning_rate_multiplier + self.bias = bias self.initial_weights = weights self.input_spec = [InputSpec(ndim=2)] @@ -1092,7 +1132,8 @@ def build(self, input_shape): initializer=self.init, name='{}_W'.format(self.name), regularizer=self.W_regularizer, - constraint=self.W_constraint) + constraint=self.W_constraint, + multiplier=self.W_learning_rate_multiplier) self.W_carry = self.add_weight((input_dim, input_dim), initializer=self.init, name='{}_W_carry'.format(self.name)) @@ -1101,7 +1142,8 @@ def build(self, input_shape): initializer='zero', name='{}_b'.format(self.name), regularizer=self.b_regularizer, - constraint=self.b_constraint) + constraint=self.b_constraint, + multiplier=self.b_learning_rate_multiplier) self.b_carry = self.add_weight((input_dim,), initializer='one', name='{}_b_carry'.format(self.name)) @@ -1134,6 +1176,8 @@ def get_config(self): 'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None, 'W_constraint': self.W_constraint.get_config() if self.W_constraint else None, 'b_constraint': self.b_constraint.get_config() if self.b_constraint else None, + 'W_learning_rate_multiplier': self.W_learning_rate_multiplier, + 'b_learning_rate_multiplier': self.b_learning_rate_multiplier, 'bias': self.bias, 'input_dim': self.input_dim} base_config = super(Highway, self).get_config() @@ -1181,6 +1225,10 @@ class TimeDistributedDense(Layer): (eg. maxnorm, nonneg), applied to the main weights matrix. b_constraint: instance of the [constraints](../constraints.md) module, applied to the bias. + W_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the main weights matrix. + b_learning_rate_multiplier: Multiplier (between 0.0 and 1.0) applied to the + learning rate of the bias. bias: whether to include a bias (i.e. make the layer affine rather than linear). input_dim: dimensionality of the input (integer). This argument @@ -1199,6 +1247,8 @@ def __init__(self, output_dim, activity_regularizer=None, W_constraint=None, b_constraint=None, + W_learning_rate_multiplier=None, + b_learning_rate_multiplier=None, bias=True, input_dim=None, input_length=None, @@ -1217,6 +1267,11 @@ def __init__(self, output_dim, self.W_constraint = constraints.get(W_constraint) self.b_constraint = constraints.get(b_constraint) + if not bias and b_learning_rate_multiplier is not None: + raise ValueError('b_learning_rate_multiplier provided with no bias.') + self.W_learning_rate_multiplier = W_learning_rate_multiplier + self.b_learning_rate_multiplier = b_learning_rate_multiplier + self.bias = bias self.initial_weights = weights self.input_spec = [InputSpec(ndim=3)] @@ -1237,13 +1292,15 @@ def build(self, input_shape): initializer=self.init, name='{}_W'.format(self.name), regularizer=self.W_regularizer, - constraint=self.W_constraint) + constraint=self.W_constraint, + multiplier=self.W_learning_rate_multiplier) if self.bias: self.b = self.add_weight((self.output_dim,), initializer='zero', name='{}_b'.format(self.name), regularizer=self.b_regularizer, - constraint=self.b_constraint) + constraint=self.b_constraint, + multiplier=self.b_learning_rate_multiplier) else: self.b = None @@ -1292,6 +1349,8 @@ def get_config(self): 'activity_regularizer': self.activity_regularizer.get_config() if self.activity_regularizer else None, 'W_constraint': self.W_constraint.get_config() if self.W_constraint else None, 'b_constraint': self.b_constraint.get_config() if self.b_constraint else None, + 'W_learning_rate_multiplier': self.W_learning_rate_multiplier, + 'b_learning_rate_multiplier': self.b_learning_rate_multiplier, 'bias': self.bias, 'input_dim': self.input_dim, 'input_length': self.input_length} diff --git a/keras/optimizers.py b/keras/optimizers.py index 81a2b529a604..a3935a8871ca 100644 --- a/keras/optimizers.py +++ b/keras/optimizers.py @@ -75,7 +75,7 @@ def __init__(self, **kwargs): self.updates = [] self.weights = [] - def get_updates(self, params, constraints, loss): + def get_updates(self, params, multipliers, constraints, loss): raise NotImplementedError def get_gradients(self, loss, params): @@ -159,7 +159,7 @@ def __init__(self, lr=0.01, momentum=0., decay=0., self.initial_decay = decay self.nesterov = nesterov - def get_updates(self, params, constraints, loss): + def get_updates(self, params, multipliers, constraints, loss): grads = self.get_gradients(loss, params) self.updates = [] @@ -173,11 +173,16 @@ def get_updates(self, params, constraints, loss): moments = [K.zeros(shape) for shape in shapes] self.weights = [self.iterations] + moments for p, g, m in zip(params, grads, moments): - v = self.momentum * m - lr * g # velocity + # Apply learning rate multipliers if needed + if p in multipliers: + lrm = K.variable(multipliers[p]) + else: + lrm = K.variable(1.0) + v = self.momentum * m - (lr*lrm) * g # velocity self.updates.append(K.update(m, v)) if self.nesterov: - new_p = p + self.momentum * v - lr * g + new_p = p + self.momentum * v - (lr*lrm) * g else: new_p = p + v @@ -228,7 +233,7 @@ def __init__(self, lr=0.001, rho=0.9, epsilon=1e-8, decay=0., self.initial_decay = decay self.iterations = K.variable(0.) - def get_updates(self, params, constraints, loss): + def get_updates(self, params, multipliers, constraints, loss): grads = self.get_gradients(loss, params) shapes = [K.get_variable_shape(p) for p in params] accumulators = [K.zeros(shape) for shape in shapes] @@ -241,10 +246,15 @@ def get_updates(self, params, constraints, loss): self.updates.append(K.update_add(self.iterations, 1)) for p, g, a in zip(params, grads, accumulators): + # Apply learning rate multipliers if needed + if p in multipliers: + lrm = K.variable(multipliers[p]) + else: + lrm = K.variable(1.0) # update accumulator new_a = self.rho * a + (1. - self.rho) * K.square(g) self.updates.append(K.update(a, new_a)) - new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon) + new_p = p - (lr * lrm) * g / (K.sqrt(new_a) + self.epsilon) # apply constraints if p in constraints: @@ -285,7 +295,7 @@ def __init__(self, lr=0.01, epsilon=1e-8, decay=0., **kwargs): self.initial_decay = decay self.iterations = K.variable(0.) - def get_updates(self, params, constraints, loss): + def get_updates(self, params, multipliers, constraints, loss): grads = self.get_gradients(loss, params) shapes = [K.get_variable_shape(p) for p in params] accumulators = [K.zeros(shape) for shape in shapes] @@ -298,9 +308,14 @@ def get_updates(self, params, constraints, loss): self.updates.append(K.update_add(self.iterations, 1)) for p, g, a in zip(params, grads, accumulators): + # Apply learning rate multipliers if needed + if p in multipliers: + lrm = K.variable(multipliers[p]) + else: + lrm = K.variable(1.0) new_a = a + K.square(g) # update accumulator self.updates.append(K.update(a, new_a)) - new_p = p - lr * g / (K.sqrt(new_a) + self.epsilon) + new_p = p - (lr * lrm) * g / (K.sqrt(new_a) + self.epsilon) # apply constraints if p in constraints: c = constraints[p] @@ -343,7 +358,7 @@ def __init__(self, lr=1.0, rho=0.95, epsilon=1e-8, decay=0., self.initial_decay = decay self.iterations = K.variable(0.) - def get_updates(self, params, constraints, loss): + def get_updates(self, params, multipliers, constraints, loss): grads = self.get_gradients(loss, params) shapes = [K.get_variable_shape(p) for p in params] accumulators = [K.zeros(shape) for shape in shapes] @@ -357,6 +372,11 @@ def get_updates(self, params, constraints, loss): self.updates.append(K.update_add(self.iterations, 1)) for p, g, a, d_a in zip(params, grads, accumulators, delta_accumulators): + # Apply learning rate multipliers if needed + if p in multipliers: + lrm = K.variable(multipliers[p]) + else: + lrm = K.variable(1.0) # update accumulator new_a = self.rho * a + (1. - self.rho) * K.square(g) self.updates.append(K.update(a, new_a)) @@ -364,7 +384,7 @@ def get_updates(self, params, constraints, loss): # use the new accumulator and the *old* delta_accumulator update = g * K.sqrt(d_a + self.epsilon) / K.sqrt(new_a + self.epsilon) - new_p = p - lr * update + new_p = p - (lr * lrm) * update # apply constraints if p in constraints: c = constraints[p] @@ -412,7 +432,7 @@ def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, self.decay = K.variable(decay) self.initial_decay = decay - def get_updates(self, params, constraints, loss): + def get_updates(self, params, multipliers, constraints, loss): grads = self.get_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] @@ -430,9 +450,14 @@ def get_updates(self, params, constraints, loss): self.weights = [self.iterations] + ms + vs for p, g, m, v in zip(params, grads, ms, vs): + # Apply learning rate multipliers if needed + if p in multipliers: + lrm = K.variable(multipliers[p]) + else: + lrm = K.variable(1.0) m_t = (self.beta_1 * m) + (1. - self.beta_1) * g v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g) - p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon) + p_t = p - (lr_t * lrm) * m_t / (K.sqrt(v_t) + self.epsilon) self.updates.append(K.update(m, m_t)) self.updates.append(K.update(v, v_t)) @@ -482,7 +507,7 @@ def __init__(self, lr=0.002, beta_1=0.9, beta_2=0.999, self.decay = K.variable(decay) self.initial_decay = decay - def get_updates(self, params, constraints, loss): + def get_updates(self, params, multipliers, constraints, loss): grads = self.get_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] @@ -501,10 +526,15 @@ def get_updates(self, params, constraints, loss): self.weights = [self.iterations] + ms + us for p, g, m, u in zip(params, grads, ms, us): + # Apply learning rate multipliers if needed + if p in multipliers: + lrm = K.variable(multipliers[p]) + else: + lrm = K.variable(1.0) m_t = (self.beta_1 * m) + (1. - self.beta_1) * g u_t = K.maximum(self.beta_2 * u, K.abs(g)) - p_t = p - lr_t * m_t / (u_t + self.epsilon) + p_t = p - (lr_t * lrm) * m_t / (u_t + self.epsilon) self.updates.append(K.update(m, m_t)) self.updates.append(K.update(u, u_t)) @@ -558,7 +588,7 @@ def __init__(self, lr=0.002, beta_1=0.9, beta_2=0.999, self.epsilon = epsilon self.schedule_decay = schedule_decay - def get_updates(self, params, constraints, loss): + def get_updates(self, params, multipliers, constraints, loss): grads = self.get_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] @@ -578,6 +608,11 @@ def get_updates(self, params, constraints, loss): self.weights = [self.iterations] + ms + vs for p, g, m, v in zip(params, grads, ms, vs): + # Apply learning rate multipliers if needed + if p in multipliers: + lrm = K.variable(multipliers[p]) + else: + lrm = K.variable(1.0) # the following equations given in [1] g_prime = g / (1. - m_schedule_new) m_t = self.beta_1 * m + (1. - self.beta_1) * g @@ -589,7 +624,7 @@ def get_updates(self, params, constraints, loss): self.updates.append(K.update(m, m_t)) self.updates.append(K.update(v, v_t)) - p_t = p - self.lr * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon) + p_t = p - (self.lr * lrm) * m_t_bar / (K.sqrt(v_t_prime) + self.epsilon) new_p = p_t # apply constraints @@ -618,11 +653,11 @@ def __init__(self, optimizer): self.iterations = K.variable(0.) self.updates = [] - def get_updates(self, params, constraints, loss): - if constraints: + def get_updates(self, params, multipliers, constraints, loss): + if constraints or multipliers: raise ValueError('TF optimizers do not support ' - 'weights constraints. Either remove ' - 'all weights constraints in your model, ' + 'weights multipliers or constraints. Either remove ' + 'all weights multipliers and constraints in your model, ' 'or use a Keras optimizer.') grads = self.optimizer.compute_gradients(loss, params) opt_update = self.optimizer.apply_gradients( diff --git a/tests/keras/test_learning_rate_multipliers.py b/tests/keras/test_learning_rate_multipliers.py new file mode 100644 index 000000000000..38c239253aca --- /dev/null +++ b/tests/keras/test_learning_rate_multipliers.py @@ -0,0 +1,159 @@ +from __future__ import print_function +import pytest +import numpy as np +from keras.utils.test_utils import get_test_data +from keras.utils.np_utils import to_categorical +from keras.models import Sequential +from keras.utils.test_utils import layer_test +from keras import backend as K + +seed = 1224 + + +def test_learning_rate_multipliers_maxout_dense(): + from keras.layers.core import MaxoutDense + + layer_test(MaxoutDense, + kwargs={'output_dim': 3, + 'W_learning_rate_multiplier': 0.1, + 'b_learning_rate_multiplier': 0.1}, + input_shape=(3, 2)) + + with pytest.raises(Exception) as e_info: + layer_test(MaxoutDense, + kwargs={'output_dim': 3, + 'bias': False, + 'W_learning_rate_multiplier': 0.1, + 'b_learning_rate_multiplier': 0.1}, + input_shape=(3, 2)) + + +def test_learning_rate_multipliers_conv1d(): + from keras.layers.convolutional import Convolution1D + + layer_test(Convolution1D, + kwargs={'nb_filter': 4, + 'filter_length': 3, + 'W_learning_rate_multiplier': 0.1, + 'b_learning_rate_multiplier': 0.1}, + input_shape=(2, 8, 5)) + + with pytest.raises(Exception) as e_info: + layer_test(Convolution1D, + kwargs={'nb_filter': 4, + 'filter_length': 3, + 'bias': False, + 'W_learning_rate_multiplier': 0.1, + 'b_learning_rate_multiplier': 0.1}, + input_shape=(2, 8, 5)) + + +@pytest.mark.skipif((K._BACKEND != 'theano'), + reason="Requires theano backend or be able to set random seed in tensorflow") +def test_learning_rate_multipliers_dense(): + from keras.layers.core import Dense + from keras.optimizers import SGD + + layer_test(Dense, + kwargs={'output_dim': 3, + 'W_learning_rate_multiplier': 0.1, + 'b_learning_rate_multiplier': 0.1}, + input_shape=(3, 2)) + + # This should raise an error + with pytest.raises(Exception) as e_info: + layer_test(Dense, + kwargs={'output_dim': 3, + 'bias': False, + 'W_learning_rate_multiplier': 0.1, + 'b_learning_rate_multiplier': 0.1}, + input_shape=(3, 2)) + + np.random.seed(seed) + (X_train, y_train), (X_test, y_test) = get_test_data(nb_train=10, + nb_test=1, + input_shape=(5,), + classification=True, + nb_class=2) + y_train = to_categorical(y_train) + y_test = to_categorical(y_test) + + np.random.seed(seed) + model0 = Sequential() + model0.add(Dense(output_dim=2, input_dim=5)) + sgd = SGD(lr=0.4, momentum=0., decay=0.) + model0.compile(loss='mse', optimizer=sgd) + (m0w0_ini, m0b0_ini) = model0.layers[0].get_weights() + model0.train_on_batch(X_train, y_train) + (m0w0_end, m0b0_end) = model0.layers[0].get_weights() + + np.random.seed(seed) + model1 = Sequential() + model1.add(Dense(output_dim=2, input_dim=5, + W_learning_rate_multiplier=0.5, b_learning_rate_multiplier=0.5)) + sgd = SGD(lr=0.4, momentum=0., decay=0.) + model1.compile(loss='mse', optimizer=sgd) + (m1w0_ini, m1b0_ini) = model1.layers[0].get_weights() + model1.train_on_batch(X_train, y_train) + (m1w0_end, m1b0_end) = model1.layers[0].get_weights() + + # This should be ~0.5 + np.testing.assert_almost_equal(np.mean((m1w0_end - m1w0_ini) / (m0w0_end - m0w0_ini)), 0.5, decimal=2) + np.testing.assert_almost_equal(np.mean((m1b0_end - m1b0_ini) / (m0b0_end - m0b0_ini)), 0.5, decimal=2) + + +@pytest.mark.skipif((K._BACKEND != 'theano'), + reason="Requires theano backend or be able to set random seed in tensorflow") +def test_learning_rate_multipliers_conv2d(): + from keras.layers.convolutional import Convolution2D + + layer_test(Convolution2D, + kwargs={'nb_filter': 3, + 'nb_row': 3, + 'nb_col': 3, + 'W_learning_rate_multiplier': 0.1, + 'b_learning_rate_multiplier': 0.1}, + input_shape=(8, 4, 10, 6)) + + with pytest.raises(Exception) as e_info: + layer_test(Convolution2D, + kwargs={'nb_filter': 3, + 'nb_row': 3, + 'nb_col': 3, + 'bias': False, + 'W_learning_rate_multiplier': 0.1, + 'b_learning_rate_multiplier': 0.1}, + input_shape=(8, 4, 10, 6)) + + np.random.seed(seed) + X_train = np.random.rand(10, 3, 10, 10) + y_train = np.random.rand(10, 2, 8, 8) + + np.random.seed(seed) + model0 = Sequential() + model0.add(Convolution2D(2, 3, 3, + input_shape=(3, 10, 10), + border_mode='valid')) + model0.compile(loss='mse', optimizer='sgd') + (m0w0_ini, m0b0_ini) = model0.layers[0].get_weights() + model0.train_on_batch(X_train, y_train) + (m0w0_end, m0b0_end) = model0.layers[0].get_weights() + + np.random.seed(seed) + model1 = Sequential() + model1.add(Convolution2D(2, 3, 3, + input_shape=(3, 10, 10), + border_mode='valid', + W_learning_rate_multiplier=0.5, b_learning_rate_multiplier=0.5)) + model1.compile(loss='mse', optimizer='sgd') + (m1w0_ini, m1b0_ini) = model1.layers[0].get_weights() + model1.train_on_batch(X_train, y_train) + (m1w0_end, m1b0_end) = model1.layers[0].get_weights() + + # This should be ~0.5 + np.testing.assert_almost_equal(np.mean((m1w0_end - m1w0_ini) / (m0w0_end - m0w0_ini)), 0.5, decimal=2) + np.testing.assert_almost_equal(np.mean((m1b0_end - m1b0_ini) / (m0b0_end - m0b0_ini)), 0.5, decimal=2) + + +if __name__ == '__main__': + pytest.main([__file__]) From dc75b3daf094f475a526252c3b2fcacaf8de559c Mon Sep 17 00:00:00 2001 From: Ken Chatfield Date: Mon, 1 May 2017 14:16:30 +0100 Subject: [PATCH 08/27] Remove exception when specifying b_learning_rate_multiplier on layers without bias --- keras/layers/convolutional.py | 8 ----- keras/layers/core.py | 8 ----- tests/keras/test_learning_rate_multipliers.py | 36 ------------------- 3 files changed, 52 deletions(-) diff --git a/keras/layers/convolutional.py b/keras/layers/convolutional.py index eabea7b081cc..25cbce156c45 100644 --- a/keras/layers/convolutional.py +++ b/keras/layers/convolutional.py @@ -121,8 +121,6 @@ def __init__(self, nb_filter, filter_length, self.W_constraint = constraints.get(W_constraint) self.b_constraint = constraints.get(b_constraint) - if not bias and b_learning_rate_multiplier is not None: - raise ValueError('b_learning_rate_multiplier provided with no bias.') self.W_learning_rate_multiplier = W_learning_rate_multiplier self.b_learning_rate_multiplier = b_learning_rate_multiplier @@ -439,8 +437,6 @@ def __init__(self, nb_filter, nb_row, nb_col, self.W_constraint = constraints.get(W_constraint) self.b_constraint = constraints.get(b_constraint) - if not bias and b_learning_rate_multiplier is not None: - raise ValueError('b_learning_rate_multiplier provided with no bias.') self.W_learning_rate_multiplier = W_learning_rate_multiplier self.b_learning_rate_multiplier = b_learning_rate_multiplier @@ -1023,8 +1019,6 @@ def __init__(self, nb_filter, nb_row, nb_col, self.pointwise_constraint = constraints.get(pointwise_constraint) self.b_constraint = constraints.get(b_constraint) - if not bias and b_learning_rate_multiplier is not None: - raise ValueError('b_learning_rate_multiplier provided with no bias.') self.depthwise_learning_rate_multiplier = depthwise_learning_rate_multiplier self.pointwise_learning_rate_multiplier = pointwise_learning_rate_multiplier self.b_learning_rate_multiplier = b_learning_rate_multiplier @@ -1235,8 +1229,6 @@ def __init__(self, nb_filter, kernel_dim1, kernel_dim2, kernel_dim3, self.W_constraint = constraints.get(W_constraint) self.b_constraint = constraints.get(b_constraint) - if not bias and b_learning_rate_multiplier is not None: - raise ValueError('b_learning_rate_multiplier provided with no bias.') self.W_learning_rate_multiplier = W_learning_rate_multiplier self.b_learning_rate_multiplier = b_learning_rate_multiplier diff --git a/keras/layers/core.py b/keras/layers/core.py index 135618c3244d..1f8c8cf38da9 100644 --- a/keras/layers/core.py +++ b/keras/layers/core.py @@ -781,8 +781,6 @@ def __init__(self, output_dim, init='glorot_uniform', self.W_constraint = constraints.get(W_constraint) self.b_constraint = constraints.get(b_constraint) - if not bias and b_learning_rate_multiplier is not None: - raise ValueError('b_learning_rate_multiplier provided with no bias.') self.W_learning_rate_multiplier = W_learning_rate_multiplier self.b_learning_rate_multiplier = b_learning_rate_multiplier @@ -963,8 +961,6 @@ def __init__(self, output_dim, self.W_constraint = constraints.get(W_constraint) self.b_constraint = constraints.get(b_constraint) - if not bias and b_learning_rate_multiplier is not None: - raise ValueError('b_learning_rate_multiplier provided with no bias.') self.W_learning_rate_multiplier = W_learning_rate_multiplier self.b_learning_rate_multiplier = b_learning_rate_multiplier @@ -1109,8 +1105,6 @@ def __init__(self, self.W_constraint = constraints.get(W_constraint) self.b_constraint = constraints.get(b_constraint) - if not bias and b_learning_rate_multiplier is not None: - raise ValueError('b_learning_rate_multiplier provided with no bias.') self.W_learning_rate_multiplier = W_learning_rate_multiplier self.b_learning_rate_multiplier = b_learning_rate_multiplier @@ -1267,8 +1261,6 @@ def __init__(self, output_dim, self.W_constraint = constraints.get(W_constraint) self.b_constraint = constraints.get(b_constraint) - if not bias and b_learning_rate_multiplier is not None: - raise ValueError('b_learning_rate_multiplier provided with no bias.') self.W_learning_rate_multiplier = W_learning_rate_multiplier self.b_learning_rate_multiplier = b_learning_rate_multiplier diff --git a/tests/keras/test_learning_rate_multipliers.py b/tests/keras/test_learning_rate_multipliers.py index 38c239253aca..8a43c02d1f88 100644 --- a/tests/keras/test_learning_rate_multipliers.py +++ b/tests/keras/test_learning_rate_multipliers.py @@ -19,14 +19,6 @@ def test_learning_rate_multipliers_maxout_dense(): 'b_learning_rate_multiplier': 0.1}, input_shape=(3, 2)) - with pytest.raises(Exception) as e_info: - layer_test(MaxoutDense, - kwargs={'output_dim': 3, - 'bias': False, - 'W_learning_rate_multiplier': 0.1, - 'b_learning_rate_multiplier': 0.1}, - input_shape=(3, 2)) - def test_learning_rate_multipliers_conv1d(): from keras.layers.convolutional import Convolution1D @@ -38,15 +30,6 @@ def test_learning_rate_multipliers_conv1d(): 'b_learning_rate_multiplier': 0.1}, input_shape=(2, 8, 5)) - with pytest.raises(Exception) as e_info: - layer_test(Convolution1D, - kwargs={'nb_filter': 4, - 'filter_length': 3, - 'bias': False, - 'W_learning_rate_multiplier': 0.1, - 'b_learning_rate_multiplier': 0.1}, - input_shape=(2, 8, 5)) - @pytest.mark.skipif((K._BACKEND != 'theano'), reason="Requires theano backend or be able to set random seed in tensorflow") @@ -60,15 +43,6 @@ def test_learning_rate_multipliers_dense(): 'b_learning_rate_multiplier': 0.1}, input_shape=(3, 2)) - # This should raise an error - with pytest.raises(Exception) as e_info: - layer_test(Dense, - kwargs={'output_dim': 3, - 'bias': False, - 'W_learning_rate_multiplier': 0.1, - 'b_learning_rate_multiplier': 0.1}, - input_shape=(3, 2)) - np.random.seed(seed) (X_train, y_train), (X_test, y_test) = get_test_data(nb_train=10, nb_test=1, @@ -115,16 +89,6 @@ def test_learning_rate_multipliers_conv2d(): 'b_learning_rate_multiplier': 0.1}, input_shape=(8, 4, 10, 6)) - with pytest.raises(Exception) as e_info: - layer_test(Convolution2D, - kwargs={'nb_filter': 3, - 'nb_row': 3, - 'nb_col': 3, - 'bias': False, - 'W_learning_rate_multiplier': 0.1, - 'b_learning_rate_multiplier': 0.1}, - input_shape=(8, 4, 10, 6)) - np.random.seed(seed) X_train = np.random.rand(10, 3, 10, 10) y_train = np.random.rand(10, 2, 8, 8) From 81d7d0ceae79ccf56e3cdc412c6bd4b59886fb0d Mon Sep 17 00:00:00 2001 From: Guilherme Pombo Date: Mon, 15 May 2017 14:23:20 +0100 Subject: [PATCH 09/27] override model definition --- keras/models.py | 60 ++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 52 insertions(+), 8 deletions(-) diff --git a/keras/models.py b/keras/models.py index 1b50b2234c50..389cad47261f 100644 --- a/keras/models.py +++ b/keras/models.py @@ -4,6 +4,7 @@ import json import os import numpy as np +import inspect from . import backend as K from . import optimizers @@ -13,6 +14,19 @@ from .optimizers import optimizer_from_config +def model_mappings(class_name): + + model_map = { + 'Model': Model, + 'Container': Model, + 'Sequential': Sequential, + 'OrdinalModel': Model, + 'ExtendedModel': Model + } + + return model_map.get(class_name, None) + + def save_model(model, filepath, overwrite=True): def get_json_type(obj): @@ -106,7 +120,12 @@ def get_json_type(obj): f.close() -def load_model(filepath, custom_objects=None): +def load_model(filepath, classify=True, custom_objects=None): + """ + :param filepath: Path to the model.h5 file + :param custom_objects: Custom layers necessary to build the model + :return: Object of Keras Model class or custom model class. + """ if not custom_objects: custom_objects = {} @@ -144,13 +163,11 @@ def deserialize(obj): # set weights model.load_weights_from_hdf5_group(f['model_weights']) - # instantiate optimizer - training_config = f.attrs.get('training_config') - if training_config is None: - warnings.warn('No training configuration found in save file: ' - 'the model was *not* compiled. Compile it manually.') + if classify: f.close() return model + + training_config = f.attrs.get('training_config') training_config = json.loads(training_config.decode('utf-8')) optimizer_config = training_config['optimizer_config'] optimizer = optimizer_from_config(optimizer_config, @@ -185,12 +202,39 @@ def deserialize(obj): def model_from_config(config, custom_objects=None): - from keras.utils.layer_utils import layer_from_config + """Instantiate a layer from a config dictionary. + + # Arguments + config: dict of the form {'class_name': str, 'config': dict} + custom_objects: dict mapping class names (or function names) + of custom (non-Keras) objects to class/functions + + # Returns + Layer instance (may be Model, Sequential, Layer...) + """ if isinstance(config, list): raise TypeError('`model_fom_config` expects a dictionary, not a list. ' 'Maybe you meant to use ' '`Sequential.from_config(config)`?') - return layer_from_config(config, custom_objects=custom_objects) + # Insert custom layers into globals so they can + # be accessed by `get_from_module`. + if custom_objects: + for cls_key in custom_objects: + globals()[cls_key] = custom_objects[cls_key] + + class_name = config['class_name'] + + if model_mappings(class_name): + layer_class = model_mappings(class_name) + else: + raise Exception('Unable to find %s in model_mappings' % class_name) + + arg_spec = inspect.getfullargspec(layer_class.from_config) + if 'custom_objects' in arg_spec.args: + return layer_class.from_config(config['config'], + custom_objects=custom_objects) + else: + return layer_class.from_config(config['config']) def model_from_yaml(yaml_string, custom_objects=None): From dc9a16070711fb0342a1535ad5f16ce25e5578cd Mon Sep 17 00:00:00 2001 From: Guilherme Pombo Date: Mon, 15 May 2017 15:15:42 +0100 Subject: [PATCH 10/27] removed unecessary docstrings --- keras/models.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/keras/models.py b/keras/models.py index 389cad47261f..a53c839727a8 100644 --- a/keras/models.py +++ b/keras/models.py @@ -121,11 +121,7 @@ def get_json_type(obj): def load_model(filepath, classify=True, custom_objects=None): - """ - :param filepath: Path to the model.h5 file - :param custom_objects: Custom layers necessary to build the model - :return: Object of Keras Model class or custom model class. - """ + if not custom_objects: custom_objects = {} @@ -202,16 +198,7 @@ def deserialize(obj): def model_from_config(config, custom_objects=None): - """Instantiate a layer from a config dictionary. - # Arguments - config: dict of the form {'class_name': str, 'config': dict} - custom_objects: dict mapping class names (or function names) - of custom (non-Keras) objects to class/functions - - # Returns - Layer instance (may be Model, Sequential, Layer...) - """ if isinstance(config, list): raise TypeError('`model_fom_config` expects a dictionary, not a list. ' 'Maybe you meant to use ' From 2891f3a85cdd3ce9dd684d6dc8eeffb3b25a2faa Mon Sep 17 00:00:00 2001 From: Guilherme Pombo Date: Mon, 15 May 2017 15:18:57 +0100 Subject: [PATCH 11/27] removed duplicated behaviour in load_model() --- keras/models.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/keras/models.py b/keras/models.py index a53c839727a8..52dab2e0e378 100644 --- a/keras/models.py +++ b/keras/models.py @@ -120,7 +120,7 @@ def get_json_type(obj): f.close() -def load_model(filepath, classify=True, custom_objects=None): +def load_model(filepath, custom_objects=None): if not custom_objects: custom_objects = {} @@ -159,7 +159,11 @@ def deserialize(obj): # set weights model.load_weights_from_hdf5_group(f['model_weights']) - if classify: + # instantiate optimizer + training_config = f.attrs.get('training_config') + if training_config is None: + warnings.warn('No training configuration found in save file: ' + 'the model was *not* compiled. Compile it manually.') f.close() return model From 28763b7a66a88872b7584e23c7f8da64dfa88272 Mon Sep 17 00:00:00 2001 From: Guilherme Pombo Date: Mon, 15 May 2017 15:20:22 +0100 Subject: [PATCH 12/27] white lines diff --- keras/models.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/keras/models.py b/keras/models.py index 52dab2e0e378..6df71691f97d 100644 --- a/keras/models.py +++ b/keras/models.py @@ -121,7 +121,6 @@ def get_json_type(obj): def load_model(filepath, custom_objects=None): - if not custom_objects: custom_objects = {} @@ -167,7 +166,6 @@ def deserialize(obj): f.close() return model - training_config = f.attrs.get('training_config') training_config = json.loads(training_config.decode('utf-8')) optimizer_config = training_config['optimizer_config'] optimizer = optimizer_from_config(optimizer_config, From 7fb918ccc4953ec35f9c2c62791fbb9f6b010ac9 Mon Sep 17 00:00:00 2001 From: Guilherme Pombo Date: Mon, 15 May 2017 15:45:36 +0100 Subject: [PATCH 13/27] code simplification --- keras/models.py | 30 ++---------------------------- 1 file changed, 2 insertions(+), 28 deletions(-) diff --git a/keras/models.py b/keras/models.py index 6df71691f97d..9d316b6cab82 100644 --- a/keras/models.py +++ b/keras/models.py @@ -14,19 +14,6 @@ from .optimizers import optimizer_from_config -def model_mappings(class_name): - - model_map = { - 'Model': Model, - 'Container': Model, - 'Sequential': Sequential, - 'OrdinalModel': Model, - 'ExtendedModel': Model - } - - return model_map.get(class_name, None) - - def save_model(model, filepath, overwrite=True): def get_json_type(obj): @@ -165,7 +152,6 @@ def deserialize(obj): 'the model was *not* compiled. Compile it manually.') f.close() return model - training_config = json.loads(training_config.decode('utf-8')) optimizer_config = training_config['optimizer_config'] optimizer = optimizer_from_config(optimizer_config, @@ -199,8 +185,7 @@ def deserialize(obj): return model -def model_from_config(config, custom_objects=None): - +def model_from_config(config, custom_objects=None, layer_class=Model): if isinstance(config, list): raise TypeError('`model_fom_config` expects a dictionary, not a list. ' 'Maybe you meant to use ' @@ -211,19 +196,8 @@ def model_from_config(config, custom_objects=None): for cls_key in custom_objects: globals()[cls_key] = custom_objects[cls_key] - class_name = config['class_name'] - - if model_mappings(class_name): - layer_class = model_mappings(class_name) - else: - raise Exception('Unable to find %s in model_mappings' % class_name) - arg_spec = inspect.getfullargspec(layer_class.from_config) - if 'custom_objects' in arg_spec.args: - return layer_class.from_config(config['config'], - custom_objects=custom_objects) - else: - return layer_class.from_config(config['config']) + return layer_class.from_config(config['config'], custom_objects=getattr(arg_spec.args, 'custom_objects', None)) def model_from_yaml(yaml_string, custom_objects=None): From 65c244c442ff41edb10c2f77c2eed6917916c7fd Mon Sep 17 00:00:00 2001 From: Guilherme Pombo Date: Mon, 15 May 2017 15:50:08 +0100 Subject: [PATCH 14/27] pass layer class through load_model --- keras/models.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/keras/models.py b/keras/models.py index 9d316b6cab82..3648f52c15d3 100644 --- a/keras/models.py +++ b/keras/models.py @@ -107,7 +107,7 @@ def get_json_type(obj): f.close() -def load_model(filepath, custom_objects=None): +def load_model(filepath, custom_objects=None, layer_class=Model): if not custom_objects: custom_objects = {} @@ -140,7 +140,7 @@ def deserialize(obj): if model_config is None: raise ValueError('No model found in config file.') model_config = json.loads(model_config.decode('utf-8')) - model = model_from_config(model_config, custom_objects=custom_objects) + model = model_from_config(model_config, custom_objects=custom_objects, layer_class=layer_class) # set weights model.load_weights_from_hdf5_group(f['model_weights']) @@ -190,12 +190,10 @@ def model_from_config(config, custom_objects=None, layer_class=Model): raise TypeError('`model_fom_config` expects a dictionary, not a list. ' 'Maybe you meant to use ' '`Sequential.from_config(config)`?') - # Insert custom layers into globals so they can - # be accessed by `get_from_module`. + # Insert custom layers into globals so they can be accessed by `get_from_module`. if custom_objects: for cls_key in custom_objects: globals()[cls_key] = custom_objects[cls_key] - arg_spec = inspect.getfullargspec(layer_class.from_config) return layer_class.from_config(config['config'], custom_objects=getattr(arg_spec.args, 'custom_objects', None)) From 0c345fb01d4ceb31d0622e96196547471fbd4a06 Mon Sep 17 00:00:00 2001 From: Guilherme Pombo Date: Mon, 15 May 2017 16:39:25 +0100 Subject: [PATCH 15/27] default behaviour is now the same as original keras --- keras/models.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/keras/models.py b/keras/models.py index 3648f52c15d3..c32b3f85c254 100644 --- a/keras/models.py +++ b/keras/models.py @@ -12,6 +12,7 @@ from .engine.training import Model from .engine.topology import get_source_inputs, Node, Layer, Merge from .optimizers import optimizer_from_config +from keras.utils.layer_utils import layer_from_config def save_model(model, filepath, overwrite=True): @@ -107,7 +108,7 @@ def get_json_type(obj): f.close() -def load_model(filepath, custom_objects=None, layer_class=Model): +def load_model(filepath, custom_objects=None, layer_class=None): if not custom_objects: custom_objects = {} @@ -185,7 +186,7 @@ def deserialize(obj): return model -def model_from_config(config, custom_objects=None, layer_class=Model): +def model_from_config(config, custom_objects=None, layer_class=None): if isinstance(config, list): raise TypeError('`model_fom_config` expects a dictionary, not a list. ' 'Maybe you meant to use ' @@ -194,8 +195,12 @@ def model_from_config(config, custom_objects=None, layer_class=Model): if custom_objects: for cls_key in custom_objects: globals()[cls_key] = custom_objects[cls_key] - arg_spec = inspect.getfullargspec(layer_class.from_config) - return layer_class.from_config(config['config'], custom_objects=getattr(arg_spec.args, 'custom_objects', None)) + if layer_class: + arg_spec = inspect.getfullargspec(layer_class.from_config) + return layer_class.from_config(config['config'], custom_objects=getattr(arg_spec.args, 'custom_objects', None)) + else: + from keras.utils.layer_utils import layer_from_config + return layer_from_config(config, custom_objects=custom_objects) def model_from_yaml(yaml_string, custom_objects=None): From fda2ede06e6ede7eefd0a67b56c943a7d853a679 Mon Sep 17 00:00:00 2001 From: Guilherme Pombo Date: Mon, 15 May 2017 16:40:33 +0100 Subject: [PATCH 16/27] put global definition in the right plcae --- keras/models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/keras/models.py b/keras/models.py index c32b3f85c254..b8cd00fdf61a 100644 --- a/keras/models.py +++ b/keras/models.py @@ -191,11 +191,11 @@ def model_from_config(config, custom_objects=None, layer_class=None): raise TypeError('`model_fom_config` expects a dictionary, not a list. ' 'Maybe you meant to use ' '`Sequential.from_config(config)`?') - # Insert custom layers into globals so they can be accessed by `get_from_module`. - if custom_objects: - for cls_key in custom_objects: - globals()[cls_key] = custom_objects[cls_key] if layer_class: + # Insert custom layers into globals so they can be accessed by `get_from_module`. + if custom_objects: + for cls_key in custom_objects: + globals()[cls_key] = custom_objects[cls_key] arg_spec = inspect.getfullargspec(layer_class.from_config) return layer_class.from_config(config['config'], custom_objects=getattr(arg_spec.args, 'custom_objects', None)) else: From 37d70c34852745a83ed9fe97570c62c8370fba1e Mon Sep 17 00:00:00 2001 From: Guilherme Pombo Date: Mon, 15 May 2017 16:42:20 +0100 Subject: [PATCH 17/27] removed uncessary import --- keras/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/models.py b/keras/models.py index b8cd00fdf61a..6631437150a6 100644 --- a/keras/models.py +++ b/keras/models.py @@ -12,7 +12,6 @@ from .engine.training import Model from .engine.topology import get_source_inputs, Node, Layer, Merge from .optimizers import optimizer_from_config -from keras.utils.layer_utils import layer_from_config def save_model(model, filepath, overwrite=True): @@ -198,6 +197,7 @@ def model_from_config(config, custom_objects=None, layer_class=None): globals()[cls_key] = custom_objects[cls_key] arg_spec = inspect.getfullargspec(layer_class.from_config) return layer_class.from_config(config['config'], custom_objects=getattr(arg_spec.args, 'custom_objects', None)) + # Default Keras behaviour if the layer class parameter is not passed else: from keras.utils.layer_utils import layer_from_config return layer_from_config(config, custom_objects=custom_objects) From c0ff731d741a3f1d9969f9b98f815dd3809f3681 Mon Sep 17 00:00:00 2001 From: Guilherme Pombo Date: Mon, 15 May 2017 18:09:22 +0100 Subject: [PATCH 18/27] remove extra parameter, pass layer_class through custom_objects --- keras/models.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/keras/models.py b/keras/models.py index 6631437150a6..bdfc56cbc0f3 100644 --- a/keras/models.py +++ b/keras/models.py @@ -107,7 +107,7 @@ def get_json_type(obj): f.close() -def load_model(filepath, custom_objects=None, layer_class=None): +def load_model(filepath, custom_objects=None): if not custom_objects: custom_objects = {} @@ -140,7 +140,7 @@ def deserialize(obj): if model_config is None: raise ValueError('No model found in config file.') model_config = json.loads(model_config.decode('utf-8')) - model = model_from_config(model_config, custom_objects=custom_objects, layer_class=layer_class) + model = model_from_config(model_config, custom_objects=custom_objects) # set weights model.load_weights_from_hdf5_group(f['model_weights']) @@ -185,16 +185,16 @@ def deserialize(obj): return model -def model_from_config(config, custom_objects=None, layer_class=None): +def model_from_config(config, custom_objects=None): if isinstance(config, list): raise TypeError('`model_fom_config` expects a dictionary, not a list. ' 'Maybe you meant to use ' '`Sequential.from_config(config)`?') - if layer_class: + if custom_objects and 'layer_class' in custom_objects: # Insert custom layers into globals so they can be accessed by `get_from_module`. - if custom_objects: - for cls_key in custom_objects: - globals()[cls_key] = custom_objects[cls_key] + for cls_key in custom_objects: + globals()[cls_key] = custom_objects[cls_key] + layer_class = custom_objects['layer_class'] arg_spec = inspect.getfullargspec(layer_class.from_config) return layer_class.from_config(config['config'], custom_objects=getattr(arg_spec.args, 'custom_objects', None)) # Default Keras behaviour if the layer class parameter is not passed From 791c4060247b7ab47a2b43fafc874738186888af Mon Sep 17 00:00:00 2001 From: Guilherme Pombo Date: Fri, 19 May 2017 12:22:11 +0100 Subject: [PATCH 19/27] code cleanup --- keras/models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/keras/models.py b/keras/models.py index bdfc56cbc0f3..d934400f1411 100644 --- a/keras/models.py +++ b/keras/models.py @@ -195,8 +195,9 @@ def model_from_config(config, custom_objects=None): for cls_key in custom_objects: globals()[cls_key] = custom_objects[cls_key] layer_class = custom_objects['layer_class'] - arg_spec = inspect.getfullargspec(layer_class.from_config) - return layer_class.from_config(config['config'], custom_objects=getattr(arg_spec.args, 'custom_objects', None)) + # Remove layer class from custom_objects as it's not needed anymore + custom_objects.pop('layer_class', None) + return layer_class.from_config(config['config'], custom_objects=custom_objects) # Default Keras behaviour if the layer class parameter is not passed else: from keras.utils.layer_utils import layer_from_config From bf847f2e0bbd0574a5ebdfce54c29373cdbe8bd4 Mon Sep 17 00:00:00 2001 From: Guilherme Pombo Date: Fri, 19 May 2017 12:30:17 +0100 Subject: [PATCH 20/27] removed unused import --- keras/models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/keras/models.py b/keras/models.py index d934400f1411..a345ecffe789 100644 --- a/keras/models.py +++ b/keras/models.py @@ -4,7 +4,6 @@ import json import os import numpy as np -import inspect from . import backend as K from . import optimizers From a1f7fe388e6368a03aa43d31b546b76896229d4f Mon Sep 17 00:00:00 2001 From: Guilherme Pombo Date: Fri, 19 May 2017 15:11:56 +0100 Subject: [PATCH 21/27] refactored code to fallback to default Keras behaviour properly --- keras/models.py | 17 ++++------------- keras/utils/layer_utils.py | 9 +++++++++ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/keras/models.py b/keras/models.py index a345ecffe789..8721b8a9834f 100644 --- a/keras/models.py +++ b/keras/models.py @@ -146,7 +146,7 @@ def deserialize(obj): # instantiate optimizer training_config = f.attrs.get('training_config') - if training_config is None: + if training_config is None or ('classify' in custom_objects and custom_objects['classify']): warnings.warn('No training configuration found in save file: ' 'the model was *not* compiled. Compile it manually.') f.close() @@ -189,18 +189,9 @@ def model_from_config(config, custom_objects=None): raise TypeError('`model_fom_config` expects a dictionary, not a list. ' 'Maybe you meant to use ' '`Sequential.from_config(config)`?') - if custom_objects and 'layer_class' in custom_objects: - # Insert custom layers into globals so they can be accessed by `get_from_module`. - for cls_key in custom_objects: - globals()[cls_key] = custom_objects[cls_key] - layer_class = custom_objects['layer_class'] - # Remove layer class from custom_objects as it's not needed anymore - custom_objects.pop('layer_class', None) - return layer_class.from_config(config['config'], custom_objects=custom_objects) - # Default Keras behaviour if the layer class parameter is not passed - else: - from keras.utils.layer_utils import layer_from_config - return layer_from_config(config, custom_objects=custom_objects) + + from keras.utils.layer_utils import layer_from_config + return layer_from_config(config, custom_objects=custom_objects) def model_from_yaml(yaml_string, custom_objects=None): diff --git a/keras/utils/layer_utils.py b/keras/utils/layer_utils.py index 748f8743b013..66829d0c2303 100644 --- a/keras/utils/layer_utils.py +++ b/keras/utils/layer_utils.py @@ -24,6 +24,15 @@ def layer_from_config(config, custom_objects=None): if custom_objects: get_custom_objects().update(custom_objects) + # New behaviour + if custom_objects and 'layer_class' in custom_objects: + layer_class = custom_objects['layer_class'] + # Remove layer class from custom_objects as it's not needed anymore + custom_objects.pop('layer_class', None) + for cls_key in custom_objects: + globals()[cls_key] = custom_objects[cls_key] + return layer_class.from_config(config['config'], custom_objects=custom_objects) + class_name = config['class_name'] if class_name == 'Sequential': From 3c5f9222b926d9b75c70c6bb8c064c8b1c9e2a56 Mon Sep 17 00:00:00 2001 From: Guilherme Pombo Date: Fri, 19 May 2017 15:13:58 +0100 Subject: [PATCH 22/27] reduce diff to original code --- keras/models.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/keras/models.py b/keras/models.py index 8721b8a9834f..e59dbece76b8 100644 --- a/keras/models.py +++ b/keras/models.py @@ -185,12 +185,11 @@ def deserialize(obj): def model_from_config(config, custom_objects=None): + from keras.utils.layer_utils import layer_from_config if isinstance(config, list): raise TypeError('`model_fom_config` expects a dictionary, not a list. ' 'Maybe you meant to use ' '`Sequential.from_config(config)`?') - - from keras.utils.layer_utils import layer_from_config return layer_from_config(config, custom_objects=custom_objects) From 43cddbcb401e6836d2ea640665a514fbd7092c94 Mon Sep 17 00:00:00 2001 From: Guilherme Pombo Date: Fri, 19 May 2017 15:29:54 +0100 Subject: [PATCH 23/27] remove classify info from custom_objects after use --- keras/models.py | 1 + 1 file changed, 1 insertion(+) diff --git a/keras/models.py b/keras/models.py index e59dbece76b8..721d73c7fabf 100644 --- a/keras/models.py +++ b/keras/models.py @@ -147,6 +147,7 @@ def deserialize(obj): # instantiate optimizer training_config = f.attrs.get('training_config') if training_config is None or ('classify' in custom_objects and custom_objects['classify']): + custom_objects.pop('layer_class', None) warnings.warn('No training configuration found in save file: ' 'the model was *not* compiled. Compile it manually.') f.close() From 5b45baacc7e27a04f2e7558536fda6af17eeff26 Mon Sep 17 00:00:00 2001 From: Guilherme Pombo Date: Fri, 19 May 2017 15:47:52 +0100 Subject: [PATCH 24/27] adressing PR comments --- keras/models.py | 2 +- keras/utils/layer_utils.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/keras/models.py b/keras/models.py index 721d73c7fabf..d4a513676333 100644 --- a/keras/models.py +++ b/keras/models.py @@ -147,7 +147,7 @@ def deserialize(obj): # instantiate optimizer training_config = f.attrs.get('training_config') if training_config is None or ('classify' in custom_objects and custom_objects['classify']): - custom_objects.pop('layer_class', None) + custom_objects.pop('classify', None) warnings.warn('No training configuration found in save file: ' 'the model was *not* compiled. Compile it manually.') f.close() diff --git a/keras/utils/layer_utils.py b/keras/utils/layer_utils.py index 66829d0c2303..06d32dff08ad 100644 --- a/keras/utils/layer_utils.py +++ b/keras/utils/layer_utils.py @@ -29,8 +29,6 @@ def layer_from_config(config, custom_objects=None): layer_class = custom_objects['layer_class'] # Remove layer class from custom_objects as it's not needed anymore custom_objects.pop('layer_class', None) - for cls_key in custom_objects: - globals()[cls_key] = custom_objects[cls_key] return layer_class.from_config(config['config'], custom_objects=custom_objects) class_name = config['class_name'] From ac4a14e8226727273211993d1c64b5133d6dd3de Mon Sep 17 00:00:00 2001 From: Ken Chatfield Date: Fri, 14 Jul 2017 16:08:04 +0100 Subject: [PATCH 25/27] Add flag to toggle metrics inclusion when loading model --- keras/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/keras/models.py b/keras/models.py index d4a513676333..1403635d8349 100644 --- a/keras/models.py +++ b/keras/models.py @@ -106,7 +106,7 @@ def get_json_type(obj): f.close() -def load_model(filepath, custom_objects=None): +def load_model(filepath, custom_objects=None, include_metrics=True): if not custom_objects: custom_objects = {} @@ -159,7 +159,7 @@ def deserialize(obj): # recover loss functions and metrics loss = deserialize(training_config['loss']) - metrics = deserialize(training_config['metrics']) + metrics = deserialize(training_config['metrics']) if include_metrics else None sample_weight_mode = training_config['sample_weight_mode'] loss_weights = training_config['loss_weights'] From 931529cc8ff17fbc4b7a2c3fbbda872fa7005d25 Mon Sep 17 00:00:00 2001 From: Ken Chatfield Date: Thu, 27 Jul 2017 09:56:47 +0100 Subject: [PATCH 26/27] Don't construct random_transform matrix unless necessary in datagen --- keras/preprocessing/image.py | 88 +++++++++++++++++++----------------- 1 file changed, 46 insertions(+), 42 deletions(-) diff --git a/keras/preprocessing/image.py b/keras/preprocessing/image.py index 4fead051d406..39c2638f6c4a 100644 --- a/keras/preprocessing/image.py +++ b/keras/preprocessing/image.py @@ -491,53 +491,57 @@ def random_transform(self, x): img_col_axis = self.col_axis - 1 img_channel_axis = self.channel_axis - 1 - # use composition of homographies - # to generate final transform that needs to be applied - if self.rotation_range: - theta = np.pi / 180 * np.random.uniform(-self.rotation_range, self.rotation_range) - else: - theta = 0 - rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], - [np.sin(theta), np.cos(theta), 0], - [0, 0, 1]]) - if self.height_shift_range: - tx = np.random.uniform(-self.height_shift_range, self.height_shift_range) * x.shape[img_row_axis] - else: - tx = 0 + transform_matrix_vars = [self.rotation_range, self.height_shift_range, self.width_shift_range, + self.shear_range, self.zoom_range] + if any(transform_matrix_vars): + # use composition of homographies + # to generate final transform that needs to be applied + if self.rotation_range: + theta = np.pi / 180 * np.random.uniform(-self.rotation_range, self.rotation_range) + else: + theta = 0 + rotation_matrix = np.array([[np.cos(theta), -np.sin(theta), 0], + [np.sin(theta), np.cos(theta), 0], + [0, 0, 1]]) + if self.height_shift_range: + tx = np.random.uniform(-self.height_shift_range, self.height_shift_range) * x.shape[img_row_axis] + else: + tx = 0 - if self.width_shift_range: - ty = np.random.uniform(-self.width_shift_range, self.width_shift_range) * x.shape[img_col_axis] - else: - ty = 0 + if self.width_shift_range: + ty = np.random.uniform(-self.width_shift_range, self.width_shift_range) * x.shape[img_col_axis] + else: + ty = 0 - translation_matrix = np.array([[1, 0, tx], - [0, 1, ty], - [0, 0, 1]]) - if self.shear_range: - shear = np.random.uniform(-self.shear_range, self.shear_range) - else: - shear = 0 - shear_matrix = np.array([[1, -np.sin(shear), 0], - [0, np.cos(shear), 0], - [0, 0, 1]]) + translation_matrix = np.array([[1, 0, tx], + [0, 1, ty], + [0, 0, 1]]) + if self.shear_range: + shear = np.random.uniform(-self.shear_range, self.shear_range) + else: + shear = 0 + shear_matrix = np.array([[1, -np.sin(shear), 0], + [0, np.cos(shear), 0], + [0, 0, 1]]) - if self.zoom_range[0] == 1 and self.zoom_range[1] == 1: - zx, zy = 1, 1 - else: - zx, zy = np.random.uniform(self.zoom_range[0], self.zoom_range[1], 2) - zoom_matrix = np.array([[zx, 0, 0], - [0, zy, 0], - [0, 0, 1]]) + if self.zoom_range[0] == 1 and self.zoom_range[1] == 1: + zx, zy = 1, 1 + else: + zx, zy = np.random.uniform(self.zoom_range[0], self.zoom_range[1], 2) + zoom_matrix = np.array([[zx, 0, 0], + [0, zy, 0], + [0, 0, 1]]) + + transform_matrix = np.dot(np.dot(np.dot(rotation_matrix, + translation_matrix), + shear_matrix), + zoom_matrix) - transform_matrix = np.dot(np.dot(np.dot(rotation_matrix, - translation_matrix), - shear_matrix), - zoom_matrix) + h, w = x.shape[img_row_axis], x.shape[img_col_axis] + transform_matrix = transform_matrix_offset_center(transform_matrix, h, w) + x = apply_transform(x, transform_matrix, img_channel_axis, + fill_mode=self.fill_mode, cval=self.cval) - h, w = x.shape[img_row_axis], x.shape[img_col_axis] - transform_matrix = transform_matrix_offset_center(transform_matrix, h, w) - x = apply_transform(x, transform_matrix, img_channel_axis, - fill_mode=self.fill_mode, cval=self.cval) if self.channel_shift_range != 0: x = random_channel_shift(x, self.channel_shift_range, From 65be0fb2729aea8b0a75b781376544aa47dd7c00 Mon Sep 17 00:00:00 2001 From: Agis Oikonomou Date: Thu, 22 Nov 2018 14:19:59 +0000 Subject: [PATCH 27/27] fixed bug --- keras/engine/training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/engine/training.py b/keras/engine/training.py index 9ea12a2aaab4..d1c9ede7a625 100644 --- a/keras/engine/training.py +++ b/keras/engine/training.py @@ -440,7 +440,7 @@ def data_generator_task(): self.queue = multiprocessing.Queue(maxsize=max_q_size) self._stop_event = multiprocessing.Event() else: - self.queue = queue.Queue() + self.queue = queue.Queue(maxsize=max_q_size) self._stop_event = threading.Event() for i in range(nb_worker):