Skip to content

Commit

Permalink
added save_weights parameter to Scan
Browse files Browse the repository at this point in the history
- Related with #343 it's now possible to avoid saving model weights in `scan_object`, which might be desirable for very long runs with very large networks, due to the memory cost of keeping the weights throughout the experiment.
- fixed a small bug in `AutoParams` where choosing `network=False` resulted in 'dense' to be split into characters
- `max_param_values` is now optional in `AutoScan` and instead created the issue to handle the underlying problem properly #367
- fixed all the tests accordingly to the change in `AutoScan` arguments
  • Loading branch information
mikkokotila committed Aug 7, 2019
1 parent 55e3fb0 commit 2d69b6f
Show file tree
Hide file tree
Showing 8 changed files with 18 additions and 15 deletions.
3 changes: 2 additions & 1 deletion docs/Scan.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ Argument | Input | Description
`minimize_loss` | bool | `reduction_metric` is a loss
`disable_progress_bar` | bool | Disable live updating progress bar
`print_params` | bool | Print each permutation hyperparameters
`clear_tf_session` | bool | Clear backend session between permutations
`clear_session` | bool | Clear backend session between permutations
`save_weights` | bool | Save model weights (increases memory pressure for large models)

NOTE: `boolean_limit` will only work if its the last argument in `Scan()` and the following bracket is on a newline:

Expand Down
2 changes: 1 addition & 1 deletion talos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@
del commands, scan, model, metrics, key
del sub, keep_from_templates, template_sub, warnings

__version__ = "0.6.2"
__version__ = "0.6.3"
2 changes: 1 addition & 1 deletion talos/autom8/autoparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _automated(self, shapes='fixed'):
if self._network:
self.networks()
else:
self.params['network'] = 'dense'
self.params['network'] = ['dense']
self.last_activations()

def shapes(self, shapes='auto'):
Expand Down
8 changes: 5 additions & 3 deletions talos/autom8/autoscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ class AutoScan:

def __init__(self,
task,
max_param_values,
experiment_name):
experiment_name,
max_param_values=None):

'''Configure the `AutoScan()` experiment and then use
the property `start` in the returned class object to start
Expand Down Expand Up @@ -42,7 +42,9 @@ def start(self, x, y, **kwargs):
**kwargs)
except KeyError:
p = talos.autom8.AutoParams(task=self.task)
p.resample_params(self.max_param_values)

if self.max_param_values is not None:
p.resample_params(self.max_param_values)
params = p.params
scan_object = talos.Scan(x=x,
y=y,
Expand Down
2 changes: 1 addition & 1 deletion test/commands/test_analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_analyze(scan_object):
r = talos.Reporting(scan_object)

# read from file
list_of_files = glob.glob('./test_latest/' + '/*.csv')
list_of_files = glob.glob('./testing_latest/' + '/*.csv')

r = talos.Reporting(list_of_files[-1])

Expand Down
8 changes: 4 additions & 4 deletions test/commands/test_autom8.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,25 @@ def test_autom8():

x, y = wrangle.utils.create_synth_data('binary', 50, 10, 1)
p.losses(['binary_crossentropy'])
auto = talos.autom8.AutoScan('binary', 1, 'testinga')
auto = talos.autom8.AutoScan('binary', 'testinga', 1)
scan_object = auto.start(x, y, params=p.params)
talos.autom8.AutoPredict(scan_object, x, y, x, 'binary')

x, y = wrangle.utils.create_synth_data('multi_label', 50, 10, 4)
p.losses(['categorical_crossentropy'])
auto = talos.autom8.AutoScan('multi_label', 1, 'testingb')
auto = talos.autom8.AutoScan('multi_label', 'testingb', 1)
auto.start(x, y, params=p.params)
talos.autom8.AutoPredict(scan_object, x, y, x, 'multi_label')

x, y = wrangle.utils.create_synth_data('multi_class', 50, 10, 3)
p.losses(['sparse_categorical_crossentropy'])
auto = talos.autom8.AutoScan('multi_class', 1, 'testingc')
auto = talos.autom8.AutoScan('multi_class', 'testingc', 1)
auto.start(x, y, params=p.params)
talos.autom8.AutoPredict(scan_object, x, y, x, 'multi_class')

x, y = wrangle.utils.create_synth_data('regression', 50, 10, 1)
p.losses(['mae'])
auto = talos.autom8.AutoScan('continuous', 1, 'testingd')
auto = talos.autom8.AutoScan('continuous', 'testingd', 1)
auto.start(x, y, params=p.params)
talos.autom8.AutoPredict(scan_object, x, y, x, 'continuous')

Expand Down
4 changes: 2 additions & 2 deletions test/commands/test_latest.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def iris_model(x_train, y_train, x_val, y_val, params):

out = model.fit(x_train,
y_train,
callbacks=[talos.utils.ExperimentLogCallback('test_latest', params)],
callbacks=[talos.utils.ExperimentLogCallback('testing_latest', params)],
batch_size=params['batch_size'],
epochs=params['epochs'],
validation_data=[x_val, y_val],
Expand All @@ -44,7 +44,7 @@ def iris_model(x_train, y_train, x_val, y_val, params):
scan_object = talos.Scan(x, y,
model=iris_model,
params=p,
experiment_name='test_latest',
experiment_name='testing_latest',
round_limit=5,
reduction_method='gamify')

Expand Down
4 changes: 2 additions & 2 deletions test/commands/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def iris_model(x_train, y_train, x_val, y_val, params):
y=y,
params=p,
model=iris_model,
experiment_name='test',
experiment_name='testingq',
val_split=0.3,
random_method='uniform_mersenne',
round_limit=15,
Expand Down Expand Up @@ -101,7 +101,7 @@ def iris_model(x_train, y_train, x_val, y_val, params):
y=y,
params=p,
model=iris_model,
experiment_name="testint3",
experiment_name="testing3",
x_val=None,
y_val=None,
val_split=0.3,
Expand Down

0 comments on commit 2d69b6f

Please sign in to comment.