-
Notifications
You must be signed in to change notification settings - Fork 51
/
tokenaccessor.go
212 lines (170 loc) · 5.86 KB
/
tokenaccessor.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
package tokenaccessor
import (
"bytes"
"errors"
"sync"
"time"
enforcerconstants "go.aporeto.io/trireme-lib/controller/internal/enforcer/constants"
"go.aporeto.io/trireme-lib/controller/pkg/claimsheader"
"go.aporeto.io/trireme-lib/controller/pkg/connection"
"go.aporeto.io/trireme-lib/controller/pkg/pucontext"
"go.aporeto.io/trireme-lib/controller/pkg/secrets"
"go.aporeto.io/trireme-lib/controller/pkg/tokens"
"go.uber.org/zap"
)
// tokenAccessor is a wrapper around tokenEngine to provide locks for accessing
type tokenAccessor struct {
sync.RWMutex
tokens tokens.TokenEngine
serverID string
validity time.Duration
binary bool
}
// New creates a new instance of TokenAccessor interface
func New(serverID string, validity time.Duration, secret secrets.Secrets, binary bool) (TokenAccessor, error) {
var tokenEngine tokens.TokenEngine
var err error
if binary {
zap.L().Info("Enabling Trireme Datapath v2.0")
tokenEngine, err = tokens.NewBinaryJWT(validity, serverID, secret)
} else {
zap.L().Info("Enabling Trireme Datapath v1.0")
tokenEngine, err = tokens.NewJWT(validity, serverID, secret)
}
if err != nil {
return nil, err
}
return &tokenAccessor{
tokens: tokenEngine,
serverID: serverID,
validity: validity,
binary: binary,
}, nil
}
func (t *tokenAccessor) getToken() tokens.TokenEngine {
t.Lock()
defer t.Unlock()
return t.tokens
}
// SetToken updates the stored token in the struct
func (t *tokenAccessor) SetToken(serverID string, validity time.Duration, secret secrets.Secrets) error {
t.Lock()
defer t.Unlock()
var tokenEngine tokens.TokenEngine
var err error
if t.binary {
tokenEngine, err = tokens.NewBinaryJWT(validity, serverID, secret)
} else {
tokenEngine, err = tokens.NewJWT(validity, serverID, secret)
}
if err != nil {
panic("unable to update token engine")
}
t.tokens = tokenEngine
return nil
}
// GetTokenValidity returns the duration the token is valid for
func (t *tokenAccessor) GetTokenValidity() time.Duration {
return t.validity
}
// GetTokenServerID returns the server ID which is used the generate the token.
func (t *tokenAccessor) GetTokenServerID() string {
return t.serverID
}
// CreateAckPacketToken creates the authentication token
func (t *tokenAccessor) CreateAckPacketToken(context *pucontext.PUContext, auth *connection.AuthInfo) ([]byte, error) {
claims := &tokens.ConnectionClaims{
ID: context.ManagementID(),
RMT: auth.RemoteContext,
RemoteID: auth.RemoteContextID,
}
token, err := t.getToken().CreateAndSign(true, claims, auth.LocalContext, claimsheader.NewClaimsHeader())
if err != nil {
return []byte{}, err
}
return token, nil
}
// createSynPacketToken creates the authentication token
func (t *tokenAccessor) CreateSynPacketToken(context *pucontext.PUContext, auth *connection.AuthInfo) (token []byte, err error) {
token, serviceContext, err := context.GetCachedTokenAndServiceContext()
if err == nil && bytes.Equal(auth.LocalServiceContext, serviceContext) {
// Randomize the nonce and send it
// FIX:we do nothing on error !!!
err = t.getToken().Randomize(token, auth.LocalContext)
if err == nil {
return token, nil
}
// If there is an error, let's try to create a new one
}
claims := &tokens.ConnectionClaims{
LCL: auth.LocalContext,
EK: auth.LocalServiceContext,
T: context.Identity(),
CT: context.CompressedTags(),
ID: context.ManagementID(),
}
if token, err = t.getToken().CreateAndSign(false, claims, auth.LocalContext, claimsheader.NewClaimsHeader()); err != nil {
return []byte{}, nil
}
context.UpdateCachedTokenAndServiceContext(token, auth.LocalServiceContext)
return token, nil
}
// createSynAckPacketToken creates the authentication token for SynAck packets
// We need to sign the received token. No caching possible here
func (t *tokenAccessor) CreateSynAckPacketToken(context *pucontext.PUContext, auth *connection.AuthInfo, claimsHeader *claimsheader.ClaimsHeader) (token []byte, err error) {
claims := &tokens.ConnectionClaims{
T: context.Identity(),
CT: context.CompressedTags(),
LCL: auth.LocalContext,
RMT: auth.RemoteContext,
EK: auth.LocalServiceContext,
ID: context.ManagementID(),
RemoteID: auth.RemoteContextID,
}
if token, err = t.getToken().CreateAndSign(false, claims, auth.LocalContext, claimsHeader); err != nil {
return []byte{}, nil
}
return token, nil
}
// parsePacketToken parses the packet token and populates the right state.
// Returns an error if the token cannot be parsed or the signature fails
func (t *tokenAccessor) ParsePacketToken(auth *connection.AuthInfo, data []byte) (*tokens.ConnectionClaims, error) {
// Validate the certificate and parse the token
claims, nonce, cert, err := t.getToken().Decode(false, data, auth.RemotePublicKey)
if err != nil {
return nil, err
}
// We always a need a valid remote context ID
if claims.T == nil {
return nil, errors.New("no claims found")
}
remoteContextID, ok := claims.T.Get(enforcerconstants.TransmitterLabel)
if !ok {
return nil, errors.New("no transmitter label")
}
auth.RemotePublicKey = cert
auth.RemoteContext = nonce
auth.RemoteContextID = remoteContextID
auth.RemoteServiceContext = claims.EK
return claims, nil
}
// parseAckToken parses the tokens in Ack packets. They don't carry all the state context
// and it needs to be recovered
func (t *tokenAccessor) ParseAckToken(auth *connection.AuthInfo, data []byte) (*tokens.ConnectionClaims, error) {
gt := t.getToken()
if gt == nil {
return nil, errors.New("token is nil")
}
if auth == nil {
return nil, errors.New("auth is nil")
}
// Validate the certificate and parse the token
claims, _, _, err := t.getToken().Decode(true, data, auth.RemotePublicKey)
if err != nil {
return nil, err
}
if !bytes.Equal(claims.RMT, auth.LocalContext) {
return nil, errors.New("failed to match context in ack packet")
}
return claims, nil
}