Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support kwargs in predict() and predict_proba() #113

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from

Conversation

suzhoum
Copy link
Contributor

@suzhoum suzhoum commented May 22, 2024

Issue #, if available:

Description of changes:

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

@@ -704,6 +702,7 @@ def predict(
instance_count: int = 1,
custom_image_uri: Optional[str] = None,
wait: bool = True,
inference_kwargs: Optional[Dict[str, Any]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add the new parameter in the doctring

@@ -265,6 +267,10 @@ def run(
else:
model_cls = AutoGluonNonRepackInferenceModel
logger.log(20, "Creating inference model...")
inference_kwargs_str = json.dumps(inference_kwargs) if inference_kwargs is not None else None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we support complex payload in the predict args such as known_covariates from TimeSeriesPredictor

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to handle DataFrame specifically. I can add that to the PR.

@suzhoum suzhoum marked this pull request as draft May 29, 2024 01:15
cols.insert(0, cols.pop(id_index))
if target is not None:
target_index = cols.index(target)
cols.append(cols.pop(target_index))
data = data[cols]

if static_features is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making id_column an optional argument might break the pd.merge in line 45.

I think that keeping the id_column, timestamp_column and target as part of the TimeSeriesSagemakerBackend API is fine since this class is not user-facing. In the public API of the TimeSeriesCloudPredictor these arguments are optional.

id_column: str = "item_id",
timestamp_column: str = "timestamp",
id_column: Optional[str] = None,
timestamp_column: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is our motivation for changing the defaults here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to make the .predict() and .predict_real_time() API align with what we have in the Chronos tutorial. I see that we might not have id_column and timestamp_column in the train_data, but please correct me if I misunderstood the example.

@@ -175,13 +178,18 @@ def predict_real_time(
Pandas.DataFrame
Predict results in DataFrame
"""
self.id_column = id_column or self.id_column
self.timestamp_column = timestamp_column or self.timestamp_column
self.target_column = target or self.target_column
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will happen if both self.target_column is None and target is None?

Copy link
Contributor Author

@suzhoum suzhoum Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the TabularPredictor API does not require target at the moment, that's why I made it optional. See this tutorial, e.g.

predictor = TimeSeriesPredictor(prediction_length=14).fit(train_data)

It has been handled https://github.com/autogluon/autogluon/blob/bda6174f4a1fb8398aef4f375d9eacfd29bb46d9/timeseries/src/autogluon/timeseries/predictor.py#L179

@@ -805,6 +805,7 @@ def predict_proba(
instance_count: int = 1,
custom_image_uri: Optional[str] = None,
wait: bool = True,
inference_kwargs: Optional[Dict[str, Any]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docstring

@@ -556,6 +556,7 @@ def predict(
custom_image_uri: Optional[str] = None,
wait: bool = True,
backend_kwargs: Optional[Dict] = None,
**kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docstring

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add example code should be added to tutorials that showcase specifying kwargs. Otherwise it will be hard for users to realize how to do this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I will add some tutorials with this PR.

@@ -648,6 +650,7 @@ def predict_proba(
custom_image_uri: Optional[str] = None,
wait: bool = True,
backend_kwargs: Optional[Dict] = None,
**kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docstring

@@ -199,6 +210,7 @@ def predict(
custom_image_uri: Optional[str] = None,
wait: bool = True,
backend_kwargs: Optional[Dict] = None,
**kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants