Skip to content

Commit

Permalink
Now all widgets/pictures are Jupyter Lab compatible; made dynamic_pic…
Browse files Browse the repository at this point in the history
…tures optional; net.picture() can now show targets/errors
  • Loading branch information
dsblank committed Sep 11, 2018
1 parent 15a1f60 commit 26d667f
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 78 deletions.
133 changes: 60 additions & 73 deletions conx/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,6 @@
from .dataset import Dataset, VirtualDataset
import conx.networks

try:
from IPython import get_ipython
except:
get_ipython = lambda: None

#------------------------------------------------------------------------

def as_sum(item):
Expand Down Expand Up @@ -326,15 +321,12 @@ def __init__(self, name: str, *sizes: int, load_config=True, debug=False,
self.epoch_count = 0
self.history = []
self.weight_history = {}
self.update_pictures = get_ipython() is not None
self._comm = None
self.model = None
self._level_ordering = None
self.prop_from_dict = {} ## FIXME: can be multiple paths
self.keras_functions = {}
self._svg_counter = 1
self._need_to_show_headings = True
self._initialized_javascript = False
# If simple feed-forward network:
for i in range(len(sizes)):
if i > 0:
Expand Down Expand Up @@ -447,8 +439,8 @@ def __getitem__(self, layer_name):
return self.layer_dict[layer_name]

def _repr_svg_(self):
return self.to_svg(show_errors=False, show_targets=False, svg_rotate=False,
svg_scale=None)
return self.to_svg(show_errors=False, show_targets=False,
svg_rotate=False, svg_scale=None)

def __repr__(self):
return "<Network name='%s' (%s)>" % (
Expand Down Expand Up @@ -594,7 +586,7 @@ def movie(self, function, movie_name=None, start=0, stop=None, step=1,

def picture(self, inputs=None, dynamic=False, rotate=False, scale=None,
show_errors=False, show_targets=False, format="html", class_id=None,
minmax=None, **kwargs):
minmax=None, targets=None, **kwargs):
"""
Create an SVG of the network given some inputs (optional).
Expand Down Expand Up @@ -629,6 +621,9 @@ def picture(self, inputs=None, dynamic=False, rotate=False, scale=None,
print("WARNING: class_id given but ignored", file=sys.stderr)
r = random.randint(1, 1000000)
class_id = "picture-static-%s-%s" % (self.name, r)
elif not dynamic_pictures_check():
print("WARNING: use dynamic_pictures() to allow dynamic pictures",
file=sys.stderr)
orig_rotate = self.config["svg_rotate"]
orig_show_errors = self.config["show_errors"]
orig_show_targets = self.config["show_targets"]
Expand All @@ -643,7 +638,7 @@ def picture(self, inputs=None, dynamic=False, rotate=False, scale=None,
elif len(self.dataset) == 0 and inputs is not None:
self.layers[0].minmax = (minimum(inputs), maximum(inputs))
## else, leave minmax as None
svg = self.to_svg(inputs=inputs, class_id=class_id, **kwargs)
svg = self.to_svg(inputs=inputs, class_id=class_id, targets=targets, **kwargs)
self.config["svg_rotate"] = orig_rotate
self.config["show_errors"] = orig_show_errors
self.config["show_targets"] = orig_show_targets
Expand Down Expand Up @@ -1776,11 +1771,8 @@ def propagate_from(self, layer_name, input, output_layer_names=None,
outputs.append([list(x) for x in prop_model.predict(inputs)][0])
## FYI: outputs not shaped
if update_pictures:
if not self._comm:
from ipykernel.comm import Comm
self._comm = Comm(target_name='conx_svg_control')
## Update from start to rest of graph
if self._comm.kernel:
if dynamic_pictures_check():
## viz this layer:
if self[layer_name].visible:
image = self[layer_name].make_image(inputs, config=self.config)
Expand All @@ -1789,7 +1781,7 @@ def propagate_from(self, layer_name, input, output_layer_names=None,
if self.config["svg_rotate"]:
class_id_name += "-rotated"
if self.debug: print("propagate_from 1: class_id_name:", class_id_name)
self._comm.send({'class': class_id_name, "xlink:href": data_uri})
dynamic_pictures_send({'class': class_id_name, "xlink:href": data_uri})
for output_layer_name in output_layer_names:
path = find_path(self, layer_name, output_layer_name)
if path is not None:
Expand All @@ -1808,7 +1800,7 @@ def propagate_from(self, layer_name, input, output_layer_names=None,
if self.config["svg_rotate"]:
class_id_name += "-rotated"
if self.debug: print("propagate_from 2: class_id_name:", class_id_name)
self._comm.send({'class': class_id_name, "xlink:href": data_uri})
dynamic_pictures_send({'class': class_id_name, "xlink:href": data_uri})
if sequence:
if isinstance(outputs, list):
outputs = [bank.tolist() for bank in outputs]
Expand Down Expand Up @@ -1837,7 +1829,7 @@ def display_component(self, vector, component, class_id=None, **opts):
config = copy.copy(self.config)
config.update(opts)
output_names = self.output_bank_order
if self._comm.kernel:
if dynamic_pictures_check():
for (target, layer_name) in zip(vector, output_names):
array = np.array(target)
if component == "targets":
Expand All @@ -1851,7 +1843,7 @@ def display_component(self, vector, component, class_id=None, **opts):
else:
class_id_name = "%s_%s" % (class_id, layer_name)
if self.debug: print("display_component: sending to class_id:", class_id_name + "_" + component)
self._comm.send({'class': class_id_name + "_" + component,
dynamic_pictures_send({'class': class_id_name + "_" + component,
"xlink:href": data_uri})

def propagate_to(self, layer_name, inputs, batch_size=32, class_id=None,
Expand Down Expand Up @@ -1890,10 +1882,7 @@ def propagate_to(self, layer_name, inputs, batch_size=32, class_id=None,
outputs = self[layer_name].model.predict(vector, batch_size=batch_size)
## output shaped below:
if update_pictures:
if not self._comm:
from ipykernel.comm import Comm
self._comm = Comm(target_name='conx_svg_control')
if self._comm.kernel:
if dynamic_pictures_check():
if update_path: ## update the whole path, from all inputs to the layer_name, if a path
## don't repeat any updates, so keep track of what you have done:
updated = set([])
Expand All @@ -1908,7 +1897,7 @@ def propagate_to(self, layer_name, inputs, batch_size=32, class_id=None,
if self.config["svg_rotate"]:
class_id_name += "-rotated"
if self.debug: print("propagate_to 1: sending to class_id_name:", class_id_name)
self._comm.send({'class': class_id_name, "xlink:href": data_uri})
dynamic_pictures_send({'class': class_id_name, "xlink:href": data_uri})
updated.add(input_layer_name)
path = find_path(self, input_layer_name, layer_name)
if path is not None:
Expand All @@ -1924,7 +1913,7 @@ def propagate_to(self, layer_name, inputs, batch_size=32, class_id=None,
if self.config["svg_rotate"]:
class_id_name += "-rotated"
if self.debug: print("propagate_to 2: sending to class_id_name:", class_id_name)
self._comm.send({'class': class_id_name, "xlink:href": data_uri})
dynamic_pictures_send({'class': class_id_name, "xlink:href": data_uri})
updated.add(layer.name)
else: # not the whole path, just to the layer_name
image = self._propagate_to_image(layer_name, inputs, sequence=sequence)
Expand All @@ -1936,7 +1925,7 @@ def propagate_to(self, layer_name, inputs, batch_size=32, class_id=None,
if self.config["svg_rotate"]:
class_id_name += "-rotated"
if self.debug: print("propagate_to 3: sending to class_id_name:", class_id_name)
self._comm.send({'class': class_id_name, "xlink:href": data_uri})
dynamic_pictures_send({'class': class_id_name, "xlink:href": data_uri})
## Shape the outputs:
if sequence:
if isinstance(outputs, list):
Expand Down Expand Up @@ -2020,11 +2009,8 @@ def propagate_to_features(self, layer_name, inputs, cols=5, resize=None, scale=1
if scale != 1.0:
image = image.resize((int(image.size[0] * scale), int(image.size[1] * scale)))
data_uri = self._image_to_uri(image)
if not self._comm:
from ipykernel.comm import Comm
self._comm = Comm(target_name='conx_svg_control')
if self._comm.kernel:
self._comm.send({'class': "%s_%s_feature%s" % (self.name, layer_name, i), "src": data_uri})
if dynamic_pictures_check():
dynamic_pictures_send({'class': "%s_%s_feature%s" % (self.name, layer_name, i), "src": data_uri})
self[layer_name].feature = orig_feature
else:
raise Exception("layer '%s' has no features" % layer_name)
Expand Down Expand Up @@ -2958,7 +2944,7 @@ def vshape(self, layer_name):
vshape = layer.get_output_shape()
return vshape

def _pre_process_struct(self, inputs, config, ordering):
def _pre_process_struct(self, inputs, config, ordering, targets):
"""
Determine sizes and pre-compute images.
"""
Expand All @@ -2970,6 +2956,16 @@ def _pre_process_struct(self, inputs, config, ordering):
max_height = 0
images = {}
image_dims = {}
## if targets, then need to propagate for error:
if targets is not None and self.model is not None:
outputs = self.propagate(inputs)
if len(self.output_bank_order) == 1:
targets = [targets]
errors = (np.array(outputs) - np.array(targets)).tolist()
else:
errors = []
for bank in range(len(self.output_bank_order)):
errors.append((np.array(outputs[bank]) - np.array(targets[bank])).tolist())
#######################################################################
## For each level:
#######################################################################
Expand Down Expand Up @@ -3006,6 +3002,7 @@ def _pre_process_struct(self, inputs, config, ordering):
if inputs is not None:
v = inputs
elif len(self.dataset.inputs) > 0 and not isinstance(self.dataset, VirtualDataset):
## don't change cache if virtual... could take some time to rebuild cache
v = self.dataset.inputs[0]
else:
if self.num_input_layers > 1:
Expand All @@ -3016,20 +3013,37 @@ def _pre_process_struct(self, inputs, config, ordering):
in_layer = [layer for layer in self.layers if layer.kind() == "input"][0]
v = in_layer.make_dummy_vector()
if self[layer_name].model:
orig_svg_rotate = self.config["svg_rotate"]
self.config["svg_rotate"] = config["svg_rotate"]
try:
orig_svg_rotate = self.config["svg_rotate"]
self.config["svg_rotate"] = config["svg_rotate"]
image = self._propagate_to_image(layer_name, v)
self.config["svg_rotate"] = orig_svg_rotate
except:
image = self[layer_name].make_image(np.array(self[layer_name].make_dummy_vector()), config=config)
self.config["svg_rotate"] = orig_svg_rotate
else:
if not warned:
print("WARNING: network is uncompiled; activations cannot be visualized", file=sys.stderr)
warned = True
image = self[layer_name].make_image(np.array(self[layer_name].make_dummy_vector()), config=config)
(width, height) = image.size
images[layer_name] = image ## little image
if self[layer_name].kind() == "output":
if self.model is not None and targets is not None:
## Target image, targets set above:
target_colormap = self[layer_name].colormap
target_bank = targets[self.output_bank_order.index(layer_name)]
target_array = np.array(target_bank)
target_image = self[layer_name].make_image(target_array, target_colormap, config)
## Error image, error set above:
error_colormap = get_error_colormap()
error_bank = errors[self.output_bank_order.index(layer_name)]
error_array = np.array(error_bank)
error_image = self[layer_name].make_image(error_array, error_colormap, config)
images[layer_name + "_errors"] = error_image
images[layer_name + "_targets"] = target_image
else:
images[layer_name + "_errors"] = image
images[layer_name + "_targets"] = image
### Layer settings:
if self[layer_name].image_maxdim:
image_maxdim = self[layer_name].image_maxdim
Expand Down Expand Up @@ -3090,9 +3104,9 @@ def _find_spacing(self, row, ordering, max_width):
"""
return max_width / (len(ordering[row]) + 1)

def build_struct(self, inputs, class_id, config):
def build_struct(self, inputs, class_id, config, targets):
ordering = list(reversed(self._get_level_ordering())) # list of names per level, input to output
max_width, max_height, row_heights, images, image_dims = self._pre_process_struct(inputs, config, ordering)
max_width, max_height, row_heights, images, image_dims = self._pre_process_struct(inputs, config, ordering, targets)
### Now that we know the dimensions:
struct = []
cheight = config["border_top"] # top border
Expand All @@ -3104,7 +3118,7 @@ def build_struct(self, inputs, class_id, config):
# draw the row of targets:
cwidth = 0
for (layer_name, anchor, fname) in ordering[0]: ## no anchors in output
image = images[layer_name]
image = images[layer_name + "_targets"]
(width, height) = image_dims[layer_name]
cwidth += (spacing - width/2)
struct.append(["image_svg", {"name": layer_name + "_targets",
Expand Down Expand Up @@ -3139,7 +3153,7 @@ def build_struct(self, inputs, class_id, config):
# draw the row of errors:
cwidth = 0
for (layer_name, anchor, fname) in ordering[0]: # no anchors in output
image = images[layer_name]
image = images[layer_name + "_errors"]
(width, height) = image_dims[layer_name]
cwidth += (spacing - (width/2))
struct.append(["image_svg", {"name": layer_name + "_errors",
Expand Down Expand Up @@ -3441,8 +3455,9 @@ def build_struct(self, inputs, class_id, config):
cheight += config["border_bottom"]
### DONE!
## Draw live/static sign
if (class_id is None):
label = "*" # lightning bold, dynamic image
if (class_id is None and dynamic_pictures_check()):
# dynamic image:
label = "*"
if config["svg_rotate"]:
struct.append(["label_svg", {"x": 10,
"y": cheight - 10,
Expand Down Expand Up @@ -3521,33 +3536,7 @@ def build_struct(self, inputs, class_id, config):
}])
return struct

def _initialize_javascript(self):
from IPython.display import Javascript, display
js = """
require(['base/js/namespace'], function(Jupyter) {
Jupyter.notebook.kernel.comm_manager.register_target('conx_svg_control', function(comm, msg) {
comm.on_msg(function(msg) {
console.log("received!")
console.log(msg)
var data = msg["content"]["data"];
var images = document.getElementsByClassName(data["class"]);
for (var i = 0; i < images.length; i++) {
if (data["xlink:href"]) {
var xlinkns="http://www.w3.org/1999/xlink";
images[i].setAttributeNS(xlinkns, "href", data["xlink:href"]);
}
if (data["src"]) {
images[i].setAttributeNS(null, "src", data["src"]);
}
}
});
});
});
"""
display(Javascript(js))
self._initialized_javascript = True

def to_svg(self, inputs=None, class_id=None, **kwargs):
def to_svg(self, inputs=None, class_id=None, targets=None, **kwargs):
"""
opts - temporary override of config
Expand All @@ -3567,7 +3556,7 @@ def to_svg(self, inputs=None, class_id=None, **kwargs):
# defaults:
config = copy.copy(self.config)
config.update(kwargs)
struct = self.build_struct(inputs, class_id, config)
struct = self.build_struct(inputs, class_id, config, targets)
### Define the SVG strings:
image_svg = """<rect x="{{rx}}" y="{{ry}}" width="{{rw}}" height="{{rh}}" style="fill:none;stroke:{border_color};stroke-width:{border_width}"/><image id="{netname}_{{name}}_{{svg_counter}}" class="{netname}_{{name}}" x="{{x}}" y="{{y}}" height="{{height}}" width="{{width}}" preserveAspectRatio="none" image-rendering="optimizeSpeed" xlink:href="{{image}}"><title>{{tooltip}}</title></image>""".format(
**{
Expand Down Expand Up @@ -3621,8 +3610,6 @@ def to_svg(self, inputs=None, class_id=None, **kwargs):
t = templates[template_name]
svg += t.format(**dict)
svg += """</svg></g></svg>"""
if (not self._initialized_javascript and get_ipython()):
self._initialize_javascript()
return svg

def _render_curve(self, start, struct, end_svg):
Expand All @@ -3639,7 +3626,7 @@ def _render_curve(self, start, struct, end_svg):
dict["drawn"] = True
end_html = end_svg.format(**start)
if len(points) == 2: ## direct, no anchors, no curve:
svg_html = """<path d="M {sx} {sy} L {ex} {ey}" """.format(**{
svg_html = """<path d="M {sx} {sy} L {ex} {ey} """.format(**{
"sx": points[-1][0],
"sy": points[-1][1],
"ex": points[0][0],
Expand Down

0 comments on commit 26d667f

Please sign in to comment.