Skip to content

Commit

Permalink
Completed redrawing connection arrows, fixes #141
Browse files Browse the repository at this point in the history
  • Loading branch information
dsblank committed Sep 10, 2018
1 parent 3f9b0b0 commit 01c4282
Showing 1 changed file with 61 additions and 63 deletions.
124 changes: 61 additions & 63 deletions conx/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def __init__(self, name: str, *sizes: int, load_config=True, debug=False,
self.update_pictures = get_ipython() is not None
self._comm = None
self.model = None
self._level_ordering = None
self.prop_from_dict = {} ## FIXME: can be multiple paths
self.keras_functions = {}
self._svg_counter = 1
Expand Down Expand Up @@ -1597,6 +1598,7 @@ def set_activation(self, layer_name, activation):
filename = tf.name
self.model.save(filename)
self.model = load_model(filename)
self._level_ordering = None
else:
raise Exception("can't change activation until after compile")

Expand Down Expand Up @@ -2707,6 +2709,7 @@ def build_model(self, starting_layers=None):
output_k_layers = self._get_output_ks_in_order()
input_k_layers = self._get_input_ks_in_order(self.input_bank_order)
self.model = keras.models.Model(inputs=input_k_layers, outputs=output_k_layers)
self._level_ordering = None
for layer in self.layers:
layer.keras_layer = self._find_keras_layer(layer.name)

Expand Down Expand Up @@ -2753,6 +2756,7 @@ def delete_layer(self, layer_name):
layer = self[layer_name]
self._delete_layer_from_connections(layer)
self.model = None
self._level_ordering = None

def _delete_layer_from_connections(self, layer):
## Remove layer.outgoing_connections to deleted layer:
Expand Down Expand Up @@ -3587,6 +3591,8 @@ def _get_level_ordering(self):
If anchor is True, this is just an anchor point.
"""
if self._level_ordering is not None:
return self._level_ordering
## First, get a level for all layers:
levels = {}
for layer in topological_sort(self, self.layers):
Expand Down Expand Up @@ -3631,12 +3637,7 @@ def _get_level_ordering(self):
next_level = [(n, anchor) for (n, anchor, fname) in ordering[level + 1]]
if (layer.name, False) not in next_level:
ordering[level + 1].append((layer.name, True, name)) # add anchor point
## replace level with sorted level:
#lev = sorted([(self._column_order(fname if anchor else name, order_cache), name, anchor, fname)
# for (name, anchor, fname) in ordering[level]])
#ordering[level] = [(name, anchor, fname) for (index, name, anchor, fname) in lev]
## wait until all assembled before ordering to use anchors:
ordering = self._optimize_ordering(ordering)
self._level_ordering = ordering = self._optimize_ordering(ordering)
return ordering

def _optimize_ordering(self, ordering):
Expand All @@ -3646,77 +3647,73 @@ def perms(l):
def distance(xy1, xy2):
return math.sqrt((xy1[0] - xy2[0]) ** 2 + (xy1[1] - xy2[1]) ** 2)

permutations = list(itertools.product(*[perms(x) for x in ordering]))
#print("permutations:", len(permutations))
if True: #len(permutations) < 1000: ## globally minimize
def find_start(cend, canchor, name, plevel):
"""
Return position and weight of link from cend/name to
col in previous level.
"""
col = 1
for bank in plevel:
pend, panchor, pstart_names = bank
if (name == pend):
if (not panchor and not canchor):
weight = 10.0
else:
weight = 1.0
return col, weight
elif cend == pend and name == pstart_names:
return col, 5.0
col += 1
raise Exception("connecting layer not found!")

## First level needs to be in bank_order, and cannot permutate:
ordering[0] = [(bank_name, False, []) for bank_name in self.input_bank_order]
permutations = [ordering[0]] + list(itertools.product(*[perms(x) for x in ordering[1:]]))
if len(permutations) < 70000: ## globally minimize
## measure arrow distances for them all and find the shortest:
best = (10000000, None)
best = (10000000, None, None)
for ordering in permutations:
sum = 0.0
for level_num in range(1, len(ordering)):
level = ordering[level_num]
plevel = ordering[level_num - 1]
col1 = 0
col1 = 1
for bank in level: # starts at level 1
end_name, anchor, start_names = bank
if anchor:
start_names = [start_names] # put in list
else:
for name in start_names:
col2 = 0
for prev_bank in ordering[level_num - 1]:
if prev_bank[0] == name:
sum += distance((col1/(len(level) + 1), 0),
(col2/(len(plevel) + 1), .1))
col2 += 1
cend, canchor, cstart_names = bank
if canchor:
cstart_names = [cstart_names] # put in list
for name in cstart_names:
col2, weight = find_start(cend, canchor, name, plevel)
dist = distance((col1/(len(level) + 1), 0),
(col2/(len(plevel) + 1), .1)) * weight
sum += dist
col1 += 1
if sum < best[0]:
best = (sum, ordering)
return best[1]
else: # locally minimize, between layers:
del permutations
for level_num in range(1, len(ordering)):
best = (10000000, None, None)
plevel = ordering[level_num - 1]
for level in itertools.permutations(ordering[level_num]):
sum = 0.0
col1 = 1
for bank in level: # starts at level 1
cend, canchor, cstart_names = bank
if canchor:
cstart_names = [cstart_names] # put in list
for name in cstart_names:
col2, weight = find_start(cend, canchor, name, plevel)
dist = distance((col1/(len(level) + 1), 0),
(col2/(len(plevel) + 1), .1)) * weight
sum += dist
col1 += 1
if sum < best[0]:
best = (sum, level)
ordering[level_num] = best[1]
return ordering

# ((('input', False, []),),
# (('hidden1', False, ['input']),
# ('hidden2', True, 'input'),
# ('hidden3', True, 'input'),
# ('output', True, 'input')),
# (('hidden3', True, 'hidden1'),
# ('output', True, 'hidden1'),
# ('hidden2', False, ['input', 'hidden1']),
# ('hidden3', True, 'input'),
# ('output', True, 'input')),
# (('output', True, 'hidden1'),
# ('hidden3', False, ['input', 'hidden1', 'hidden2']),
# ('output', True, 'hidden2'),
# ('output', True, 'input')),
# (('output', False, ['input', 'hidden1', 'hidden2', 'hidden3']),))


def _column_order(self, layer_name, order_cache):
"""
Get the column order of a layer_name. Note that in this
version, the path grows on each split, and never shrinks.
"""
## special case to get started:
if layer_name in self.input_bank_order:
order_cache[layer_name] = [self.input_bank_order.index(layer_name)]
## Get path to this node:
path = order_cache[layer_name]
## Put next layer in cache:
if len(self[layer_name].outgoing_connections) > 1: ## split!
count = 0
for layer in self[layer_name].outgoing_connections:
order_cache[layer.name] = path + [count]
count += 1
elif len(self[layer_name].outgoing_connections) == 0: ## output layer
pass
else:
## just one output, no split:
order_cache[self[layer_name].outgoing_connections[0].name] = path
## should we worry about merges at all?
return order_cache[layer_name]

def describe_connection_to(self, layer1, layer2):
"""
Returns a textual description of the weights for the SVG tooltip.
Expand Down Expand Up @@ -3798,6 +3795,7 @@ def load_model(self, dir=None, filename=None):
if filename is None:
filename = "model.h5"
self.model = load_model(os.path.join(dir, filename))
self._level_ordering = None
if self.compile_options:
self.reset()

Expand Down

0 comments on commit 01c4282

Please sign in to comment.