Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support kwargs in predict() and predict_proba() #113

Draft
wants to merge 11 commits into
base: master
Choose a base branch
from
10 changes: 5 additions & 5 deletions .github/workflows/continuous_integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ jobs:
strategy:
fail-fast: false
matrix:
AG_VERSION: ["source", "0.7.0"]
AG_VERSION: ["source", "1.1.0"]
needs: cloud_lint_check
runs-on: ubuntu-latest
steps:
Expand Down Expand Up @@ -130,7 +130,7 @@ jobs:
strategy:
fail-fast: false
matrix:
AG_VERSION: ["source", "0.7.0"]
AG_VERSION: ["source", "1.1.0"]
needs: cloud_lint_check
runs-on: ubuntu-latest
steps:
Expand Down Expand Up @@ -162,7 +162,7 @@ jobs:
strategy:
fail-fast: false
matrix:
AG_VERSION: ["source", "0.7.0"]
AG_VERSION: ["source", "1.1.0"]
needs: cloud_lint_check
runs-on: ubuntu-latest
steps:
Expand Down Expand Up @@ -194,7 +194,7 @@ jobs:
strategy:
fail-fast: false
matrix:
AG_VERSION: ["source", "0.7.0"]
AG_VERSION: ["source", "1.1.0"]
needs: cloud_lint_check
runs-on: ubuntu-latest
steps:
Expand Down Expand Up @@ -226,7 +226,7 @@ jobs:
strategy:
fail-fast: false
matrix:
AG_VERSION: ["source", "0.7.0"]
AG_VERSION: ["source", "1.1.0"]
needs: cloud_lint_check
runs-on: ubuntu-latest
steps:
Expand Down
2 changes: 1 addition & 1 deletion src/autogluon/cloud/backend/ray_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def _get_image_uri(self, framework_version: str, instance_type: str, custom_imag
logger.log(20, f"Training with custom_image_uri=={custom_image_uri}")
else:
framework_version, py_version = parse_framework_version(
framework_version, "training", minimum_version="0.7.0"
framework_version, "training", minimum_version="1.0.0"
)
logger.log(20, f"Training with framework_version=={framework_version}")
image_uri = image_uris.retrieve(
Expand Down
8 changes: 6 additions & 2 deletions src/autogluon/cloud/backend/sagemaker_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,6 @@ def predict_real_time(
test_data_image_column: Optional[str] = None,
accept: str = "application/x-parquet",
inference_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Union[pd.DataFrame, pd.Series]:
"""
Predict with the deployed SageMaker endpoint. A deployed SageMaker endpoint is required.
Expand Down Expand Up @@ -595,7 +594,6 @@ def predict_proba_real_time(
test_data_image_column: Optional[str] = None,
accept: str = "application/x-parquet",
inference_kwargs: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Union[pd.DataFrame, pd.Series]:
"""
Predict probability with the deployed SageMaker endpoint. A deployed SageMaker endpoint is required.
Expand Down Expand Up @@ -704,6 +702,7 @@ def predict(
instance_count: int = 1,
custom_image_uri: Optional[str] = None,
wait: bool = True,
inference_kwargs: Optional[Dict[str, Any]] = None,
suzhoum marked this conversation as resolved.
Show resolved Hide resolved
download: bool = True,
persist: bool = True,
save_path: Optional[str] = None,
Expand Down Expand Up @@ -783,6 +782,7 @@ def predict(
instance_count=instance_count,
custom_image_uri=custom_image_uri,
wait=wait,
inference_kwargs=inference_kwargs,
download=download,
persist=persist,
save_path=save_path,
Expand All @@ -805,6 +805,7 @@ def predict_proba(
instance_count: int = 1,
custom_image_uri: Optional[str] = None,
wait: bool = True,
inference_kwargs: Optional[Dict[str, Any]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docstring

download: bool = True,
persist: bool = True,
save_path: Optional[str] = None,
Expand Down Expand Up @@ -889,6 +890,7 @@ def predict_proba(
instance_count=instance_count,
custom_image_uri=custom_image_uri,
wait=wait,
inference_kwargs=inference_kwargs,
download=download,
persist=persist,
save_path=save_path,
Expand Down Expand Up @@ -1133,6 +1135,7 @@ def _predict(
instance_count=1,
custom_image_uri=None,
wait=True,
inference_kwargs=None,
download=True,
persist=True,
save_path=None,
Expand Down Expand Up @@ -1256,6 +1259,7 @@ def _predict(
transformer_kwargs=transformer_kwargs,
model_kwargs=model_kwargs,
repack_model=repack_model,
inference_kwargs=inference_kwargs,
**transform_kwargs,
)
self._batch_transform_jobs[job_name] = batch_transform_job
Expand Down
31 changes: 17 additions & 14 deletions src/autogluon/cloud/backend/timeseries_sagemaker_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ class TimeSeriesSagemakerBackend(SagemakerBackend):
def _preprocess_data(
self,
data: Union[pd.DataFrame, str],
id_column: str,
timestamp_column: str,
target: str,
id_column: Optional[str] = None,
timestamp_column: Optional[str] = None,
target: Optional[str] = None,
static_features: Optional[Union[pd.DataFrame, str]] = None,
) -> pd.DataFrame:
if isinstance(data, str):
Expand All @@ -27,12 +27,15 @@ def _preprocess_data(
cols = data.columns.to_list()
# Make sure id and timestamp columns are the first two columns, and target column is in the end
# This is to ensure in the container we know how to find id and timestamp columns, and whether there are static features being merged
timestamp_index = cols.index(timestamp_column)
cols.insert(0, cols.pop(timestamp_index))
id_index = cols.index(id_column)
cols.insert(0, cols.pop(id_index))
target_index = cols.index(target)
cols.append(cols.pop(target_index))
if timestamp_column is not None:
timestamp_index = cols.index(timestamp_column)
cols.insert(0, cols.pop(timestamp_index))
if id_column is not None:
id_index = cols.index(id_column)
cols.insert(0, cols.pop(id_index))
if target is not None:
target_index = cols.index(target)
cols.append(cols.pop(target_index))
data = data[cols]

if static_features is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making id_column an optional argument might break the pd.merge in line 45.

I think that keeping the id_column, timestamp_column and target as part of the TimeSeriesSagemakerBackend API is fine since this class is not user-facing. In the public API of the TimeSeriesCloudPredictor these arguments are optional.

Expand All @@ -48,8 +51,8 @@ def fit(
*,
predictor_init_args: Dict[str, Any],
predictor_fit_args: Dict[str, Any],
id_column: str,
timestamp_column: str,
id_column: Optional[str] = None,
timestamp_column: Optional[str] = None,
static_features: Optional[Union[str, pd.DataFrame]] = None,
framework_version: str = "latest",
job_name: Optional[str] = None,
Expand Down Expand Up @@ -199,9 +202,9 @@ def predict_proba_real_time(self, **kwargs) -> pd.DataFrame:
def predict(
self,
test_data: Union[str, pd.DataFrame],
id_column: str,
timestamp_column: str,
target: str,
id_column: Optional[str] = None,
timestamp_column: Optional[str] = None,
target: Optional[str] = None,
static_features: Optional[Union[str, pd.DataFrame]] = None,
**kwargs,
) -> Optional[pd.DataFrame]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ max_workers: 1

docker:
# The image uri will be dynamically replaced by cloud predictor
image: "763104351884.dkr.ecr.us-east-1.amazonaws.com/autogluon-training:0.7.0-cpu-py39-ubuntu20.04"
image: "763104351884.dkr.ecr.us-east-1.amazonaws.com/autogluon-training:1.0.0-cpu-py39-ubuntu20.04"
container_name: "ag_dlc"

# Cloud-provider specific configuration.
Expand Down
7 changes: 7 additions & 0 deletions src/autogluon/cloud/job/sagemaker_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
AutoGluonSagemakerEstimator,
)
from ..utils.constants import LOCAL_MODE, LOCAL_MODE_GPU, MODEL_ARTIFACT_NAME
from ..utils.utils import serialize_kwargs
from .remote_job import RemoteJob

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -257,6 +258,7 @@ def run(
model_kwargs,
transformer_kwargs,
repack_model=False,
inference_kwargs=None,
**kwargs,
):
self._local_mode = instance_type in (LOCAL_MODE, LOCAL_MODE_GPU)
Expand All @@ -265,6 +267,10 @@ def run(
else:
model_cls = AutoGluonNonRepackInferenceModel
logger.log(20, "Creating inference model...")
inference_kwargs_str = serialize_kwargs(inference_kwargs) if inference_kwargs is not None else None
env = {}
if len(inference_kwargs_str) > 0:
env["inference_kwargs"] = inference_kwargs_str
model = model_cls(
model_data=model_data,
role=role,
Expand All @@ -275,6 +281,7 @@ def run(
custom_image_uri=custom_image_uri,
entry_point=entry_point,
predictor_cls=predictor_cls,
env=env,
**model_kwargs,
)
logger.log(20, "Inference model created successfully")
Expand Down
6 changes: 5 additions & 1 deletion src/autogluon/cloud/predictor/cloud_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def predict_proba_real_time(
"""
self._validate_inference_kwargs(inference_kwargs=kwargs)
return self.backend.predict_proba_real_time(
test_data=test_data, test_data_image_column=test_data_image_column, accept=accept
test_data=test_data, test_data_image_column=test_data_image_column, accept=accept, inference_kwargs=kwargs
)

def predict(
Expand All @@ -556,6 +556,7 @@ def predict(
custom_image_uri: Optional[str] = None,
wait: bool = True,
backend_kwargs: Optional[Dict] = None,
**kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docstring

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add example code should be added to tutorials that showcase specifying kwargs. Otherwise it will be hard for users to realize how to do this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. I will add some tutorials with this PR.

) -> Optional[pd.Series]:
"""
Batch inference.
Expand Down Expand Up @@ -632,6 +633,7 @@ def predict(
instance_count=instance_count,
custom_image_uri=custom_image_uri,
wait=wait,
inference_kwargs=kwargs,
**backend_kwargs,
)

Expand All @@ -648,6 +650,7 @@ def predict_proba(
custom_image_uri: Optional[str] = None,
wait: bool = True,
backend_kwargs: Optional[Dict] = None,
**kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docstring

) -> Optional[Union[Tuple[pd.Series, Union[pd.DataFrame, pd.Series]], Union[pd.DataFrame, pd.Series]]]:
"""
Batch inference
Expand Down Expand Up @@ -730,6 +733,7 @@ def predict_proba(
instance_count=instance_count,
custom_image_uri=custom_image_uri,
wait=wait,
inference_kwargs=kwargs,
**backend_kwargs,
)

Expand Down
23 changes: 20 additions & 3 deletions src/autogluon/cloud/predictor/timeseries_cloud_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def fit(
*,
predictor_init_args: Dict[str, Any],
predictor_fit_args: Dict[str, Any],
id_column: str = "item_id",
timestamp_column: str = "timestamp",
id_column: Optional[str] = None,
timestamp_column: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is our motivation for changing the defaults here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to make the .predict() and .predict_real_time() API align with what we have in the Chronos tutorial. I see that we might not have id_column and timestamp_column in the train_data, but please correct me if I misunderstood the example.

static_features: Optional[Union[str, pd.DataFrame]] = None,
framework_version: str = "latest",
job_name: Optional[str] = None,
Expand Down Expand Up @@ -120,7 +120,7 @@ def fit(
if backend_kwargs is None:
backend_kwargs = {}

self.target_column = predictor_init_args.get("target", "target")
self.target_column = predictor_init_args.get("target")
self.id_column = id_column
self.timestamp_column = timestamp_column

Expand All @@ -146,6 +146,9 @@ def fit(
def predict_real_time(
self,
test_data: Union[str, pd.DataFrame],
id_column: Optional[str] = None,
timestamp_column: Optional[str] = None,
target: Optional[str] = None,
static_features: Optional[Union[str, pd.DataFrame]] = None,
accept: str = "application/x-parquet",
**kwargs,
Expand Down Expand Up @@ -175,13 +178,18 @@ def predict_real_time(
Pandas.DataFrame
Predict results in DataFrame
"""
self.id_column = id_column or self.id_column
self.timestamp_column = timestamp_column or self.timestamp_column
self.target_column = target or self.target_column
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What will happen if both self.target_column is None and target is None?

Copy link
Contributor Author

@suzhoum suzhoum Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the TabularPredictor API does not require target at the moment, that's why I made it optional. See this tutorial, e.g.

predictor = TimeSeriesPredictor(prediction_length=14).fit(train_data)

It has been handled https://github.com/autogluon/autogluon/blob/bda6174f4a1fb8398aef4f375d9eacfd29bb46d9/timeseries/src/autogluon/timeseries/predictor.py#L179


return self.backend.predict_real_time(
test_data=test_data,
id_column=self.id_column,
timestamp_column=self.timestamp_column,
target=self.target_column,
static_features=static_features,
accept=accept,
inference_kwargs=kwargs,
)

def predict_proba_real_time(self, **kwargs) -> pd.DataFrame:
Expand All @@ -190,6 +198,9 @@ def predict_proba_real_time(self, **kwargs) -> pd.DataFrame:
def predict(
self,
test_data: Union[str, pd.DataFrame],
id_column: Optional[str] = None,
timestamp_column: Optional[str] = None,
target: Optional[str] = None,
static_features: Optional[Union[str, pd.DataFrame]] = None,
predictor_path: Optional[str] = None,
framework_version: str = "latest",
Expand All @@ -199,6 +210,7 @@ def predict(
custom_image_uri: Optional[str] = None,
wait: bool = True,
backend_kwargs: Optional[Dict] = None,
**kwargs,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

docstring

) -> Optional[pd.DataFrame]:
"""
Predict using SageMaker batch transform.
Expand Down Expand Up @@ -263,6 +275,10 @@ def predict(
Please refer to
https://sagemaker.readthedocs.io/en/stable/api/inference/transformer.html#sagemaker.transformer.Transformer.transform for all options.
"""
self.id_column = id_column or self.id_column
self.timestamp_column = timestamp_column or self.timestamp_column
self.target_column = target or self.target_column

if backend_kwargs is None:
backend_kwargs = {}
backend_kwargs = self.backend.parse_backend_predict_kwargs(backend_kwargs)
Expand All @@ -279,6 +295,7 @@ def predict(
instance_count=instance_count,
custom_image_uri=custom_image_uri,
wait=wait,
inference_kwargs=kwargs,
**backend_kwargs,
)

Expand Down
Loading
Loading