New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT allow metadata to be transformed in a Pipeline #28901
base: main
Are you sure you want to change the base?
Conversation
So for simple cases where metadata is only used in Specifically, in this test: @pytest.mark.usefixtures("enable_slep006")
@pytest.mark.parametrize("method", ["fit", "fit_transform"])
def test_transform_input_pipeline(method):
"""Test that with transform_input, data is correctly transformed for each step."""
def get_transformer(registry, sample_weight, metadata):
"""Get a transformer with requests set."""
return (
ConsumingTransformer(registry=registry)
.set_fit_request(sample_weight=sample_weight, metadata=metadata)
.set_transform_request(sample_weight=sample_weight, metadata=metadata)
)
def get_pipeline():
"""Get a pipeline and corresponding registries.
The pipeline has 4 steps, with different request values set to test different
cases. One is aliased.
"""
registry_1, registry_2, registry_3, registry_4 = (
_Registry(),
_Registry(),
_Registry(),
_Registry(),
)
pipe = make_pipeline(
get_transformer(registry_1, sample_weight=True, metadata=True),
get_transformer(registry_2, sample_weight=False, metadata=False),
get_transformer(registry_3, sample_weight=True, metadata=True),
get_transformer(registry_4, sample_weight="other_weights", metadata=True),
transform_input=["sample_weight"],
)
return pipe, registry_1, registry_2, registry_3, registry_4
def check_metadata(registry, methods, **metadata):
"""Check that the right metadata was recorded for the given methods."""
assert registry
for estimator in registry:
for method in methods:
check_recorded_metadata(
estimator,
method=method,
**metadata,
)
X = np.array([[1, 2], [3, 4]])
y = np.array([0, 1])
sample_weight = np.array([[1, 2]])
other_weights = np.array([[30, 40]])
metadata = np.array([[100, 200]])
pipe, registry_1, registry_2, registry_3, registry_4 = get_pipeline()
pipe.fit(
X,
y,
sample_weight=sample_weight,
other_weights=other_weights,
metadata=metadata,
)
check_metadata(
registry_1, ["fit", "transform"], sample_weight=sample_weight, metadata=metadata
)
check_metadata(registry_2, ["fit", "transform"])
check_metadata(
registry_3,
["fit", "transform"],
sample_weight=sample_weight + 2,
metadata=metadata,
)
check_metadata(
registry_4,
method.split("_"), # ["fit", "transform"] if "fit_transform", ["fit"] otherwise
sample_weight=other_weights + 3,
metadata=metadata,
) Step 3 receives transformed data in its The question is, what should be the expected behavior? Do we want |
Actually, in if _routing_enabled():
transform_params = self.get_metadata_routing().consumes(
method="transform", params=fit_params.keys()
)
if transform_params:
warnings.warn(
(
f"This object ({self.__class__.__name__}) has a `transform`"
" method which consumes metadata, but `fit_transform` does not"
" forward metadata to `transform`. Please implement a custom"
" `fit_transform` method to forward metadata to `transform` as"
" well. Alternatively, you can explicitly do"
" `set_transform_request`and set all values to `False` to"
" disable metadata routed to `transform`, if that's an option."
),
UserWarning,
) and we never send anything to However, for third party transformers where they can have their own |
Another question is, do we want to have this syntactic sugar? pipe = make_pipeline(
StandardScaler(),
HistGradientBoostingClassifier(..., early_stopping=True)
).fit(X, y, X_val, y_val) The above code would:
It wouldn't change what we have now implemented in For that to happen, HGBC need to have: class HistGradientBoostingClassifier(...):
...
def get_metadata_routing(self):
routing = super().get_metadata_routing()
if self.early_stopping:
routing.fit.add(X_val=True, y_val=True)
def __sklearn_get_transforming_data__(self):
return ["X_val"] And cc @glemaitre It goes towards the direction of having more default routing info as @ogrisel really likes. (ref #26179 ) Note that this could come later separately as an enhancement to this PR. |
Initial proposal: #28440 (comment)
xref: #28440 (comment)
This adds
transform_input
as a constructor argument toPipeline
, as:It simply allows to transform metadata with fitted estimators up to the step which needs the metadata.
How does this look?
cc @lorentzenchr @ogrisel @amueller @betatim