Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor ADFS provider to support AzureMFA as well #380

Merged
merged 4 commits into from
Dec 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
205 changes: 117 additions & 88 deletions pkg/provider/adfs/adfs.go
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"net/url"
"strings"
"time"

"github.com/PuerkitoBio/goquery"
"github.com/pkg/errors"
Expand All @@ -25,6 +26,15 @@ type Client struct {
idpAccount *cfg.IDPAccount
}

type AuthResponseType int

const (
UNKNOWN AuthResponseType = iota
SAML_RESPONSE
MFA_PROMPT
AZURE_MFA_WAIT
)

// New create a new ADFS client
func New(idpAccount *cfg.IDPAccount) (*Client, error) {

Expand All @@ -49,19 +59,17 @@ func (ac *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error)

var authSubmitURL string
var samlAssertion string
var instructions string

awsURN := url.QueryEscape(ac.idpAccount.AmazonWebservicesURN)

adfsURL := fmt.Sprintf("%s/adfs/ls/IdpInitiatedSignOn.aspx?loginToRp=%s", loginDetails.URL, awsURN)

res, err := ac.client.Get(adfsURL)
if err != nil {
return samlAssertion, errors.Wrap(err, "error retrieving form")
}
mfaToken := loginDetails.MFAToken

doc, err := goquery.NewDocumentFromResponse(res)
doc, err := ac.get(adfsURL)
if err != nil {
return samlAssertion, errors.Wrap(err, "failed to build document from response")
return "", err
}

authForm := url.Values{}
Expand All @@ -82,33 +90,91 @@ func (ac *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error)
return samlAssertion, fmt.Errorf("unable to locate IDP authentication form submit URL")
}

//log.Printf("id authentication url: %s", authSubmitURL)
doc, err = ac.submit(authSubmitURL, authForm)

for {
responseType, samlAssertion, err := checkResponse(doc)

req, err := http.NewRequest("POST", authSubmitURL, strings.NewReader(authForm.Encode()))
switch responseType {
case SAML_RESPONSE:
return samlAssertion, err
case MFA_PROMPT:
otpForm := url.Values{}
if mfaToken == "" {
mfaToken = prompter.RequestSecurityCode("000000")
}

doc.Find("input").Each(func(i int, s *goquery.Selection) {
updateOTPFormData(otpForm, s, mfaToken)
})
doc, err = ac.submit(authSubmitURL, otpForm)
if err != nil {
return samlAssertion, errors.Wrap(err, "error retrieving mfa form results")
}
mfaToken = ""
case AZURE_MFA_WAIT:
azureForm := url.Values{}
doc.Find("input").Each(func(i int, s *goquery.Selection) {
updatePassthroughFormData(azureForm, s)
})
sel := doc.Find("p#instructions")
if sel.Index() != -1 {
if instructions != sel.Text() {
instructions = sel.Text()
fmt.Println(instructions)
}
}
time.Sleep(1 * time.Second)
doc, err = ac.submit(authSubmitURL, azureForm)
if err != nil {
return samlAssertion, errors.Wrap(err, "error retrieving mfa form results")
}
case UNKNOWN:
return samlAssertion, errors.New("unable to classify response from auth server")
}
}
}

func (ac *Client) get(url string) (*goquery.Document, error) {
res, err := ac.client.Get(url)
if err != nil {
return samlAssertion, errors.Wrap(err, "error building authentication request")
return nil, errors.Wrap(err, "error retieving form")

}
defer res.Body.Close()

req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
doc, err := goquery.NewDocumentFromReader(res.Body)
if err != nil {
return nil, errors.Wrap(err, "failed to build document from response")
}
return doc, nil
}

res, err = ac.client.Do(req)
func (ac *Client) submit(url string, form url.Values) (*goquery.Document, error) {
req, err := http.NewRequest("POST", url, strings.NewReader(form.Encode()))
if err != nil {
return samlAssertion, errors.Wrap(err, "error retrieving login form results")
return nil, errors.Wrap(err, "error building request")
}

switch ac.idpAccount.MFA {
case "VIP":
res, err = ac.vipMFA(authSubmitURL, loginDetails.MFAToken, res)
if err != nil {
return samlAssertion, errors.Wrap(err, "error retrieving mfa form results")
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")

res, err := ac.client.Do(req)
if err != nil {
return nil, errors.Wrap(err, "error submitting form")

}
defer res.Body.Close()

// just parse the response whether res is from the login form or MFA form
doc, err = goquery.NewDocumentFromResponse(res)
doc, err := goquery.NewDocumentFromReader(res.Body)
if err != nil {
return samlAssertion, errors.Wrap(err, "error retrieving login response body")
return nil, errors.Wrap(err, "failed to build document from response")
}
return doc, nil
}

func checkResponse(doc *goquery.Document) (AuthResponseType, string, error) {
samlAssertion := ""
responseType := UNKNOWN

doc.Find("input").Each(func(i int, s *goquery.Selection) {
name, ok := s.Attr("name")
Expand All @@ -121,67 +187,26 @@ func (ac *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error)
log.Fatalf("unable to locate saml assertion value")
}
samlAssertion = val
responseType = SAML_RESPONSE
}
})

return samlAssertion, nil
}

// vipMFA when supplied with the the form response document attempt to extract the VIP mfa related field
// then use that to trigger a submit of the MFA security token
func (ac *Client) vipMFA(authSubmitURL string, mfaToken string, res *http.Response) (*http.Response, error) {

doc, err := goquery.NewDocumentFromResponse(res)
if err != nil {
return nil, errors.Wrap(err, "error retrieving saml response body")
}

otpForm := url.Values{}

vipIndex := doc.Find("input#authMethod[value=VIPAuthenticationProviderWindowsAccountName]").Index()

if vipIndex == -1 {
return res, nil // if we didn't find the MFA flag then just continue
}

if mfaToken == "" {
mfaToken = prompter.RequestSecurityCode("000000")
}

doc.Find("input").Each(func(i int, s *goquery.Selection) {
updateOTPFormData(otpForm, s, mfaToken)
})

doc.Find("form").Each(func(i int, s *goquery.Selection) {
action, ok := s.Attr("action")
if !ok {
return
if name == "AuthMethod" {
val, _ := s.Attr("value")
if val == "VIPAuthenticationProviderWindowsAccountName" {
responseType = MFA_PROMPT
}
if val == "AzureMfaAuthentication" {
responseType = AZURE_MFA_WAIT
}
}
if name == "VerificationCode" {
responseType = MFA_PROMPT
}
authSubmitURL = action
})

if authSubmitURL == "" {
return nil, fmt.Errorf("unable to locate IDP MFA form submit URL")
}

req, err := http.NewRequest("POST", authSubmitURL, strings.NewReader(otpForm.Encode()))
if err != nil {
return nil, errors.Wrap(err, "error building MFA request")
}

req.Header.Add("Content-Type", "application/x-www-form-urlencoded")

res, err = ac.client.Do(req)
if err != nil {
return nil, errors.Wrap(err, "error retrieving content")
}

return res, nil
return responseType, samlAssertion, nil
}

func updateFormData(authForm url.Values, s *goquery.Selection, user *creds.LoginDetails) {
name, ok := s.Attr("name")
// log.Printf("name = %s ok = %v", name, ok)
if !ok {
return
}
Expand All @@ -203,31 +228,35 @@ func updateFormData(authForm url.Values, s *goquery.Selection, user *creds.Login
authForm.Add(name, user.Password)
}
} else {
// pass through any hidden fields
val, ok := s.Attr("value")
if !ok {
return
}
authForm.Add(name, val)
updatePassthroughFormData(authForm, s)
}
}

func updateOTPFormData(otpForm url.Values, s *goquery.Selection, token string) {
name, ok := s.Attr("name")
// log.Printf("name = %s ok = %v", name, ok)
if !ok {
return
}
lname := strings.ToLower(name)
if strings.Contains(lname, "security_code") {
otpForm.Add(name, token)
} else if strings.Contains(lname, "verificationcode") {
otpForm.Add(name, token)
} else {
// pass through any hidden fields
val, ok := s.Attr("value")
if !ok {
return
}
otpForm.Add(name, val)
updatePassthroughFormData(otpForm, s)
}

}

func updatePassthroughFormData(otpForm url.Values, s *goquery.Selection) {
name, ok := s.Attr("name")
if !ok {
return
}
val, ok := s.Attr("value")
if !ok {
return
}
otpForm.Add(name, val)

}
2 changes: 1 addition & 1 deletion saml2aws.go
Expand Up @@ -30,7 +30,7 @@ type ProviderList map[string][]string
// MFAsByProvider a list of providers with their respective supported MFAs
var MFAsByProvider = ProviderList{
"AzureAD": []string{"Auto", "PhoneAppOTP", "PhoneAppNotification", "OneWaySMS"},
"ADFS": []string{"Auto", "VIP"},
"ADFS": []string{"Auto", "VIP", "Azure"},
"ADFS2": []string{"Auto", "RSA"}, // nothing automatic about ADFS 2.x
"Ping": []string{"Auto"}, // automatically detects PingID
"PingOne": []string{"Auto"}, // automatically detects PingID
Expand Down