-
Notifications
You must be signed in to change notification settings - Fork 1
/
inference.py
98 lines (77 loc) · 2.56 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import os
import json
import warnings
import torch
from finetune import T5FineTuner
from transformers import (
T5Tokenizer
)
warnings.filterwarnings("ignore",category=FutureWarning)
def model_fn(model_dir):
"""
Load the model for inference
"""
num_gpus = os.environ['num_gpus'] if ('num_gpus' in os.environ) else 0
if(num_gpus > 0):
device = torch.device(f'cuda:{0}')
else:
device = torch.device('cpu')
saved_model_dir = '/opt/ml/model'
all_checkpoints = []
for f in os.listdir(saved_model_dir):
file_name = os.path.join(saved_model_dir, f)
if 'cktepoch' in file_name:
all_checkpoints.append(file_name)
print ("all checkpoints: ", all_checkpoints)
checkpoint = os.path.join(saved_model_dir,all_checkpoints[-1])
model_ckpt = torch.load(checkpoint, map_location=device)
model = T5FineTuner(model_ckpt['hyper_parameters'])
model.load_state_dict(model_ckpt['state_dict'])
tokenizer = T5Tokenizer.from_pretrained('t5-base')
model.model.to(device)
model.model.eval()
model_dict = {'model': model, 'tokenizer':tokenizer}
return model_dict
def predict_fn(input_data, model):
"""
Apply model to the incoming request
"""
tokenizer = model['tokenizer']
model = model['model']
data = input_data['inputs']
max_seq_length=512
num_gpus = os.environ['num_gpus'] if ('num_gpus' in os.environ) else 0
if(num_gpus > 0):
device = torch.device(f'cuda:{0}')
else:
device = torch.device('cpu')
inputs = tokenizer(
data, max_length=max_seq_length, pad_to_max_length=True, truncation=True,
return_tensors="pt",
)
outs = model.model.generate(input_ids=inputs["input_ids"].to(device),
attention_mask=inputs["attention_mask"].to(device),
max_length=1024)
dec=tokenizer.decode(outs[0], skip_special_tokens=True)
result = {
'result': dec
}
return result
def input_fn(request_body, request_content_type):
"""
Deserialize and prepare the prediction input
"""
if request_content_type == "application/json":
request = json.loads(request_body)
else:
request = request_body
return request
def output_fn(prediction, response_content_type):
"""
Serialize and prepare the prediction output
"""
if response_content_type == "application/json":
response = json.dumps(prediction)
else:
response = str(prediction)
return response