-
Notifications
You must be signed in to change notification settings - Fork 7k
Open
Description
Describe the bug
I am trying to deploy a bert model I finetuned for text classification. The model is stored in s3. When I try to deploy it, contents of my custom code/inference.py file in the model.tar.gz are getting modified. Hence, I see a model load error in the logs.
To reproduce
Below is my custom inference.py
import os
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import json
import logging
# Initialize logging
logging.basicConfig(level=logging.INFO)
def model_fn(model_dir):
"""Load model and tokenizer from the specified directory."""
logging.info(f"Loading model and tokenizer from {model_dir}")
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
logging.info("Model and tokenizer loaded successfully")
return model, tokenizer
def input_fn(request_body, request_content_type):
"""Parse input data from request body."""
if request_content_type == 'application/json':
input_data = json.loads(request_body)
return input_data['text']
else:
raise ValueError(f"Unsupported content type: {request_content_type}")
def predict_fn(input_data, model_tokenizer):
"""Make a prediction based on the input data and model."""
model, tokenizer = model_tokenizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
try:
# Tokenize input
inputs = tokenizer(input_data, return_tensors="pt").to(device)
# Perform prediction
with torch.no_grad():
output = model(**inputs)
return output
except Exception as e:
logging.error(f"Error during prediction: {str(e)}")
raise ValueError(f"Error during prediction: {str(e)}")
def output_fn(prediction, content_type):
"""Convert the model prediction output to a JSON serializable format."""
if content_type == 'application/json':
if hasattr(prediction, 'logits'):
return json.dumps(prediction.logits[0].cpu().numpy().tolist())
else:
raise ValueError("Prediction output does not contain logits")
else:
raise ValueError(f"Unsupported content type: {content_type}")
It is being replaced by the following:
import os
import torch
from safetensors.torch import load_file
from transformers import BertTokenizer
def model_fn(model_dir):
# Load model
model_path = os.path.join(model_dir, 'model.safetensors')
model = load_file(model_path)
model.eval()
# Load tokenizer
tokenizer_dir = os.path.join(model_dir, 'tokenizer')
tokenizer = BertTokenizer.from_pretrained(tokenizer_dir)
return model, tokenizer
def input_fn(request_body, request_content_type):
# Assuming JSON input
import json
input_data = json.loads(request_body)
return input_data['text']
def predict_fn(input_data, model_tokenizer):
model, tokenizer = model_tokenizer
# Tokenize input
inputs = tokenizer(input_data, return_tensors="pt")
# Perform prediction
with torch.no_grad():
output = model(**inputs)
return output
def output_fn(prediction, content_type):
# Assuming the model returns a tensor
return prediction[0].numpy().tolist()
Metadata
Metadata
Assignees
Labels
No labels