diff --git a/autorest/adal/go.mod b/autorest/adal/go.mod index f5114e422..ad9d847ee 100644 --- a/autorest/adal/go.mod +++ b/autorest/adal/go.mod @@ -9,6 +9,7 @@ require ( github.com/Azure/go-autorest/logger v0.2.1 github.com/Azure/go-autorest/tracing v0.6.0 github.com/golang-jwt/jwt/v4 v4.5.0 + github.com/stretchr/testify v1.8.2 golang.org/x/crypto v0.6.0 ) diff --git a/autorest/adal/go.sum b/autorest/adal/go.sum index f910e3f9d..28eaaa4a5 100644 --- a/autorest/adal/go.sum +++ b/autorest/adal/go.sum @@ -8,8 +8,20 @@ github.com/Azure/go-autorest/logger v0.2.1 h1:IG7i4p/mDa2Ce4TRyAO8IHnVhAVF3RFU+Z github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8= github.com/Azure/go-autorest/tracing v0.6.0 h1:TYi4+3m5t6K48TGI9AUdb+IzbnSxvnvUMfuitfgcfuo= github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= @@ -39,3 +51,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/autorest/adal/token.go b/autorest/adal/token.go index c90209a94..2a24ab80c 100644 --- a/autorest/adal/token.go +++ b/autorest/adal/token.go @@ -127,6 +127,9 @@ type TokenRefreshCallback func(Token) error // TokenRefresh is a type representing a custom callback to refresh a token type TokenRefresh func(ctx context.Context, resource string) (*Token, error) +// JWTCallback is the type representing callback that will be called to get the federated OIDC JWT +type JWTCallback func() (string, error) + // Token encapsulates the access token used to authorize Azure requests. // https://docs.microsoft.com/en-us/azure/active-directory/develop/v1-oauth2-client-creds-grant-flow#service-to-service-access-token-response type Token struct { @@ -367,14 +370,18 @@ func (secret ServicePrincipalAuthorizationCodeSecret) MarshalJSON() ([]byte, err // ServicePrincipalFederatedSecret implements ServicePrincipalSecret for Federated JWTs. type ServicePrincipalFederatedSecret struct { - jwt string + jwtCallback JWTCallback } // SetAuthenticationValues is a method of the interface ServicePrincipalSecret. // It will populate the form submitted during OAuth Token Acquisition using a JWT signed by an OIDC issuer. -func (secret *ServicePrincipalFederatedSecret) SetAuthenticationValues(spt *ServicePrincipalToken, v *url.Values) error { +func (secret *ServicePrincipalFederatedSecret) SetAuthenticationValues(_ *ServicePrincipalToken, v *url.Values) error { + jwt, err := secret.jwtCallback() + if err != nil { + return err + } - v.Set("client_assertion", secret.jwt) + v.Set("client_assertion", jwt) v.Set("client_assertion_type", "urn:ietf:params:oauth:client-assertion-type:jwt-bearer") return nil } @@ -687,6 +694,8 @@ func NewServicePrincipalTokenFromAuthorizationCode(oauthConfig OAuthConfig, clie } // NewServicePrincipalTokenFromFederatedToken creates a ServicePrincipalToken from the supplied federated OIDC JWT. +// +// Deprecated: Use NewServicePrincipalTokenFromFederatedTokenWithCallback to refresh jwt dynamically. func NewServicePrincipalTokenFromFederatedToken(oauthConfig OAuthConfig, clientID string, jwt string, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { if err := validateOAuthConfig(oauthConfig); err != nil { return nil, err @@ -700,12 +709,37 @@ func NewServicePrincipalTokenFromFederatedToken(oauthConfig OAuthConfig, clientI if jwt == "" { return nil, fmt.Errorf("parameter 'jwt' cannot be empty") } + return NewServicePrincipalTokenFromFederatedTokenCallback( + oauthConfig, + clientID, + func() (string, error) { + return jwt, nil + }, + resource, + callbacks..., + ) +} + +// NewServicePrincipalTokenFromFederatedTokenCallback creates a ServicePrincipalToken from the supplied federated OIDC JWTCallback. +func NewServicePrincipalTokenFromFederatedTokenCallback(oauthConfig OAuthConfig, clientID string, jwtCallback JWTCallback, resource string, callbacks ...TokenRefreshCallback) (*ServicePrincipalToken, error) { + if err := validateOAuthConfig(oauthConfig); err != nil { + return nil, err + } + if err := validateStringParam(clientID, "clientID"); err != nil { + return nil, err + } + if err := validateStringParam(resource, "resource"); err != nil { + return nil, err + } + if jwtCallback == nil { + return nil, fmt.Errorf("parameter 'jwtCallback' cannot be empty") + } return NewServicePrincipalTokenWithSecret( oauthConfig, clientID, resource, &ServicePrincipalFederatedSecret{ - jwt: jwt, + jwtCallback: jwtCallback, }, callbacks..., ) diff --git a/autorest/adal/token_test.go b/autorest/adal/token_test.go index 4c9ecf342..728a0f153 100644 --- a/autorest/adal/token_test.go +++ b/autorest/adal/token_test.go @@ -25,8 +25,10 @@ import ( "io/ioutil" "math/big" "net/http" + "net/http/httptest" "net/url" "os" + "path/filepath" "reflect" "strconv" "strings" @@ -37,6 +39,7 @@ import ( "github.com/Azure/go-autorest/autorest/date" "github.com/Azure/go-autorest/autorest/mocks" jwt "github.com/golang-jwt/jwt/v4" + "github.com/stretchr/testify/assert" ) const ( @@ -206,6 +209,62 @@ func TestServicePrincipalTokenRefreshUsesCustomRefreshFunc(t *testing.T) { } } +func TestFederatedTokenRefreshUsesJwtCallback(t *testing.T) { + baseDir, err := os.MkdirTemp("", "") + assert.NoError(t, err) + jwtFile := filepath.Join(baseDir, "token") + + jwtCallback := func() (string, error) { + jwt, err := os.ReadFile(jwtFile) + if err != nil { + return "", fmt.Errorf("failed to read a file with a federated token: %w", err) + } + return string(jwt), nil + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + jwt := r.FormValue("client_assertion") + refreshToken := r.FormValue("refresh_token") + + if jwt == "aaa.aaa" { + w.Write([]byte(`{"access_token":"A","expires_in":"3600"}`)) + } else if jwt == "bbb.bbb" { + w.Write([]byte(`{"access_token":"B","expires_in":"3600","refresh_token":"R"}`)) + } else if refreshToken == "R" { + w.Write([]byte(`{"access_token":"C","expires_in":"3600"}`)) + } else { + w.WriteHeader(http.StatusBadRequest) + } + })) + + spt := newServicePrincipalTokenFederatedJwtCallback(t, jwtCallback, server.URL) + + // token file does not exist, no such file error + err = spt.refreshInternal(context.Background(), "") + assert.ErrorIs(t, err, os.ErrNotExist) + + // get jwt token from jwtFile + err = os.WriteFile(jwtFile, []byte("aaa.aaa"), 0600) + assert.NoError(t, err) + err = spt.refreshInternal(context.Background(), "") + assert.NoError(t, err) + assert.Equal(t, "A", spt.inner.Token.AccessToken) + + // jwtFile is refreshed + err = os.WriteFile(jwtFile, []byte("bbb.bbb"), 0600) + assert.NoError(t, err) + err = spt.refreshInternal(context.Background(), "") + assert.NoError(t, err) + assert.Equal(t, "B", spt.inner.Token.AccessToken) + // refresh_token is set + assert.Equal(t, "R", spt.inner.Token.RefreshToken) + + // after refresh_token is set, the callback won't be called + err = spt.refreshInternal(context.Background(), "") + assert.NoError(t, err) + assert.Equal(t, "C", spt.inner.Token.AccessToken) +} + func TestServicePrincipalTokenRefreshUsesPOST(t *testing.T) { spt := newServicePrincipalToken() @@ -1602,3 +1661,9 @@ func newServicePrincipalTokenFederatedJwt(t *testing.T) *ServicePrincipalToken { spt, _ := NewServicePrincipalTokenFromFederatedToken(TestOAuthConfig, "id", signedString, "resource") return spt } + +func newServicePrincipalTokenFederatedJwtCallback(t *testing.T, callback JWTCallback, fakeEndpoint string) *ServicePrincipalToken { + outhConfig, _ := NewOAuthConfig(fakeEndpoint, TestTenantID) + spt, _ := NewServicePrincipalTokenFromFederatedTokenCallback(*outhConfig, "id", callback, "resource") + return spt +}