Skip to content

Commit

Permalink
Issue warnings when altering a virtual dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
dsblank committed Sep 13, 2018
1 parent d1825a4 commit 6ee7221
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 6 deletions.
24 changes: 24 additions & 0 deletions conx/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,6 +1800,13 @@ def __len__(self):
else:
return size

def reshape(self, *args, **kwargs):
print("WARNING: applying to virtual dataset, this cache only!",
file=sys.stderr)
super().reshape(*args, **kwargs)

## FIXME: any other virtual functions that should get a datavector warning?

class VirtualDataset(Dataset):
"""
Create a virtual dataset. VirtualData set takes:
Expand Down Expand Up @@ -2075,6 +2082,23 @@ def get_validation_generator(self, batch_size):
## split can be 1.0 (use all), or > 0 and < 1
return DatasetGenerator(self, validation_set=True)

def set_targets_from_labels(self, *args, **kwargs):
print("WARNING: applying to virtual dataset, this cache only!",
file=sys.stderr)
super().set_targets_from_labels(*args, **kwargs)

def set_targets_from_inputs(self, *args, **kwargs):
print("WARNING: applying to virtual dataset, this cache only!",
file=sys.stderr)
super().set_targets_from_inputs(*args, **kwargs)

def set_inputs_from_targets(self, *args, **kwargs):
print("WARNING: applying to virtual dataset, this cache only!",
file=sys.stderr)
super().set_inputs_from_targets(*args, **kwargs)

## FIXME: any other virtual functions that should get a dataset warning?

class DatasetGenerator(keras.utils.Sequence):
"""
DatasetGenerator takes a VirtualDataset and can be trained
Expand Down
12 changes: 6 additions & 6 deletions conx/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ def evaluate(self, batch_size=None, show=False, show_inputs=True, show_targets=T
If show is True, then it will show details for each
training/test pair, the amount of detail then determined by
show_inputs, and show_outputs.
show_inputs, and show_targets.
If force is True, then it will show all patterns, even if there are many.
Expand Down Expand Up @@ -1055,7 +1055,7 @@ def evaluate_and_label(self, batch_size=32, tolerance=None):
categories[label] = []
categories[label].append(self.dataset.inputs[i])
return sorted(categories.items())

def compute_correct(self, outputs, targets, tolerance=None):
"""
Both are np.arrays. Return [True, ...].
Expand Down Expand Up @@ -4093,11 +4093,11 @@ def pf(self, vector, max_line_width=79, **opts):
formatter={'float_kind': precision.format,
'int_kind': precision.format},
separator=", ",
max_line_width=max_line_width).replace("\n", "")[:max_line_width]
if len(retval) == max_line_width:
return retval + "..."
else:
max_line_width=max_line_width).replace("\n", "")
if len(retval) <= max_line_width:
return retval
else:
return retval[:max_line_width - 3] + "..."

def set_weights(self, weights, layer_name=None):
"""
Expand Down

0 comments on commit 6ee7221

Please sign in to comment.