Skip to content

Commit

Permalink
Merge pull request #55 from Azure/aaqib-m/rate-limited-client
Browse files Browse the repository at this point in the history
feat: client side rate limiting with IMDS
  • Loading branch information
aaqib-m committed Mar 26, 2024
2 parents 3c37f33 + 9cbb334 commit 95c7ea6
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 17 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/Azure/msi-acrpull

go 1.20
go 1.21

require (
github.com/go-logr/logr v1.2.4
Expand Down
45 changes: 45 additions & 0 deletions pkg/authorizer/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package authorizer

import (
"context"
"fmt"
"net/http"

"golang.org/x/time/rate"
)

const (
defaultRPS = 1
defaultBurst = 5
)

type rateLimitedClient struct {
httpClient *http.Client
rateLimiter *rate.Limiter
}

func newRateLimitedClient() *rateLimitedClient {
return newRateLimitedClientWithRPS(defaultRPS, defaultBurst)
}

func newRateLimitedClientWithRPS(rps float64, burst int) *rateLimitedClient {
client := &rateLimitedClient{
httpClient: http.DefaultClient,
rateLimiter: rate.NewLimiter(rate.Limit(rps), burst),
}
return client
}

func (client *rateLimitedClient) Do(req *http.Request) (*http.Response, error) {
ctx := context.Background()
err := client.rateLimiter.Wait(ctx)
if err != nil {
return nil, fmt.Errorf("failed to wait for rate limit token: %w", err)
}

resp, err := client.httpClient.Do(req)
if err != nil {
return nil, err
}
return resp, nil
}
5 changes: 3 additions & 2 deletions pkg/authorizer/token_exchanger.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@ import (
// TokenExchanger is an instance of ACRTokenExchanger
type TokenExchanger struct {
acrServerScheme string
client *rateLimitedClient
}

// NewTokenExchanger returns a new token exchanger
func NewTokenExchanger() *TokenExchanger {
return &TokenExchanger{
acrServerScheme: "https",
client: newRateLimitedClient(),
}
}

Expand Down Expand Up @@ -55,11 +57,10 @@ func (te *TokenExchanger) ExchangeACRAccessToken(armToken types.AccessToken, acr
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
req.Header.Add("Content-Length", strconv.Itoa(len(parameters.Encode())))

client := &http.Client{}
var resp *http.Response
defer closeResponse(resp)

resp, err = client.Do(req)
resp, err = te.client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to send token exchange request: %w", err)
}
Expand Down
10 changes: 7 additions & 3 deletions pkg/authorizer/token_exchanger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ var _ = Describe("Token Exchanger Tests", func() {
ghttp.RespondWithJSONEncoded(200, tokenResp),
))

te := newTestTokenExchanger()
te := newTestTokenExchanger(server)
token, err := te.ExchangeACRAccessToken(armToken, ul.Host)

Expect(err).To(BeNil())
Expand All @@ -73,7 +73,7 @@ var _ = Describe("Token Exchanger Tests", func() {
ghttp.RespondWith(403, "Unauthorized"),
))

te := newTestTokenExchanger()
te := newTestTokenExchanger(server)
token, err := te.ExchangeACRAccessToken(armToken, ul.Host)

Expect(err).NotTo(BeNil())
Expand All @@ -85,8 +85,12 @@ var _ = Describe("Token Exchanger Tests", func() {
})
})

func newTestTokenExchanger() *TokenExchanger {
func newTestTokenExchanger(server *ghttp.Server) *TokenExchanger {
client := newRateLimitedClient()
client.httpClient = server.HTTPTestServer.Client()

return &TokenExchanger{
acrServerScheme: "http",
client: client,
}
}
5 changes: 3 additions & 2 deletions pkg/authorizer/token_retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type TokenRetriever struct {
metadataEndpoint string
cache sync.Map
cacheExpiration time.Duration
client *rateLimitedClient
}

type cachedToken struct {
Expand All @@ -39,6 +40,7 @@ func NewTokenRetriever() *TokenRetriever {
metadataEndpoint: msiMetadataEndpoint,
cache: sync.Map{},
cacheExpiration: time.Duration(defaultCacheExpirationInSeconds) * time.Second,
client: newRateLimitedClient(),
}
}

Expand Down Expand Up @@ -98,11 +100,10 @@ func (tr *TokenRetriever) refreshToken(clientID, resourceID string) (types.Acces
}
req.Header.Add("Metadata", "true")

client := &http.Client{}
var resp *http.Response
defer closeResponse(resp)

resp, err = client.Do(req)
resp, err = tr.client.Do(req)
if err != nil {
return "", fmt.Errorf("failed to send metadata endpoint request: %w", err)
}
Expand Down
22 changes: 13 additions & 9 deletions pkg/authorizer/token_retriever_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ var _ = Describe("Token Retriever Tests", func() {
ghttp.RespondWithJSONEncoded(200, tokenResp),
))

tr := newTestTokenRetriever(server.URL(), defaultCacheExpirationInSeconds)
tr := newTestTokenRetriever(server, defaultCacheExpirationInSeconds)
token, err := tr.AcquireARMToken("", testResourceID)

Expect(err).To(BeNil())
Expand All @@ -60,7 +60,7 @@ var _ = Describe("Token Retriever Tests", func() {
ghttp.RespondWithJSONEncoded(200, tokenResp),
))

tr := newTestTokenRetriever(server.URL(), defaultCacheExpirationInSeconds)
tr := newTestTokenRetriever(server, defaultCacheExpirationInSeconds)
token, err := tr.AcquireARMToken("", testResourceID)

os.Unsetenv(customARMResourceEnvVar)

Check failure on line 66 in pkg/authorizer/token_retriever_test.go

View workflow job for this annotation

GitHub Actions / Build

Error return value of `os.Unsetenv` is not checked (errcheck)
Expand All @@ -82,7 +82,7 @@ var _ = Describe("Token Retriever Tests", func() {
ghttp.RespondWithJSONEncoded(200, tokenResp),
))

tr := newTestTokenRetriever(server.URL(), defaultCacheExpirationInSeconds)
tr := newTestTokenRetriever(server, defaultCacheExpirationInSeconds)
token, err := tr.AcquireARMToken(testClientID, "")

Expect(err).To(BeNil())
Expand All @@ -97,7 +97,7 @@ var _ = Describe("Token Retriever Tests", func() {
ghttp.RespondWith(404, ""),
))

tr := newTestTokenRetriever(server.URL(), defaultCacheExpirationInSeconds)
tr := newTestTokenRetriever(server, defaultCacheExpirationInSeconds)
token, err := tr.AcquireARMToken(testClientID, "")

Expect(err).NotTo(BeNil())
Expand All @@ -118,7 +118,7 @@ var _ = Describe("Token Retriever Tests", func() {
ghttp.RespondWithJSONEncoded(200, tokenResp),
))

tr := newTestTokenRetriever(server.URL(), defaultCacheExpirationInSeconds*1000)
tr := newTestTokenRetriever(server, defaultCacheExpirationInSeconds*1000)
token, err := tr.AcquireARMToken(testClientID, "")
Expect(err).To(BeNil())
Expect(token).To(Equal(armToken))
Expand All @@ -142,7 +142,7 @@ var _ = Describe("Token Retriever Tests", func() {
ghttp.RespondWithJSONEncoded(200, tokenResp),
))

tr := newTestTokenRetriever(server.URL(), defaultCacheExpirationInSeconds*1000)
tr := newTestTokenRetriever(server, defaultCacheExpirationInSeconds*1000)
token, err := tr.AcquireARMToken("", testResourceID)
Expect(err).To(BeNil())
Expect(token).To(Equal(armToken))
Expand Down Expand Up @@ -171,7 +171,7 @@ var _ = Describe("Token Retriever Tests", func() {
))

// set cache expire immediately
tr := newTestTokenRetriever(server.URL(), 0)
tr := newTestTokenRetriever(server, 0)
token, err := tr.AcquireARMToken(testClientID, "")
Expect(err).To(BeNil())
Expect(token).To(Equal(armToken))
Expand All @@ -185,10 +185,14 @@ var _ = Describe("Token Retriever Tests", func() {
})
})

func newTestTokenRetriever(metadataEndpoint string, cacheExpirationInMilliSeconds int) *TokenRetriever {
func newTestTokenRetriever(server *ghttp.Server, cacheExpirationInMilliSeconds int) *TokenRetriever {
client := newRateLimitedClient()
client.httpClient = server.HTTPTestServer.Client()

return &TokenRetriever{
metadataEndpoint: metadataEndpoint,
metadataEndpoint: server.URL(),
cache: sync.Map{},
cacheExpiration: time.Duration(cacheExpirationInMilliSeconds) * time.Millisecond,
client: client,
}
}

0 comments on commit 95c7ea6

Please sign in to comment.