Skip to content

Commit

Permalink
WIP: working on drawing better connection paths, see #141
Browse files Browse the repository at this point in the history
  • Loading branch information
dsblank committed Sep 9, 2018
1 parent ed9ed52 commit 3f9b0b0
Showing 1 changed file with 70 additions and 11 deletions.
81 changes: 70 additions & 11 deletions conx/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -1541,7 +1541,7 @@ def get_dataset(self, dataset_name):
>>> net.get_dataset("vmnist")
"""
self.set_dataset(Dataset.get(dataset_name))

def set_dataset(self, dataset):
"""
Set the dataset for the network.
Expand Down Expand Up @@ -3598,16 +3598,19 @@ def _get_level_ordering(self):
ordering = []
for i in range(max_level + 1): # input to output
layer_names = [layer.name for layer in self.layers if levels[layer.name] == i]
ordering.append([(name, False, None) for name in layer_names]) # (going_to/layer_name, anchor, coming_from)
ordering.append([(name, False, [x.name for x in self[name].incoming_connections])
for name in layer_names]) # (going_to/layer_name, anchor, coming_from)
## promote all output banks to last row:
for level in range(len(ordering)): # input to output
tuples = ordering[level]
for (name, anchor, none) in tuples[:]: # go through copy
index = 0
for (name, anchor, none) in tuples[:]:
if self[name].kind() == "output":
## move it to last row
## find it and remove
index = tuples.index((name, anchor, None))
ordering[-1].append(tuples.pop(index))
else:
index += 1
## insert anchor points for any in next level
## that doesn't go to a bank in this level
order_cache = {}
Expand All @@ -3616,24 +3619,80 @@ def _get_level_ordering(self):
for (name, anchor, fname) in tuples:
if anchor:
## is this in next? if not add it
next_level = [(n, hfname) for (n, anchor, hfname) in ordering[level + 1]]
if (name, None) not in next_level and (name, fname) not in next_level:
next_level = [(n, anchor) for (n, anchor, hfname) in ordering[level + 1]]
if (name, False) not in next_level: ## actual layer not in next level
ordering[level + 1].append((name, True, fname)) # add anchor point
else:
pass ## finally!
else:
## if next level doesn't contain an outgoing
## connection, add it to next level as anchor point
for layer in self[name].outgoing_connections:
next_level = [(n,fname) for (n, anchor, fname) in ordering[level + 1]]
if (layer.name, None) not in next_level:
next_level = [(n, anchor) for (n, anchor, fname) in ordering[level + 1]]
if (layer.name, False) not in next_level:
ordering[level + 1].append((layer.name, True, name)) # add anchor point
## replace level with sorted level:
lev = sorted([(self._column_order(fname if anchor else name, order_cache), name, anchor, fname)
for (name, anchor, fname) in ordering[level]])
ordering[level] = [(name, anchor, fname) for (index, name, anchor, fname) in lev]
#lev = sorted([(self._column_order(fname if anchor else name, order_cache), name, anchor, fname)
# for (name, anchor, fname) in ordering[level]])
#ordering[level] = [(name, anchor, fname) for (index, name, anchor, fname) in lev]
## wait until all assembled before ordering to use anchors:
ordering = self._optimize_ordering(ordering)
return ordering

def _optimize_ordering(self, ordering):
def perms(l):
return list(itertools.permutations(l))

def distance(xy1, xy2):
return math.sqrt((xy1[0] - xy2[0]) ** 2 + (xy1[1] - xy2[1]) ** 2)

permutations = list(itertools.product(*[perms(x) for x in ordering]))
#print("permutations:", len(permutations))
if True: #len(permutations) < 1000: ## globally minimize
## measure arrow distances for them all and find the shortest:
best = (10000000, None)
for ordering in permutations:
sum = 0.0
for level_num in range(1, len(ordering)):
level = ordering[level_num]
plevel = ordering[level_num - 1]
col1 = 0
for bank in level: # starts at level 1
end_name, anchor, start_names = bank
if anchor:
start_names = [start_names] # put in list
else:
for name in start_names:
col2 = 0
for prev_bank in ordering[level_num - 1]:
if prev_bank[0] == name:
sum += distance((col1/(len(level) + 1), 0),
(col2/(len(plevel) + 1), .1))
col2 += 1
col1 += 1
if sum < best[0]:
best = (sum, ordering)
return best[1]
else: # locally minimize, between layers:
return ordering

# ((('input', False, []),),
# (('hidden1', False, ['input']),
# ('hidden2', True, 'input'),
# ('hidden3', True, 'input'),
# ('output', True, 'input')),
# (('hidden3', True, 'hidden1'),
# ('output', True, 'hidden1'),
# ('hidden2', False, ['input', 'hidden1']),
# ('hidden3', True, 'input'),
# ('output', True, 'input')),
# (('output', True, 'hidden1'),
# ('hidden3', False, ['input', 'hidden1', 'hidden2']),
# ('output', True, 'hidden2'),
# ('output', True, 'input')),
# (('output', False, ['input', 'hidden1', 'hidden2', 'hidden3']),))


def _column_order(self, layer_name, order_cache):
"""
Get the column order of a layer_name. Note that in this
Expand Down

0 comments on commit 3f9b0b0

Please sign in to comment.