Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
suzhoum committed May 29, 2024
1 parent 2afba82 commit ed5d97f
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 6 deletions.
2 changes: 0 additions & 2 deletions src/autogluon/cloud/job/sagemaker_job.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import json
import logging
from abc import abstractmethod
from typing import Optional

import pandas as pd
import sagemaker

from ..utils.ag_sagemaker import (
Expand Down
4 changes: 2 additions & 2 deletions src/autogluon/cloud/predictor/timeseries_cloud_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,15 @@ def predict_real_time(
self.id_column = id_column or self.id_column
self.timestamp_column = timestamp_column or self.timestamp_column
self.target_column = target or self.target_column

return self.backend.predict_real_time(
test_data=test_data,
id_column=self.id_column,
timestamp_column=self.timestamp_column,
target=self.target_column,
static_features=static_features,
accept=accept,
inference_kwargs=kwargs
inference_kwargs=kwargs,
)

def predict_proba_real_time(self, **kwargs) -> pd.DataFrame:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def _save_image_and_update_dataframe_column(bytes):

return im_path


def _custom_json_deserializer(serialized_str):
"""
Deserialize the JSON string that may include representations of complex data types like DataFrames
Expand All @@ -55,6 +56,7 @@ def _custom_json_deserializer(serialized_str):

return deserialized_kwargs


def model_fn(model_dir):
"""loads model from previously saved artifact"""
logger.info("Loading the model")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
# flake8: noqa
import logging
import os
import pickle
import shutil
import sys
from io import BytesIO, StringIO

import pandas as pd
import logging
import sys

from autogluon.timeseries import TimeSeriesDataFrame, TimeSeriesPredictor

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)


def model_fn(model_dir):
"""loads model from previously saved artifact"""
# TSPredictor will write to the model file during inference while the default model_dir is read only
Expand Down

0 comments on commit ed5d97f

Please sign in to comment.