Skip to content

Commit

Permalink
ENH: user friendly error if there is no loss function (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Oct 1, 2020
1 parent c7609e8 commit b8f1b63
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 2 deletions.
15 changes: 15 additions & 0 deletions scikeras/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,21 @@ def _build_keras_model(self):
compile_kwargs = self._get_compile_kwargs()
model.compile(**compile_kwargs)

if not getattr(model, "loss", None) or (
isinstance(model.loss, list)
and not any(callable(l) or isinstance(l, str) for l in model.loss)
):
raise ValueError(
"No valid loss function found."
" You must provide a loss function to train."
"\n\nTo resolve this issue, do one of the following:"
"\n 1. Provide a loss function via the loss parameter."
"\n 2. Compile your model with a loss function inside the"
" model-building method."
"\n\nSee https://www.tensorflow.org/api_docs/python/tf/keras/losses"
" for more information on Keras losses."
)

return model

def _fit_keras_model(self, X, y, sample_weight, warm_start):
Expand Down
1 change: 0 additions & 1 deletion tests/test_compile_kwargs.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np
import pytest

from numpy.lib.arraysetops import isin
from sklearn.datasets import make_classification
from tensorflow.keras import losses as losses_module
from tensorflow.keras import metrics as metrics_module
Expand Down
50 changes: 49 additions & 1 deletion tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import pytest

from sklearn.exceptions import NotFittedError
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model

from scikeras.wrappers import BaseWrapper, KerasClassifier, KerasRegressor
from scikeras.wrappers import KerasClassifier, KerasRegressor

from .mlp_models import dynamic_classifier, dynamic_regressor

Expand Down Expand Up @@ -131,3 +133,49 @@ def no_bar(foo=42):
est = wrapper(model=no_bar, bar=42, foo=43)
with pytest.raises(TypeError, match="got an unexpected keyword argument"):
est.fit([[1]], [1])


@pytest.mark.parametrize("loss", [None, [None]])
@pytest.mark.parametrize("compile", [True, False])
def test_no_loss(loss, compile):
def get_model(compile, meta, compile_kwargs):
inp = Input(shape=(meta["n_features_in_"],))
hidden = Dense(10, activation="relu")(inp)
out = [
Dense(1, activation="sigmoid", name=f"out{i+1}")(hidden)
for i in range(meta["n_outputs_"])
]
model = Model(inp, out)
if compile:
model.compile(**compile_kwargs)
return model

est = KerasRegressor(model=get_model, loss=loss, compile=compile)
with pytest.raises(ValueError, match="must provide a loss function"):
est.fit([[1]], [1])


@pytest.mark.parametrize("compile", [True, False])
def test_no_optimizer(compile):
def get_model(compile, meta, compile_kwargs):
inp = Input(shape=(meta["n_features_in_"],))
hidden = Dense(10, activation="relu")(inp)
out = [
Dense(1, activation="sigmoid", name=f"out{i+1}")(hidden)
for i in range(meta["n_outputs_"])
]
model = Model(inp, out)
if compile:
model.compile(**compile_kwargs)
return model

est = KerasRegressor(
model=get_model,
loss="categorical_crossentropy",
compile=compile,
optimizer=None,
)
with pytest.raises(
ValueError, match="Could not interpret optimizer identifier" # Keras error
):
est.fit([[1]], [1])

0 comments on commit b8f1b63

Please sign in to comment.