Skip to content

Commit

Permalink
Methods to save/load the network/model/weights
Browse files Browse the repository at this point in the history
  • Loading branch information
dsblank committed Aug 10, 2017
1 parent 936c9b1 commit a6bc97d
Showing 1 changed file with 68 additions and 19 deletions.
87 changes: 68 additions & 19 deletions conx/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@
from functools import reduce
import signal
import numbers
import pickle
import base64
import html
import copy
import io
import os
import PIL

import numpy as np
Expand Down Expand Up @@ -1600,30 +1602,77 @@ def describe_connection_to(self, layer1, layer2):
## FIXME: how to show merged layer weights?
return retval

def save(self, filename=None):
def save(self, foldername=None):
"""
Save the weights to a file.
Save the network to a folder.
"""
if filename is None:
filename = "%s.wts" % self.name
with open(filename, "wb") as fp:
for layer in self.model.layers:
for weight in layer.get_weights():
np.save(fp, weight)
if foldername is None:
foldername = "%s.conx" % self.name
if not os.path.isdir(foldername):
os.makedirs(foldername)
if self.model:
self.save_model(foldername)
self.save_weights(foldername)
self.model, tmp_model = None, self.model
with open("%s/network.pickle" % foldername, "wb") as fp:
pickle.dump(self, fp)
self.model = tmp_model

@classmethod
def load(cls, foldername):
"""
Load the network from a folder.
"""
with open("%s/network.pickle" % foldername, "rb") as fp:
net = pickle.load(fp)
net.load_model(foldername)
net.load_weights(foldername)
return net

def save_weights(self, foldername=None):
"""
Save the model weights to a folder.
"""
if self.model:
if foldername is None:
foldername = "%s.conx" % self.name
if not os.path.isdir(foldername):
os.makedirs(foldername)
self.model.save_weights("%s/weights.h5" % foldername)
else:
raise Exception("need to compile network first")

def save_model(self, foldername=None):
"""
Save the model to a folder.
"""
if self.model:
if foldername is None:
foldername = "%s.conx" % self.name
if not os.path.isdir(foldername):
os.makedirs(foldername)
self.model.save("%s/model.h5" % foldername)
else:
raise Exception("need to compile network first")

def load_weights(self, foldername=None):
"""
Load the model weights from a folder.
"""
if self.model:
if foldername is None:
foldername = "%s.conx" % self.name
if os.path.isfile("%s/model.h5" % foldername):
self.model.load_weights("%s/weights.h5" % foldername)

def load(self, filename=None):
def load_model(self, foldername=None):
"""
Load the weights from a file.
Load and set the model from a folder.
"""
if filename is None:
filename = "%s.wts" % self.name
with open(filename, "rb") as fp:
for layer in self.model.layers:
weights = layer.get_weights()
new_weights = []
for w in range(len(weights)):
new_weights.append(np.load(fp))
layer.set_weights(new_weights)
if foldername is None:
foldername = "%s.conx" % self.name
if os.path.isfile("%s/model.h5" % foldername):
self.model = keras.models.load_model("%s/model.h5" % foldername)

def get_inputs_length(self):
"""
Expand Down

0 comments on commit a6bc97d

Please sign in to comment.