Skip to content

Commit

Permalink
Merge pull request #380 from bsx/azure-mfa
Browse files Browse the repository at this point in the history
refactor ADFS provider to support AzureMFA as well
  • Loading branch information
Mark Wolfe committed Dec 28, 2019
2 parents 8657d7b + 93aed57 commit 28ff558
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 89 deletions.
205 changes: 117 additions & 88 deletions pkg/provider/adfs/adfs.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"net/url"
"strings"
"time"

"github.com/PuerkitoBio/goquery"
"github.com/pkg/errors"
Expand All @@ -25,6 +26,15 @@ type Client struct {
idpAccount *cfg.IDPAccount
}

type AuthResponseType int

const (
UNKNOWN AuthResponseType = iota
SAML_RESPONSE
MFA_PROMPT
AZURE_MFA_WAIT
)

// New create a new ADFS client
func New(idpAccount *cfg.IDPAccount) (*Client, error) {

Expand All @@ -49,19 +59,17 @@ func (ac *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error)

var authSubmitURL string
var samlAssertion string
var instructions string

awsURN := url.QueryEscape(ac.idpAccount.AmazonWebservicesURN)

adfsURL := fmt.Sprintf("%s/adfs/ls/IdpInitiatedSignOn.aspx?loginToRp=%s", loginDetails.URL, awsURN)

res, err := ac.client.Get(adfsURL)
if err != nil {
return samlAssertion, errors.Wrap(err, "error retrieving form")
}
mfaToken := loginDetails.MFAToken

doc, err := goquery.NewDocumentFromResponse(res)
doc, err := ac.get(adfsURL)
if err != nil {
return samlAssertion, errors.Wrap(err, "failed to build document from response")
return "", err
}

authForm := url.Values{}
Expand All @@ -82,33 +90,91 @@ func (ac *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error)
return samlAssertion, fmt.Errorf("unable to locate IDP authentication form submit URL")
}

//log.Printf("id authentication url: %s", authSubmitURL)
doc, err = ac.submit(authSubmitURL, authForm)

for {
responseType, samlAssertion, err := checkResponse(doc)

req, err := http.NewRequest("POST", authSubmitURL, strings.NewReader(authForm.Encode()))
switch responseType {
case SAML_RESPONSE:
return samlAssertion, err
case MFA_PROMPT:
otpForm := url.Values{}
if mfaToken == "" {
mfaToken = prompter.RequestSecurityCode("000000")
}

doc.Find("input").Each(func(i int, s *goquery.Selection) {
updateOTPFormData(otpForm, s, mfaToken)
})
doc, err = ac.submit(authSubmitURL, otpForm)
if err != nil {
return samlAssertion, errors.Wrap(err, "error retrieving mfa form results")
}
mfaToken = ""
case AZURE_MFA_WAIT:
azureForm := url.Values{}
doc.Find("input").Each(func(i int, s *goquery.Selection) {
updatePassthroughFormData(azureForm, s)
})
sel := doc.Find("p#instructions")
if sel.Index() != -1 {
if instructions != sel.Text() {
instructions = sel.Text()
fmt.Println(instructions)
}
}
time.Sleep(1 * time.Second)
doc, err = ac.submit(authSubmitURL, azureForm)
if err != nil {
return samlAssertion, errors.Wrap(err, "error retrieving mfa form results")
}
case UNKNOWN:
return samlAssertion, errors.New("unable to classify response from auth server")
}
}
}

func (ac *Client) get(url string) (*goquery.Document, error) {
res, err := ac.client.Get(url)
if err != nil {
return samlAssertion, errors.Wrap(err, "error building authentication request")
return nil, errors.Wrap(err, "error retieving form")

}
defer res.Body.Close()

req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
doc, err := goquery.NewDocumentFromReader(res.Body)
if err != nil {
return nil, errors.Wrap(err, "failed to build document from response")
}
return doc, nil
}

res, err = ac.client.Do(req)
func (ac *Client) submit(url string, form url.Values) (*goquery.Document, error) {
req, err := http.NewRequest("POST", url, strings.NewReader(form.Encode()))
if err != nil {
return samlAssertion, errors.Wrap(err, "error retrieving login form results")
return nil, errors.Wrap(err, "error building request")
}

switch ac.idpAccount.MFA {
case "VIP":
res, err = ac.vipMFA(authSubmitURL, loginDetails.MFAToken, res)
if err != nil {
return samlAssertion, errors.Wrap(err, "error retrieving mfa form results")
}
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")

res, err := ac.client.Do(req)
if err != nil {
return nil, errors.Wrap(err, "error submitting form")

}
defer res.Body.Close()

// just parse the response whether res is from the login form or MFA form
doc, err = goquery.NewDocumentFromResponse(res)
doc, err := goquery.NewDocumentFromReader(res.Body)
if err != nil {
return samlAssertion, errors.Wrap(err, "error retrieving login response body")
return nil, errors.Wrap(err, "failed to build document from response")
}
return doc, nil
}

func checkResponse(doc *goquery.Document) (AuthResponseType, string, error) {
samlAssertion := ""
responseType := UNKNOWN

doc.Find("input").Each(func(i int, s *goquery.Selection) {
name, ok := s.Attr("name")
Expand All @@ -121,67 +187,26 @@ func (ac *Client) Authenticate(loginDetails *creds.LoginDetails) (string, error)
log.Fatalf("unable to locate saml assertion value")
}
samlAssertion = val
responseType = SAML_RESPONSE
}
})

return samlAssertion, nil
}

// vipMFA when supplied with the the form response document attempt to extract the VIP mfa related field
// then use that to trigger a submit of the MFA security token
func (ac *Client) vipMFA(authSubmitURL string, mfaToken string, res *http.Response) (*http.Response, error) {

doc, err := goquery.NewDocumentFromResponse(res)
if err != nil {
return nil, errors.Wrap(err, "error retrieving saml response body")
}

otpForm := url.Values{}

vipIndex := doc.Find("input#authMethod[value=VIPAuthenticationProviderWindowsAccountName]").Index()

if vipIndex == -1 {
return res, nil // if we didn't find the MFA flag then just continue
}

if mfaToken == "" {
mfaToken = prompter.RequestSecurityCode("000000")
}

doc.Find("input").Each(func(i int, s *goquery.Selection) {
updateOTPFormData(otpForm, s, mfaToken)
})

doc.Find("form").Each(func(i int, s *goquery.Selection) {
action, ok := s.Attr("action")
if !ok {
return
if name == "AuthMethod" {
val, _ := s.Attr("value")
if val == "VIPAuthenticationProviderWindowsAccountName" {
responseType = MFA_PROMPT
}
if val == "AzureMfaAuthentication" {
responseType = AZURE_MFA_WAIT
}
}
if name == "VerificationCode" {
responseType = MFA_PROMPT
}
authSubmitURL = action
})

if authSubmitURL == "" {
return nil, fmt.Errorf("unable to locate IDP MFA form submit URL")
}

req, err := http.NewRequest("POST", authSubmitURL, strings.NewReader(otpForm.Encode()))
if err != nil {
return nil, errors.Wrap(err, "error building MFA request")
}

req.Header.Add("Content-Type", "application/x-www-form-urlencoded")

res, err = ac.client.Do(req)
if err != nil {
return nil, errors.Wrap(err, "error retrieving content")
}

return res, nil
return responseType, samlAssertion, nil
}

func updateFormData(authForm url.Values, s *goquery.Selection, user *creds.LoginDetails) {
name, ok := s.Attr("name")
// log.Printf("name = %s ok = %v", name, ok)
if !ok {
return
}
Expand All @@ -203,31 +228,35 @@ func updateFormData(authForm url.Values, s *goquery.Selection, user *creds.Login
authForm.Add(name, user.Password)
}
} else {
// pass through any hidden fields
val, ok := s.Attr("value")
if !ok {
return
}
authForm.Add(name, val)
updatePassthroughFormData(authForm, s)
}
}

func updateOTPFormData(otpForm url.Values, s *goquery.Selection, token string) {
name, ok := s.Attr("name")
// log.Printf("name = %s ok = %v", name, ok)
if !ok {
return
}
lname := strings.ToLower(name)
if strings.Contains(lname, "security_code") {
otpForm.Add(name, token)
} else if strings.Contains(lname, "verificationcode") {
otpForm.Add(name, token)
} else {
// pass through any hidden fields
val, ok := s.Attr("value")
if !ok {
return
}
otpForm.Add(name, val)
updatePassthroughFormData(otpForm, s)
}

}

func updatePassthroughFormData(otpForm url.Values, s *goquery.Selection) {
name, ok := s.Attr("name")
if !ok {
return
}
val, ok := s.Attr("value")
if !ok {
return
}
otpForm.Add(name, val)

}
2 changes: 1 addition & 1 deletion saml2aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ type ProviderList map[string][]string
// MFAsByProvider a list of providers with their respective supported MFAs
var MFAsByProvider = ProviderList{
"AzureAD": []string{"Auto", "PhoneAppOTP", "PhoneAppNotification", "OneWaySMS"},
"ADFS": []string{"Auto", "VIP"},
"ADFS": []string{"Auto", "VIP", "Azure"},
"ADFS2": []string{"Auto", "RSA"}, // nothing automatic about ADFS 2.x
"Ping": []string{"Auto"}, // automatically detects PingID
"PingOne": []string{"Auto"}, // automatically detects PingID
Expand Down

0 comments on commit 28ff558

Please sign in to comment.