From 242063f3728cf82d9bab33d34e2123baa5fca9f3 Mon Sep 17 00:00:00 2001 From: Pavel Nikolov Date: Wed, 18 Jul 2018 10:32:59 +1000 Subject: [PATCH] Add support for OneLogin --- Gopkg.lock | 103 +++++++- cmd/saml2aws/commands/configure.go | 14 + cmd/saml2aws/commands/login.go | 4 +- cmd/saml2aws/main.go | 8 +- helper/credentials/saml.go | 11 +- input.go | 22 +- pkg/cfg/cfg.go | 22 +- pkg/creds/creds.go | 10 +- pkg/flags/flags.go | 12 + pkg/provider/onelogin/mock/provider.go | 30 +++ pkg/provider/onelogin/onelogin.go | 341 +++++++++++++++++++++++++ pkg/provider/onelogin/onelogin_test.go | 40 +++ saml2aws.go | 7 + saml2aws_test.go | 2 +- 14 files changed, 602 insertions(+), 24 deletions(-) create mode 100644 pkg/provider/onelogin/mock/provider.go create mode 100644 pkg/provider/onelogin/onelogin.go create mode 100644 pkg/provider/onelogin/onelogin_test.go diff --git a/Gopkg.lock b/Gopkg.lock index ba4cb3755..e85c9fd6b 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -2,51 +2,66 @@ [[projects]] + digest = "1:f8810373a82e746e36ae75d76e979562b19640986c2fc2429e44776c70bd1839" name = "github.com/AlecAivazis/survey" packages = ["."] + pruneopts = "UT" revision = "e752db451e07e09c7d7dc8cada807a44bdb0fd47" version = "v1.5.3" [[projects]] branch = "master" + digest = "1:bf641dbf159d010db3c6da01311806f3b85a97cf26da3299bfa690ee4caf9ae4" name = "github.com/Azure/go-ntlmssp" packages = ["."] + pruneopts = "UT" revision = "4b934ac9dad38d389d34f0b98d98b2467c422012" [[projects]] + digest = "1:a62f6ed230a8cd138a9efbe718e7d0b0294f139266f5f55cd942769a9aac8de2" name = "github.com/PuerkitoBio/goquery" packages = ["."] + pruneopts = "UT" revision = "dc2ec5c7ca4d9aae063b79b9f581dd3ea6afd2b2" version = "v1.4.1" [[projects]] + digest = "1:c06d9e11d955af78ac3bbb26bd02e01d2f61f689e1a3bce2ef6fb683ef8a7f2d" name = "github.com/alecthomas/kingpin" packages = ["."] + pruneopts = "UT" revision = "947dcec5ba9c011838740e680966fd7087a71d0d" version = "v2.2.6" [[projects]] branch = "master" + digest = "1:45a787c1adea69a03a5384865b307c7a72bb28bd5844bd57679d889a726a588b" name = "github.com/alecthomas/template" packages = [ ".", - "parse" + "parse", ] + pruneopts = "UT" revision = "a0175ee3bccc567396460bf5acd36800cb10c49c" [[projects]] branch = "master" + digest = "1:c198fdc381e898e8fb62b8eb62758195091c313ad18e52a3067366e1dda2fb3c" name = "github.com/alecthomas/units" packages = ["."] + pruneopts = "UT" revision = "2efee857e7cfd4f3d0138cc3cbb1b4966962b93a" [[projects]] + digest = "1:66b3310cf22cdc96c35ef84ede4f7b9b370971c4025f394c89a2638729653b11" name = "github.com/andybalholm/cascadia" packages = ["."] + pruneopts = "UT" revision = "901648c87902174f774fac311d7f176f8647bdaa" version = "v1.0.0" [[projects]] + digest = "1:572a558f325deb4f667adda03a0242c4e98a20ddb47b22a89f2b2ba8ddf7aeae" name = "github.com/aws/aws-sdk-go" packages = [ "aws", @@ -74,152 +89,196 @@ "private/protocol/query/queryutil", "private/protocol/rest", "private/protocol/xml/xmlutil", - "service/sts" + "service/sts", ] + pruneopts = "UT" revision = "bfc1a07cf158c30c41a3eefba8aae043d0bb5bff" version = "v1.14.8" [[projects]] + digest = "1:f25dfe4538112147d22e4c8437efbc2e216112bec37277534ffe146935064ea8" name = "github.com/beevik/etree" packages = ["."] + pruneopts = "UT" revision = "9d7e8feddccb4ed1b8afb54e368bd323d2ff652c" version = "v1.0.1" [[projects]] + digest = "1:889290ee5c1f1888baa7caa2b4cdfa8a6abcfb86dd772fe6470ad7925cc44bff" name = "github.com/briandowns/spinner" packages = ["."] + pruneopts = "UT" revision = "48dbb65d7bd5c74ab50d53d04c949f20e3d14944" version = "1.0" [[projects]] + digest = "1:586992e81213a853bfe5c102709c0c92020d21b386907ceae783f13bbe899ad7" name = "github.com/danieljoos/wincred" packages = ["."] + pruneopts = "UT" revision = "412b574fb496839b312a75fba146bd32a89001cf" version = "v1.0.1" [[projects]] + digest = "1:a2c1d0e43bd3baaa071d1b9ed72c27d78169b2b269f71c105ac4ba34b1be4a39" name = "github.com/davecgh/go-spew" packages = ["spew"] + pruneopts = "UT" revision = "346938d642f2ec3594ed81d874461961cd0faa76" version = "v1.1.0" [[projects]] + digest = "1:865079840386857c809b72ce300be7580cb50d3d3129ce11bf9aa6ca2bc1934a" name = "github.com/fatih/color" packages = ["."] + pruneopts = "UT" revision = "5b77d2a35fb0ede96d138fc9a99f5c9b6aef11b4" version = "v1.7.0" [[projects]] + digest = "1:fb46255681497314debedde38b64be32a75bae50bad107586c22f1662bf2d352" name = "github.com/go-ini/ini" packages = ["."] + pruneopts = "UT" revision = "06f5f3d67269ccec1fe5fe4134ba6e982984f7f5" version = "v1.37.0" [[projects]] + digest = "1:e22af8c7518e1eab6f2eab2b7d7558927f816262586cd6ed9f349c97a6c285c4" name = "github.com/jmespath/go-jmespath" packages = ["."] + pruneopts = "UT" revision = "0b12d6b5" [[projects]] + digest = "1:c658e84ad3916da105a761660dcaeb01e63416c8ec7bc62256a9b411a05fcd67" name = "github.com/mattn/go-colorable" packages = ["."] + pruneopts = "UT" revision = "167de6bfdfba052fa6b2d3664c8f5272e23c9072" version = "v0.0.9" [[projects]] + digest = "1:d4d17353dbd05cb52a2a52b7fe1771883b682806f68db442b436294926bbfafb" name = "github.com/mattn/go-isatty" packages = ["."] + pruneopts = "UT" revision = "0360b2af4f38e8d38c7fce2a9f4e702702d73a39" version = "v0.0.3" [[projects]] branch = "master" + digest = "1:2b32af4d2a529083275afc192d1067d8126b578c7a9613b26600e4df9c735155" name = "github.com/mgutz/ansi" packages = ["."] + pruneopts = "UT" revision = "9520e82c474b0a04dd04f8a40959027271bab992" [[projects]] branch = "master" + digest = "1:8eb17c2ec4df79193ae65b621cd1c0c4697db3bc317fe6afdc76d7f2746abd05" name = "github.com/mitchellh/go-homedir" packages = ["."] + pruneopts = "UT" revision = "3864e76763d94a6df2f9960b16a20a33da9f9a66" [[projects]] + digest = "1:40e195917a951a8bf867cd05de2a46aaf1806c50cf92eebf4c16f78cd196f747" name = "github.com/pkg/errors" packages = ["."] + pruneopts = "UT" revision = "645ef00459ed84a119197bfb8d8205042c6df63d" version = "v0.8.0" [[projects]] + digest = "1:0028cb19b2e4c3112225cd871870f2d9cf49b9b4276531f03438a88e94be86fe" name = "github.com/pmezard/go-difflib" packages = ["difflib"] + pruneopts = "UT" revision = "792786c7400a136282c1664665ae0a8db921c6c2" version = "v1.0.0" [[projects]] + digest = "1:9e9193aa51197513b3abcb108970d831fbcf40ef96aa845c4f03276e1fa316d2" name = "github.com/sirupsen/logrus" packages = ["."] + pruneopts = "UT" revision = "c155da19408a8799da419ed3eeb0cb5db0ad5dbc" version = "v1.0.5" [[projects]] + digest = "1:ac83cf90d08b63ad5f7e020ef480d319ae890c208f8524622a2f3136e2686b02" name = "github.com/stretchr/objx" packages = ["."] + pruneopts = "UT" revision = "477a77ecc69700c7cdeb1fa9e129548e1c1c393c" version = "v0.1.1" [[projects]] + digest = "1:0ce644ed4e959cb140cb8ece625650cdad11499671a00f5878ccd0c38c334010" name = "github.com/stretchr/testify" packages = [ "assert", "mock", - "require" + "require", ] + pruneopts = "UT" revision = "f35b8ab0b5a2cef36673838d662e249dd9c94686" version = "v1.2.2" [[projects]] + digest = "1:5035b1f8df9ce0a77ce87eb485cb5f10e285b6772c953ab9fcd69f418620a998" name = "github.com/tidwall/gjson" packages = ["."] + pruneopts = "UT" revision = "afaeb9562041a8018c74e006551143666aed08bf" version = "v1.1.1" [[projects]] branch = "master" + digest = "1:d3f968e2a2c9f8506ed44b01b605ade0176ba6cf73ff679073e77cfdef2c0d55" name = "github.com/tidwall/match" packages = ["."] + pruneopts = "UT" revision = "1731857f09b1f38450e2c12409748407822dc6be" [[projects]] branch = "master" + digest = "1:0db47756e69b5d1749fbfc501974538c6885bb9593b62c54736c9de035dbe057" name = "golang.org/x/crypto" packages = [ "md4", - "ssh/terminal" + "ssh/terminal", ] + pruneopts = "UT" revision = "a8fb68e7206f8c78be19b432c58eb52a6aa34462" [[projects]] branch = "master" + digest = "1:d4c427db92c1d8c0df6c3ba45ce63c10d967dbb31d8bf5a87e16f1714ae370d2" name = "golang.org/x/net" packages = [ "html", "html/atom", "idna", - "publicsuffix" + "publicsuffix", ] + pruneopts = "UT" revision = "db08ff08e8622530d9ed3a0e8ac279f6d4c02196" [[projects]] branch = "master" + digest = "1:63c79e21224f8c86558234dbadf41df6cd77d9312dc60326129200b84e32d1d6" name = "golang.org/x/sys" packages = [ "unix", - "windows" + "windows", ] + pruneopts = "UT" revision = "8014b7b116a67fea23fbb82cd834c9ad656ea44b" [[projects]] + digest = "1:7509ba4347d1f8de6ae9be8818b0cd1abc3deeffe28aeaf4be6d4b6b5178d9ca" name = "golang.org/x/text" packages = [ "collate", @@ -235,29 +294,55 @@ "unicode/bidi", "unicode/cldr", "unicode/norm", - "unicode/rangetable" + "unicode/rangetable", ] + pruneopts = "UT" revision = "f21a4dfb5e38f5895301dc265a8def02365cc3d0" version = "v0.3.0" [[projects]] + digest = "1:7a08a75e5e4fad22b3377922414fbb3449a4a7c868f9da41e8acb2daac40497d" name = "gopkg.in/AlecAivazis/survey.v1" packages = [ "core", - "terminal" + "terminal", ] + pruneopts = "UT" revision = "e752db451e07e09c7d7dc8cada807a44bdb0fd47" version = "v1.5.3" [[projects]] + digest = "1:fb46255681497314debedde38b64be32a75bae50bad107586c22f1662bf2d352" name = "gopkg.in/ini.v1" packages = ["."] + pruneopts = "UT" revision = "06f5f3d67269ccec1fe5fe4134ba6e982984f7f5" version = "v1.37.0" [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "c1e9a52c74db3154d2934ab91fcfaed070abe1044c208e51283f8e6cf510a5c7" + input-imports = [ + "github.com/AlecAivazis/survey", + "github.com/Azure/go-ntlmssp", + "github.com/PuerkitoBio/goquery", + "github.com/alecthomas/kingpin", + "github.com/aws/aws-sdk-go/aws", + "github.com/aws/aws-sdk-go/aws/awserr", + "github.com/aws/aws-sdk-go/aws/session", + "github.com/aws/aws-sdk-go/service/sts", + "github.com/beevik/etree", + "github.com/briandowns/spinner", + "github.com/danieljoos/wincred", + "github.com/mitchellh/go-homedir", + "github.com/pkg/errors", + "github.com/sirupsen/logrus", + "github.com/stretchr/testify/assert", + "github.com/stretchr/testify/mock", + "github.com/stretchr/testify/require", + "github.com/tidwall/gjson", + "golang.org/x/net/publicsuffix", + "gopkg.in/ini.v1", + ] solver-name = "gps-cdcl" solver-version = 1 diff --git a/cmd/saml2aws/commands/configure.go b/cmd/saml2aws/commands/configure.go index fdc46e393..63c46fb46 100644 --- a/cmd/saml2aws/commands/configure.go +++ b/cmd/saml2aws/commands/configure.go @@ -3,6 +3,7 @@ package commands import ( "fmt" "os" + "path" "github.com/pkg/errors" "github.com/versent/saml2aws" @@ -10,8 +11,12 @@ import ( "github.com/versent/saml2aws/pkg/cfg" "github.com/versent/saml2aws/pkg/flags" "github.com/versent/saml2aws/pkg/prompter" + "github.com/versent/saml2aws/pkg/provider/onelogin" ) +// OneLoginOAuthPath is the path used to generate OAuth token in order to access OneLogin's API. +const OneLoginOAuthPath = "/auth/oauth2/v2/token" + // Configure configure account profiles func Configure(configFlags *flags.CommonFlags) error { @@ -78,5 +83,14 @@ func storeCredentials(configFlags *flags.CommonFlags, account *cfg.IDPAccount) e fmt.Println("No password supplied") } } + if account.Provider == onelogin.ProviderName { + if configFlags.ClientID == "" || configFlags.ClientSecret == "" { + fmt.Println("OneLogin provider requires --client_id and --client_secret flags to be set.") + os.Exit(1) + } + if err := credentials.SaveCredentials(path.Join(account.URL, OneLoginOAuthPath), configFlags.ClientID, configFlags.ClientSecret); err != nil { + return errors.Wrap(err, "error storing client_id and client_secret in keychain") + } + } return nil } diff --git a/cmd/saml2aws/commands/login.go b/cmd/saml2aws/commands/login.go index c88492957..bfb414f57 100644 --- a/cmd/saml2aws/commands/login.go +++ b/cmd/saml2aws/commands/login.go @@ -128,7 +128,7 @@ func resolveLoginDetails(account *cfg.IDPAccount, loginFlags *flags.LoginExecFla fmt.Printf("Using IDP Account %s to access %s %s\n", loginFlags.CommonFlags.IdpAccount, account.Provider, account.URL) - err := credentials.LookupCredentials(loginDetails) + err := credentials.LookupCredentials(loginDetails, account.Provider) if err != nil { if !credentials.IsErrCredentialsNotFound(err) { return nil, errors.Wrap(err, "error loading saved password") @@ -154,7 +154,7 @@ func resolveLoginDetails(account *cfg.IDPAccount, loginFlags *flags.LoginExecFla return loginDetails, nil } - err = saml2aws.PromptForLoginDetails(loginDetails) + err = saml2aws.PromptForLoginDetails(loginDetails, account.Provider) if err != nil { return nil, errors.Wrap(err, "Error occurred accepting input") } diff --git a/cmd/saml2aws/main.go b/cmd/saml2aws/main.go index 1c8257fac..ffdf43eef 100644 --- a/cmd/saml2aws/main.go +++ b/cmd/saml2aws/main.go @@ -46,12 +46,12 @@ func main() { // Settings not related to commands verbose := app.Flag("verbose", "Enable verbose logging").Bool() - provider := app.Flag("provider", "This flag it is obsolete see https://github.com/Versent/saml2aws#adding-idp-accounts.").Short('i').Enum("ADFS", "ADFS2", "Ping", "JumpCloud", "Okta", "KeyCloak") + provider := app.Flag("provider", "This flag it is obsolete see https://github.com/Versent/saml2aws#adding-idp-accounts.").Short('i').Enum("ADFS", "ADFS2", "Ping", "JumpCloud", "Okta", "OneLogin", "KeyCloak") // Common (to all commands) settings commonFlags := new(flags.CommonFlags) app.Flag("idp-account", "The name of the configured IDP account").Short('a').Default("default").StringVar(&commonFlags.IdpAccount) - app.Flag("idp-provider", "The configured IDP provider").EnumVar(&commonFlags.IdpProvider, "ADFS", "ADFS2", "Ping", "JumpCloud", "Okta", "KeyCloak") + app.Flag("idp-provider", "The configured IDP provider").EnumVar(&commonFlags.IdpProvider, "ADFS", "ADFS2", "Ping", "JumpCloud", "Okta", "OneLogin", "KeyCloak") app.Flag("mfa", "The name of the mfa").EnumVar(&commonFlags.MFA, "Auto", "VIP") app.Flag("skip-verify", "Skip verification of server certificate.").Short('s').BoolVar(&commonFlags.SkipVerify) app.Flag("url", "The URL of the SAML IDP server used to login.").StringVar(&commonFlags.URL) @@ -65,6 +65,10 @@ func main() { // `configure` command and settings cmdConfigure := app.Command("configure", "Configure a new IDP account.") + cmdConfigure.Flag("app-id", "OneLogin app id required for SAML assertion.").Envar("ONELOGIN_APP_ID").StringVar(&commonFlags.AppID) + cmdConfigure.Flag("client-id", "OneLogin client id, used to generate API access token.").Envar("ONELOGIN_CLIENT_ID").StringVar(&commonFlags.ClientID) + cmdConfigure.Flag("client-secret", "OneLogin client secret, used to generate API access token.").Envar("ONELOGIN_CLIENT_SECRET").StringVar(&commonFlags.ClientSecret) + cmdConfigure.Flag("subdomain", "OneLogin subdomain of your company account.").Envar("ONELOGIN_SUBDOMAIN").StringVar(&commonFlags.Subdomain) configFlags := commonFlags // `login` command and settings diff --git a/helper/credentials/saml.go b/helper/credentials/saml.go index dd20a9c4c..9e986035d 100644 --- a/helper/credentials/saml.go +++ b/helper/credentials/saml.go @@ -2,12 +2,13 @@ package credentials import ( "fmt" + "path" "github.com/versent/saml2aws/pkg/creds" ) // LookupCredentials lookup an existing set of credentials and validate it. -func LookupCredentials(loginDetails *creds.LoginDetails) error { +func LookupCredentials(loginDetails *creds.LoginDetails, provider string) error { username, password, err := CurrentHelper.Get(fmt.Sprintf("%s", loginDetails.URL)) if err != nil { @@ -17,6 +18,14 @@ func LookupCredentials(loginDetails *creds.LoginDetails) error { loginDetails.Username = username loginDetails.Password = password + if provider == "OneLogin" { + id, secret, err := CurrentHelper.Get(path.Join(loginDetails.URL, "/auth/oauth2/v2/token")) + if err != nil { + return err + } + loginDetails.ClientID = id + loginDetails.ClientSecret = secret + } return nil } diff --git a/input.go b/input.go index 0f461520d..f12040c88 100644 --- a/input.go +++ b/input.go @@ -2,12 +2,12 @@ package saml2aws import ( "fmt" + "sort" "github.com/pkg/errors" "github.com/versent/saml2aws/pkg/cfg" "github.com/versent/saml2aws/pkg/creds" "github.com/versent/saml2aws/pkg/prompter" - "sort" ) // PromptForConfigurationDetails prompt the user to present their hostname, username and mfa @@ -44,11 +44,18 @@ func PromptForConfigurationDetails(idpAccount *cfg.IDPAccount) error { fmt.Println("") + if idpAccount.Provider == "OneLogin" { + idpAccount.AppID = prompter.String("App ID", idpAccount.AppID) + fmt.Println("") + idpAccount.Subdomain = prompter.String("Subdomain", idpAccount.Subdomain) + fmt.Println("") + } + return nil } // PromptForLoginDetails prompt the user to present their username, password -func PromptForLoginDetails(loginDetails *creds.LoginDetails) error { +func PromptForLoginDetails(loginDetails *creds.LoginDetails, provider string) error { fmt.Println("To use saved password just hit enter.") @@ -57,8 +64,17 @@ func PromptForLoginDetails(loginDetails *creds.LoginDetails) error { if enteredPassword := prompter.Password("Password"); enteredPassword != "" { loginDetails.Password = enteredPassword } - fmt.Println("") + if provider == "OneLogin" { + if enteredClientID := prompter.Password("Client ID"); enteredClientID != "" { + loginDetails.ClientID = enteredClientID + } + fmt.Println("") + if enteredCientSecret := prompter.Password("Client Secret"); enteredCientSecret != "" { + loginDetails.ClientSecret = enteredCientSecret + } + fmt.Println("") + } return nil } diff --git a/pkg/cfg/cfg.go b/pkg/cfg/cfg.go index ec1b52895..8962c89ad 100644 --- a/pkg/cfg/cfg.go +++ b/pkg/cfg/cfg.go @@ -31,6 +31,7 @@ const ( // IDPAccount saml IDP account type IDPAccount struct { + AppID string `ini:"app_id"` // used by OneLogin URL string `ini:"url"` Username string `ini:"username"` Provider string `ini:"provider"` @@ -40,10 +41,18 @@ type IDPAccount struct { AmazonWebservicesURN string `ini:"aws_urn"` SessionDuration int `ini:"aws_session_duration"` Profile string `ini:"aws_profile"` + Subdomain string `ini:"subdomain"` // used by OneLogin } func (ia IDPAccount) String() string { - return fmt.Sprintf(`account { + var appID string + if ia.Provider == "OneLogin" { + appID = fmt.Sprintf(` + AppID: %s + Subdomain: %s`, ia.AppID, ia.Subdomain) + } + + return fmt.Sprintf(`account {%s URL: %s Username: %s Provider: %s @@ -52,11 +61,20 @@ func (ia IDPAccount) String() string { AmazonWebservicesURN: %s SessionDuration: %d Profile: %s -}`, ia.URL, ia.Username, ia.Provider, ia.MFA, ia.SkipVerify, ia.AmazonWebservicesURN, ia.SessionDuration, ia.Profile) +}`, appID, ia.URL, ia.Username, ia.Provider, ia.MFA, ia.SkipVerify, ia.AmazonWebservicesURN, ia.SessionDuration, ia.Profile) } // Validate validate the required / expected fields are set func (ia *IDPAccount) Validate() error { + if ia.Provider == "OneLogin" { + if ia.AppID == "" { + return errors.New("app ID empty in idp account") + } + if ia.Subdomain == "" { + return errors.New("subdomain empty in idp account") + } + } + if ia.URL == "" { return errors.New("URL empty in idp account") } diff --git a/pkg/creds/creds.go b/pkg/creds/creds.go index 45120f0f5..83a358325 100644 --- a/pkg/creds/creds.go +++ b/pkg/creds/creds.go @@ -4,10 +4,12 @@ import "errors" // LoginDetails used to authenticate type LoginDetails struct { - Username string - Password string - MFAToken string - URL string + ClientID string // used by OneLogin + ClientSecret string // used by OneLogin + Username string + Password string + MFAToken string + URL string } // Validate validate the login details diff --git a/pkg/flags/flags.go b/pkg/flags/flags.go index eccc01eb5..eb0ab6109 100644 --- a/pkg/flags/flags.go +++ b/pkg/flags/flags.go @@ -6,6 +6,9 @@ import ( // CommonFlags flags common to all of the `saml2aws` commands (except `help`) type CommonFlags struct { + AppID string + ClientID string + ClientSecret string IdpAccount string IdpProvider string MFA string @@ -19,6 +22,7 @@ type CommonFlags struct { SkipPrompt bool SkipVerify bool Profile string + Subdomain string } // RoleSupplied role arn has been passed as a flag @@ -33,6 +37,10 @@ type LoginExecFlags struct { // ApplyFlagOverrides overrides IDPAccount with command line settings func ApplyFlagOverrides(commonFlags *CommonFlags, account *cfg.IDPAccount) { + if commonFlags.AppID != "" { + account.AppID = commonFlags.AppID + } + if commonFlags.URL != "" { account.URL = commonFlags.URL } @@ -64,4 +72,8 @@ func ApplyFlagOverrides(commonFlags *CommonFlags, account *cfg.IDPAccount) { if commonFlags.Profile != "" { account.Profile = commonFlags.Profile } + + if commonFlags.Subdomain != "" { + account.Subdomain = commonFlags.Subdomain + } } diff --git a/pkg/provider/onelogin/mock/provider.go b/pkg/provider/onelogin/mock/provider.go new file mode 100644 index 000000000..9a7fceced --- /dev/null +++ b/pkg/provider/onelogin/mock/provider.go @@ -0,0 +1,30 @@ +package mock + +import ( + "net/http" + "net/http/httptest" + "testing" +) + +// New returns an instance of the mock OneLogin indetity provider. +func New(t *testing.T, requests []ExpectedRequest) *httptest.Server { + h := mockHandler(t, requests) + return httptest.NewServer(h) +} + +// ExpectedRequest represents a request that the mock identity provider expects and its predefined response. +type ExpectedRequest struct { + reqBody []byte + reqHeaders http.Header + reqMethod string + reqPath string + + resBody []byte + resHeaders http.Header + resStatus int +} + +func mockHandler(t *testing.T, requests []ExpectedRequest) http.Handler { + // WIP + return http.NotFoundHandler() +} diff --git a/pkg/provider/onelogin/onelogin.go b/pkg/provider/onelogin/onelogin.go new file mode 100644 index 000000000..6a90a3a5c --- /dev/null +++ b/pkg/provider/onelogin/onelogin.go @@ -0,0 +1,341 @@ +package onelogin + +import ( + "bytes" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "strings" + "time" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/versent/saml2aws/pkg/cfg" + "github.com/versent/saml2aws/pkg/creds" + "github.com/versent/saml2aws/pkg/prompter" + "github.com/versent/saml2aws/pkg/provider" +) + +// MFA identifier constants. +const ( + IdentifierOneLoginProtectMfa = "OneLogin Protect" + IdentifierSmsMfa = "OneLogin SMS" + IdentifierTotpMfa = "Google Authenticator" + + MessageMFARequired = "MFA is required for this user" + MessageSuccess = "Success" + TypePending = "pending" + TypeSuccess = "success" +) + +// ProviderName constant holds the name of the OneLogin IDP. +const ProviderName = "OneLogin" + +var logger = logrus.WithField("provider", ProviderName) + +var ( + supportedMfaOptions = map[string]string{ + IdentifierOneLoginProtectMfa: "OLP MFA authentication", + IdentifierSmsMfa: "SMS MFA authentication", + IdentifierTotpMfa: "TOTP MFA authentication", + } +) + +// Client is a wrapper representing a OneLogin SAML client. +type Client struct { + // AppID represents the OneLogin connector id. + AppID string + // Client is the HTTP client for accessing the IDP provider's APIs. + Client *provider.HTTPClient + // Subdomain is the organisation subdomain in OneLogin. + Subdomain string +} + +// AuthRequest represents an mfa OneLogin request. +type AuthRequest struct { + AppID string `json:"app_id"` + Password string `json:"password"` + Subdomain string `json:"subdomain"` + Username string `json:"username_or_email"` + IPAddress string `json:"ip_address,omitempty"` +} + +// VerifyRequest represents an mfa verify request +type VerifyRequest struct { + AppID string `json:"app_id"` + DeviceID string `json:"device_id"` + OTPToken string `json:"otp_token,omitempty"` + StateToken string `json:"state_token"` +} + +// New creates a new OneLogin client. +func New(idpAccount *cfg.IDPAccount) (*Client, error) { + tr := provider.NewDefaultTransport(idpAccount.SkipVerify) + client, err := provider.NewHTTPClient(tr) + if err != nil { + return nil, errors.Wrap(err, "error building http client") + } + // Assign a response validator to ensure all responses are either success or a redirect. + // This is to avoid have explicit checks for every single response. + client.CheckResponseStatus = provider.SuccessOrRedirectResponseValidator + return &Client{AppID: idpAccount.AppID, Client: client, Subdomain: idpAccount.Subdomain}, nil +} + +// Authenticate logs into OneLogin and returns a SAML response. +func (c *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error) { + providerURL, err := url.Parse(loginDetails.URL) + if err != nil { + return "", errors.Wrap(err, "error building providerURL") + } + host := providerURL.Host + + logger.Debug("Generating OneLogin access token") + // request oAuth token required for working with OneLogin APIs + oauthToken, err := generateToken(c, loginDetails, host) + if err != nil { + return "", errors.Wrap(err, "failed to generate oauth token") + } + + logger.Debug("Retrieved OneLogin OAuth token:", oauthToken) + + authReq := AuthRequest{Username: loginDetails.Username, Password: loginDetails.Password, AppID: c.AppID, Subdomain: c.Subdomain} + var authBody bytes.Buffer + err = json.NewEncoder(&authBody).Encode(authReq) + if err != nil { + return "", errors.Wrap(err, "error encoding authreq") + } + + authSubmitURL := fmt.Sprintf("https://%s/api/1/saml_assertion", host) + + req, err := http.NewRequest("POST", authSubmitURL, &authBody) + if err != nil { + return "", errors.Wrap(err, "error building authentication request") + } + + addContentHeaders(req) + addAuthHeader(req, oauthToken) + + logger.Debug("Requesting SAML Assertion") + + // request the SAML assertion. For more details check https://developers.onelogin.com/api-docs/1/saml-assertions/generate-saml-assertion + res, err := c.Client.Do(req) + if err != nil { + return "", errors.Wrap(err, "error retrieving auth response") + } + defer res.Body.Close() + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + return "", errors.Wrap(err, "error retrieving body from response") + } + + resp := string(body) + + logger.Debug("SAML Assertion response code:", res.StatusCode) + logger.Debug("SAML Assertion response body:", resp) + + authError := gjson.Get(resp, "status.error").Bool() + authMessage := gjson.Get(resp, "status.message").String() + authType := gjson.Get(resp, "status.type").String() + if authError || authType != TypeSuccess { + return "", errors.New(authMessage) + } + + authData := gjson.Get(resp, "data") + var samlAssertion string + switch authMessage { + // MFA not required + case MessageSuccess: + if authData.IsArray() { + return "", errors.New("invalid SAML assertion returned") + } + samlAssertion = authData.String() + case MessageMFARequired: + if !authData.IsArray() { + return "", errors.New("invalid MFA data returned") + } + logger.Debug("Verifying MFA") + samlAssertion, err = verifyMFA(c, oauthToken, c.AppID, resp) + if err != nil { + return "", errors.Wrap(err, "error verifying MFA") + } + default: + return "", errors.New("unexpected SAML assertion response") + } + + return samlAssertion, nil +} + +// generateToken is used to generate access token for all OneLogin APIs. +// For more infor read https://developers.onelogin.com/api-docs/1/oauth20-tokens/generate-tokens-2 +func generateToken(oc *Client, loginDetails *creds.LoginDetails, host string) (string, error) { + oauthTokenURL := fmt.Sprintf("https://%s/auth/oauth2/v2/token", host) + + req, err := http.NewRequest("POST", oauthTokenURL, strings.NewReader(`{"grant_type":"client_credentials"}`)) + if err != nil { + return "", errors.Wrap(err, "error building oauth token request") + } + + addContentHeaders(req) + req.SetBasicAuth(loginDetails.ClientID, loginDetails.ClientSecret) + + res, err := oc.Client.Do(req) + if err != nil { + return "", errors.Wrap(err, "error retrieving oauth token response") + } + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + return "", errors.Wrap(err, "error reading oauth token response") + } + defer res.Body.Close() + + return gjson.Get(string(body), "access_token").String(), nil +} + +func addAuthHeader(r *http.Request, oauthToken string) { + r.Header.Add("Authorization", "bearer: "+oauthToken) +} + +func addContentHeaders(r *http.Request) { + r.Header.Add("Content-Type", "application/json") + r.Header.Add("Accept", "application/json") +} + +// verifyMFA is used to either prompt to user for one time password or request approval using push notification. +// For more details check https://developers.onelogin.com/api-docs/1/saml-assertions/verify-factor +func verifyMFA(oc *Client, oauthToken, appID, resp string) (string, error) { + stateToken := gjson.Get(resp, "data.0.state_token").String() + // choose an mfa option if there are multiple enabled + var option int + var mfaOptions []string + for _, id := range gjson.Get(resp, "data.0.devices.#.device_type").Array() { + identifier := id.String() + if val, ok := supportedMfaOptions[identifier]; ok { + mfaOptions = append(mfaOptions, val) + } else { + mfaOptions = append(mfaOptions, "UNSUPPORTED: "+identifier) + } + } + if len(mfaOptions) > 1 { + option = prompter.Choose("Select which MFA option to use", mfaOptions) + } + + factorID := gjson.Get(resp, fmt.Sprintf("data.0.devices.%d.device_id", option)).String() + callbackURL := gjson.Get(resp, "data.0.callback_url").String() + mfaIdentifer := gjson.Get(resp, fmt.Sprintf("data.0.devices.%d.device_type", option)).String() + mfaDeviceID := gjson.Get(resp, fmt.Sprintf("data.0.devices.%d.device_id", option)).String() + + logger.WithField("factorID", factorID).WithField("callbackURL", callbackURL).WithField("mfaIdentifer", mfaIdentifer).Debug("MFA") + + if _, ok := supportedMfaOptions[mfaIdentifer]; !ok { + return "", errors.New("unsupported mfa provider") + } + + // get signature & callback + verifyReq := VerifyRequest{AppID: appID, DeviceID: mfaDeviceID, StateToken: stateToken} + verifyBody := new(bytes.Buffer) + err := json.NewEncoder(verifyBody).Encode(verifyReq) + if err != nil { + return "", errors.Wrap(err, "error encoding verifyReq") + } + + req, err := http.NewRequest("POST", callbackURL, verifyBody) + if err != nil { + return "", errors.Wrap(err, "error building verify request") + } + + addContentHeaders(req) + addAuthHeader(req, oauthToken) + + res, err := oc.Client.Do(req) + if err != nil { + return "", errors.Wrap(err, "error retrieving verify response") + } + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + return "", errors.Wrap(err, "error retrieving body from response") + } + resp = string(body) + + switch mfa := mfaIdentifer; mfa { + case IdentifierSmsMfa, IdentifierTotpMfa: + verifyCode := prompter.StringRequired("Enter verification code") + tokenReq := VerifyRequest{AppID: appID, DeviceID: mfaDeviceID, StateToken: stateToken, OTPToken: verifyCode} + tokenBody := new(bytes.Buffer) + json.NewEncoder(tokenBody).Encode(tokenReq) + + req, err = http.NewRequest("POST", callbackURL, tokenBody) + if err != nil { + return "", errors.Wrap(err, "error building token post request") + } + + addContentHeaders(req) + addAuthHeader(req, oauthToken) + + res, err := oc.Client.Do(req) + if err != nil { + return "", errors.Wrap(err, "error retrieving token post response") + } + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + return "", errors.Wrap(err, "error retrieving body from response") + } + + resp = string(body) + + return gjson.Get(resp, "data").String(), nil + + case IdentifierOneLoginProtectMfa: + fmt.Printf("\nWaiting for approval, please check your OneLogin Protect app ...") + + started := time.Now() + // loop until success, error, or timeout + for { + if time.Since(started) > time.Minute { + fmt.Println(" Timeout") + return "", errors.New("User did not accept MFA in time") + } + + logger.Debug("Verifying with OneLogin Protect") + res, err = oc.Client.Do(req) + if err != nil { + return "", errors.Wrap(err, "error retrieving verify response") + } + + body, err = ioutil.ReadAll(res.Body) + if err != nil { + return "", errors.Wrap(err, "error retrieving body from response") + } + + message := gjson.Get(string(body), "status.message").String() + + // on 'error' status + if gjson.Get(string(body), "status.error").Bool() { + return "", errors.New(message) + } + + switch gjson.Get(string(body), "status.type").String() { + case TypePending: + time.Sleep(time.Second) + fmt.Printf(".") + + case TypeSuccess: + fmt.Println(" Approved") + return gjson.Get(string(body), "data").String(), nil + + default: + fmt.Println(" Error:") + return "", errors.New("unsupported response from OneLogin, please raise ticket with saml2aws") + } + } + } + + // catch all + return "", errors.New("no mfa options provided") +} diff --git a/pkg/provider/onelogin/onelogin_test.go b/pkg/provider/onelogin/onelogin_test.go new file mode 100644 index 000000000..e54c685d3 --- /dev/null +++ b/pkg/provider/onelogin/onelogin_test.go @@ -0,0 +1,40 @@ +package onelogin_test + +import ( + "testing" + + "github.com/versent/saml2aws/pkg/creds" + "github.com/versent/saml2aws/pkg/provider" + "github.com/versent/saml2aws/pkg/provider/onelogin" +) + +func TestClient_Authenticate(t *testing.T) { + type fields struct { + client *provider.HTTPClient + } + type args struct { + loginDetails *creds.LoginDetails + } + tests := []struct { + name string + fields fields + args args + want string + wantErr bool + }{ + // TODO: Add test cases. + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + oc := &onelogin.Client{Client: tt.fields.client} + got, err := oc.Authenticate(tt.args.loginDetails) + if (err != nil) != tt.wantErr { + t.Errorf("Client.Authenticate() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Client.Authenticate() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/saml2aws.go b/saml2aws.go index bee5747b4..38a61f1e1 100644 --- a/saml2aws.go +++ b/saml2aws.go @@ -12,6 +12,7 @@ import ( "github.com/versent/saml2aws/pkg/provider/jumpcloud" "github.com/versent/saml2aws/pkg/provider/keycloak" "github.com/versent/saml2aws/pkg/provider/okta" + "github.com/versent/saml2aws/pkg/provider/onelogin" "github.com/versent/saml2aws/pkg/provider/pingfed" "github.com/versent/saml2aws/pkg/provider/pingone" ) @@ -27,6 +28,7 @@ var MFAsByProvider = ProviderList{ "PingOne": []string{"Auto"}, // automatically detects PingID "JumpCloud": []string{"Auto"}, "Okta": []string{"Auto"}, // automatically detects DUO, SMS and ToTP + "OneLogin": []string{"Auto"}, // automatically detects OneLogin Protect, SMS and ToTP "KeyCloak": []string{"Auto"}, // automatically detects ToTP "GoogleApps": []string{"Auto"}, // automatically detects ToTP } @@ -104,6 +106,11 @@ func NewSAMLClient(idpAccount *cfg.IDPAccount) (SAMLClient, error) { return nil, fmt.Errorf("Invalid MFA type: %v for %v provider", idpAccount.MFA, idpAccount.Provider) } return okta.New(idpAccount) + case "OneLogin": + if invalidMFA(idpAccount.Provider, idpAccount.MFA) { + return nil, fmt.Errorf("Invalid MFA type: %v for %v provider", idpAccount.MFA, idpAccount.Provider) + } + return onelogin.New(idpAccount) case "KeyCloak": if invalidMFA(idpAccount.Provider, idpAccount.MFA) { return nil, fmt.Errorf("Invalid MFA type: %v for %v provider", idpAccount.MFA, idpAccount.Provider) diff --git a/saml2aws_test.go b/saml2aws_test.go index 74d3f7cd8..c9f842ef9 100644 --- a/saml2aws_test.go +++ b/saml2aws_test.go @@ -10,7 +10,7 @@ func TestProviderList_Keys(t *testing.T) { names := MFAsByProvider.Names() - require.Len(t, names, 8) + require.Len(t, names, 9) }