Skip to content

Commit

Permalink
net.save(), net.load() are robust
Browse files Browse the repository at this point in the history
  • Loading branch information
dsblank committed Sep 13, 2018
1 parent 4e5b77a commit 5230145
Showing 1 changed file with 48 additions and 38 deletions.
86 changes: 48 additions & 38 deletions conx/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def __init__(self, name: str, *sizes: int, load_config=True, debug=False,
self.output_bank_order = []
self.dataset = Dataset(self)
self.compile_options = {}
self.compile_args = {}
self.train_options = {}
self._tolerance = K.variable(0.1, dtype='float32', name='tolerance')
self.layer_dict = {}
Expand Down Expand Up @@ -882,7 +883,7 @@ def evaluate(self, batch_size=None, show=False, show_inputs=True, show_targets=T
if select is not None:
if isinstance(select, int):
select = (select,)
slice_select = slice(*select) if select is not None else slice(len(self.dataset))
slice_select = slice(*select) if select is not None else slice(len(self.dataset))
print("%s:" % self.name) ## network name
if 0 < self.dataset._split <= 1 and select is None:
size, num_train, num_test = self.dataset._get_split_sizes()
Expand Down Expand Up @@ -2591,6 +2592,10 @@ def compile(self, **kwargs):
>>> net.compile(error="mse", optimizer="adam")
"""
try:
self.compile_args = copy.deepcopy(kwargs)
except:
self.compile_args = {} # can't copy the state of args
if self.model is None:
self.build_model()
self.compile_model(**kwargs)
Expand Down Expand Up @@ -2718,12 +2723,12 @@ def _delete_intermediary_models(self):
layer.input_names = set([])
layer.model = None

def build_model(self, starting_layers=None, build_intermediary=True):
def build_model(self, starting_layers=None):
"""
Build the model.
"""
self._reset_layer_metadata()
self._build_intermediary_models(starting_layers=starting_layers, build_intermediary=build_intermediary)
self._build_intermediary_models(starting_layers=starting_layers)
output_k_layers = self._get_output_ks_in_order()
input_k_layers = self._get_input_ks_in_order(self.input_bank_order)
self.model = keras.models.Model(inputs=input_k_layers, outputs=output_k_layers)
Expand Down Expand Up @@ -2826,23 +2831,14 @@ def connect_network(self, output_layer_name, network):
self.build_model(starting_layers=[self[output_layer_name]])
self.compile_model(**network.compile_options)

def _build_intermediary_models(self, starting_layers=None, build_intermediary=True):
def _build_intermediary_models(self, starting_layers=None):
"""
Construct the layer.k, layer.input_names, and layer.model's.
"""
if starting_layers is None:
starting_layers = self.layers
self.prop_from_dict.clear()
self.keras_functions.clear()
if not build_intermediary:
## fill keras_functions from existing model's layers:
for layer in self.layers:
## the k:
layer.keras_layer = self._find_keras_layer(layer.name)
if layer.kind() == "input":
self.keras_functions[layer.name] = layer.keras_layer.get_input_at(0)
else:
self.keras_functions[layer.name] = [layer.keras_layer]
sequence = topological_sort(self, starting_layers)
if self.debug: print("topological sort:", [l.name for l in sequence])
for layer in sequence:
Expand Down Expand Up @@ -3919,22 +3915,15 @@ def load(self, dir=None):
return network
else:
self.load_config(dir)
self.load_models(dir)
self.load_weights(dir)

def save(self, dir=None):
"""
Save the model and the weights/history (if compiled) to a dir.
"""
save_network(dir, self)
self.save_config(dir)
if self.model:
self.save_models(dir)
self.save_weights(dir)
else:
raise Exception("need to build network before saving")

def load_models(self, dir=None, filename=None):
def load_model(self, dir=None, filename=None):
"""
Load a model from a dir/filename.
"""
Expand All @@ -3944,15 +3933,11 @@ def load_models(self, dir=None, filename=None):
if filename is None:
filename = "model.h5"
self.model = load_model(os.path.join(dir, filename))
for layer in self.layers:
filename = os.path.join(dir, "%s.model.h5" % layer.name)
if os.path.exists(filename):
layer.model = load_model(filename)
self._level_ordering = None
if self.compile_options:
self.reset()

def save_models(self, dir=None, filename=None):
def save_model(self, dir=None, filename=None):
"""
Save a model (if compiled) to a dir/filename.
"""
Expand All @@ -3964,10 +3949,6 @@ def save_models(self, dir=None, filename=None):
if not os.path.isdir(dir):
os.makedirs(dir)
self.model.save(os.path.join(dir, filename))
for layer in self.layers:
if layer.model is not None:
filename = "%s.model.h5" % layer.name
layer.model.save(os.path.join(dir, filename))
else:
raise Exception("need to build network before saving")

Expand Down Expand Up @@ -4017,7 +3998,9 @@ def load_weights(self, dir=None, filename=None):
dir = "%s.conx" % self.name.replace(" ", "_")
if filename is None:
filename = "weights.h5"
self.model.load_weights(os.path.join(dir, filename))
full_filename = os.path.join(dir, filename)
if os.path.exists(full_filename):
self.model.load_weights(full_filename)
self.load_history(dir)
else:
raise Exception("need to build network before loading weights")
Expand Down Expand Up @@ -4310,14 +4293,23 @@ def load_network(dir):
for connection in config["connections"]:
from_name, to_name = connection
network.connect(from_name, to_name)
network.load_models(dir)
network.build_model(build_intermediary=False)
network.load_weights(dir)
## FIXME: don't forget about error/loss functions on other layers
## after we build model
if config["compile_args"]:
network.compile(**config["compile_args"])
if network.model:
network.load_weights(dir)
return network

def save_network(datadir, network):
"""
Save the network description in order to be able to recreate
it.
Saves the network name, layers, conecctions, compile args
to network.pickle.
Saves the weights to weights.h5.
Saves the training history to history.pickle.
Saves the network config to config.json.
"""
import pickle
if datadir is None:
datadir = "%s.conx" % network.name.replace(" ", "_")
Expand All @@ -4327,17 +4319,35 @@ def save_network(datadir, network):
except:
datadir = os.path.join('/tmp', datadir)
os.makedirs(datadir)
## get latest from layers:
config = {
"name": network.name,
"layers": {},
"connections": network.connections,
"compile_args": network.compile_args,
}
## get latest from layers:
for layer in network.layers:
d = {}
config["layers"][layer.name] = d
for item in ["config"]:
d[item] = getattr(layer, item)
success = False
with open("%s/network.pickle" % (("%s.conx" % network.name.replace(" ", "_"))
if datadir is None else datadir), "wb") as fp:
pickle.dump(config, fp)
try:
pickle.dump(config, fp)
success = True
except TypeError:
success = False

if not success:
with open("%s/network.pickle" % (("%s.conx" % network.name.replace(" ", "_"))
if datadir is None else datadir), "wb") as fp:
config["compile_args"] = {}
pickle.dump(config, fp)
print("WARNING: can't save compile args; recompile after loading from disk",
file=sys.stderr)

network.save_config(datadir)
if network.model:
network.save_weights(datadir)

0 comments on commit 5230145

Please sign in to comment.