Skip to content

Commit

Permalink
Protections for non-ipython and pickle
Browse files Browse the repository at this point in the history
  • Loading branch information
dsblank committed Aug 10, 2017
1 parent c2e7968 commit aa2e1c8
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 40 deletions.
8 changes: 7 additions & 1 deletion conx/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def __init__(self, name, *args, **params):
self.decode_model = None
self.input_names = []
# used to determine image ranges:
self.activation = params.get("activation", None) # make a copy, if one
self.activation = params.get("activation", None) # make a copy, if one, and str
if not isinstance(self.activation, str):
self.activation = None
# set visual shape for display purposes
if 'vshape' in params:
vs = params['vshape']
Expand Down Expand Up @@ -123,6 +125,8 @@ def __init__(self, name, *args, **params):

if 'activation' in params: # let's keep a copy of it
self.activation = params["activation"]
if not isinstance(self.activation, str):
self.activation = None

self.incoming_connections = []
self.outgoing_connections = []
Expand Down Expand Up @@ -343,6 +347,8 @@ def __init__(self, name: str, shape, **params):
if not (callable(act) or act in Layer.ACTIVATION_FUNCTIONS):
raise Exception('unknown activation function: %s' % (act,))
self.activation = act
if not isinstance(self.activation, str):
self.activation = None

def __repr__(self):
return "<Layer name='%s', shape=%s, act='%s'>" % (
Expand Down
118 changes: 79 additions & 39 deletions conx/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,10 +1027,11 @@ def propagate(self, input, batch_size=32):
if not self._comm:
from ipykernel.comm import Comm
self._comm = Comm(target_name='conx_svg_control')
for layer in self.layers:
image = self.propagate_to_image(layer.name, input, batch_size)
data_uri = self._image_to_uri(image)
self._comm.send({'class': "%s_%s" % (self.name, layer.name), "href": data_uri})
if self._comm.kernel:
for layer in self.layers:
image = self.propagate_to_image(layer.name, input, batch_size)
data_uri = self._image_to_uri(image)
self._comm.send({'class': "%s_%s" % (self.name, layer.name), "href": data_uri})
return outputs

def propagate_from(self, layer_name, input, output_layer_names=None, batch_size=32):
Expand Down Expand Up @@ -1079,12 +1080,13 @@ def propagate_from(self, layer_name, input, output_layer_names=None, batch_size=
from ipykernel.comm import Comm
self._comm = Comm(target_name='conx_svg_control')
## Update from start to rest of graph
for layer in topological_sort(self, [self[layer_name]]):
model = self.prop_from_dict[(layer_name, layer.name)]
vector = model.predict(inputs)[0]
image = layer.make_image(vector, self.config)
data_uri = self._image_to_uri(image)
self._comm.send({'class': "%s_%s" % (self.name, layer.name), "href": data_uri})
if self._comm.kernel:
for layer in topological_sort(self, [self[layer_name]]):
model = self.prop_from_dict[(layer_name, layer.name)]
vector = model.predict(inputs)[0]
image = layer.make_image(vector, self.config)
data_uri = self._image_to_uri(image)
self._comm.send({'class': "%s_%s" % (self.name, layer.name), "href": data_uri})
if len(output_layer_names) == 1:
return outputs[0]
else:
Expand All @@ -1101,11 +1103,12 @@ def display_component(self, vector, component, **opts): #minmax=None, colormap=N
output_names = self.output_layer_order
else:
output_names = [layer.name for layer in self.layers if layer.kind() == "output"]
for (target, layer_name) in zip(vector, output_names):
array = np.array(target)
image = self[layer_name].make_image(array, config) # minmax=minmax, colormap=colormap)
data_uri = self._image_to_uri(image)
self._comm.send({'class': "%s_%s_%s" % (self.name, layer_name, component), "href": data_uri})
if self._comm.kernel:
for (target, layer_name) in zip(vector, output_names):
array = np.array(target)
image = self[layer_name].make_image(array, config) # minmax=minmax, colormap=colormap)
data_uri = self._image_to_uri(image)
self._comm.send({'class': "%s_%s_%s" % (self.name, layer_name, component), "href": data_uri})

def propagate_to(self, layer_name, inputs, batch_size=32, visualize=True):
"""
Expand All @@ -1128,11 +1131,12 @@ def propagate_to(self, layer_name, inputs, batch_size=32, visualize=True):
from ipykernel.comm import Comm
self._comm = Comm(target_name='conx_svg_control')
# Update path from input to output
for layer in self.layers: # FIXME??: update all layers for now
out = self.propagate_to(layer.name, inputs, visualize=False)
image = self[layer.name].make_image(np.array(out), self.config) # single vector, as an np.array
data_uri = self._image_to_uri(image)
self._comm.send({'class': "%s_%s" % (self.name, layer.name), "href": data_uri})
if self._comm.kernel:
for layer in self.layers: # FIXME??: update all layers for now
out = self.propagate_to(layer.name, inputs, visualize=False)
image = self[layer.name].make_image(np.array(out), self.config) # single vector, as an np.array
data_uri = self._image_to_uri(image)
self._comm.send({'class': "%s_%s" % (self.name, layer.name), "href": data_uri})
outputs = outputs[0].tolist()
return outputs

Expand Down Expand Up @@ -1190,6 +1194,27 @@ def compile(self, **kwargs):
raise Exception("layer '%s' is not listed in set_output_layer_order()" % layer.name)
else:
raise Exception("improper set_output_layer_order() names")
self._build_intermediary_models()
output_k_layers = self._get_ordered_output_layers()
input_k_layers = self._get_ordered_input_layers()
self.model = keras.models.Model(inputs=input_k_layers, outputs=output_k_layers)
kwargs['metrics'] = ['accuracy']
self.compile_options = copy.copy(kwargs)
self.model.compile(**kwargs)

def _delete_intermediary_models(self):
"""
Remove these, as they don't pickle.
"""
for layer in self.layers:
layer.k = None
layer.input_names = []
layer.model = None

def _build_intermediary_models(self):
"""
Construct the layer.k, layer.input_names, and layer.model's.
"""
sequence = topological_sort(self, self.layers)
for layer in sequence:
if layer.kind() == 'input':
Expand All @@ -1215,12 +1240,6 @@ def compile(self, **kwargs):
## get the inputs to this branch, in order:
input_ks = self._get_input_ks_in_order(layer.input_names)
layer.model = keras.models.Model(inputs=input_ks, outputs=layer.k)
output_k_layers = self._get_ordered_output_layers()
input_k_layers = self._get_ordered_input_layers()
self.model = keras.models.Model(inputs=input_k_layers, outputs=output_k_layers)
kwargs['metrics'] = ['accuracy']
self.compile_options = copy.copy(kwargs)
self.model.compile(**kwargs)

def _get_input_ks_in_order(self, layer_names):
"""
Expand Down Expand Up @@ -1602,32 +1621,53 @@ def describe_connection_to(self, layer1, layer2):
## FIXME: how to show merged layer weights?
return retval

def save(self, foldername=None):
def save(self, foldername=None, save_all=True):
"""
Save the network to a folder.
"""
if foldername is None:
foldername = "%s.conx" % self.name
if not os.path.isdir(foldername):
os.makedirs(foldername)
if self.model:
if self.model and save_all:
self.save_model(foldername)
self.save_weights(foldername)
self._delete_intermediary_models()
self.model, tmp_model = None, self.model
with open("%s/network.pickle" % foldername, "wb") as fp:
pickle.dump(self, fp)
self.model = tmp_model

@classmethod
def load(cls, foldername):
self._comm, tmp_comm = None, self._comm
self.compile_options, tmp_co = {}, self.compile_options
try:
with open("%s/network.pickle" % foldername, "wb") as fp:
pickle.dump(self, fp)
except:
raise
finally:
self.model = tmp_model
self._comm = tmp_comm
self.compile_options = tmp_co
if self.model and save_all:
self._build_intermediary_models()

## classmethod or method
def load(self, foldername=None):
"""
Load the network from a folder.
"""
with open("%s/network.pickle" % foldername, "rb") as fp:
net = pickle.load(fp)
net.load_model(foldername)
net.load_weights(foldername)
return net
if self is None or isinstance(self, str):
foldername = self
if foldername is None:
raise Exception("foldername is required")
net = Network("Temp")
net.load_model(foldername)
net.load_weights(foldername)
if os.path.isfile("%s/network.pickle" % foldername):
with open("%s/network.pickle" % foldername, "rb") as fp:
net = pickle.load(fp)
net._build_intermediary_models()
return net
else:
self.load_model(foldername)
self.load_weights(foldername)

def save_weights(self, foldername=None):
"""
Expand Down

0 comments on commit aa2e1c8

Please sign in to comment.