Skip to content

Commit

Permalink
Return expiry time along with auth token (#6)
Browse files Browse the repository at this point in the history
Co-authored-by: Mohit Paliwal <mpaliwa@amazon.com>
  • Loading branch information
mohitpali and Mohit Paliwal committed Aug 8, 2023
1 parent bf63255 commit d7d4a0f
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 24 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ type MSKAccessTokenProvider struct {
}

func (m *MSKAccessTokenProvider) Token() (*sarama.AccessToken, error) {
token, err := signer.GenerateAuthToken(context.TODO(), "<region>")
token, _, err := signer.GenerateAuthToken(context.TODO(), "<region>")
return &sarama.AccessToken{Token: token}, err}

func main() {
Expand Down Expand Up @@ -132,7 +132,7 @@ type MSKAccessTokenProvider struct {
}

func (m *MSKAccessTokenProvider) Token() (*sarama.AccessToken, error) {
token, err := signer.GenerateAuthToken(context.TODO(), "<region>")
token, _, err := signer.GenerateAuthToken(context.TODO(), "<region>")
return &sarama.AccessToken{Token: token}, err
}

Expand Down Expand Up @@ -229,7 +229,7 @@ func consumeMessages(consumer sarama.Consumer) {
* To use IAM credentials from a named profile, update the Token() function:
```go
func (t *MSKAccessTokenProvider) Token() (*sarama.AccessToken, error) {
token, err := signer.GenerateAuthTokenFromProfile(context.TODO(), "<region>", "<namedProfile>")
token, _, err := signer.GenerateAuthTokenFromProfile(context.TODO(), "<region>", "<namedProfile>")
return &sarama.AccessToken{Token: token}, err
}
```
Expand All @@ -238,14 +238,14 @@ func (t *MSKAccessTokenProvider) Token() (*sarama.AccessToken, error) {

```go
func (t *MSKAccessTokenProvider) Token() (*sarama.AccessToken, error) {
token, err := signer.GenerateAuthTokenFromRole(context.TODO(), "<region>", "<my-role-arn>", "my-sts-session-name")
token, _, err := signer.GenerateAuthTokenFromRole(context.TODO(), "<region>", "<my-role-arn>", "my-sts-session-name")
return &sarama.AccessToken{Token: token}, err
}
```
* To use IAM credentials from a credentials provider, update the Token() function:
```go
func (t *MSKAccessTokenProvider) Token() (*sarama.AccessToken, error) {
token, err := signer.GenerateAuthTokenFromCredentialsProvider(context.TODO(), "<region>", <MyCredentialsProvider>)
token, _, err := signer.GenerateAuthTokenFromCredentialsProvider(context.TODO(), "<region>", <MyCredentialsProvider>)
return &sarama.AccessToken{Token: token}, err
}
```
Expand Down
60 changes: 46 additions & 14 deletions signer/msk_auth_token_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,22 @@ var (

// GenerateAuthToken generates base64 encoded signed url as auth token from default credentials.
// Loads the IAM credentials from default credentials provider chain.
func GenerateAuthToken(ctx context.Context, region string) (string, error) {
func GenerateAuthToken(ctx context.Context, region string) (string, int64, error) {
credentials, err := loadDefaultCredentials(ctx, region)

if err != nil {
return "", fmt.Errorf("failed to load credentials: %w", err)
return "", 0, fmt.Errorf("failed to load credentials: %w", err)
}

return constructAuthToken(ctx, region, credentials)
}

// GenerateAuthTokenFromProfile generates base64 encoded signed url as auth token by loading IAM credentials from an AWS named profile.
func GenerateAuthTokenFromProfile(ctx context.Context, region string, awsProfile string) (string, error) {
func GenerateAuthTokenFromProfile(ctx context.Context, region string, awsProfile string) (string, int64, error) {
credentials, err := loadCredentialsFromProfile(ctx, region, awsProfile)

if err != nil {
return "", fmt.Errorf("failed to load credentials: %w", err)
return "", 0, fmt.Errorf("failed to load credentials: %w", err)
}

return constructAuthToken(ctx, region, credentials)
Expand All @@ -60,14 +60,14 @@ func GenerateAuthTokenFromProfile(ctx context.Context, region string, awsProfile
// GenerateAuthTokenFromRole generates base64 encoded signed url as auth token by loading IAM credentials from an aws role Arn
func GenerateAuthTokenFromRole(
ctx context.Context, region string, roleArn string, stsSessionName string,
) (string, error) {
) (string, int64, error) {
if stsSessionName == "" {
stsSessionName = DefaultSessionName
}
credentials, err := loadCredentialsFromRoleArn(ctx, region, roleArn, stsSessionName)

if err != nil {
return "", fmt.Errorf("failed to load credentials: %w", err)
return "", 0, fmt.Errorf("failed to load credentials: %w", err)
}

return constructAuthToken(ctx, region, credentials)
Expand All @@ -77,11 +77,11 @@ func GenerateAuthTokenFromRole(
// from an aws credentials provider
func GenerateAuthTokenFromCredentialsProvider(
ctx context.Context, region string, credentialsProvider aws.CredentialsProvider,
) (string, error) {
) (string, int64, error) {
credentials, err := loadCredentialsFromCredentialsProvider(ctx, credentialsProvider)

if err != nil {
return "", fmt.Errorf("failed to load credentials: %w", err)
return "", 0, fmt.Errorf("failed to load credentials: %w", err)
}

return constructAuthToken(ctx, region, credentials)
Expand Down Expand Up @@ -155,29 +155,34 @@ func loadCredentialsFromCredentialsProvider(
}

// Constructs Auth Token.
func constructAuthToken(ctx context.Context, region string, credentials *aws.Credentials) (string, error) {
func constructAuthToken(ctx context.Context, region string, credentials *aws.Credentials) (string, int64, error) {
endpointURL := fmt.Sprintf(endpointURLTemplate, region)

if credentials == nil || credentials.AccessKeyID == "" || credentials.SecretAccessKey == "" {
return "", fmt.Errorf("aws credentials cannot be empty")
return "", 0, fmt.Errorf("aws credentials cannot be empty")
}

req, err := buildRequest(DefaultExpirySeconds, endpointURL)
if err != nil {
return "", fmt.Errorf("failed to build request for signing: %w", err)
return "", 0, fmt.Errorf("failed to build request for signing: %w", err)
}

signedURL, err := signRequest(ctx, req, region, credentials)
if err != nil {
return "", fmt.Errorf("failed to sign request with aws sig v4: %w", err)
return "", 0, fmt.Errorf("failed to sign request with aws sig v4: %w", err)
}

expirationTimeMs, err := getExpirationTimeMs(signedURL)
if err != nil {
return "", 0, fmt.Errorf("failed to extract expiration from signed url: %w", err)
}

signedURLWithUserAgent, err := addUserAgent(signedURL)
if err != nil {
return "", fmt.Errorf("failed to add user agent to the signed url: %w", err)
return "", 0, fmt.Errorf("failed to add user agent to the signed url: %w", err)
}

return base64Encode(signedURLWithUserAgent), nil
return base64Encode(signedURLWithUserAgent), expirationTimeMs, nil
}

// Build https request with query parameters in order to sign.
Expand Down Expand Up @@ -210,6 +215,33 @@ func signRequest(ctx context.Context, req *http.Request, region string, credenti
return signedURL, err
}

// Parses the URL and gets the expiration time in millis associated with the signed url
func getExpirationTimeMs(signedURL string) (int64, error) {
parsedURL, err := url.Parse(signedURL)

if err != nil {
return 0, fmt.Errorf("failed to parse the signed url: %w", err)
}

params := parsedURL.Query()
date, err := time.Parse("20060102T150405Z", params.Get("X-Amz-Date"))

if err != nil {
return 0, fmt.Errorf("failed to parse the 'X-Amz-Date' param from signed url: %w", err)
}

signingTimeMs := date.UnixNano() / int64(time.Millisecond)
expiryDurationSeconds, err := strconv.ParseInt(params.Get("X-Amz-Expires"), 10, 64)

if err != nil {
return 0, fmt.Errorf("failed to parse the 'X-Amz-Expires' param from signed url: %w", err)
}

expiryDurationMs := expiryDurationSeconds * 1000
expiryMs := signingTimeMs + expiryDurationMs
return expiryMs, nil
}

// Calculate sha256Hash and hex encode it.
func calculateSHA256Hash(input string) string {
hash := sha256.Sum256([]byte(input))
Expand Down
25 changes: 20 additions & 5 deletions signer/msk_auth_token_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net/url"
"os"
"strconv"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -80,10 +81,11 @@ func TestConstructAuthToken(t *testing.T) {
SessionToken: "MOCK-SESSION-TOKEN",
}

token, err := constructAuthToken(Ctx, TestRegion, &mockCreds)
token, expiryMs, err := constructAuthToken(Ctx, TestRegion, &mockCreds)

assert.NoError(t, err)
assert.NotNil(t, token)
assert.NotEqual(t, int64(0), expiryMs)

decodedSignedURLBytes, err := base64.RawURLEncoding.DecodeString(token)
assert.NoError(t, err)
Expand Down Expand Up @@ -118,10 +120,11 @@ func TestConstructAuthToken(t *testing.T) {
func TestGenerateAuthTokenEmptyCredentials(t *testing.T) {
mockCreds := aws.AnonymousCredentials{}

token, err := GenerateAuthTokenFromCredentialsProvider(Ctx, TestRegion, &mockCreds)
token, expiryMs, err := GenerateAuthTokenFromCredentialsProvider(Ctx, TestRegion, &mockCreds)

assert.Error(t, err)
assert.Equal(t, token, "")
assert.Equal(t, int64(0), expiryMs)
}

func TestGenerateAuthToken(t *testing.T) {
Expand All @@ -135,10 +138,11 @@ func TestGenerateAuthToken(t *testing.T) {
os.Setenv("AWS_SECRET_ACCESS_KEY", mockCreds.SecretAccessKey)
os.Setenv("AWS_SESSION_TOKEN", mockCreds.SessionToken)

token, err := GenerateAuthToken(Ctx, TestRegion)
token, expiryMs, err := GenerateAuthToken(Ctx, TestRegion)

assert.NoError(t, err)
assert.NotNil(t, token)
assert.NotEqual(t, int64(0), expiryMs)

decodedSignedURLBytes, err := base64.RawURLEncoding.DecodeString(token)
assert.NoError(t, err)
Expand Down Expand Up @@ -183,10 +187,11 @@ func TestGenerateAuthTokenWithCredentialsProvider(t *testing.T) {

mockCredentialsProvider := MockCredentialsProvider{credentials: mockCreds}

token, err := GenerateAuthTokenFromCredentialsProvider(Ctx, TestRegion, mockCredentialsProvider)
token, expiryMs, err := GenerateAuthTokenFromCredentialsProvider(Ctx, TestRegion, mockCredentialsProvider)

assert.NoError(t, err)
assert.NotNil(t, token)
assert.NotEqual(t, int64(0), expiryMs)

decodedSignedURLBytes, err := base64.RawURLEncoding.DecodeString(token)
assert.NoError(t, err)
Expand Down Expand Up @@ -216,13 +221,23 @@ func TestGenerateAuthTokenWithCredentialsProvider(t *testing.T) {
assert.NoError(t, err)
assert.True(t, date.Before(time.Now().UTC()))
assert.True(t, strings.HasPrefix(params.Get(UserAgentKey), "aws-msk-iam-sasl-signer-go/"))

signingTimeMs := date.UnixNano() / int64(time.Millisecond)
expiryDurationSeconds, err := strconv.ParseInt(params.Get("X-Amz-Expires"), 10, 64)
assert.NoError(t, err)
expiryDurationMs := expiryDurationSeconds * 1000
assert.Equal(t, expiryMs, signingTimeMs+expiryDurationMs)

currentMillis := time.Now().UnixNano() / int64(time.Millisecond)
assert.True(t, expiryMs > currentMillis)
}

func TestGenerateAuthTokenWithFailingCredentialsProvider(t *testing.T) {
mockCredentialsProvider := aws.AnonymousCredentials{}

token, err := GenerateAuthTokenFromCredentialsProvider(Ctx, TestRegion, mockCredentialsProvider)
token, expiryMs, err := GenerateAuthTokenFromCredentialsProvider(Ctx, TestRegion, mockCredentialsProvider)

assert.Error(t, err)
assert.NotNil(t, token)
assert.Equal(t, int64(0), expiryMs)
}

0 comments on commit d7d4a0f

Please sign in to comment.