Skip to content

Commit

Permalink
Better warning system for networks an datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
dsblank committed Sep 13, 2018
1 parent 6ee7221 commit 9cc52b6
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 24 deletions.
36 changes: 24 additions & 12 deletions conx/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,8 @@ def __init__(self,
Defaults inputs and targets are given as a list of tuple shapes,
one shape per bank.
"""
self.warnings = {}
self.warn_categories = {}
self.network = network
self.name = name
self.description = description
Expand Down Expand Up @@ -740,7 +742,6 @@ def clear(self):
"""
Remove all of the inputs/targets.
"""
self._warning_set = False
self._inputs = []
self._targets = []
self._labels = []
Expand Down Expand Up @@ -1194,7 +1195,7 @@ def _cache_values_actual(self):
self._target_shapes = [shape(x[0]) for x in self._targets]
# Final checks:
if len(self.inputs) != len(self.targets):
print("WARNING: inputs/targets lengths do not match", file=sys.stderr)
self.warn_once("WARNING: inputs/targets lengths do not match")
if self.network:
self.network.test_dataset_ranges()
self._verify_network_dataset_match()
Expand All @@ -1208,24 +1209,35 @@ def _verify_network_dataset_match(self):
## check to see if number of input banks match
if len(self.network.input_bank_order) != self._num_input_banks():
warning = True
print("WARNING: number of dataset input banks != network input banks in network '%s'" % self.network.name,
file=sys.stderr)
self.warn_once("WARNING: number of dataset input banks != network input banks in network '%s'" % self.network.name, "VERIFY")
if len(self.inputs) > 0 and not isinstance(self, VirtualDataset):
try:
self.network.propagate(self.inputs[0])
except:
warning = True
print("WARNING: dataset does not yet work with network '%s'" % self.network.name,
file=sys.stderr)
self.warn_once("WARNING: dataset does not yet work with network '%s'" % self.network.name, "VERIFY")
## check to see if number of output banks match
if len(self.network.output_bank_order) != self._num_target_banks():
warning = True
print("WARNING: number of dataset target banks != network output banks in network '%s'" % self.network.name,
file=sys.stderr)
if self._warning_set and not warning:
print("INFO: dataset now works with network '%s'" % self.network.name,
file=sys.stderr)
self._warning_set = warning
self.warn_once("WARNING: number of dataset target banks != network output banks in network '%s'" % self.network.name, "VERIFY")
if not warning and self.warned("VERIFY"):
self.warn_once("INFO: dataset now works with network '%s'" % self.network.name)

def warned(self, category):
"""
Has the user been warned about this category of error before?
"""
return category in self.warn_categories

def warn_once(self, message, category=None):
"""
Warning the user just once about this particular message.
"""
if category:
self.warn_categories[category] = True
if message not in self.warnings:
print(message, file=sys.stderr)
self.warnings[message] = True

def set_targets_from_inputs(self, f=None, input_bank=0, target_bank=0):
"""
Expand Down
27 changes: 15 additions & 12 deletions conx/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ class Network():

def __init__(self, name: str, *sizes: int, load_config=True, debug=False,
build_propagate_from_models=True, **config: Any):
self.warnings = {}
self.NETWORKS = {name: function for (name, function) in
inspect.getmembers(conx.networks, inspect.isfunction)}
if not isinstance(name, str):
Expand Down Expand Up @@ -1224,8 +1225,8 @@ def test_dataset_ranges(self):
return # nothing to test
for index in range(len(self.dataset._targets)):
if len(self.dataset._targets[index].shape) > 2:
print("WARNING: network '%s' target bank #%s has a multi-dimensional shape; is this correct?" %
(self.name, index), file=sys.stderr)
self.warn_once("WARNING: network '%s' target bank #%s has a multi-dimensional shape; is this correct?" %
(self.name, index))
for index in range(len(self.output_bank_order)):
layer_name = self.output_bank_order[index]
if self[layer_name].activation == "linear":
Expand All @@ -1234,12 +1235,12 @@ def test_dataset_ranges(self):
# test dataset min to see if in range of act output:
if self[layer_name].activation is not None:
if not (lmin <= self.dataset._targets_range[index][0] <= lmax):
print("WARNING: output bank '%s' has activation function, '%s', that is not consistent with minimum value of targets" %
(layer_name, self[layer_name].activation), file=sys.stderr)
self.warn_once("WARNING: output bank '%s' has activation function, '%s', that is not consistent with minimum value of targets" %
(layer_name, self[layer_name].activation))
# test dataset min to see if in range of act output:
if not (lmin <= self.dataset._targets_range[index][1] <= lmax):
print("WARNING: output bank '%s' has activation function, '%s', that is not consistent with maximum value of targets" %
(layer_name, self[layer_name].activation), file=sys.stderr)
self.warn_once("WARNING: output bank '%s' has activation function, '%s', that is not consistent with maximum value of targets" %
(layer_name, self[layer_name].activation))

def train(self, epochs=1, accuracy=None, error=None, batch_size=32,
report_rate=1, verbose=1, kverbose=0, shuffle=True, tolerance=None,
Expand Down Expand Up @@ -2229,8 +2230,8 @@ def plot_layer_weights(self, layer_name, units='all', wrange=None, wmin=None, wm
aspect_ratio = max(rows,cols)/min(rows,cols)
#print("aspect_ratio is", aspect_ratio)
if aspect_ratio > 50: # threshold may need further refinement
print("WARNING: using a visual display shape of (%d, %d), which may be hard to see."
% (rows, cols), file=sys.stderr)
self.warn_once("WARNING: using a visual display shape of (%d, %d), which may be hard to see."
% (rows, cols))
print("You can use vshape=(rows, cols) to specify a different display shape.")
if not isinstance(wmin, (numbers.Number, type(None))):
raise Exception("wmin: expected a number or None but got %s" % (wmin,))
Expand Down Expand Up @@ -2970,7 +2971,6 @@ def _pre_process_struct(self, inputs, config, ordering, targets):
"""
Determine sizes and pre-compute images.
"""
warned = False
### find max_width, image_dims, and row_height
# Go through and build images, compute max_width:
row_heights = []
Expand Down Expand Up @@ -3043,9 +3043,7 @@ def _pre_process_struct(self, inputs, config, ordering, targets):
image = self[layer_name].make_image(np.array(self[layer_name].make_dummy_vector()), config=config)
self.config["svg_rotate"] = orig_svg_rotate
else:
if not warned:
print("WARNING: network is uncompiled; activations cannot be visualized", file=sys.stderr)
warned = True
self.warn_once("WARNING: network is uncompiled; activations cannot be visualized")
image = self[layer_name].make_image(np.array(self[layer_name].make_dummy_vector()), config=config)
(width, height) = image.size
images[layer_name] = image ## little image
Expand Down Expand Up @@ -4242,6 +4240,11 @@ def update_layer_from_config(self, layer):
for item in self.config["config_layers"][layer.name]:
setattr(layer, item, self.config["config_layers"][layer.name][item])

def warn_once(self, message):
if message not in self.warnings:
print(message, file=sys.stderr)
self.warnings[message] = True

class _InterruptHandler():
"""
Class for handling interrupts so that state is not left
Expand Down

0 comments on commit 9cc52b6

Please sign in to comment.