Skip to content

Commit

Permalink
[BUGFIX] Key by step.name in fit_params_steps to handle unhashable st…
Browse files Browse the repository at this point in the history
…eps (#43)

* Add test for fit_params using unhashable steps

* Key fit_params_steps by step name instead of the step object to
handle unhashable steps (step names are unique in a model)
  • Loading branch information
alegonz committed Nov 15, 2020
1 parent a342e14 commit 28fa0af
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
6 changes: 3 additions & 3 deletions baikal/_core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def get_step(self, name: str) -> Step:
-------
The step.
"""
# Steps are assumed to have unique names (guaranteed by success of _build_graph)
# Steps are assumed to have unique names (guaranteed by success of build_graph_from_outputs)
if name in self._steps.keys():
return self._steps[name]
raise ValueError("{} was not found in the model.".format(name))
Expand Down Expand Up @@ -384,7 +384,7 @@ def fit(
# TODO: Add check for __. Add error message if step was not found
step_name, _, param_name = param_key.partition("__")
step = self.get_step(step_name)
fit_params_steps[step][param_name] = param_value
fit_params_steps[step.name][param_name] = param_value

# Intermediate results are stored here
# keys: DataPlaceholder instances, values: actual data (e.g. numpy arrays)
Expand All @@ -411,7 +411,7 @@ def fit(
continue

ys = [results_cache[t] for t in node.targets]
fit_params = fit_params_steps.get(node.step, {})
fit_params = fit_params_steps.get(node.step.name, {})

if node.fit_compute_func is not None:
self._fit_compute_node(node, Xs, ys, results_cache, **fit_params)
Expand Down
18 changes: 17 additions & 1 deletion tests/_core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline, Pipeline

from baikal import Model, Input
from baikal import Model, Input, Step
from baikal._core.data_placeholder import DataPlaceholder
from baikal._core.typing import ArrayLike
from baikal.steps import Concatenate, ColumnStack, Stack, Lambda
Expand Down Expand Up @@ -1002,6 +1002,22 @@ def test_fit_params(teardown):
assert_allclose(model.get_step("logreg").coef_, pipe.named_steps["logreg"].coef_)


def test_fit_params_unhashable_step():
class UnhashableStep(Step, sklearn.linear_model.LogisticRegression):
def __eq__(self, other):
pass

x = Input()
y_t = Input()
y = UnhashableStep()(x, y_t)
model = Model(x, y, y_t)

mask = iris.target != 2 # Reduce to binary problem to avoid ConvergenceWarning
x_data = iris.data[mask]
y_t_data = iris.target[mask]
model.fit(x_data, y_t_data)


def test_get_params(teardown):
dummy1 = DummyEstimator(name="dummy1")
dummy2 = DummyEstimator(x=456, y="def", name="dummy2")
Expand Down

0 comments on commit 28fa0af

Please sign in to comment.