forked from kserve/kserve
/
predictor_triton.go
102 lines (86 loc) · 3.3 KB
/
predictor_triton.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
/*
Copyright 2021 The KServe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package v1beta1
import (
"fmt"
"github.com/akravacyber/kserve/pkg/constants"
"github.com/akravacyber/kserve/pkg/utils"
"github.com/golang/protobuf/proto"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
)
var (
TritonISGRPCPort = int32(9000)
TritonISRestPort = int32(8080)
)
// TritonSpec defines arguments for configuring Triton model serving.
type TritonSpec struct {
// Contains fields shared across all predictors
PredictorExtensionSpec `json:",inline"`
}
var (
_ ComponentImplementation = &TritonSpec{}
_ PredictorImplementation = &TritonSpec{}
)
// Validate returns an error if invalid
func (t *TritonSpec) Validate() error {
return utils.FirstNonNilError([]error{
validateStorageURI(t.GetStorageUri()),
})
}
// Default sets defaults on the resource
func (t *TritonSpec) Default(config *InferenceServicesConfig) {
t.Container.Name = constants.InferenceServiceContainerName
if t.RuntimeVersion == nil {
t.RuntimeVersion = proto.String(config.Predictors.Triton.DefaultImageVersion)
}
setResourceRequirementDefaults(&t.Resources)
}
// GetContainers transforms the resource into a container spec
func (t *TritonSpec) GetContainer(metadata metav1.ObjectMeta, extensions *ComponentExtensionSpec, config *InferenceServicesConfig) *v1.Container {
arguments := []string{
"tritonserver",
fmt.Sprintf("%s=%s", "--model-store", constants.DefaultModelLocalMountPath),
fmt.Sprintf("%s=%s", "--grpc-port", fmt.Sprint(TritonISGRPCPort)),
fmt.Sprintf("%s=%s", "--http-port", fmt.Sprint(TritonISRestPort)),
fmt.Sprintf("%s=%s", "--allow-grpc", "true"),
fmt.Sprintf("%s=%s", "--allow-http", "true"),
}
if extensions.ContainerConcurrency != nil && *extensions.ContainerConcurrency != 0 {
arguments = append(arguments, fmt.Sprintf("%s=%d", "--http-thread-count", *extensions.ContainerConcurrency))
}
// when storageURI is nil we enable explicit load/unload
if t.StorageURI == nil {
arguments = append(arguments, fmt.Sprintf("%s=%s", "--model-control-mode", "explicit"))
}
if t.Container.Image == "" {
t.Container.Image = config.Predictors.Triton.ContainerImage + ":" + *t.RuntimeVersion
}
t.Name = constants.InferenceServiceContainerName
arguments = append(arguments, t.Args...)
t.Args = arguments
return &t.Container
}
func (t *TritonSpec) GetStorageUri() *string {
return t.StorageURI
}
func (t *TritonSpec) GetProtocol() constants.InferenceServiceProtocol {
return constants.ProtocolV2
}
func (t *TritonSpec) IsMMS(config *InferenceServicesConfig) bool {
return config.Predictors.Triton.MultiModelServer
}
func (t *TritonSpec) IsFrameworkSupported(framework string, config *InferenceServicesConfig) bool {
supportedFrameworks := config.Predictors.Triton.SupportedFrameworks
return isFrameworkIncluded(supportedFrameworks, framework)
}