Skip to content

Commit

Permalink
Truncate when converting to images; don't show error/targets by defau…
Browse files Browse the repository at this point in the history
…lt; Network.get_weights(name); all widget pages are same length; default colormap for all
  • Loading branch information
dsblank committed Aug 11, 2017
1 parent f43c9b1 commit fa62dba
Show file tree
Hide file tree
Showing 4 changed files with 169 additions and 51 deletions.
8 changes: 4 additions & 4 deletions conx/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, name, *args, **params):
self.vshape = None
self.image_maxdim = None
self.visible = True
self.colormap = None
self.colormap = "RdGy"
self.minmax = None
self.model = None
self.decode_model = None
Expand Down Expand Up @@ -223,7 +223,7 @@ def make_image(self, vector, config={}):
minmax = config.get("minmax")
if minmax is None:
minmax = self.get_minmax(vector)
vector = self.scale_output_for_image(vector, minmax)
vector = self.scale_output_for_image(vector, minmax, truncate=True)
if len(vector.shape) == 1:
vector = vector.reshape((1, vector.shape[0]))
size = config["pixels_per_unit"]
Expand All @@ -242,12 +242,12 @@ def make_image(self, vector, config={}):
image = image.resize((new_height, new_width))
return image

def scale_output_for_image(self, vector, minmax):
def scale_output_for_image(self, vector, minmax, truncate=False):
"""
Given an activation name (or something else) and an output
vector, scale the vector.
"""
return rescale_numpy_array(vector, minmax, (0,255), 'uint8')
return rescale_numpy_array(vector, minmax, (0,255), 'uint8', truncate=truncate)

def make_dummy_vector(self):
"""
Expand Down
28 changes: 18 additions & 10 deletions conx/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,10 @@ def __init__(self, name: str, *sizes: int, **config: Any):
"arrow_width": "2",
"border_width": "2",
"border_color": "blue",
"show_targets": True,
"show_targets": False,
"show_errors": False,
"minmax": None,
"colormap": None,
"show_errors": True,
"pixels_per_unit": 1,
"pp_max_length": 20,
"pp_precision": 1,
Expand Down Expand Up @@ -1002,6 +1002,14 @@ def get_test_target(self, i):
targets.append(list(self.test_targets[c][i]))
return targets

def get_weights(self, layer_name):
"""
Get the weights from the model in an easy to read format.
"""
weights = [layer.get_weights() for layer in self.model.layers
if layer_name == layer.name][0]
return [m.tolist() for m in weights]

def propagate(self, input, batch_size=32):
"""
Propagate an input (in human API) through the network.
Expand Down Expand Up @@ -1777,7 +1785,7 @@ def get_train_targets_length(self):
else:
return self.train_targets.shape[0]

def dashboard(self, width="100%", max_height="550px", iwidth="800px"): ## FIXME: iwidth hack
def dashboard(self, width="100%", height="550px", iwidth="960px"): ## FIXME: iwidth hack
"""
Build the dashboard for Jupyter widgets. Requires running
in a notebook/jupyterlab.
Expand Down Expand Up @@ -1855,7 +1863,7 @@ def prop_one(button):
update_slider_control({"name": "value"})

net_svg = HTML(value=self.build_svg(), layout=Layout(
width=width, height="100%", max_height=max_height, overflow_x='auto',
width=width, height=height, overflow_x='auto',
justify_content="center"))
button_begin = Button(icon="fast-backward", layout=Layout(width='100%'))
button_prev = Button(icon="backward", layout=Layout(width='100%'))
Expand Down Expand Up @@ -1894,12 +1902,12 @@ def prop_one(button):

# Put them together:
control = VBox([control_select, control_slider, control_buttons], layout=Layout(width='100%'))
net_page = VBox([net_svg, control], layout=Layout(width='100%'))
graph_page = VBox()
analysis_page = VBox()
camera_page = VBox([Button(description="Turn on webcamera")])
help_page = HTML('<iframe style="width: %s" src="https://conx.readthedocs.io" width="100%%" height="%s"></frame>' % (iwidth, max_height),
layout=Layout(width="100%"))
net_page = VBox([net_svg, control], layout=Layout(width='100%', height=height))
graph_page = VBox(layout=Layout(width='100%', height=height))
analysis_page = VBox(layout=Layout(width='100%', height=height))
camera_page = VBox([Button(description="Turn on webcamera")], layout=Layout(width='100%', height=height))
help_page = HTML('<iframe style="width: %s" src="https://conx.readthedocs.io" width="100%%" height="%s"></frame>' % (iwidth, height),
layout=Layout(width="100%", height=height))
net_page.on_displayed(lambda widget: update_slider_control({"name": "value"}))
tabs = [("Network", net_page), ("Graphs", graph_page), ("Analysis", analysis_page),
("Camera", camera_page), ("Help", help_page)]
Expand Down
8 changes: 6 additions & 2 deletions conx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# Boston, MA 02110-1301 USA

import numbers
import numpy as np
from keras.utils import to_categorical

#------------------------------------------------------------------------
Expand Down Expand Up @@ -84,15 +85,18 @@ def valid_vshape(x):
# vshape must be a single int or a 2-dimensional tuple
return valid_shape(x) and (isinstance(x, numbers.Integral) or len(x) == 2)

def rescale_numpy_array(a, old_range, new_range, new_dtype):
def rescale_numpy_array(a, old_range, new_range, new_dtype, truncate=False):
"""
Given a vector, old min/max, a new min/max and a numpy type,
create a new vector scaling the old values.
"""
assert isinstance(old_range, (tuple, list)) and isinstance(new_range, (tuple, list))
old_min, old_max = old_range
if a.min() < old_min or a.max() > old_max:
raise Exception('array values are outside range %s' % (old_range,))
if truncate:
a = np.clip(a, old_min, old_max)
else:
raise Exception('array values are outside range %s' % (old_range,))
new_min, new_max = new_range
old_delta = old_max - old_min
new_delta = new_max - new_min
Expand Down

0 comments on commit fa62dba

Please sign in to comment.