Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,30 +78,18 @@ public class ImdsClient private constructor(builder: Builder) : InstanceMetadata
}

// cached middleware instances
private val middleware: List<Feature> = listOf(
ResolveEndpoint.create {
resolver = ImdsEndpointResolver(platformProvider, endpointConfiguration)
},
UserAgent.create {
staticMetadata = AwsUserAgentMetadata.fromEnvironment(ApiMetadata(SERVICE, "unknown"))
},
Retry.create {
val tokenBucket = StandardRetryTokenBucket(StandardRetryTokenBucketOptions.Default)
val delayProvider = ExponentialBackoffWithJitter(ExponentialBackoffWithJitterOptions.Default)
strategy = StandardRetryStrategy(
StandardRetryStrategyOptions.Default.copy(maxAttempts = maxRetries),
tokenBucket,
delayProvider
)
policy = ImdsRetryPolicy()
},
// must come after retries
TokenMiddleware.create {
httpClient = this@ImdsClient.httpClient
ttl = tokenTtl
clock = this@ImdsClient.clock
},
private val resolveEndpointMiddleware = ResolveEndpoint(ImdsEndpointResolver(platformProvider, endpointConfiguration))
private val userAgentMiddleware = UserAgent(
staticMetadata = AwsUserAgentMetadata.fromEnvironment(ApiMetadata(SERVICE, "unknown"))
)
private val retryMiddleware = run {
val tokenBucket = StandardRetryTokenBucket(StandardRetryTokenBucketOptions.Default)
val delayProvider = ExponentialBackoffWithJitter(ExponentialBackoffWithJitterOptions.Default)
val strategy = StandardRetryStrategy(StandardRetryStrategyOptions.Default, tokenBucket, delayProvider)
val policy = ImdsRetryPolicy()
Retry<String>(strategy, policy)
}
private val tokenMiddleware = TokenMiddleware(httpClient, tokenTtl, clock)

public companion object {
public operator fun invoke(block: Builder.() -> Unit): ImdsClient = ImdsClient(Builder().apply(block))
Expand Down Expand Up @@ -142,7 +130,10 @@ public class ImdsClient private constructor(builder: Builder) : InstanceMetadata
set(SdkClientOption.LogMode, sdkLogMode)
}
}
middleware.forEach { it.install(op) }
op.install(resolveEndpointMiddleware)
op.install(userAgentMiddleware)
op.install(retryMiddleware)
op.install(tokenMiddleware)
op.execution.mutate.intercept(Phase.Order.Before) { req, next ->
req.subject.url.path = path
next.call(req)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package aws.sdk.kotlin.runtime.config.imds
import aws.sdk.kotlin.runtime.config.CachedValue
import aws.sdk.kotlin.runtime.config.ExpiringValue
import aws.smithy.kotlin.runtime.http.*
import aws.smithy.kotlin.runtime.http.operation.ModifyRequestMiddleware
import aws.smithy.kotlin.runtime.http.operation.SdkHttpOperation
import aws.smithy.kotlin.runtime.http.operation.SdkHttpRequest
import aws.smithy.kotlin.runtime.http.operation.getLogger
Expand All @@ -29,33 +30,21 @@ internal const val X_AWS_EC2_METADATA_TOKEN_TTL_SECONDS = "x-aws-ec2-metadata-to
internal const val X_AWS_EC2_METADATA_TOKEN = "x-aws-ec2-metadata-token"

@OptIn(ExperimentalTime::class)
internal class TokenMiddleware(config: Config) : Feature {
private val ttl: Duration = config.ttl
private val httpClient = requireNotNull(config.httpClient) { "SdkHttpClient is required for token middleware to make requests" }
private val clock: Clock = config.clock
internal class TokenMiddleware(
private val httpClient: SdkHttpClient,
private val ttl: Duration = Duration.seconds(DEFAULT_TOKEN_TTL_SECONDS),
private val clock: Clock = Clock.System
) : ModifyRequestMiddleware {
private var cachedToken = CachedValue<Token>(null, bufferTime = Duration.seconds(TOKEN_REFRESH_BUFFER_SECONDS), clock = clock)

public class Config {
var ttl: Duration = Duration.seconds(DEFAULT_TOKEN_TTL_SECONDS)
var httpClient: SdkHttpClient? = null
var clock: Clock = Clock.System
override fun install(op: SdkHttpOperation<*, *>) {
op.execution.finalize.register(this)
}

public companion object Feature :
HttpClientFeatureFactory<Config, TokenMiddleware> {
override val key: FeatureKey<TokenMiddleware> = FeatureKey("EC2Metadata_Token_Middleware")
override fun create(block: Config.() -> Unit): TokenMiddleware {
val config = Config().apply(block)
return TokenMiddleware(config)
}
}

override fun <I, O> install(operation: SdkHttpOperation<I, O>) {
operation.execution.finalize.intercept { req, next ->
val token = cachedToken.getOrLoad { getToken(clock, req).let { ExpiringValue(it, it.expires) } }
req.subject.headers.append(X_AWS_EC2_METADATA_TOKEN, token.value.decodeToString())
next.call(req)
}
override suspend fun modifyRequest(req: SdkHttpRequest): SdkHttpRequest {
val token = cachedToken.getOrLoad { getToken(clock, req).let { ExpiringValue(it, it.expires) } }
req.subject.headers.append(X_AWS_EC2_METADATA_TOKEN, token.value.decodeToString())
return req
}

private suspend fun getToken(clock: Clock, req: SdkHttpRequest): Token {
Expand All @@ -76,7 +65,6 @@ internal class TokenMiddleware(config: Config) : Feature {
}
}

// TODO - retries with custom policy around 400 and 403
val call = httpClient.call(tokenReq)
return try {
when (call.response.status) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ import aws.sdk.kotlin.runtime.endpoint.AwsEndpointResolver
import aws.sdk.kotlin.runtime.execution.AuthAttributes
import aws.smithy.kotlin.runtime.http.*
import aws.smithy.kotlin.runtime.http.middleware.setRequestEndpoint
import aws.smithy.kotlin.runtime.http.operation.SdkHttpOperation
import aws.smithy.kotlin.runtime.http.operation.ModifyRequestMiddleware
import aws.smithy.kotlin.runtime.http.operation.SdkHttpRequest
import aws.smithy.kotlin.runtime.http.operation.getLogger
import aws.smithy.kotlin.runtime.util.get

Expand All @@ -20,54 +21,35 @@ import aws.smithy.kotlin.runtime.util.get
*/
@InternalSdkApi
public class ResolveAwsEndpoint(
config: Config
) : Feature {

private val serviceId: String = requireNotNull(config.serviceId) { "ServiceId must not be null" }
private val resolver: AwsEndpointResolver = requireNotNull(config.resolver) { "EndpointResolver must not be null" }

public class Config {
/**
* The AWS service ID to resolve endpoints for
*/
public var serviceId: String? = null

/**
* The resolver to use
*/
public var resolver: AwsEndpointResolver? = null
}

public companion object Feature : HttpClientFeatureFactory<Config, ResolveAwsEndpoint> {
override val key: FeatureKey<ResolveAwsEndpoint> = FeatureKey("ServiceEndpointResolver")

override fun create(block: Config.() -> Unit): ResolveAwsEndpoint {
val config = Config().apply(block)
return ResolveAwsEndpoint(config)
}
}

override fun <I, O> install(operation: SdkHttpOperation<I, O>) {
operation.execution.mutate.intercept { req, next ->

val region = req.context[AwsClientOption.Region]
val endpoint = resolver.resolve(serviceId, region)
setRequestEndpoint(req, endpoint.endpoint)

endpoint.credentialScope?.let { scope ->
// resolved endpoint has credential scope override(s), update the context for downstream consumers
scope.service?.let {
if (it.isNotBlank()) req.context[AuthAttributes.SigningService] = it
}
scope.region?.let {
if (it.isNotBlank()) req.context[AuthAttributes.SigningRegion] = it
}
/**
* The AWS service ID to resolve endpoints for
*/
private val serviceId: String,

/**
* The resolver to use
*/
private val resolver: AwsEndpointResolver

) : ModifyRequestMiddleware {

override suspend fun modifyRequest(req: SdkHttpRequest): SdkHttpRequest {
val region = req.context[AwsClientOption.Region]
val endpoint = resolver.resolve(serviceId, region)
setRequestEndpoint(req, endpoint.endpoint)

endpoint.credentialScope?.let { scope ->
// resolved endpoint has credential scope override(s), update the context for downstream consumers
scope.service?.let {
if (it.isNotBlank()) req.context[AuthAttributes.SigningService] = it
}
scope.region?.let {
if (it.isNotBlank()) req.context[AuthAttributes.SigningRegion] = it
}

val logger = req.context.getLogger("ResolveAwsEndpoint")
logger.debug { "resolved endpoint: $endpoint" }

next.call(req)
}

val logger = req.context.getLogger("ResolveAwsEndpoint")
logger.debug { "resolved endpoint: $endpoint" }
return req
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ package aws.sdk.kotlin.runtime.http.middleware
import aws.sdk.kotlin.runtime.InternalSdkApi
import aws.sdk.kotlin.runtime.http.AwsUserAgentMetadata
import aws.sdk.kotlin.runtime.http.operation.CustomUserAgentMetadata
import aws.smithy.kotlin.runtime.http.Feature
import aws.smithy.kotlin.runtime.http.FeatureKey
import aws.smithy.kotlin.runtime.http.HttpClientFeatureFactory
import aws.smithy.kotlin.runtime.http.operation.ModifyRequestMiddleware
import aws.smithy.kotlin.runtime.http.operation.SdkHttpOperation
import aws.smithy.kotlin.runtime.http.operation.SdkHttpRequest
import aws.smithy.kotlin.runtime.io.middleware.Phase

internal const val X_AMZ_USER_AGENT: String = "x-amz-user-agent"
Expand All @@ -22,43 +21,30 @@ internal const val USER_AGENT: String = "User-Agent"
*/
@InternalSdkApi
public class UserAgent(
/**
* Metadata that doesn't change per/request (e.g. sdk and environment related metadata)
*/
private val staticMetadata: AwsUserAgentMetadata
) : Feature {
) : ModifyRequestMiddleware {

public class Config {
/**
* Metadata that doesn't change per/request (e.g. sdk and environment related metadata)
*/
public var staticMetadata: AwsUserAgentMetadata? = null
override fun install(op: SdkHttpOperation<*, *>) {
op.execution.mutate.register(this, Phase.Order.After)
}
Comment on lines +30 to 32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: Isn't Order.After the default? How is this different from the default implementation of install?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it is yes. Probably just a miss, will double check. Good catch


public companion object Feature :
HttpClientFeatureFactory<Config, UserAgent> {
override val key: FeatureKey<UserAgent> = FeatureKey("UserAgent")
override suspend fun modifyRequest(req: SdkHttpRequest): SdkHttpRequest {
// pull dynamic values out of the context
val customMetadata = req.context.getOrNull(CustomUserAgentMetadata.ContextKey)

override fun create(block: Config.() -> Unit): UserAgent {
val config = Config().apply(block)
val metadata = requireNotNull(config.staticMetadata) { "staticMetadata is required" }
return UserAgent(metadata)
}
}

override fun <I, O> install(operation: SdkHttpOperation<I, O>) {
operation.execution.mutate.intercept(Phase.Order.After) { req, next ->

// pull dynamic values out of the context
val customMetadata = req.context.getOrNull(CustomUserAgentMetadata.ContextKey)
// resolve the metadata for the request which is a combination of the static and per/operation metadata
val requestMetadata = staticMetadata.copy(customMetadata = customMetadata)

// resolve the metadata for the request which is a combination of the static and per/operation metadata
val requestMetadata = staticMetadata.copy(customMetadata = customMetadata)
// NOTE: Due to legacy issues with processing the user agent, the original content for
// x-amz-user-agent and User-Agent is swapped here. See top note in the
// sdk-user-agent-header SEP and https://github.com/awslabs/smithy-kotlin/issues/373
// for further details.
req.subject.headers[USER_AGENT] = requestMetadata.xAmzUserAgent
req.subject.headers[X_AMZ_USER_AGENT] = requestMetadata.userAgent

// NOTE: Due to legacy issues with processing the user agent, the original content for
// x-amz-user-agent and User-Agent is swapped here. See top note in the
// sdk-user-agent-header SEP and https://github.com/awslabs/smithy-kotlin/issues/373
// for further details.
req.subject.headers[USER_AGENT] = requestMetadata.xAmzUserAgent
req.subject.headers[X_AMZ_USER_AGENT] = requestMetadata.userAgent
next.call(req)
}
return req
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,8 @@ class ResolveAwsEndpointTest {
}

val endpoint = AwsEndpoint("https://api.test.com")
op.install(ResolveAwsEndpoint) {
resolver = AwsEndpointResolver { _, _ -> endpoint }
serviceId = "TestService"
}
val resolver = AwsEndpointResolver { _, _ -> endpoint }
op.install(ResolveAwsEndpoint("TestService", resolver))

op.roundTrip(client, Unit)
val actual = op.context[HttpOperationContext.HttpCallList].first().request
Expand All @@ -78,10 +76,8 @@ class ResolveAwsEndpointTest {
}

val endpoint = AwsEndpoint("https://api.test.com", CredentialScope("us-west-2", "foo"))
op.install(ResolveAwsEndpoint) {
resolver = AwsEndpointResolver { _, _ -> endpoint }
serviceId = "TestService"
}
val resolver = AwsEndpointResolver { _, _ -> endpoint }
op.install(ResolveAwsEndpoint("TestService", resolver))

op.roundTrip(client, Unit)
val actual = op.context[HttpOperationContext.HttpCallList].first().request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,8 @@ class UserAgentTest {
}

val provider = TestPlatformProvider()
op.install(UserAgent) {
staticMetadata = loadAwsUserAgentMetadataFromEnvironment(provider, ApiMetadata("Test Service", "1.2.3"))
}
val metadata = loadAwsUserAgentMetadataFromEnvironment(provider, ApiMetadata("Test Service", "1.2.3"))
op.install(UserAgent(metadata))

op.roundTrip(client, Unit)
val request = op.context[HttpOperationContext.HttpCallList].last().request
Expand All @@ -75,9 +74,7 @@ class UserAgentTest {

val provider = TestPlatformProvider()
val staticMeta = loadAwsUserAgentMetadataFromEnvironment(provider, ApiMetadata("Test Service", "1.2.3"))
op.install(UserAgent) {
staticMetadata = staticMeta
}
op.install(UserAgent(staticMeta))

op.context.customUserAgentMetadata.add("foo", "bar")

Expand All @@ -96,9 +93,7 @@ class UserAgentTest {
}
}

op2.install(UserAgent) {
staticMetadata = staticMeta
}
op2.install(UserAgent(staticMeta))

op2.context.customUserAgentMetadata.add("baz", "quux")

Expand Down
Loading