Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Commit

Permalink
Fix save and load of wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG committed Oct 29, 2018
1 parent 55d1c4d commit e2334aa
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
31 changes: 29 additions & 2 deletions keras_bert/layers/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,14 @@

class Wrapper(keras.layers.Layer):

def __init__(self, **kwargs):
self.layers = {}
CONFIG_PREFIX = 'wrapper_layer_'

def __init__(self, layers=None, **kwargs):
if layers is None:
self.layers = {}
else:
self.layers = layers
self.built = True
super(Wrapper, self).__init__(**kwargs)

@property
Expand All @@ -20,3 +26,24 @@ def non_trainable_weights(self):
for key in sorted(self.layers.keys()):
weights += self.layers[key].non_trainable_weights
return weights

def get_config(self):
config = {}
for name, layer in self.layers.items():
config[self.CONFIG_PREFIX + name] = {
'class_name': layer.__class__.__name__,
'config': layer.get_config(),
}
base_config = super(Wrapper, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

@classmethod
def from_config(cls, config, custom_objects=None):
layers = {}
print(config)
keys = list(filter(lambda key: key.startswith(Wrapper.CONFIG_PREFIX), config.keys()))
for key in keys:
if key.startswith(Wrapper.CONFIG_PREFIX):
name = key[len(Wrapper.CONFIG_PREFIX):]
layers[name] = keras.layers.deserialize(config.pop(key), custom_objects=custom_objects)
return cls(layers=layers, **config)
Binary file modified tests/test_bert_fit.h5
Binary file not shown.

0 comments on commit e2334aa

Please sign in to comment.