forked from InVisionApp/saml
/
session_jwt.go
155 lines (133 loc) · 4.05 KB
/
session_jwt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
package samlsp
import (
"crypto/rsa"
"errors"
"fmt"
"time"
"github.com/form3tech-oss/jwt-go"
"github.com/amboss-mededu/saml"
)
const (
defaultSessionMaxAge = time.Hour
claimNameSessionIndex = "SessionIndex"
)
// JWTSessionCodec implements SessionCoded to encode and decode Sessions from
// the corresponding JWT.
type JWTSessionCodec struct {
SigningMethod jwt.SigningMethod
Audience string
Issuer string
MaxAge time.Duration
Key *rsa.PrivateKey
}
var _ SessionCodec = JWTSessionCodec{}
// New creates a Session from the SAML assertion.
//
// The returned Session is a JWTSessionClaims.
func (c JWTSessionCodec) New(assertion *saml.Assertion) (Session, error) {
now := saml.TimeNow()
claims := JWTSessionClaims{}
claims.SAMLSession = true
claims.Audience = []string{c.Audience}
claims.Issuer = c.Issuer
claims.IssuedAt = now.Unix()
claims.ExpiresAt = now.Add(c.MaxAge).Unix()
claims.NotBefore = now.Unix()
if sub := assertion.Subject; sub != nil {
if nameID := sub.NameID; nameID != nil {
claims.Subject = nameID.Value
}
}
claims.Attributes = map[string][]string{}
for _, attributeStatement := range assertion.AttributeStatements {
for _, attr := range attributeStatement.Attributes {
claimName := attr.FriendlyName
if claimName == "" {
claimName = attr.Name
}
for _, value := range attr.Values {
if value.Value == "" && value.NameID != nil {
claims.Attributes[claimName] = append(claims.Attributes[claimName], value.NameID.Value) // cater to eduPersonTargetedID
} else {
claims.Attributes[claimName] = append(claims.Attributes[claimName], value.Value)
}
}
}
}
// add SessionIndex to claims Attributes
for _, authnStatement := range assertion.AuthnStatements {
claims.Attributes[claimNameSessionIndex] = append(claims.Attributes[claimNameSessionIndex],
authnStatement.SessionIndex)
}
return claims, nil
}
// Encode returns a serialized version of the Session.
//
// The provided session must be a JWTSessionClaims, otherwise this
// function will panic.
func (c JWTSessionCodec) Encode(s Session) (string, error) {
claims := s.(JWTSessionClaims) // this will panic if you pass the wrong kind of session
token := jwt.NewWithClaims(c.SigningMethod, claims)
signedString, err := token.SignedString(c.Key)
if err != nil {
return "", err
}
return signedString, nil
}
// Decode parses the serialized session that may have been returned by Encode
// and returns a Session.
func (c JWTSessionCodec) Decode(signed string) (Session, error) {
parser := jwt.Parser{
ValidMethods: []string{c.SigningMethod.Alg()},
}
claims := JWTSessionClaims{}
_, err := parser.ParseWithClaims(signed, &claims, func(*jwt.Token) (interface{}, error) {
return c.Key.Public(), nil
})
// TODO(ross): check for errors due to bad time and return ErrNoSession
if err != nil {
return nil, err
}
if !claims.VerifyAudience(c.Audience, true) {
return nil, fmt.Errorf("expected audience %q, got %q", c.Audience, claims.Audience)
}
if !claims.VerifyIssuer(c.Issuer, true) {
return nil, fmt.Errorf("expected issuer %q, got %q", c.Issuer, claims.Issuer)
}
if claims.SAMLSession != true {
return nil, errors.New("expected saml-session")
}
return claims, nil
}
// JWTSessionClaims represents the JWT claims in the encoded session
type JWTSessionClaims struct {
jwt.StandardClaims
Attributes Attributes `json:"attr"`
SAMLSession bool `json:"saml-session"`
}
var _ Session = JWTSessionClaims{}
// GetAttributes implements SessionWithAttributes. It returns the SAMl attributes.
func (c JWTSessionClaims) GetAttributes() Attributes {
return c.Attributes
}
// Attributes is a map of attributes provided in the SAML assertion
type Attributes map[string][]string
// Get returns the first attribute named `key` or an empty string if
// no such attributes is present.
func (a Attributes) Get(key string) string {
if a == nil {
return ""
}
v := a[key]
if len(v) == 0 {
return ""
}
return v[0]
}
func (a Attributes) GetAll(key string) []string {
if a == nil {
return []string{""}
}
v := a[key]
return v
}