-
Notifications
You must be signed in to change notification settings - Fork 70
/
default_pytorch_inference_handler.py
148 lines (125 loc) · 5.98 KB
/
default_pytorch_inference_handler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Copyright 2019-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import os
import torch
from sagemaker_inference import (
content_types,
decoder,
default_inference_handler,
encoder,
errors,
utils,
)
INFERENCE_ACCELERATOR_PRESENT_ENV = "SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT"
DEFAULT_MODEL_FILENAME = "model.pt"
class ModelLoadError(Exception):
pass
class DefaultPytorchInferenceHandler(default_inference_handler.DefaultInferenceHandler):
VALID_CONTENT_TYPES = (content_types.JSON, content_types.NPY)
@staticmethod
def _is_model_file(filename):
is_model_file = False
if os.path.isfile(filename):
_, ext = os.path.splitext(filename)
is_model_file = ext in [".pt", ".pth"]
return is_model_file
def default_model_fn(self, model_dir):
"""Loads a model. For PyTorch, a default function to load a model only if Elastic Inference is used.
In other cases, users should provide customized model_fn() in script.
Args:
model_dir: a directory where model is saved.
Returns: A PyTorch model.
"""
if os.getenv(INFERENCE_ACCELERATOR_PRESENT_ENV) == "true":
model_path = os.path.join(model_dir, DEFAULT_MODEL_FILENAME)
if not os.path.exists(model_path):
raise FileNotFoundError("Failed to load model with default model_fn: missing file {}."
.format(DEFAULT_MODEL_FILENAME))
# Client-framework is CPU only. But model will run in Elastic Inference server with CUDA.
try:
return torch.jit.load(model_path, map_location=torch.device('cpu'))
except RuntimeError as e:
raise ModelLoadError(
"Failed to load {}. Please ensure model is saved using torchscript.".format(model_path)
) from e
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = os.path.join(model_dir, DEFAULT_MODEL_FILENAME)
if not os.path.exists(model_path):
model_files = [file for file in os.listdir(model_dir) if self._is_model_file(file)]
if len(model_files) != 1:
raise ValueError(
"Exactly one .pth or .pt file is required for PyTorch models: {}".format(model_files)
)
model_path = os.path.join(model_dir, model_files[0])
try:
model = torch.jit.load(model_path, map_location=device)
except RuntimeError as e:
raise ModelLoadError(
"Failed to load {}. Please ensure model is saved using torchscript.".format(model_path)
) from e
model = model.to(device)
return model
def default_input_fn(self, input_data, content_type):
"""A default input_fn that can handle JSON, CSV and NPZ formats.
Args:
input_data: the request payload serialized in the content_type format
content_type: the request content_type
Returns: input_data deserialized into torch.FloatTensor or torch.cuda.FloatTensor,
depending if cuda is available.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
np_array = decoder.decode(input_data, content_type)
tensor = torch.FloatTensor(
np_array) if content_type in content_types.UTF8_TYPES else torch.from_numpy(np_array)
return tensor.to(device)
def default_predict_fn(self, data, model):
"""A default predict_fn for PyTorch. Calls a model on data deserialized in input_fn.
Runs prediction on GPU if cuda is available.
Args:
data: input data (torch.Tensor) for prediction deserialized by input_fn
model: PyTorch model loaded in memory by model_fn
Returns: a prediction
"""
with torch.no_grad():
if os.getenv(INFERENCE_ACCELERATOR_PRESENT_ENV) == "true":
device = torch.device("cpu")
model = model.to(device)
input_data = data.to(device)
model.eval()
with torch.jit.optimized_execution(True, {"target_device": "eia:0"}):
output = model(input_data)
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
input_data = data.to(device)
model.eval()
output = model(input_data)
return output
def default_output_fn(self, prediction, accept):
"""A default output_fn for PyTorch. Serializes predictions from predict_fn to JSON, CSV or NPY format.
Args:
prediction: a prediction result from predict_fn
accept: type which the output data needs to be serialized
Returns: output data serialized
"""
if type(prediction) is torch.Tensor:
prediction = prediction.detach().cpu().numpy().tolist()
for content_type in utils.parse_accept(accept):
if content_type in encoder.SUPPORTED_CONTENT_TYPES:
encoded_prediction = encoder.encode(prediction, content_type)
if content_type == content_types.CSV:
encoded_prediction = encoded_prediction.encode("utf-8")
return encoded_prediction
raise errors.UnsupportedFormatError(accept)