-
Notifications
You must be signed in to change notification settings - Fork 13
/
client.go
182 lines (156 loc) · 3.97 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
package watchmanager
import (
"context"
"encoding/json"
"fmt"
"math/big"
"sync"
"time"
"google.golang.org/grpc"
"github.com/Axway/agent-sdk/pkg/watchmanager/proto"
"github.com/golang-jwt/jwt"
)
type clientConfig struct {
errors chan error
events chan *proto.Event
tokenGetter TokenGetter
topicSelfLink string
}
type watchClient struct {
cancelStreamCtx context.CancelFunc
cfg clientConfig
getTokenExpirationTime getTokenExpFunc
isRunning bool
stream proto.Watch_SubscribeClient
streamCtx context.Context
timer *time.Timer
mutex sync.Mutex
}
// newWatchClientFunc func signature to create a watch client
type newWatchClientFunc func(cc grpc.ClientConnInterface) proto.WatchClient
type getTokenExpFunc func(token string) (time.Duration, error)
func newWatchClient(cc grpc.ClientConnInterface, clientCfg clientConfig, newClient newWatchClientFunc) (*watchClient, error) {
svcClient := newClient(cc)
streamCtx, streamCancel := context.WithCancel(context.Background())
stream, err := svcClient.Subscribe(streamCtx)
if err != nil {
streamCancel()
return nil, err
}
client := &watchClient{
cancelStreamCtx: streamCancel,
cfg: clientCfg,
getTokenExpirationTime: getTokenExpirationTime,
isRunning: true,
stream: stream,
streamCtx: streamCtx,
timer: time.NewTimer(0),
}
return client, nil
}
// processEvents process incoming chimera events
func (c *watchClient) processEvents() {
for {
err := c.recv()
if err != nil {
c.handleError(err)
return
}
}
}
// recv blocks until an event is received
func (c *watchClient) recv() error {
event, err := c.stream.Recv()
if err != nil {
return err
}
c.cfg.events <- event
return nil
}
// processRequest sends a message to the client when the timer expires, and handles when the stream is closed.
func (c *watchClient) processRequest() error {
var err error
wg := sync.WaitGroup{}
wg.Add(1)
wait := true
go func() {
for {
select {
case <-c.streamCtx.Done():
c.handleError(c.streamCtx.Err())
return
case <-c.stream.Context().Done():
c.handleError(c.stream.Context().Err())
return
case <-c.timer.C:
err = c.send()
if wait {
wg.Done()
wait = false
}
if err != nil {
c.handleError(err)
return
}
}
}
}()
wg.Wait()
return err
}
// send a message with a new token to the grpc server and returns the expiration time
func (c *watchClient) send() error {
token, err := c.cfg.tokenGetter()
if err != nil {
return err
}
exp, err := c.getTokenExpirationTime(token)
if err != nil {
return err
}
req := createWatchRequest(c.cfg.topicSelfLink, token)
err = c.stream.Send(req)
if err != nil {
return err
}
c.timer.Reset(exp)
return nil
}
// handleError stop the running timer, send to the error channel, and close the open stream.
func (c *watchClient) handleError(err error) {
c.mutex.Lock()
defer c.mutex.Unlock()
c.isRunning = false
c.timer.Stop()
c.cfg.errors <- err
c.cancelStreamCtx()
}
func createWatchRequest(watchTopicSelfLink, token string) *proto.Request {
return &proto.Request{
SelfLink: watchTopicSelfLink,
Token: "Bearer " + token,
}
}
func getTokenExpirationTime(token string) (time.Duration, error) {
parser := new(jwt.Parser)
parser.SkipClaimsValidation = true
claims := jwt.MapClaims{}
_, _, err := parser.ParseUnverified(token, claims)
if err != nil {
return time.Duration(0), fmt.Errorf("getTokenExpirationTime failed to parse token: %s", err)
}
var tm time.Time
switch exp := claims["exp"].(type) {
case float64:
tm = time.Unix(int64(exp), 0)
case json.Number:
v, _ := exp.Int64()
tm = time.Unix(v, 0)
}
exp := time.Until(tm)
// use big.NewInt to avoid an int overflow
i := big.NewInt(int64(exp))
i = i.Mul(i, big.NewInt(4))
i = i.Div(i, big.NewInt(5))
return time.Duration(i.Int64()), nil
}