Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: sso compatibility mode #5260

Merged
merged 1 commit into from
Sep 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions internal/bootstrap/data/setting.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ func InitialSettings() []model.SettingItem {
{Key: conf.SSOAutoRegister, Value: "false", Type: conf.TypeBool, Group: model.SSO, Flag: model.PRIVATE},
{Key: conf.SSODefaultDir, Value: "/", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE},
{Key: conf.SSODefaultPermission, Value: "0", Type: conf.TypeNumber, Group: model.SSO, Flag: model.PRIVATE},
{Key: conf.SSOCompatibilityMode, Value: "false", Type: conf.TypeBool, Group: model.SSO, Flag: model.PUBLIC},

// qbittorrent settings
{Key: conf.QbittorrentUrl, Value: "http://admin:adminadmin@localhost:8080/", Type: conf.TypeString, Group: model.SINGLE, Flag: model.PRIVATE},
Expand Down
1 change: 1 addition & 0 deletions internal/conf/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ const (
SSOAutoRegister = "sso_auto_register"
SSODefaultDir = "sso_default_dir"
SSODefaultPermission = "sso_default_permission"
SSOCompatibilityMode = "sso_compatibility_mode"

// qbittorrent
QbittorrentUrl = "qbittorrent_url"
Expand Down
145 changes: 96 additions & 49 deletions server/handles/ssologin.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"net/url"
"path"
"strings"
"time"

Expand Down Expand Up @@ -36,71 +37,85 @@ var opts = totp.ValidateOpts{

func SSOLoginRedirect(c *gin.Context) {
method := c.Query("method")
usecompatibility := setting.GetBool(conf.SSOCompatibilityMode)
enabled := setting.GetBool(conf.SSOLoginEnabled)
clientId := setting.GetStr(conf.SSOClientId)
platform := setting.GetStr(conf.SSOLoginPlatform)
var r_url string
var redirect_uri string
if enabled {
urlValues := url.Values{}
if method == "" {
common.ErrorStrResp(c, "no method provided", 400)
return
}
if !enabled {
common.ErrorStrResp(c, "Single sign-on is not enabled", 403)
return
}
urlValues := url.Values{}
if method == "" {
common.ErrorStrResp(c, "no method provided", 400)
return
}
if usecompatibility {
redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/" + method
} else {
redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + method
}
urlValues.Add("response_type", "code")
urlValues.Add("redirect_uri", redirect_uri)
urlValues.Add("client_id", clientId)
switch platform {
case "Github":
r_url = "https://github.com/login/oauth/authorize?"
urlValues.Add("scope", "read:user")
case "Microsoft":
r_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize?"
urlValues.Add("scope", "user.read")
urlValues.Add("response_mode", "query")
case "Google":
r_url = "https://accounts.google.com/o/oauth2/v2/auth?"
urlValues.Add("scope", "https://www.googleapis.com/auth/userinfo.profile")
case "Dingtalk":
r_url = "https://login.dingtalk.com/oauth2/auth?"
urlValues.Add("scope", "openid")
urlValues.Add("prompt", "consent")
urlValues.Add("response_type", "code")
urlValues.Add("redirect_uri", redirect_uri)
urlValues.Add("client_id", clientId)
switch platform {
case "Github":
r_url = "https://github.com/login/oauth/authorize?"
urlValues.Add("scope", "read:user")
case "Microsoft":
r_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize?"
urlValues.Add("scope", "user.read")
urlValues.Add("response_mode", "query")
case "Google":
r_url = "https://accounts.google.com/o/oauth2/v2/auth?"
urlValues.Add("scope", "https://www.googleapis.com/auth/userinfo.profile")
case "Dingtalk":
r_url = "https://login.dingtalk.com/oauth2/auth?"
urlValues.Add("scope", "openid")
urlValues.Add("prompt", "consent")
urlValues.Add("response_type", "code")
case "Casdoor":
endpoint := strings.TrimSuffix(setting.GetStr(conf.SSOEndpointName), "/")
r_url = endpoint + "/login/oauth/authorize?"
urlValues.Add("scope", "profile")
urlValues.Add("state", endpoint)
case "OIDC":
oauth2Config, err := GetOIDCClient(c)
if err != nil {
common.ErrorStrResp(c, err.Error(), 400)
return
}
// generate state parameter
state, err := totp.GenerateCodeCustom(base32.StdEncoding.EncodeToString([]byte(oauth2Config.ClientSecret)), time.Now(), opts)
if err != nil {
common.ErrorStrResp(c, err.Error(), 400)
return
}
c.Redirect(http.StatusFound, oauth2Config.AuthCodeURL(state))
case "Casdoor":
endpoint := strings.TrimSuffix(setting.GetStr(conf.SSOEndpointName), "/")
r_url = endpoint + "/login/oauth/authorize?"
urlValues.Add("scope", "profile")
urlValues.Add("state", endpoint)
case "OIDC":
oauth2Config, err := GetOIDCClient(c)
if err != nil {
common.ErrorStrResp(c, err.Error(), 400)
return
default:
common.ErrorStrResp(c, "invalid platform", 400)
}
// generate state parameter
state, err := totp.GenerateCodeCustom(base32.StdEncoding.EncodeToString([]byte(oauth2Config.ClientSecret)), time.Now(), opts)
if err != nil {
common.ErrorStrResp(c, err.Error(), 400)
return
}
c.Redirect(302, r_url+urlValues.Encode())
} else {
common.ErrorStrResp(c, "Single sign-on is not enabled", 403)
c.Redirect(http.StatusFound, oauth2Config.AuthCodeURL(state))
return
default:
common.ErrorStrResp(c, "invalid platform", 400)
return
}
c.Redirect(302, r_url+urlValues.Encode())
}

var ssoClient = resty.New().SetRetryCount(3)

func GetOIDCClient(c *gin.Context) (*oauth2.Config, error) {
var redirect_uri string
usecompatibility := setting.GetBool(conf.SSOCompatibilityMode)
argument := c.Query("method")
redirect_uri := common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + argument
if usecompatibility {
argument = path.Base(c.Request.URL.Path)
}
if usecompatibility {
redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/" + argument
} else {
redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + argument
}
endpoint := setting.GetStr(conf.SSOEndpointName)
provider, err := oidc.NewProvider(c, endpoint)
if err != nil {
Expand Down Expand Up @@ -152,7 +167,11 @@ func autoRegister(username, userID string, err error) (*model.User, error) {
}

func OIDCLoginCallback(c *gin.Context) {
usecompatibility := setting.GetBool(conf.SSOCompatibilityMode)
argument := c.Query("method")
if usecompatibility {
argument = path.Base(c.Request.URL.Path)
}
clientId := setting.GetStr(conf.SSOClientId)
endpoint := setting.GetStr(conf.SSOEndpointName)
provider, err := oidc.NewProvider(c, endpoint)
Expand Down Expand Up @@ -204,6 +223,10 @@ func OIDCLoginCallback(c *gin.Context) {
}
UserID := claims.Name
if argument == "get_sso_id" {
if usecompatibility {
c.Redirect(302, common.GetApiUrl(c.Request)+"/@manage?sso_id="+UserID)
return
}
html := fmt.Sprintf(`<!DOCTYPE html>
<head></head>
<body>
Expand All @@ -227,6 +250,10 @@ func OIDCLoginCallback(c *gin.Context) {
if err != nil {
common.ErrorResp(c, err, 400)
}
if usecompatibility {
c.Redirect(302, common.GetApiUrl(c.Request)+"/@login?token="+token)
return
}
html := fmt.Sprintf(`<!DOCTYPE html>
<head></head>
<body>
Expand All @@ -242,12 +269,18 @@ func OIDCLoginCallback(c *gin.Context) {

func SSOLoginCallback(c *gin.Context) {
enabled := setting.GetBool(conf.SSOLoginEnabled)
usecompatibility := setting.GetBool(conf.SSOCompatibilityMode)
if !enabled {
common.ErrorResp(c, errors.New("sso login is disabled"), 500)
return
}
argument := c.Query("method")
if usecompatibility {
argument = path.Base(c.Request.URL.Path)
}
if !utils.SliceContains([]string{"get_sso_id", "sso_get_token"}, argument) {
common.ErrorResp(c, errors.New("invalid request"), 500)
return
}
clientId := setting.GetStr(conf.SSOClientId)
platform := setting.GetStr(conf.SSOLoginPlatform)
Expand Down Expand Up @@ -317,12 +350,18 @@ func SSOLoginCallback(c *gin.Context) {
}).
Post(tokenUrl)
} else {
var redirect_uri string
if usecompatibility {
redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/" + argument
} else {
redirect_uri = common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + argument
}
resp, err = ssoClient.R().SetHeader("Accept", "application/json").
SetFormData(map[string]string{
"client_id": clientId,
"client_secret": clientSecret,
"code": callbackCode,
"redirect_uri": common.GetApiUrl(c.Request) + "/api/auth/sso_callback?method=" + argument,
"redirect_uri": redirect_uri,
"scope": scope,
}).SetFormData(additionalForm).Post(tokenUrl)
}
Expand All @@ -349,6 +388,10 @@ func SSOLoginCallback(c *gin.Context) {
return
}
if argument == "get_sso_id" {
if usecompatibility {
c.Redirect(302, common.GetApiUrl(c.Request)+"/@manage?sso_id="+userID)
return
}
html := fmt.Sprintf(`<!DOCTYPE html>
<head></head>
<body>
Expand All @@ -373,6 +416,10 @@ func SSOLoginCallback(c *gin.Context) {
if err != nil {
common.ErrorResp(c, err, 400)
}
if usecompatibility {
c.Redirect(302, common.GetApiUrl(c.Request)+"/@login?token="+token)
return
}
html := fmt.Sprintf(`<!DOCTYPE html>
<head></head>
<body>
Expand Down
2 changes: 2 additions & 0 deletions server/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ func Init(e *gin.Engine) {
// auth
api.GET("/auth/sso", handles.SSOLoginRedirect)
api.GET("/auth/sso_callback", handles.SSOLoginCallback)
api.GET("/auth/get_sso_id", handles.SSOLoginCallback)
api.GET("/auth/sso_get_token", handles.SSOLoginCallback)

//webauthn
webauthn.GET("/webauthn_begin_registration", handles.BeginAuthnRegistration)
Expand Down