diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index b61a29ddb4..dcd89e995a 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -492,8 +492,8 @@ def compile_model( 'var2':[1,1,28,28]} output_path (str): Specifies where to store the compiled model framework (str): The framework that is used to train the original - model. Allowed values: 'mxnet', 'tensorflow', 'pytorch', 'onnx', - 'xgboost' + model. Allowed values: 'mxnet', 'tensorflow', 'keras', 'pytorch', + 'onnx', 'xgboost' framework_version (str): The version of the framework compile_max_run (int): Timeout in seconds for compilation (default: 3 * 60). After this amount of time Amazon SageMaker Neo diff --git a/src/sagemaker/model.py b/src/sagemaker/model.py index aace01546d..ea66e2c65c 100644 --- a/src/sagemaker/model.py +++ b/src/sagemaker/model.py @@ -45,7 +45,7 @@ "qcs605", ] ) -NEO_ALLOWED_FRAMEWORKS = set(["mxnet", "tensorflow", "pytorch", "onnx", "xgboost"]) +NEO_ALLOWED_FRAMEWORKS = set(["mxnet", "tensorflow", "keras", "pytorch", "onnx", "xgboost"]) NEO_IMAGE_ACCOUNT = { "us-west-1": "710691900526", @@ -322,8 +322,8 @@ def compile( 3 * 60). After this amount of time Amazon SageMaker Neo terminates the compilation job regardless of its current status. framework (str): The framework that is used to train the original - model. Allowed values: 'mxnet', 'tensorflow', 'pytorch', 'onnx', - 'xgboost' + model. Allowed values: 'mxnet', 'tensorflow', 'keras', 'pytorch', + 'onnx', 'xgboost' framework_version (str): Returns: