Skip to content

Commit

Permalink
Added select=slice to net.evaluate() and net.evaluate_and_label()
Browse files Browse the repository at this point in the history
  • Loading branch information
dsblank committed Sep 13, 2018
1 parent 648af24 commit 4e5b77a
Showing 1 changed file with 33 additions and 18 deletions.
51 changes: 33 additions & 18 deletions conx/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ def reset(self, clear=False, **overrides):

def evaluate(self, batch_size=None, show=False, show_inputs=True, show_targets=True,
kverbose=0, sample_weight=None, steps=None, tolerance=None, force=False,
max_col_width=15):
max_col_width=15, select=None):
"""
Evaluate the train and/or test sets.
Expand All @@ -879,8 +879,12 @@ def evaluate(self, batch_size=None, show=False, show_inputs=True, show_targets=T
return
if tolerance is not None:
self.tolerance = tolerance
print("%s:" % self.name)
if 0 < self.dataset._split <= 1:
if select is not None:
if isinstance(select, int):
select = (select,)
slice_select = slice(*select) if select is not None else slice(len(self.dataset))
print("%s:" % self.name) ## network name
if 0 < self.dataset._split <= 1 and select is None:
size, num_train, num_test = self.dataset._get_split_sizes()
results = self.model.evaluate(self.dataset._inputs, self.dataset._targets,
batch_size=batch_size, verbose=kverbose,
Expand Down Expand Up @@ -910,15 +914,23 @@ def evaluate(self, batch_size=None, show=False, show_inputs=True, show_targets=T
sample_weight=sample_weight, steps=steps)
for i in range(len(self.model.metrics_names)):
print(" %15s: %10s" % (self.model.metrics_names[i], self.pf(results[i])))

else: # all data, no split:
print("All Data Results:")
else: # all (or select data):
if select is None:
print("All Data Results:")
else:
print("Selected Data Results: range(%s)" % (", ".join([str(n) for n in select])))
if show:
self._evaluate_range(slice(len(self.dataset)), show_inputs, show_targets, batch_size,
self._evaluate_range(slice_select, show_inputs, show_targets, batch_size,
kverbose, sample_weight, steps, force, max_col_width)
results = self.model.evaluate(self.dataset._inputs, self.dataset._targets,
batch_size=batch_size, verbose=kverbose,
sample_weight=sample_weight, steps=steps)
if select is None:
results = self.model.evaluate(self.dataset._inputs, self.dataset._targets,
batch_size=batch_size, verbose=kverbose,
sample_weight=sample_weight, steps=steps)
else:
results = self.model.evaluate([bank[slice_select] for bank in self.dataset._inputs],
[bank[slice_select] for bank in self.dataset._targets],
batch_size=batch_size, verbose=kverbose,
sample_weight=sample_weight, steps=steps)
for i in range(len(self.model.metrics_names)):
print(" %15s: %10s" % (self.model.metrics_names[i], self.pf(results[i])))

Expand Down Expand Up @@ -1016,25 +1028,28 @@ def split_heading(name, fill=""):
print("-" * column_widths[c], end="---")
print()

def evaluate_and_label(self, batch_size=32, tolerance=None):
def evaluate_and_label(self, batch_size=32, tolerance=None,
select=None):
"""
Test the network on the dataset, and categorize the results.
"""
tolerance = tolerance if tolerance is not None else self.tolerance
length = len(self.dataset.train_targets)
if self.dataset._split == 1.0: ## special case; use entire set
if select is not None:
if isinstance(select, int):
select = (select,)
slice_select = slice(*select) if select is not None else slice(len(self.dataset))
if 0 < self.dataset._split < 1.0 and select is None: ## special case; use entire set
inputs = self.dataset._inputs
targets = self.dataset._targets
else:
## need to split; check format based on output banks:
targets = [column[:length] for column in self.dataset._targets]
inputs = [column[:length] for column in self.dataset._inputs]
inputs = [column[slice_select] for column in self.dataset._inputs]
targets = [column[slice_select] for column in self.dataset._targets]
outputs = self.model.predict(inputs, batch_size=batch_size)
correct = self.compute_correct([outputs], targets, tolerance)
categories = {}
for i in range(length):
for c, i in enumerate(range(len(self.dataset))[slice_select]):
label_i = self.dataset.labels[i] if i < len(self.dataset.labels) else "Unlabeled"
label = "%s (%s)" % (label_i, "correct" if correct[i] else "wrong")
label = "%s (%s)" % (label_i, "correct" if correct[c] else "wrong")
if not label in categories:
categories[label] = []
categories[label].append(self.dataset.inputs[i])
Expand Down

0 comments on commit 4e5b77a

Please sign in to comment.