forked from macewan-cs/lti
-
Notifications
You must be signed in to change notification settings - Fork 0
/
launch.go
336 lines (277 loc) · 10.7 KB
/
launch.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
// Copyright (c) 2021 MacEwan University. All rights reserved.
//
// This source code is licensed under the MIT-style license found in
// the LICENSE file in the root directory of this source tree.
// Package launch provides functions and methods for LTI's tool launch.
package launch
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"github.com/google/uuid"
"github.com/lestrrat-go/jwx/jwk"
"github.com/lestrrat-go/jwx/jwt"
"github.com/astiusa/lti/datastore"
"github.com/astiusa/lti/datastore/nonpersistent"
"github.com/astiusa/lti/login"
)
// A Launch implements an external application's role in the LTI specification's launch flow.
type Launch struct {
cfg datastore.Config
next http.HandlerFunc
}
// ContextKeyType is used as the key to store the launch ID in the request context.
type ContextKeyType string
// ContextKey is the actual value used for the context key.
const ContextKey = ContextKeyType("LaunchID")
var (
maximumResourceLinkIDLength = 255
supportedLTIVersion = "1.3.0"
launchIDPrefix = "lti1p3-launch-"
)
// New creates a *Launch, which implements the http.Handler interface for launching a tool.
func New(cfg datastore.Config, next http.HandlerFunc) *Launch {
launch := Launch{
cfg: cfg,
next: next,
}
if launch.cfg.LaunchData == nil {
launch.cfg.LaunchData = nonpersistent.DefaultStore
}
if launch.cfg.Registrations == nil {
launch.cfg.Registrations = nonpersistent.DefaultStore
}
if launch.cfg.Nonces == nil {
launch.cfg.Nonces = nonpersistent.DefaultStore
}
return &launch
}
// ServeHTTP performs validations according the OIDC launch flow modified for use by the IMS Global LTI v1p3
// specifications. State is found in a user agent cookie and the POST body. Nonce is found embedded in the id_token and
// in a datastore.
func (l *Launch) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var (
rawToken []byte
statusCode int
err error
registration datastore.Registration
verifiedToken jwt.Token
launchData json.RawMessage
)
if rawToken, statusCode, err = getRawToken(r); err != nil {
http.Error(w, err.Error(), statusCode)
return
}
if registration, statusCode, err = validateRegistration(rawToken, l, r); err != nil {
http.Error(w, err.Error(), statusCode)
return
}
if verifiedToken, statusCode, err = validateSignature(rawToken, registration, r); err != nil {
http.Error(w, err.Error(), statusCode)
return
}
if statusCode, err = validateState(r); err != nil {
http.Error(w, err.Error(), statusCode)
return
}
if statusCode, err = validateClientID(verifiedToken, registration); err != nil {
http.Error(w, err.Error(), statusCode)
return
}
if statusCode, err = validateNonceAndTargetLinkURI(verifiedToken, l); err != nil {
http.Error(w, err.Error(), statusCode)
return
}
if statusCode, err = validateDeploymentID(verifiedToken, l); err != nil {
http.Error(w, err.Error(), statusCode)
return
}
if statusCode, err = validateVersionAndMessageType(verifiedToken); err != nil {
http.Error(w, err.Error(), statusCode)
return
}
if statusCode, err = validateResourceLink(verifiedToken); err != nil {
http.Error(w, err.Error(), statusCode)
return
}
if launchData, statusCode, err = getLaunchData(rawToken); err != nil {
http.Error(w, err.Error(), statusCode)
return
}
// Store the Launch data under a unique Launch ID for future reference.
launchID := launchIDPrefix + uuid.New().String()
l.cfg.LaunchData.StoreLaunchData(launchID, launchData)
// Put the launch ID in the request context for subsequent handlers.
r = r.WithContext(contextWithLaunchID(r.Context(), launchID))
l.next(w, r)
}
// getRawToken gets the OIDC id_token.
func getRawToken(r *http.Request) ([]byte, int, error) {
// Decode token and check for JWT format errors without verification. An external keyset is needed for verification.
idToken := []byte(r.FormValue("id_token"))
_, err := jwt.Parse(idToken)
if err != nil {
return nil, http.StatusBadRequest, fmt.Errorf("get raw token: %w", err)
}
return idToken, http.StatusOK, nil
}
// validateRegistration finds the registration by the issuer of the token.
func validateRegistration(rawToken []byte, l *Launch, r *http.Request) (datastore.Registration, int, error) {
token, err := jwt.Parse(rawToken)
if err != nil {
return datastore.Registration{}, http.StatusBadRequest, fmt.Errorf("validate registration: %w", err)
}
issuer := token.Issuer()
clientID := token.Audience()[0]
registration, err := l.cfg.Registrations.FindRegistrationByIssuerAndClientID(issuer, clientID)
if err != nil {
if err == datastore.ErrRegistrationNotFound {
return datastore.Registration{}, http.StatusBadRequest, fmt.Errorf("no registration found for iss %s", issuer)
}
return datastore.Registration{}, http.StatusInternalServerError, fmt.Errorf("validate registration: %w", err)
}
return registration, http.StatusOK, nil
}
// validateSignature checks the authenticity of the token.
func validateSignature(rawToken []byte, registration datastore.Registration, r *http.Request) (jwt.Token, int, error) {
// Get keyset from the Platform for verification.
keyset, err := jwk.Fetch(context.Background(), registration.KeysetURI.String())
if err != nil {
// Since the KeysetURI is part of the registration, a failure to retrieve it should be reported as an
// internal server error.
return nil, http.StatusInternalServerError, fmt.Errorf("validate signature: %w", err)
}
// Perform the signature check.
verifiedToken, err := jwt.Parse(rawToken, jwt.WithKeySet(keyset))
if err != nil {
return nil, http.StatusBadRequest, fmt.Errorf("validate signature: %w", err)
}
return verifiedToken, http.StatusOK, nil
}
// validateState checks the state cookie against the state query value returned by the Platform.
func validateState(r *http.Request) (int, error) {
stateCookie, err := r.Cookie(login.StateCookieName)
if errors.Is(err, http.ErrNoCookie) {
stateCookie, err = r.Cookie(login.LegacyStateCookieName)
}
if err != nil {
return http.StatusBadRequest, fmt.Errorf("cannot get cookie from request: %w", err)
}
state := r.FormValue("state")
if stateCookie.Value != state {
return http.StatusBadRequest, errors.New("state validation failed")
}
return http.StatusOK, nil
}
// validateClientID checks that the claimed client ID (aud) is listed for the claimed issuer.
func validateClientID(verifiedToken jwt.Token, registration datastore.Registration) (int, error) {
audience := verifiedToken.Audience()
found := contains(registration.ClientID, audience)
if !found {
return http.StatusBadRequest, errors.New("client ID not registered for this issuer")
}
return http.StatusOK, nil
}
// validateNonceAndTargetLinkURI verifies that the TargetLinkURI provided during the initial (login) auth request and
// the id_token matches, and in the process, it checks that the nonce also exists.
func validateNonceAndTargetLinkURI(verifiedToken jwt.Token, l *Launch) (int, error) {
targetLinkURI, ok := verifiedToken.Get("https://purl.imsglobal.org/spec/lti/claim/target_link_uri")
if !ok {
return http.StatusBadRequest, errors.New("target link URI not found in request")
}
nonce, ok := verifiedToken.Get("nonce")
if !ok {
return http.StatusBadRequest, errors.New("nonce not found in request")
}
err := l.cfg.Nonces.TestAndClearNonce(nonce.(string), targetLinkURI.(string))
if err != nil {
if err == datastore.ErrNonceNotFound || err == datastore.ErrNonceTargetLinkURIMismatch {
return http.StatusBadRequest, err
}
return http.StatusInternalServerError, err
}
return http.StatusOK, nil
}
// validateDeploymentID verifies that the deployment ID exists under the issuer.
func validateDeploymentID(verifiedToken jwt.Token, l *Launch) (int, error) {
deploymentID, ok := verifiedToken.Get("https://purl.imsglobal.org/spec/lti/claim/deployment_id")
if !ok {
return http.StatusBadRequest, errors.New("deployment not found in request")
}
_, err := l.cfg.Registrations.FindDeployment(verifiedToken.Issuer(), deploymentID.(string))
if err != nil {
if err == datastore.ErrDeploymentNotFound {
return http.StatusBadRequest, err
}
return http.StatusInternalServerError, err
}
return http.StatusOK, nil
}
// validateVersionAndMessageType checks for a valid version and message type. Only 'Resource link launch request'
// (LtiResourceLinkRequest) is currently supported.
func validateVersionAndMessageType(verifiedToken jwt.Token) (int, error) {
ltiVersion, ok := verifiedToken.Get("https://purl.imsglobal.org/spec/lti/claim/version")
if !ok {
return http.StatusBadRequest, errors.New("LTI version not found in request")
}
if ltiVersion != supportedLTIVersion {
return http.StatusBadRequest, errors.New("compatible version not found in request")
}
messageType, ok := verifiedToken.Get("https://purl.imsglobal.org/spec/lti/claim/message_type")
if !ok {
return http.StatusBadRequest, errors.New("message type not found in request")
}
if messageType.(string) != "LtiResourceLinkRequest" {
return http.StatusBadRequest, errors.New("supported message type not found in request")
}
return http.StatusOK, nil
}
// validateResourceLink verifies the resource link and ID.
func validateResourceLink(verifiedToken jwt.Token) (int, error) {
rawResourceLink, ok := verifiedToken.Get("https://purl.imsglobal.org/spec/lti/claim/resource_link")
if !ok {
return http.StatusBadRequest, errors.New("resource link not found in request")
}
resourceLink, ok := rawResourceLink.(map[string]interface{})
if !ok {
return http.StatusBadRequest, errors.New("resource link improperly formatted")
}
resourceLinkID, ok := resourceLink["id"]
if !ok {
return http.StatusBadRequest, errors.New("resource link ID not found")
}
if len(resourceLinkID.(string)) > maximumResourceLinkIDLength {
return http.StatusBadRequest, fmt.Errorf("resource link ID exceeds maximum length (%d)", maximumResourceLinkIDLength)
}
return http.StatusOK, nil
}
// getLaunchData parses the id_token to get JWT payload for storage.
func getLaunchData(rawToken []byte) (json.RawMessage, int, error) {
if len(rawToken) == 0 {
return nil, http.StatusBadRequest, errors.New("received empty raw token argument")
}
rawTokenParts := strings.SplitN(string(rawToken), ".", 3)
payload, err := base64.RawURLEncoding.DecodeString(rawTokenParts[1])
if err != nil {
return nil, http.StatusBadRequest, fmt.Errorf("get launch data: %w", err)
}
return json.RawMessage(payload), http.StatusOK, nil
}
// contains returns whether a string exists in a []string.
func contains(n string, s []string) bool {
for _, v := range s {
if v == n {
return true
}
}
return false
}
// contextWithLaunchID puts the launch ID into the given context.
func contextWithLaunchID(ctx context.Context, launchID string) context.Context {
key := ContextKey
return context.WithValue(ctx, key, launchID)
}