Skip to content

Commit

Permalink
fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikko Kotila committed Apr 23, 2022
1 parent 864275a commit 73c6202
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 7 deletions.
16 changes: 15 additions & 1 deletion talos/templates/params.py
@@ -1,4 +1,4 @@
def titanic():
def titanic(debug=False):

from tensorflow.keras.optimizers import Adam, Nadam

Expand All @@ -15,6 +15,20 @@ def titanic():
'activation': ['relu', 'elu'],
'last_activation': ['sigmoid']}

if debug:

p = {'lr': [0.1, 0.2],
'first_neuron': [4, 8],
'batch_size': [20, 30],
'dropout': [0.2, 0.3],
'optimizer': [Adam(), Nadam()],
'epochs': [50, 100],
'losses': ['logcosh', 'binary_crossentropy'],
'shapes': ['brick', 'triangle', 0.2],
'hidden_layers': [0, 1],
'activation': ['relu', 'elu'],
'last_activation': ['sigmoid']}

return p


Expand Down
8 changes: 4 additions & 4 deletions talos/templates/pipelines.py
Expand Up @@ -40,13 +40,13 @@ def iris(round_limit=2, random_method='uniform_mersenne'):
return scan_object


def titanic(round_limit=2, random_method='uniform_mersenne'):
def titanic(round_limit=2, random_method='uniform_mersenne', debug=False):

'''Performs a Scan with Iris dataset and simple dense net'''
import talos as ta
scan_object = ta.Scan(ta.templates.datasets.titanic()[0][:50],
ta.templates.datasets.titanic()[1][:50],
ta.templates.params.titanic(),
scan_object = ta.Scan(ta.templates.datasets.titanic()[0],
ta.templates.datasets.titanic()[1],
ta.templates.params.titanic(debug),
ta.templates.models.titanic,
'test',
random_method=random_method,
Expand Down
2 changes: 1 addition & 1 deletion tests/commands/test_predict.py
Expand Up @@ -20,7 +20,7 @@ def test_predict():
predict = talos.Predict(scan_object)

_preds = predict.predict(x, 'val_acc', False)
_preds = predict.predict_classes(x, 'val_acc', False)
_preds = predict.predict_classes(x, 'val_acc', False, task='multi_label')

print('finised Predict() \n')

Expand Down
2 changes: 1 addition & 1 deletion tests/commands/test_random_methods.py
Expand Up @@ -21,6 +21,6 @@ def test_random_methods():
]

for method in random_methods:
talos.templates.pipelines.titanic(random_method=method)
talos.templates.pipelines.titanic(random_method=method, debug=True)

print('finish Random Methods \n')

0 comments on commit 73c6202

Please sign in to comment.