Skip to content

Commit

Permalink
Abstract progress_bar; show totals in evaluate()
Browse files Browse the repository at this point in the history
  • Loading branch information
dsblank committed Sep 14, 2018
1 parent c9809d3 commit 6d4a3ed
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
3 changes: 3 additions & 0 deletions conx/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,6 +902,7 @@ def evaluate(self, batch_size=None, show=False, show_inputs=True, show_targets=T
sample_weight=sample_weight, steps=steps)
for i in range(len(self.model.metrics_names)):
print(" %15s: %10s" % (self.model.metrics_names[i], self.pf(results[i])))
print(" %15s: %10s" % ("Total", num_train))
print()
print("Testing Data Results:")
if show:
Expand All @@ -915,6 +916,7 @@ def evaluate(self, batch_size=None, show=False, show_inputs=True, show_targets=T
sample_weight=sample_weight, steps=steps)
for i in range(len(self.model.metrics_names)):
print(" %15s: %10s" % (self.model.metrics_names[i], self.pf(results[i])))
print(" %15s: %10s" % ("Total", num_test))
else: # all (or select data):
if select is None:
print("All Data Results:")
Expand All @@ -934,6 +936,7 @@ def evaluate(self, batch_size=None, show=False, show_inputs=True, show_targets=T
sample_weight=sample_weight, steps=steps)
for i in range(len(self.model.metrics_names)):
print(" %15s: %10s" % (self.model.metrics_names[i], self.pf(results[i])))
print(" %15s: %10s" % ("Total", len(self.dataset)))

def _evaluate_range(self, slice, show_inputs, show_targets, batch_size,
kverbose, sample_weight, steps, force, max_col_width):
Expand Down
27 changes: 24 additions & 3 deletions conx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,34 @@
get_ipython = lambda: None

#------------------------------------------------------------------------
# configuration constants
# configuration settings

AVAILABLE_COLORMAPS = sorted(list(plt.cm.cmap_d.keys()))
CURRENT_COLORMAP = "seismic_r"
ERROR_COLORMAP = "seismic_r"
_PROGRESS_BAR = 'standard'

array = np.array

def progress_bar(*args, **kwargs):
if _PROGRESS_BAR is None:
return items
elif _PROGRESS_BAR == "notebook":
tqdm.tqdm_notebook(*args, **kwargs)
elif _PROGRESS_BAR == "standard":
tqdm.tqdm(*args, **kwargs)
else:
return items

def set_progress_bar(mode):
if mode in [None, 'notebook', 'standard']:
_PROGRESS_BAR = mode
else:
raise Exception("no such progress mode: use None, 'notebook', or 'standard'")

def get_progress_bar():
return _PROGRESS_BAR

def set_colormap(s):
"""
Set the global colormap for displaying all network activations.
Expand Down Expand Up @@ -537,12 +557,13 @@ def download(url, directory="./", force=False, unzip=True, filename=None,
response = requests.get(url, stream=True)
total_length = response.headers.get('content-length')
if total_length:
bar = tqdm.tqdm_notebook(total=int(total_length))
bar = progress_bar(total=int(total_length))
with open(file_path, 'wb') as f:
for data in response.iter_content(chunk_size=4096):
f.write(data)
if total_length:
bar.update(4096)
if bar:
bar.update(4096)
print("Done!")
else:
print("Using cached %s as '%s'." % (url, file_path))
Expand Down

0 comments on commit 6d4a3ed

Please sign in to comment.