-
Notifications
You must be signed in to change notification settings - Fork 60
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How does the predict function know what model to use? #7
Comments
There is a PR for this already, which will be merged today. #2 |
ok perfect. Yes indeed, with the current predict_fn(processed_data) the inference fails with a I'm using the below functions def load_fn(model_dir):
"""this function reads the model from disk"""
print('load_fn dir view:')
print(os.listdir())
# load model
model = TFAutoModelForQuestionAnswering.from_pretrained('/opt/ml/model')
# load tokenizer
tokenizer = AutoTokenizer.from_pretrained('/opt/ml/model')
return model, tokenizer
def predict_fn(processed_data):
"""this function runs inference"""
print('processed_data received: ')
print(processed_data)
print('model name:')
print(model.name)
print('tok name:')
print(tokenizer.name)
question, text = processed_data['inputs']['question'], processed_data['inputs']['context']
input_dict = tokenizer(question, text, return_tensors='tf')
outputs = model(input_dict)
start_logits = outputs.start_logits
end_logits = outputs.end_logits
all_tokens = tokenizer.convert_ids_to_tokens(input_dict["input_ids"].numpy()[0])
answer = ' '.join(all_tokens[tf.math.argmax(start_logits, 1)[0] : tf.math.argmax(end_logits, 1)[0]+1])
return answer |
with the new |
Hi,
In current SageMaker Framework hosting (eg Sklearn-on-Flask, PyTorch-on-Torchserve), the prediction function is a
predict(input_object, model)
; that takes as input the result of themodel_fn
.In SM HF Hosting the predict only receives the
processed_data
. Then how does it know on which model to work? the objects returned byload_fn
are available in memory?The text was updated successfully, but these errors were encountered: