-
Notifications
You must be signed in to change notification settings - Fork 6
/
tfserving_request.go
45 lines (37 loc) · 1.09 KB
/
tfserving_request.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
package eas
import (
"context"
"errors"
"time"
tensorflow_serving "github.com/alibaba/pairec/v2/pkg/tensorflow_serving/apis"
"google.golang.org/grpc/metadata"
)
type TFServingRequest struct {
EasRequest
SignatureName string
ModelName string
Outputs []string
Client tensorflow_serving.PredictionServiceClient
}
func (r *TFServingRequest) SetSignatureName(name string) {
r.SignatureName = name
}
func (r *TFServingRequest) SetModelName(name string) {
r.ModelName = name
}
func (r *TFServingRequest) SetOutputs(outputs []string) {
r.Outputs = outputs
}
func (r *TFServingRequest) Invoke(requestData interface{}) (response interface{}, err error) {
request, ok := requestData.(*tensorflow_serving.PredictRequest)
if !ok {
err = errors.New("requestData is not tensorflow_serving.PredictRequest type")
return
}
ctx, cancel := context.WithTimeout(context.Background(), r.timeout*time.Millisecond)
defer cancel()
md := metadata.New(map[string]string{"Authorization": r.auth})
ctx = metadata.NewOutgoingContext(ctx, md)
response, err = r.Client.Predict(ctx, request)
return
}