Skip to content

Commit

Permalink
feat: support IMDSv2 for ECS metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
yndu13 committed May 10, 2024
1 parent 98dedb8 commit 4613843
Show file tree
Hide file tree
Showing 8 changed files with 352 additions and 19 deletions.
45 changes: 40 additions & 5 deletions credentials/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ type Config struct {
RoleSessionName *string `json:"role_session_name"`
PublicKeyId *string `json:"public_key_id"`
RoleName *string `json:"role_name"`
EnableIMDSv2 *bool `json:"enable_imds_v2"`
MetadataTokenDuration *int `json:"metadata_token_duration"`
SessionExpiration *int `json:"session_expiration"`
PrivateKeyFile *string `json:"private_key_file"`
BearerToken *string `json:"bearer_token"`
Expand Down Expand Up @@ -106,6 +108,16 @@ func (s *Config) SetRoleName(v string) *Config {
return s
}

func (s *Config) SetEnableIMDSv2(v bool) *Config {
s.EnableIMDSv2 = &v
return s
}

func (s *Config) SetMetadataTokenDuration(v int) *Config {
s.MetadataTokenDuration = &v
return s
}

func (s *Config) SetSessionExpiration(v int) *Config {
s.SessionExpiration = &v
return s
Expand Down Expand Up @@ -205,19 +217,33 @@ func NewCredential(config *Config) (credential Credential, err error) {
ConnectTimeout: tea.IntValue(config.ConnectTimeout),
STSEndpoint: tea.StringValue(config.STSEndpoint),
}
credential = newOIDCRoleArnCredential(tea.StringValue(config.AccessKeyId), tea.StringValue(config.AccessKeySecret), tea.StringValue(config.RoleArn), tea.StringValue(config.OIDCProviderArn), tea.StringValue(config.OIDCTokenFilePath), tea.StringValue(config.RoleSessionName), tea.StringValue(config.Policy), tea.IntValue(config.RoleSessionExpiration), runtime)
credential = newOIDCRoleArnCredential(
tea.StringValue(config.AccessKeyId),
tea.StringValue(config.AccessKeySecret),
tea.StringValue(config.RoleArn),
tea.StringValue(config.OIDCProviderArn),
tea.StringValue(config.OIDCTokenFilePath),
tea.StringValue(config.RoleSessionName),
tea.StringValue(config.Policy),
tea.IntValue(config.RoleSessionExpiration),
runtime)
case "access_key":
err = checkAccessKey(config)
if err != nil {
return
}
credential = newAccessKeyCredential(tea.StringValue(config.AccessKeyId), tea.StringValue(config.AccessKeySecret))
credential = newAccessKeyCredential(
tea.StringValue(config.AccessKeyId),
tea.StringValue(config.AccessKeySecret))
case "sts":
err = checkSTS(config)
if err != nil {
return
}
credential = newStsTokenCredential(tea.StringValue(config.AccessKeyId), tea.StringValue(config.AccessKeySecret), tea.StringValue(config.SecurityToken))
credential = newStsTokenCredential(
tea.StringValue(config.AccessKeyId),
tea.StringValue(config.AccessKeySecret),
tea.StringValue(config.SecurityToken))
case "ecs_ram_role":
checkEcsRAMRole(config)
runtime := &utils.Runtime{
Expand All @@ -226,7 +252,12 @@ func NewCredential(config *Config) (credential Credential, err error) {
ReadTimeout: tea.IntValue(config.Timeout),
ConnectTimeout: tea.IntValue(config.ConnectTimeout),
}
credential = newEcsRAMRoleCredential(tea.StringValue(config.RoleName), tea.Float64Value(config.InAdvanceScale), runtime)
credential = newEcsRAMRoleCredentialWithEnableIMDSv2(
tea.StringValue(config.RoleName),
tea.BoolValue(config.EnableIMDSv2),
tea.IntValue(config.MetadataTokenDuration),
tea.Float64Value(config.InAdvanceScale),
runtime)
case "ram_role_arn":
err = checkRAMRoleArn(config)
if err != nil {
Expand Down Expand Up @@ -274,7 +305,11 @@ func NewCredential(config *Config) (credential Credential, err error) {
ConnectTimeout: tea.IntValue(config.ConnectTimeout),
STSEndpoint: tea.StringValue(config.STSEndpoint),
}
credential = newRsaKeyPairCredential(privateKey, tea.StringValue(config.PublicKeyId), tea.IntValue(config.SessionExpiration), runtime)
credential = newRsaKeyPairCredential(
privateKey,
tea.StringValue(config.PublicKeyId),
tea.IntValue(config.SessionExpiration),
runtime)
case "bearer":
if tea.StringValue(config.BearerToken) == "" {
err = errors.New("BearerToken cannot be empty")
Expand Down
20 changes: 18 additions & 2 deletions credentials/credential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ this is privatekey`

func TestConfig(t *testing.T) {
config := new(Config)
assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"public_key_id\": null,\n \"role_name\": null,\n \"session_expiration\": null,\n \"private_key_file\": null,\n \"bearer_token\": null,\n \"security_token\": null,\n \"role_session_expiratioon\": null,\n \"policy\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null,\n \"url\": null,\n \"sts_endpoint\": null,\n \"external_id\": null\n}", config.String())
assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"public_key_id\": null,\n \"role_name\": null,\n \"session_expiration\": null,\n \"private_key_file\": null,\n \"bearer_token\": null,\n \"security_token\": null,\n \"role_session_expiratioon\": null,\n \"policy\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null,\n \"url\": null,\n \"sts_endpoint\": null,\n \"external_id\": null\n}", config.GoString())
assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"public_key_id\": null,\n \"role_name\": null,\n \"enable_imds_v2\": null,\n \"metadata_token_duration\": null,\n \"session_expiration\": null,\n \"private_key_file\": null,\n \"bearer_token\": null,\n \"security_token\": null,\n \"role_session_expiratioon\": null,\n \"policy\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null,\n \"url\": null,\n \"sts_endpoint\": null,\n \"external_id\": null\n}", config.String())
assert.Equal(t, "{\n \"type\": null,\n \"access_key_id\": null,\n \"access_key_secret\": null,\n \"oidc_provider_arn\": null,\n \"oidc_token\": null,\n \"role_arn\": null,\n \"role_session_name\": null,\n \"public_key_id\": null,\n \"role_name\": null,\n \"enable_imds_v2\": null,\n \"metadata_token_duration\": null,\n \"session_expiration\": null,\n \"private_key_file\": null,\n \"bearer_token\": null,\n \"security_token\": null,\n \"role_session_expiratioon\": null,\n \"policy\": null,\n \"host\": null,\n \"timeout\": null,\n \"connect_timeout\": null,\n \"proxy\": null,\n \"inAdvanceScale\": null,\n \"url\": null,\n \"sts_endpoint\": null,\n \"external_id\": null\n}", config.GoString())

config.SetSTSEndpoint("sts.cn-hangzhou.aliyuncs.com")
assert.Equal(t, "sts.cn-hangzhou.aliyuncs.com", *config.STSEndpoint)
Expand Down Expand Up @@ -96,6 +96,22 @@ func TestNewCredentialWithECSRAMRole(t *testing.T) {
cred, err = NewCredential(config)
assert.Nil(t, err)
assert.NotNil(t, cred)

config.SetEnableIMDSv2(false)
cred, err = NewCredential(config)
assert.Nil(t, err)
assert.NotNil(t, cred)

config.SetEnableIMDSv2(true)
cred, err = NewCredential(config)
assert.Nil(t, err)
assert.NotNil(t, cred)

config.SetEnableIMDSv2(true)
config.SetMetadataTokenDuration(180)
cred, err = NewCredential(config)
assert.Nil(t, err)
assert.NotNil(t, cred)
}

func TestNewCredentialWithRSAKeyPair(t *testing.T) {
Expand Down
60 changes: 57 additions & 3 deletions credentials/ecs_ram_role.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package credentials
import (
"encoding/json"
"fmt"
"strconv"
"time"

"github.com/alibabacloud-go/tea/tea"
Expand All @@ -11,13 +12,20 @@ import (
)

var securityCredURL = "http://100.100.100.200/latest/meta-data/ram/security-credentials/"
var securityCredTokenURL = "http://100.100.100.200/latest/api/token"

const defaultMetadataTokenDuration = int(21600)

// EcsRAMRoleCredential is a kind of credential
type EcsRAMRoleCredential struct {
*credentialUpdater
RoleName string
sessionCredential *sessionCredential
runtime *utils.Runtime
RoleName string
EnableIMDSv2 bool
MetadataTokenDuration int
sessionCredential *sessionCredential
runtime *utils.Runtime
metadataToken string
staleTime int64
}

type ecsRAMRoleResponse struct {
Expand All @@ -40,6 +48,20 @@ func newEcsRAMRoleCredential(roleName string, inAdvanceScale float64, runtime *u
}
}

func newEcsRAMRoleCredentialWithEnableIMDSv2(roleName string, enableIMDSv2 bool, metadataTokenDuration int, inAdvanceScale float64, runtime *utils.Runtime) *EcsRAMRoleCredential {
credentialUpdater := new(credentialUpdater)
if inAdvanceScale < 1 && inAdvanceScale > 0 {
credentialUpdater.inAdvanceScale = inAdvanceScale
}
return &EcsRAMRoleCredential{
RoleName: roleName,
EnableIMDSv2: enableIMDSv2,
MetadataTokenDuration: metadataTokenDuration,
credentialUpdater: credentialUpdater,
runtime: runtime,
}
}

func (e *EcsRAMRoleCredential) GetCredential() (*CredentialModel, error) {
if e.sessionCredential == nil || e.needUpdateCredential() {
err := e.updateCredential()
Expand Down Expand Up @@ -123,6 +145,26 @@ func getRoleName() (string, error) {
return string(content), nil
}

func (e *EcsRAMRoleCredential) getMetadataToken() (err error) {
if e.needToRefresh() {
if e.MetadataTokenDuration <= 0 {
e.MetadataTokenDuration = defaultMetadataTokenDuration
}
tmpTime := time.Now().Unix() + int64(e.MetadataTokenDuration*1000)
request := request.NewCommonRequest()
request.URL = securityCredTokenURL
request.Method = "PUT"
request.Headers["X-aliyun-ecs-metadata-token-ttl-seconds"] = strconv.Itoa(e.MetadataTokenDuration)
content, err := doAction(request, e.runtime)
if err != nil {
return err
}
e.staleTime = tmpTime
e.metadataToken = string(content)
}
return
}

func (e *EcsRAMRoleCredential) updateCredential() (err error) {
if e.runtime == nil {
e.runtime = new(utils.Runtime)
Expand All @@ -134,6 +176,13 @@ func (e *EcsRAMRoleCredential) updateCredential() (err error) {
return fmt.Errorf("refresh Ecs sts token err: %s", err.Error())
}
}
if e.EnableIMDSv2 {
err = e.getMetadataToken()
if err != nil {
return fmt.Errorf("Failed to get token from ECS Metadata Service: %s", err.Error())
}
request.Headers["X-aliyun-ecs-metadata-token"] = e.metadataToken
}
request.URL = securityCredURL + e.RoleName
request.Method = "GET"
content, err := doAction(request, e.runtime)
Expand Down Expand Up @@ -163,3 +212,8 @@ func (e *EcsRAMRoleCredential) updateCredential() (err error) {

return
}

func (e *EcsRAMRoleCredential) needToRefresh() (needToRefresh bool) {
needToRefresh = time.Now().Unix() >= e.staleTime
return
}
127 changes: 127 additions & 0 deletions credentials/ecs_ram_role_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,130 @@ func Test_EcsRAmRoleCredential(t *testing.T) {
assert.Equal(t, "refresh Ecs sts token err: error parse", err.Error())
assert.Equal(t, "", *accesskeyId)
}

func Test_EcsRAmRoleCredentialEnableIMDSv2(t *testing.T) {
auth := newEcsRAMRoleCredentialWithEnableIMDSv2("go sdk", false, 0, 0.5, nil)
origTestHookDo := hookDo
defer func() { hookDo = origTestHookDo }()

hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) {
return func(req *http.Request) (*http.Response, error) {
return mockResponse(300, ``, errors.New("sdk test"))
}
}
accesskeyId, err := auth.GetAccessKeyId()
assert.NotNil(t, err)
assert.Equal(t, "refresh Ecs sts token err: sdk test", err.Error())
assert.Equal(t, "", *accesskeyId)

auth = newEcsRAMRoleCredentialWithEnableIMDSv2("go sdk", true, 0, 0.5, nil)
accesskeyId, err = auth.GetAccessKeyId()
assert.NotNil(t, err)
assert.Equal(t, "Failed to get token from ECS Metadata Service: sdk test", err.Error())
assert.Equal(t, "", *accesskeyId)

auth = newEcsRAMRoleCredentialWithEnableIMDSv2("go sdk", true, 180, 0.5, nil)
accesskeyId, err = auth.GetAccessKeyId()
assert.NotNil(t, err)
assert.Equal(t, "Failed to get token from ECS Metadata Service: sdk test", err.Error())
assert.Equal(t, "", *accesskeyId)

hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) {
return func(req *http.Request) (*http.Response, error) {
return mockResponse(300, ``, nil)
}
}
accesskeyId, err = auth.GetAccessKeyId()
assert.NotNil(t, err)
assert.Equal(t, "Failed to get token from ECS Metadata Service: httpStatus: 300, message = ", err.Error())
assert.Equal(t, "", *accesskeyId)

hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) {
return func(req *http.Request) (*http.Response, error) {
return mockResponse(400, `role`, nil)
}
}
auth.RoleName = ""
accesskeyId, err = auth.GetAccessKeyId()
assert.NotNil(t, err)
assert.Equal(t, "refresh Ecs sts token err: httpStatus: 400, message = role", err.Error())

hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) {
return func(req *http.Request) (*http.Response, error) {
return mockResponse(200, `role`, nil)
}
}
accesskeyId, err = auth.GetAccessKeyId()
assert.NotNil(t, err)
assert.Equal(t, "refresh Ecs sts token err: Json Unmarshal fail: invalid character 'r' looking for beginning of value", err.Error())
hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) {
return func(req *http.Request) (*http.Response, error) {
return mockResponse(200, `"AccessKeyId":"accessKeyId","AccessKeySecret":"accessKeySecret","SecurityToken":"securitytoken","Expiration":"expiration"`, nil)
}
}
auth.RoleName = "role"
accesskeyId, err = auth.GetAccessKeyId()
assert.NotNil(t, err)
assert.Equal(t, "refresh Ecs sts token err: Json Unmarshal fail: invalid character ':' after top-level value", err.Error())
assert.Equal(t, "", *accesskeyId)

hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) {
return func(req *http.Request) (*http.Response, error) {
return mockResponse(200, `{"AccessKeySecret":"accessKeySecret","SecurityToken":"securitytoken","Expiration":"expiration","Code":"fail"}`, nil)
}
}
accesskeyId, err = auth.GetAccessKeyId()
assert.NotNil(t, err)
assert.Equal(t, "refresh Ecs sts token err: Code is not Success", err.Error())
assert.Equal(t, "", *accesskeyId)

hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) {
return func(req *http.Request) (*http.Response, error) {
return mockResponse(200, `{"AccessKeySecret":"accessKeySecret","SecurityToken":"securitytoken","Expiration":"expiration","Code":"Success"}`, nil)
}
}
accesskeyId, err = auth.GetAccessKeyId()
assert.NotNil(t, err)
assert.Equal(t, "refresh Ecs sts token err: AccessKeyId: , AccessKeySecret: accessKeySecret, SecurityToken: securitytoken, Expiration: expiration", err.Error())
assert.Equal(t, "", *accesskeyId)

hookDo = func(fn func(req *http.Request) (*http.Response, error)) func(req *http.Request) (*http.Response, error) {
return func(req *http.Request) (*http.Response, error) {
return mockResponse(200, `{"AccessKeyId":"accessKeyId","AccessKeySecret":"accessKeySecret","SecurityToken":"securitytoken","Expiration":"2018-01-02T15:04:05Z","Code":"Success"}`, nil)
}
}
accesskeyId, err = auth.GetAccessKeyId()
assert.Nil(t, err)
assert.Equal(t, "accessKeyId", *accesskeyId)

accesskeySecret, err := auth.GetAccessKeySecret()
assert.Nil(t, err)
assert.Equal(t, "accessKeySecret", *accesskeySecret)

ststoken, err := auth.GetSecurityToken()
assert.Nil(t, err)
assert.Equal(t, "securitytoken", *ststoken)

err = errors.New("credentials")
err = hookParse(err)
assert.Equal(t, "credentials", err.Error())

cred, err := auth.GetCredential()
assert.Nil(t, err)
assert.Equal(t, "accessKeyId", *cred.AccessKeyId)
assert.Equal(t, "accessKeySecret", *cred.AccessKeySecret)
assert.Equal(t, "securitytoken", *cred.SecurityToken)
assert.Nil(t, cred.BearerToken)
assert.Equal(t, "ecs_ram_role", *cred.Type)

originHookParse := hookParse
hookParse = func(err error) error {
return errors.New("error parse")
}
defer func() {
hookParse = originHookParse
}()
accesskeyId, err = auth.GetAccessKeyId()
assert.Equal(t, "refresh Ecs sts token err: error parse", err.Error())
assert.Equal(t, "", *accesskeyId)
}
7 changes: 5 additions & 2 deletions credentials/instance_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package credentials

import (
"os"
"strings"

"github.com/alibabacloud-go/tea/tea"
)
Expand All @@ -19,10 +20,12 @@ func (p *instanceCredentialsProvider) resolve() (*Config, error) {
if !ok {
return nil, nil
}
enableIMDSv2, _ := os.LookupEnv(ENVEcsMetadataIMDSv2Enable)

config := &Config{
Type: tea.String("ecs_ram_role"),
RoleName: tea.String(roleName),
Type: tea.String("ecs_ram_role"),
RoleName: tea.String(roleName),
EnableIMDSv2: tea.Bool(strings.ToLower(enableIMDSv2) == "true"),
}
return config, nil
}
Loading

0 comments on commit 4613843

Please sign in to comment.