diff --git a/internal/bootstrap/data/setting.go b/internal/bootstrap/data/setting.go index 4eeb7ce140d..14be12f27d9 100644 --- a/internal/bootstrap/data/setting.go +++ b/internal/bootstrap/data/setting.go @@ -165,6 +165,7 @@ func InitialSettings() []model.SettingItem { {Key: conf.SSOAutoRegister, Value: "false", Type: conf.TypeBool, Group: model.SSO, Flag: model.PRIVATE}, {Key: conf.SSODefaultDir, Value: "/", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, {Key: conf.SSODefaultPermission, Value: "0", Type: conf.TypeNumber, Group: model.SSO, Flag: model.PRIVATE}, + {Key: conf.SSOCompatibilityMode, Value: "false", Type: conf.TypeBool, Group: model.SSO, Flag: model.PUBLIC}, // qbittorrent settings {Key: conf.QbittorrentUrl, Value: "http://admin:adminadmin@localhost:8080/", Type: conf.TypeString, Group: model.SINGLE, Flag: model.PRIVATE}, diff --git a/internal/conf/const.go b/internal/conf/const.go index 1cda3e322dc..2876bdefcd5 100644 --- a/internal/conf/const.go +++ b/internal/conf/const.go @@ -69,6 +69,7 @@ const ( SSOAutoRegister = "sso_auto_register" SSODefaultDir = "sso_default_dir" SSODefaultPermission = "sso_default_permission" + SSOCompatibilityMode = "sso_compatibility_mode" // qbittorrent QbittorrentUrl = "qbittorrent_url" diff --git a/server/handles/ssologin.go b/server/handles/ssologin.go index f7e85807925..236ebf67418 100644 --- a/server/handles/ssologin.go +++ b/server/handles/ssologin.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "net/url" + "path" "strings" "time" @@ -36,71 +37,85 @@ var opts = totp.ValidateOpts{ func SSOLoginRedirect(c *gin.Context) { method := c.Query("method") + usecompatibility := setting.GetBool(conf.SSOCompatibilityMode) enabled := setting.GetBool(conf.SSOLoginEnabled) clientId := setting.GetStr(conf.SSOClientId) platform := setting.GetStr(conf.SSOLoginPlatform) var r_url string var redirect_uri string - if enabled { - urlValues := url.Values{} - if method == "" { - common.ErrorStrResp(c, "no method provided", 400) - return - } + if !enabled { + common.ErrorStrResp(c, "Single sign-on is not enabled", 403) + return + } + urlValues := url.Values{} + if method == "" { + common.ErrorStrResp(c, "no method provided", 400) + return + } + if usecompatibility { + redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/" + method + } else { redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + method + } + urlValues.Add("response_type", "code") + urlValues.Add("redirect_uri", redirect_uri) + urlValues.Add("client_id", clientId) + switch platform { + case "Github": + r_url = "https://github.com/login/oauth/authorize?" + urlValues.Add("scope", "read:user") + case "Microsoft": + r_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize?" + urlValues.Add("scope", "user.read") + urlValues.Add("response_mode", "query") + case "Google": + r_url = "https://accounts.google.com/o/oauth2/v2/auth?" + urlValues.Add("scope", "https://www.googleapis.com/auth/userinfo.profile") + case "Dingtalk": + r_url = "https://login.dingtalk.com/oauth2/auth?" + urlValues.Add("scope", "openid") + urlValues.Add("prompt", "consent") urlValues.Add("response_type", "code") - urlValues.Add("redirect_uri", redirect_uri) - urlValues.Add("client_id", clientId) - switch platform { - case "Github": - r_url = "https://github.com/login/oauth/authorize?" - urlValues.Add("scope", "read:user") - case "Microsoft": - r_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize?" - urlValues.Add("scope", "user.read") - urlValues.Add("response_mode", "query") - case "Google": - r_url = "https://accounts.google.com/o/oauth2/v2/auth?" - urlValues.Add("scope", "https://www.googleapis.com/auth/userinfo.profile") - case "Dingtalk": - r_url = "https://login.dingtalk.com/oauth2/auth?" - urlValues.Add("scope", "openid") - urlValues.Add("prompt", "consent") - urlValues.Add("response_type", "code") - case "Casdoor": - endpoint := strings.TrimSuffix(setting.GetStr(conf.SSOEndpointName), "/") - r_url = endpoint + "/login/oauth/authorize?" - urlValues.Add("scope", "profile") - urlValues.Add("state", endpoint) - case "OIDC": - oauth2Config, err := GetOIDCClient(c) - if err != nil { - common.ErrorStrResp(c, err.Error(), 400) - return - } - // generate state parameter - state, err := totp.GenerateCodeCustom(base32.StdEncoding.EncodeToString([]byte(oauth2Config.ClientSecret)), time.Now(), opts) - if err != nil { - common.ErrorStrResp(c, err.Error(), 400) - return - } - c.Redirect(http.StatusFound, oauth2Config.AuthCodeURL(state)) + case "Casdoor": + endpoint := strings.TrimSuffix(setting.GetStr(conf.SSOEndpointName), "/") + r_url = endpoint + "/login/oauth/authorize?" + urlValues.Add("scope", "profile") + urlValues.Add("state", endpoint) + case "OIDC": + oauth2Config, err := GetOIDCClient(c) + if err != nil { + common.ErrorStrResp(c, err.Error(), 400) return - default: - common.ErrorStrResp(c, "invalid platform", 400) + } + // generate state parameter + state, err := totp.GenerateCodeCustom(base32.StdEncoding.EncodeToString([]byte(oauth2Config.ClientSecret)), time.Now(), opts) + if err != nil { + common.ErrorStrResp(c, err.Error(), 400) return } - c.Redirect(302, r_url+urlValues.Encode()) - } else { - common.ErrorStrResp(c, "Single sign-on is not enabled", 403) + c.Redirect(http.StatusFound, oauth2Config.AuthCodeURL(state)) + return + default: + common.ErrorStrResp(c, "invalid platform", 400) + return } + c.Redirect(302, r_url+urlValues.Encode()) } var ssoClient = resty.New().SetRetryCount(3) func GetOIDCClient(c *gin.Context) (*oauth2.Config, error) { + var redirect_uri string + usecompatibility := setting.GetBool(conf.SSOCompatibilityMode) argument := c.Query("method") - redirect_uri := common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + argument + if usecompatibility { + argument = path.Base(c.Request.URL.Path) + } + if usecompatibility { + redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/" + argument + } else { + redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + argument + } endpoint := setting.GetStr(conf.SSOEndpointName) provider, err := oidc.NewProvider(c, endpoint) if err != nil { @@ -152,7 +167,11 @@ func autoRegister(username, userID string, err error) (*model.User, error) { } func OIDCLoginCallback(c *gin.Context) { + usecompatibility := setting.GetBool(conf.SSOCompatibilityMode) argument := c.Query("method") + if usecompatibility { + argument = path.Base(c.Request.URL.Path) + } clientId := setting.GetStr(conf.SSOClientId) endpoint := setting.GetStr(conf.SSOEndpointName) provider, err := oidc.NewProvider(c, endpoint) @@ -204,6 +223,10 @@ func OIDCLoginCallback(c *gin.Context) { } UserID := claims.Name if argument == "get_sso_id" { + if usecompatibility { + c.Redirect(302, common.GetApiUrl(c.Request)+"/@manage?sso_id="+UserID) + return + } html := fmt.Sprintf(` @@ -227,6 +250,10 @@ func OIDCLoginCallback(c *gin.Context) { if err != nil { common.ErrorResp(c, err, 400) } + if usecompatibility { + c.Redirect(302, common.GetApiUrl(c.Request)+"/@login?token="+token) + return + } html := fmt.Sprintf(` @@ -242,12 +269,18 @@ func OIDCLoginCallback(c *gin.Context) { func SSOLoginCallback(c *gin.Context) { enabled := setting.GetBool(conf.SSOLoginEnabled) + usecompatibility := setting.GetBool(conf.SSOCompatibilityMode) if !enabled { common.ErrorResp(c, errors.New("sso login is disabled"), 500) + return } argument := c.Query("method") + if usecompatibility { + argument = path.Base(c.Request.URL.Path) + } if !utils.SliceContains([]string{"get_sso_id", "sso_get_token"}, argument) { common.ErrorResp(c, errors.New("invalid request"), 500) + return } clientId := setting.GetStr(conf.SSOClientId) platform := setting.GetStr(conf.SSOLoginPlatform) @@ -317,12 +350,18 @@ func SSOLoginCallback(c *gin.Context) { }). Post(tokenUrl) } else { + var redirect_uri string + if usecompatibility { + redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/" + argument + } else { + redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + argument + } resp, err = ssoClient.R().SetHeader("Accept", "application/json"). SetFormData(map[string]string{ "client_id": clientId, "client_secret": clientSecret, "code": callbackCode, - "redirect_uri": common.GetApiUrl(c.Request) + "/api/auth/sso_callback?method=" + argument, + "redirect_uri": redirect_uri, "scope": scope, }).SetFormData(additionalForm).Post(tokenUrl) } @@ -349,6 +388,10 @@ func SSOLoginCallback(c *gin.Context) { return } if argument == "get_sso_id" { + if usecompatibility { + c.Redirect(302, common.GetApiUrl(c.Request)+"/@manage?sso_id="+userID) + return + } html := fmt.Sprintf(` @@ -373,6 +416,10 @@ func SSOLoginCallback(c *gin.Context) { if err != nil { common.ErrorResp(c, err, 400) } + if usecompatibility { + c.Redirect(302, common.GetApiUrl(c.Request)+"/@login?token="+token) + return + } html := fmt.Sprintf(` diff --git a/server/router.go b/server/router.go index 26e8e14724d..92ede88bfde 100644 --- a/server/router.go +++ b/server/router.go @@ -56,6 +56,8 @@ func Init(e *gin.Engine) { // auth api.GET("/auth/sso", handles.SSOLoginRedirect) api.GET("/auth/sso_callback", handles.SSOLoginCallback) + api.GET("/auth/get_sso_id", handles.SSOLoginCallback) + api.GET("/auth/sso_get_token", handles.SSOLoginCallback) //webauthn webauthn.GET("/webauthn_begin_registration", handles.BeginAuthnRegistration)