Skip to content

Commit

Permalink
Minor fix
Browse files Browse the repository at this point in the history
  • Loading branch information
lvapeab committed Jan 28, 2019
1 parent 724772d commit 382e69c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
6 changes: 4 additions & 2 deletions keras/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,9 +1127,10 @@ def is_indexed_slices(grad):
self.sess.graph)
else:
self.writer = tf.summary.FileWriter(self.log_dir)

if self.embeddings_freq:
embeddings_layer_names = self.embeddings_layer_names

if not embeddings_layer_names:
embeddings_layer_names = [layer.name for layer in self.model.layers
if type(layer).__name__ == 'Embedding']
Expand All @@ -1155,14 +1156,15 @@ def is_indexed_slices(grad):
batch = tf.assign(embedding[batch_id:batch_id + step],
embedding_input)
self.assign_embeddings.append(batch)
self.saver = tf.train.Saver(list(embeddings_vars.values()))
else:
if not self.saved:
embeddings_vars = {layer.name: layer.weights[0]
for layer in self.model.layers
if layer.name in embeddings_layer_names}
self.saver = tf.train.Saver(list(embeddings_vars.values()))
self.saved = True
self.saver = tf.train.Saver(list(embeddings_vars.values()))

if not isinstance(self.embeddings_metadata, str):
embeddings_metadata = self.embeddings_metadata
else:
Expand Down
1 change: 0 additions & 1 deletion keras/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def __init__(self, n_heads,
bias_initializer='zeros',
bias_regularizer=None,
bias_constraint=None,

**kwargs):
super(MultiHeadAttention, self).__init__(**kwargs)
self.supports_masking = True
Expand Down

0 comments on commit 382e69c

Please sign in to comment.