-
Notifications
You must be signed in to change notification settings - Fork 0
/
mjwt.go
129 lines (111 loc) · 3.38 KB
/
mjwt.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package mjwt
import (
"encoding/json"
"github.com/golang-jwt/jwt/v4"
"github.com/pkg/errors"
"time"
)
var ErrClaimTypeMismatch = errors.New("claim type mismatch")
// wrapClaims creates a BaseTypeClaims wrapper for a generic claims struct
func wrapClaims[T Claims](p Signer, sub, id string, aud jwt.ClaimStrings, dur time.Duration, claims T) *BaseTypeClaims[T] {
now := time.Now()
return (&BaseTypeClaims[T]{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: p.Issuer(),
Subject: sub,
Audience: aud,
ExpiresAt: jwt.NewNumericDate(now.Add(dur)),
NotBefore: jwt.NewNumericDate(now),
IssuedAt: jwt.NewNumericDate(now),
ID: id,
},
Claims: claims,
}).init()
}
// ExtractClaims uses a Verifier to validate the MJWT token and returns the parsed
// token and BaseTypeClaims
func ExtractClaims[T Claims](p Verifier, token string) (*jwt.Token, BaseTypeClaims[T], error) {
b := BaseTypeClaims[T]{
RegisteredClaims: jwt.RegisteredClaims{},
Claims: *new(T),
}
tok, err := p.VerifyJwt(token, &b)
return tok, b, err
}
// Claims is a wrapper for jwt.Claims and adds a Type method to name internal claim structs
type Claims interface {
jwt.Claims
Type() string
}
type baseTypeClaim interface {
jwt.Claims
InternalClaimType() string
}
// BaseTypeClaims is a wrapper for combining the jwt.RegisteredClaims with a ClaimType
// and generic Claims data
type BaseTypeClaims[T Claims] struct {
jwt.RegisteredClaims
ClaimType string
Claims T
}
func (b *BaseTypeClaims[T]) init() *BaseTypeClaims[T] {
b.ClaimType = b.InternalClaimType()
return b
}
// Valid checks the InternalClaimType matches and the type claim type
func (b *BaseTypeClaims[T]) Valid() error {
if b.ClaimType != b.InternalClaimType() {
return ErrClaimTypeMismatch
}
return b.Claims.Valid()
}
// InternalClaimType returns the Type of the generic claim struct
func (b *BaseTypeClaims[T]) InternalClaimType() string { return b.Claims.Type() }
// MarshalJSON converts the internalBaseTypeClaims and generic claim struct into
// a serialized JSON byte array
func (b *BaseTypeClaims[T]) MarshalJSON() ([]byte, error) {
// encode the internalBaseTypeClaims
b1, err := json.Marshal(internalBaseTypeClaims{
RegisteredClaims: b.RegisteredClaims,
ClaimType: b.InternalClaimType(),
})
if err != nil {
return nil, err
}
// encode the generic claims struct
b2, err := json.Marshal(b.Claims)
if err != nil {
return nil, err
}
// replace starting '{' with ','
b2[0] = ','
// join the two json strings and remove the last char '}' from the first string
return append(b1[:len(b1)-1], b2...), nil
}
// UnmarshalJSON reads the internalBaseTypeClaims and generic claim struct from
// a serialized JSON byte array
func (b *BaseTypeClaims[T]) UnmarshalJSON(bytes []byte) error {
a := internalBaseTypeClaims{}
var t T
// convert JSON to internalBaseTypeClaims
err := json.Unmarshal(bytes, &a)
if err != nil {
return err
}
// convert JSON to the generic claim struct
err = json.Unmarshal(bytes, &t)
if err != nil {
return err
}
// assign the fields in BaseTypeClaims
b.RegisteredClaims = a.RegisteredClaims
b.ClaimType = a.ClaimType
b.Claims = t
return err
}
// internalBaseTypeClaims is a wrapper for jwt.RegisteredClaims which adds a
// ClaimType field containing the type of the generic claim struct
type internalBaseTypeClaims struct {
jwt.RegisteredClaims
ClaimType string `json:"mct"`
}