1
1
package handles
2
2
3
3
import (
4
- "encoding/base32"
5
4
"encoding/base64"
6
5
"errors"
7
6
"fmt"
7
+ "github.com/Xhofe/go-cache"
8
8
"net/http"
9
9
"net/url"
10
10
"path"
@@ -21,19 +21,28 @@ import (
21
21
"github.com/coreos/go-oidc"
22
22
"github.com/gin-gonic/gin"
23
23
"github.com/go-resty/resty/v2"
24
- "github.com/pquerna/otp"
25
- "github.com/pquerna/otp/totp"
26
24
"golang.org/x/oauth2"
27
25
"gorm.io/gorm"
28
26
)
29
27
30
- var opts = totp.ValidateOpts {
31
- // state verify won't expire in 30 secs, which is quite enough for the callback
32
- Period : 30 ,
33
- Skew : 1 ,
34
- // in some OIDC providers(such as Authelia), state parameter must be at least 8 characters
35
- Digits : otp .DigitsEight ,
36
- Algorithm : otp .AlgorithmSHA1 ,
28
+ const stateLength = 16
29
+ const stateExpire = time .Minute * 5
30
+
31
+ var stateCache = cache.NewMemCache [string ](cache.WithShards [string ](stateLength ))
32
+
33
+ func _keyState (clientID , state string ) string {
34
+ return fmt .Sprintf ("%s_%s" , clientID , state )
35
+ }
36
+
37
+ func generateState (clientID , ip string ) string {
38
+ state := random .String (stateLength )
39
+ stateCache .Set (_keyState (clientID , state ), ip , cache.WithEx [string ](stateExpire ))
40
+ return state
41
+ }
42
+
43
+ func verifyState (clientID , ip , state string ) bool {
44
+ value , ok := stateCache .Get (_keyState (clientID , state ))
45
+ return ok && value == ip
37
46
}
38
47
39
48
func ssoRedirectUri (c * gin.Context , useCompatibility bool , method string ) string {
@@ -91,12 +100,7 @@ func SSOLoginRedirect(c *gin.Context) {
91
100
common .ErrorStrResp (c , err .Error (), 400 )
92
101
return
93
102
}
94
- // generate state parameter
95
- state , err := totp .GenerateCodeCustom (base32 .StdEncoding .EncodeToString ([]byte (oauth2Config .ClientSecret )), time .Now (), opts )
96
- if err != nil {
97
- common .ErrorStrResp (c , err .Error (), 400 )
98
- return
99
- }
103
+ state := generateState (clientId , c .ClientIP ())
100
104
c .Redirect (http .StatusFound , oauth2Config .AuthCodeURL (state ))
101
105
return
102
106
default :
@@ -192,13 +196,7 @@ func OIDCLoginCallback(c *gin.Context) {
192
196
common .ErrorResp (c , err , 400 )
193
197
return
194
198
}
195
- // add state verify process
196
- stateVerification , err := totp .ValidateCustom (c .Query ("state" ), base32 .StdEncoding .EncodeToString ([]byte (oauth2Config .ClientSecret )), time .Now (), opts )
197
- if err != nil {
198
- common .ErrorResp (c , err , 400 )
199
- return
200
- }
201
- if ! stateVerification {
199
+ if ! verifyState (clientId , c .ClientIP (), c .Query ("state" )) {
202
200
common .ErrorStrResp (c , "incorrect or expired state parameter" , 400 )
203
201
return
204
202
}
0 commit comments