Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require TLS protected endpoints for key and SAS authentication #21832

Merged
merged 1 commit into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/azcore/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* Fixed an issue that could cause some allowed HTTP header values to not show up in logs.
* Include error text instead of error type in traces when the transport returns an error.
* Fixed an issue that could cause an HTTP/2 request to hang when the TCP connection becomes unresponsive.
* Block key and SAS authentication for non TLS protected endpoints.

### Other Changes

Expand Down
11 changes: 9 additions & 2 deletions sdk/azcore/runtime/policy_bearer_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ func (b *BearerTokenPolicy) authenticateAndAuthorize(req *policy.Request) func(p

// Do authorizes a request with a bearer token
func (b *BearerTokenPolicy) Do(req *policy.Request) (*http.Response, error) {
if strings.ToLower(req.Raw().URL.Scheme) != "https" {
return nil, shared.NonRetriableError(errors.New("bearer token authentication is not permitted for non TLS protected (https) endpoints"))
if err := checkHTTPSForAuth(req); err != nil {
return nil, err
jhendrixMSFT marked this conversation as resolved.
Show resolved Hide resolved
}
var err error
if b.authzHandler.OnRequest != nil {
Expand Down Expand Up @@ -103,3 +103,10 @@ func (b *BearerTokenPolicy) Do(req *policy.Request) (*http.Response, error) {
}
return res, err
}

func checkHTTPSForAuth(req *policy.Request) error {
if strings.ToLower(req.Raw().URL.Scheme) != "https" {
return shared.NonRetriableError(errors.New("authenticated requests are not permitted for non TLS protected (https) endpoints"))
}
return nil
}
9 changes: 9 additions & 0 deletions sdk/azcore/runtime/policy_bearer_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,12 @@ func TestBearerTokenPolicy_RequiresHTTPS(t *testing.T) {
var nre errorinfo.NonRetriable
require.ErrorAs(t, err, &nre)
}

func TestCheckHTTPSForAuth(t *testing.T) {
req, err := NewRequest(context.Background(), http.MethodGet, "http://contoso.com")
require.NoError(t, err)
require.Error(t, checkHTTPSForAuth(req))
req, err = NewRequest(context.Background(), http.MethodGet, "https://contoso.com")
require.NoError(t, err)
require.NoError(t, checkHTTPSForAuth(req))
}
3 changes: 3 additions & 0 deletions sdk/azcore/runtime/policy_key_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ func NewKeyCredentialPolicy(cred *exported.KeyCredential, header string, options

// Do implementes the Do method on the [policy.Polilcy] interface.
func (k *KeyCredentialPolicy) Do(req *policy.Request) (*http.Response, error) {
if err := checkHTTPSForAuth(req); err != nil {
return nil, err
}
val := exported.KeyCredentialGet(k.cred)
if k.prefix != "" {
val = k.prefix + val
Expand Down
21 changes: 19 additions & 2 deletions sdk/azcore/runtime/policy_key_credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func TestKeyCredentialPolicy(t *testing.T) {
return &http.Response{}, nil
}), policy)

req, err := NewRequest(context.Background(), http.MethodGet, "http://contoso.com")
req, err := NewRequest(context.Background(), http.MethodGet, "https://contoso.com")
require.NoError(t, err)

_, err = pl.Do(req)
Expand All @@ -42,9 +42,26 @@ func TestKeyCredentialPolicy(t *testing.T) {
return &http.Response{}, nil
}), policy)

req, err = NewRequest(context.Background(), http.MethodGet, "http://contoso.com")
req, err = NewRequest(context.Background(), http.MethodGet, "https://contoso.com")
require.NoError(t, err)

_, err = pl.Do(req)
require.NoError(t, err)
}

func TestKeyCredentialPolicy_RequiresHTTPS(t *testing.T) {
cred := exported.NewKeyCredential("foo")

policy := NewKeyCredentialPolicy(cred, "fake-auth", nil)
require.NotNil(t, policy)

pl := exported.NewPipeline(shared.TransportFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{}, nil
}), policy)

req, err := NewRequest(context.Background(), http.MethodGet, "http://contoso.com")
require.NoError(t, err)

_, err = pl.Do(req)
require.Error(t, err)
}
3 changes: 3 additions & 0 deletions sdk/azcore/runtime/policy_sas_credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ func NewSASCredentialPolicy(cred *exported.SASCredential, header string, options

// Do implementes the Do method on the [policy.Polilcy] interface.
func (k *SASCredentialPolicy) Do(req *policy.Request) (*http.Response, error) {
if err := checkHTTPSForAuth(req); err != nil {
return nil, err
}
req.Raw().Header.Add(k.header, exported.SASCredentialGet(k.cred))
return req.Next()
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,26 @@ func TestSASCredentialPolicy(t *testing.T) {
return &http.Response{}, nil
}), policy)

req, err := NewRequest(context.Background(), http.MethodGet, "http://contoso.com")
req, err := NewRequest(context.Background(), http.MethodGet, "https://contoso.com")
require.NoError(t, err)

_, err = pl.Do(req)
require.NoError(t, err)
}

func TestSASCredentialPolicy_RequiresHTTPS(t *testing.T) {
cred := exported.NewSASCredential("foo")

policy := NewSASCredentialPolicy(cred, "fake-auth", nil)
require.NotNil(t, policy)

pl := exported.NewPipeline(shared.TransportFunc(func(req *http.Request) (*http.Response, error) {
return &http.Response{}, nil
}), policy)

req, err := NewRequest(context.Background(), http.MethodGet, "http://contoso.com")
require.NoError(t, err)

_, err = pl.Do(req)
require.Error(t, err)
}