/
predictionserviceconn.go
64 lines (57 loc) · 1.66 KB
/
predictionserviceconn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
package tensorflow_service_apis
import (
tfserv "github.com/Golang-Tools/tensorflow_service_apis/tensorflow_serving/apis"
grpc "google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
)
//PredictionServiceConn PredictionServiceClient客户端的连接类
type PredictionServiceConn struct {
tfserv.PredictionServiceClient
conn *grpc.ClientConn
sdk *SDK
}
func newPredictionServiceConn(sdk *SDK, addr string, opts ...grpc.DialOption) (*PredictionServiceConn, error) {
c := new(PredictionServiceConn)
conn, err := grpc.Dial(addr, opts...)
if err != nil {
return nil, err
}
c.conn = conn
c.PredictionServiceClient = tfserv.NewPredictionServiceClient(conn)
c.sdk = sdk
return c, nil
}
//NewPredictionServiceConn 建立一个新的连接
func (c *SDK) NewPredictionServiceConn() (*PredictionServiceConn, error) {
conn, err := newPredictionServiceConn(c, c.addr, c.opts...)
if err != nil {
return nil, err
}
c.getConnLock.Lock()
defer c.getConnLock.Unlock()
c.predictionServiceConn = conn
return conn, nil
}
func (c *SDK) getPredictionServiceConn() *PredictionServiceConn {
c.getConnLock.RLock()
defer c.getConnLock.RUnlock()
if c.predictionServiceConn != nil {
if c.predictionServiceConn.conn.GetState() == connectivity.Shutdown {
return nil
}
return c.predictionServiceConn
}
return nil
}
//GetPredictionServiceConn 获取PredictionServiceClient客户端连接
func (c *SDK) GetPredictionServiceConn() (*PredictionServiceConn, error) {
conn := c.getPredictionServiceConn()
if conn != nil {
return conn, nil
}
return c.NewPredictionServiceConn()
}
//Close 断开连接
func (c *PredictionServiceConn) Close() error {
return c.conn.Close()
}