Skip to content

Commit 398c043

Browse files
authored
feat(sso): generate and verify OAuth state with go-cache (#7527)
1 parent 12b4295 commit 398c043

File tree

1 file changed

+21
-23
lines changed

1 file changed

+21
-23
lines changed

server/handles/ssologin.go

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
package handles
22

33
import (
4-
"encoding/base32"
54
"encoding/base64"
65
"errors"
76
"fmt"
7+
"github.com/Xhofe/go-cache"
88
"net/http"
99
"net/url"
1010
"path"
@@ -21,19 +21,28 @@ import (
2121
"github.com/coreos/go-oidc"
2222
"github.com/gin-gonic/gin"
2323
"github.com/go-resty/resty/v2"
24-
"github.com/pquerna/otp"
25-
"github.com/pquerna/otp/totp"
2624
"golang.org/x/oauth2"
2725
"gorm.io/gorm"
2826
)
2927

30-
var opts = totp.ValidateOpts{
31-
// state verify won't expire in 30 secs, which is quite enough for the callback
32-
Period: 30,
33-
Skew: 1,
34-
// in some OIDC providers(such as Authelia), state parameter must be at least 8 characters
35-
Digits: otp.DigitsEight,
36-
Algorithm: otp.AlgorithmSHA1,
28+
const stateLength = 16
29+
const stateExpire = time.Minute * 5
30+
31+
var stateCache = cache.NewMemCache[string](cache.WithShards[string](stateLength))
32+
33+
func _keyState(clientID, state string) string {
34+
return fmt.Sprintf("%s_%s", clientID, state)
35+
}
36+
37+
func generateState(clientID, ip string) string {
38+
state := random.String(stateLength)
39+
stateCache.Set(_keyState(clientID, state), ip, cache.WithEx[string](stateExpire))
40+
return state
41+
}
42+
43+
func verifyState(clientID, ip, state string) bool {
44+
value, ok := stateCache.Get(_keyState(clientID, state))
45+
return ok && value == ip
3746
}
3847

3948
func ssoRedirectUri(c *gin.Context, useCompatibility bool, method string) string {
@@ -91,12 +100,7 @@ func SSOLoginRedirect(c *gin.Context) {
91100
common.ErrorStrResp(c, err.Error(), 400)
92101
return
93102
}
94-
// generate state parameter
95-
state, err := totp.GenerateCodeCustom(base32.StdEncoding.EncodeToString([]byte(oauth2Config.ClientSecret)), time.Now(), opts)
96-
if err != nil {
97-
common.ErrorStrResp(c, err.Error(), 400)
98-
return
99-
}
103+
state := generateState(clientId, c.ClientIP())
100104
c.Redirect(http.StatusFound, oauth2Config.AuthCodeURL(state))
101105
return
102106
default:
@@ -192,13 +196,7 @@ func OIDCLoginCallback(c *gin.Context) {
192196
common.ErrorResp(c, err, 400)
193197
return
194198
}
195-
// add state verify process
196-
stateVerification, err := totp.ValidateCustom(c.Query("state"), base32.StdEncoding.EncodeToString([]byte(oauth2Config.ClientSecret)), time.Now(), opts)
197-
if err != nil {
198-
common.ErrorResp(c, err, 400)
199-
return
200-
}
201-
if !stateVerification {
199+
if !verifyState(clientId, c.ClientIP(), c.Query("state")) {
202200
common.ErrorStrResp(c, "incorrect or expired state parameter", 400)
203201
return
204202
}

0 commit comments

Comments
 (0)