Skip to content

Commit

Permalink
Merge pull request #546 from autonomio/fix_recover_best_model
Browse files Browse the repository at this point in the history
fixes recovery issue mentioned in #534
  • Loading branch information
mikkokotila committed May 27, 2021
2 parents 3e8bd31 + 569f533 commit b4c823c
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 7 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
URL = 'http://autonom.io'
LICENSE = 'MIT'
DOWNLOAD_URL = 'https://github.com/autonomio/talos/'
VERSION = '1.0.1'
VERSION = '1.0.2'


try:
Expand Down
2 changes: 1 addition & 1 deletion talos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@
del commands, scan, model, metrics, key
del sub, keep_from_templates, template_sub, warnings

__version__ = "1.0.1"
__version__ = "1.0.2"
12 changes: 7 additions & 5 deletions talos/utils/recover_best_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ def recover_best_model(x_train,
y_val,
experiment_log,
input_model,
metric,
x_cross=None,
y_cross=None,
n_models=5,
Expand All @@ -15,10 +16,11 @@ def recover_best_model(x_train,
y_train | array | same as was used in the experiment
x_val | array | same as was used in the experiment
y_val | array | same as was used in the experiment
x_cross | array | data for the cross-validation or None for use x_val
y_cross | array | data for the cross-validation or None for use y_val
experiment_log | str | path to the Talos experiment log
input_model | function | model used in the experiment
metric | str | use this metric to pick evaluation candidates
x_cross | array | data for the cross-validation or None for use x_val
y_cross | array | data for the cross-validation or None for use y_val
n_models | int | number of models to cross-validate
task | str | binary, multi_class, multi_label or continuous
Expand Down Expand Up @@ -48,8 +50,8 @@ def recover_best_model(x_train,
for i in range(n_models):

# get the params for the model and train it
params = df.sort_values('val_acc', ascending=False).drop('val_acc', 1).iloc[i].to_dict()
history, model = input_model(x_train, y_train, x_val, y_val, params)
params = df.sort_values(metric, ascending=False).drop(metric, 1).iloc[i].to_dict()
_history, model = input_model(x_train, y_train, x_val, y_val, params)

# start kfold cross-validation
out = []
Expand Down Expand Up @@ -83,7 +85,7 @@ def recover_best_model(x_train,
results.append(np.mean(out))
models.append(model)

out = df.sort_values('val_acc', ascending=False).head(n_models)
out = df.sort_values(metric, ascending=False).head(n_models)
out['crossval_mean_f1score'] = results

return out, models
1 change: 1 addition & 0 deletions tests/commands/recover_best_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def iris_model(x_train, y_train, x_val, y_val, params):
y_val=y,
experiment_log=experiment_log,
input_model=iris_model,
metric='acc',
x_cross=x,
y_cross=y,
n_models=5,
Expand Down

0 comments on commit b4c823c

Please sign in to comment.