Skip to content

Commit

Permalink
fix(credential-providers): support custom middleware for sts client (#…
Browse files Browse the repository at this point in the history
…3887)

* feat(client-sts): allow setting custom middleware in the default role assumers

* feat(credential-providers): allow setting sts client middleware from cred providers
  • Loading branch information
AllanZhengYP committed Aug 30, 2022
1 parent fe23216 commit 072dea3
Show file tree
Hide file tree
Showing 17 changed files with 265 additions and 54 deletions.
5 changes: 5 additions & 0 deletions clients/client-sts/jest.config.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
const base = require("../../jest.config.base.js");

module.exports = {
...base,
};
4 changes: 3 additions & 1 deletion clients/client-sts/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
"build:es": "tsc -p tsconfig.es.json",
"build:types": "tsc -p tsconfig.types.json",
"build:types:downlevel": "downlevel-dts dist-types dist-types/ts3.4",
"clean": "rimraf ./dist-* && rimraf *.tsbuildinfo"
"clean": "rimraf ./dist-* && rimraf *.tsbuildinfo",
"test": "yarn test:unit",
"test:unit": "jest"
},
"main": "./dist-cjs/index.js",
"types": "./dist-types/index.d.ts",
Expand Down
31 changes: 26 additions & 5 deletions clients/client-sts/src/defaultRoleAssumers.ts
Original file line number Diff line number Diff line change
@@ -1,28 +1,49 @@
// smithy-typescript generated code
// Please do not touch this file. It's generated from template in:
// https://github.com/aws/aws-sdk-js-v3/blob/main/codegen/smithy-aws-typescript-codegen/src/main/resources/software/amazon/smithy/aws/typescript/codegen/sts-client-defaultRoleAssumers.ts
import { Pluggable } from "@aws-sdk/types";

import {
DefaultCredentialProvider,
getDefaultRoleAssumer as StsGetDefaultRoleAssumer,
getDefaultRoleAssumerWithWebIdentity as StsGetDefaultRoleAssumerWithWebIdentity,
RoleAssumer,
RoleAssumerWithWebIdentity,
} from "./defaultStsRoleAssumers";
import { STSClient, STSClientConfig } from "./STSClient";
import { ServiceInputTypes, ServiceOutputTypes, STSClient, STSClientConfig } from "./STSClient";

const getCustomizableStsClientCtor = (
baseCtor: new (config: STSClientConfig) => STSClient,
customizations?: Pluggable<ServiceInputTypes, ServiceOutputTypes>[]
) => {
if (!customizations) return baseCtor;
else
return class CustomizableSTSClient extends baseCtor {
constructor(config: STSClientConfig) {
super(config);
for (const customization of customizations!) {
this.middlewareStack.use(customization);
}
}
};
};

/**
* The default role assumer that used by credential providers when sts:AssumeRole API is needed.
*/
export const getDefaultRoleAssumer = (
stsOptions: Pick<STSClientConfig, "logger" | "region" | "requestHandler"> = {}
): RoleAssumer => StsGetDefaultRoleAssumer(stsOptions, STSClient);
stsOptions: Pick<STSClientConfig, "logger" | "region" | "requestHandler"> = {},
stsPlugins?: Pluggable<ServiceInputTypes, ServiceOutputTypes>[]
): RoleAssumer => StsGetDefaultRoleAssumer(stsOptions, getCustomizableStsClientCtor(STSClient, stsPlugins));

/**
* The default role assumer that used by credential providers when sts:AssumeRoleWithWebIdentity API is needed.
*/
export const getDefaultRoleAssumerWithWebIdentity = (
stsOptions: Pick<STSClientConfig, "logger" | "region" | "requestHandler"> = {}
): RoleAssumerWithWebIdentity => StsGetDefaultRoleAssumerWithWebIdentity(stsOptions, STSClient);
stsOptions: Pick<STSClientConfig, "logger" | "region" | "requestHandler"> = {},
stsPlugins?: Pluggable<ServiceInputTypes, ServiceOutputTypes>[]
): RoleAssumerWithWebIdentity =>
StsGetDefaultRoleAssumerWithWebIdentity(stsOptions, getCustomizableStsClientCtor(STSClient, stsPlugins));

/**
* The default credential providers depend STS client to assume role with desired API: sts:assumeRole,
Expand Down
56 changes: 51 additions & 5 deletions clients/client-sts/test/defaultRoleAssumers.spec.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
// Please do not touch this file. It's generated from template in:
// https://github.com/aws/aws-sdk-js-v3/blob/main/codegen/smithy-aws-typescript-codegen/src/main/resources/software/amazon/smithy/aws/typescript/codegen/sts-client-defaultRoleAssumers.spec.ts
import { NodeHttpHandler, streamCollector } from "@aws-sdk/node-http-handler";
import { HttpResponse } from "@aws-sdk/protocol-http";
import { Readable } from "stream";

import type { AssumeRoleCommandInput } from "../src/commands/AssumeRoleCommand";
import { AssumeRoleWithWebIdentityCommandInput } from "../src/commands/AssumeRoleWithWebIdentityCommand";
import { getDefaultRoleAssumer, getDefaultRoleAssumerWithWebIdentity } from "../src/defaultRoleAssumers";

const mockHandle = jest.fn().mockResolvedValue({
response: new HttpResponse({
statusCode: 200,
Expand All @@ -17,11 +22,6 @@ jest.mock("@aws-sdk/node-http-handler", () => ({
streamCollector: jest.fn(),
}));

import { NodeHttpHandler, streamCollector } from "@aws-sdk/node-http-handler";

import type { AssumeRoleCommandInput } from "../src/commands/AssumeRoleCommand";
import { AssumeRoleWithWebIdentityCommandInput } from "../src/commands/AssumeRoleWithWebIdentityCommand";
import { getDefaultRoleAssumer, getDefaultRoleAssumerWithWebIdentity } from "../src/defaultRoleAssumers";
const mockConstructorInput = jest.fn();
jest.mock("../src/STSClient", () => ({
STSClient: function (params: any) {
Expand Down Expand Up @@ -102,6 +102,29 @@ describe("getDefaultRoleAssumer", () => {
region,
});
});

it("should use the STS client middleware", async () => {
const customMiddlewareFunction = jest.fn();
const roleAssumer = getDefaultRoleAssumer({}, [
{
applyToStack: (stack) => {
stack.add((next) => (args) => {
customMiddlewareFunction(args);
return next(args);
});
},
},
]);
const params: AssumeRoleCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
};
const sourceCred = { accessKeyId: "key", secretAccessKey: "secrete" };
await Promise.all([roleAssumer(sourceCred, params), roleAssumer(sourceCred, params)]);
expect(customMiddlewareFunction).toHaveBeenCalledTimes(2); // make sure the middleware is not added to stack multiple times.
expect(customMiddlewareFunction).toHaveBeenNthCalledWith(1, expect.objectContaining({ input: params }));
expect(customMiddlewareFunction).toHaveBeenNthCalledWith(2, expect.objectContaining({ input: params }));
});
});

describe("getDefaultRoleAssumerWithWebIdentity", () => {
Expand Down Expand Up @@ -146,4 +169,27 @@ describe("getDefaultRoleAssumerWithWebIdentity", () => {
region,
});
});

it("should use the STS client middleware", async () => {
const customMiddlewareFunction = jest.fn();
const roleAssumerWithWebIdentity = getDefaultRoleAssumerWithWebIdentity({}, [
{
applyToStack: (stack) => {
stack.add((next) => (args) => {
customMiddlewareFunction(args);
return next(args);
});
},
},
]);
const params: AssumeRoleWithWebIdentityCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
WebIdentityToken: "token",
};
await Promise.all([roleAssumerWithWebIdentity(params), roleAssumerWithWebIdentity(params)]);
expect(customMiddlewareFunction).toHaveBeenCalledTimes(2); // make sure the middleware is not added to stack multiple times.
expect(customMiddlewareFunction).toHaveBeenNthCalledWith(1, expect.objectContaining({ input: params }));
expect(customMiddlewareFunction).toHaveBeenNthCalledWith(2, expect.objectContaining({ input: params }));
});
});
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import { NodeHttpHandler, streamCollector } from "@aws-sdk/node-http-handler";
import { HttpResponse } from "@aws-sdk/protocol-http";
import { Readable } from "stream";

import type { AssumeRoleCommandInput } from "../src/commands/AssumeRoleCommand";
import { AssumeRoleWithWebIdentityCommandInput } from "../src/commands/AssumeRoleWithWebIdentityCommand";
import { getDefaultRoleAssumer, getDefaultRoleAssumerWithWebIdentity } from "../src/defaultRoleAssumers";

const mockHandle = jest.fn().mockResolvedValue({
response: new HttpResponse({
statusCode: 200,
Expand All @@ -15,11 +20,6 @@ jest.mock("@aws-sdk/node-http-handler", () => ({
streamCollector: jest.fn(),
}));

import { NodeHttpHandler, streamCollector } from "@aws-sdk/node-http-handler";

import type { AssumeRoleCommandInput } from "../src/commands/AssumeRoleCommand";
import { AssumeRoleWithWebIdentityCommandInput } from "../src/commands/AssumeRoleWithWebIdentityCommand";
import { getDefaultRoleAssumer, getDefaultRoleAssumerWithWebIdentity } from "../src/defaultRoleAssumers";
const mockConstructorInput = jest.fn();
jest.mock("../src/STSClient", () => ({
STSClient: function (params: any) {
Expand Down Expand Up @@ -100,6 +100,29 @@ describe("getDefaultRoleAssumer", () => {
region,
});
});

it("should use the STS client middleware", async () => {
const customMiddlewareFunction = jest.fn();
const roleAssumer = getDefaultRoleAssumer({}, [
{
applyToStack: (stack) => {
stack.add((next) => (args) => {
customMiddlewareFunction(args);
return next(args);
});
},
},
]);
const params: AssumeRoleCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
};
const sourceCred = { accessKeyId: "key", secretAccessKey: "secrete" };
await Promise.all([roleAssumer(sourceCred, params), roleAssumer(sourceCred, params)]);
expect(customMiddlewareFunction).toHaveBeenCalledTimes(2); // make sure the middleware is not added to stack multiple times.
expect(customMiddlewareFunction).toHaveBeenNthCalledWith(1, expect.objectContaining({ input: params }));
expect(customMiddlewareFunction).toHaveBeenNthCalledWith(2, expect.objectContaining({ input: params }));
});
});

describe("getDefaultRoleAssumerWithWebIdentity", () => {
Expand Down Expand Up @@ -144,4 +167,27 @@ describe("getDefaultRoleAssumerWithWebIdentity", () => {
region,
});
});

it("should use the STS client middleware", async () => {
const customMiddlewareFunction = jest.fn();
const roleAssumerWithWebIdentity = getDefaultRoleAssumerWithWebIdentity({}, [
{
applyToStack: (stack) => {
stack.add((next) => (args) => {
customMiddlewareFunction(args);
return next(args);
});
},
},
]);
const params: AssumeRoleWithWebIdentityCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
WebIdentityToken: "token",
};
await Promise.all([roleAssumerWithWebIdentity(params), roleAssumerWithWebIdentity(params)]);
expect(customMiddlewareFunction).toHaveBeenCalledTimes(2); // make sure the middleware is not added to stack multiple times.
expect(customMiddlewareFunction).toHaveBeenNthCalledWith(1, expect.objectContaining({ input: params }));
expect(customMiddlewareFunction).toHaveBeenNthCalledWith(2, expect.objectContaining({ input: params }));
});
});
Original file line number Diff line number Diff line change
@@ -1,25 +1,46 @@
import { Pluggable } from "@aws-sdk/types";

import {
DefaultCredentialProvider,
getDefaultRoleAssumer as StsGetDefaultRoleAssumer,
getDefaultRoleAssumerWithWebIdentity as StsGetDefaultRoleAssumerWithWebIdentity,
RoleAssumer,
RoleAssumerWithWebIdentity,
} from "./defaultStsRoleAssumers";
import { STSClient, STSClientConfig } from "./STSClient";
import { ServiceInputTypes, ServiceOutputTypes, STSClient, STSClientConfig } from "./STSClient";

const getCustomizableStsClientCtor = (
baseCtor: new (config: STSClientConfig) => STSClient,
customizations?: Pluggable<ServiceInputTypes, ServiceOutputTypes>[]
) => {
if (!customizations) return baseCtor;
else
return class CustomizableSTSClient extends baseCtor {
constructor(config: STSClientConfig) {
super(config);
for (const customization of customizations!) {
this.middlewareStack.use(customization);
}
}
};
};

/**
* The default role assumer that used by credential providers when sts:AssumeRole API is needed.
*/
export const getDefaultRoleAssumer = (
stsOptions: Pick<STSClientConfig, "logger" | "region" | "requestHandler"> = {}
): RoleAssumer => StsGetDefaultRoleAssumer(stsOptions, STSClient);
stsOptions: Pick<STSClientConfig, "logger" | "region" | "requestHandler"> = {},
stsPlugins?: Pluggable<ServiceInputTypes, ServiceOutputTypes>[]
): RoleAssumer => StsGetDefaultRoleAssumer(stsOptions, getCustomizableStsClientCtor(STSClient, stsPlugins));

/**
* The default role assumer that used by credential providers when sts:AssumeRoleWithWebIdentity API is needed.
*/
export const getDefaultRoleAssumerWithWebIdentity = (
stsOptions: Pick<STSClientConfig, "logger" | "region" | "requestHandler"> = {}
): RoleAssumerWithWebIdentity => StsGetDefaultRoleAssumerWithWebIdentity(stsOptions, STSClient);
stsOptions: Pick<STSClientConfig, "logger" | "region" | "requestHandler"> = {},
stsPlugins?: Pluggable<ServiceInputTypes, ServiceOutputTypes>[]
): RoleAssumerWithWebIdentity =>
StsGetDefaultRoleAssumerWithWebIdentity(stsOptions, getCustomizableStsClientCtor(STSClient, stsPlugins));

/**
* The default credential providers depend STS client to assume role with desired API: sts:assumeRole,
Expand Down
29 changes: 29 additions & 0 deletions packages/credential-providers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,35 @@ const credentialProvider = fromNodeProviderChain({
});
```
## Add Custom Headers to STS assume-role calls
You can specify the plugins--groups of middleware, to inject to the STS client.
For example, you can inject custom headers to each STS assume-role calls. It's
available in [`fromTemporaryCredentials()`](#fromtemporarycredentials),
[`fromWebToken()`](#fromwebtoken), [`fromTokenFile()`](#fromtokenfile), [`fromIni()`](#fromini).
Code example:
```javascript
const addConfusedDeputyMiddleware = (next) => (args) => {
args.request.headers["x-amz-source-account"] = account;
args.request.headers["x-amz-source-arn"] = sourceArn;
return next(args);
};
const confusedDeputyPlugin = {
applyToStack: (stack) => {
stack.add(addConfusedDeputyMiddleware, { step: "finalizeRequest" });
},
};
const provider = fromTemporaryCredentials({
// Required. Options passed to STS AssumeRole operation.
params: {
RoleArn: "arn:aws:iam::1234567890:role/Role",
},
clientPlugins: [confusedDeputyPlugin],
});
```
[getcredentialsforidentity_api]: https://docs.aws.amazon.com/cognitoidentity/latest/APIReference/API_GetCredentialsForIdentity.html
[getid_api]: https://docs.aws.amazon.com/cognitoidentity/latest/APIReference/API_GetId.html
[assumerole_api]: https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html
Expand Down
9 changes: 5 additions & 4 deletions packages/credential-providers/src/fromIni.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,14 @@ describe("fromIni", () => {
expect(getDefaultRoleAssumerWithWebIdentity).not.toBeCalled();
});

it("should use supplied sts options", () => {
it("should use supplied sts and plugins options", () => {
const profile = "profile";
const clientConfig = {
region: "US_BAR_1",
};
fromIni({ profile, clientConfig });
expect(getDefaultRoleAssumer).toBeCalledWith(clientConfig);
expect(getDefaultRoleAssumerWithWebIdentity).toBeCalledWith(clientConfig);
const plugin = { applyToStack: () => {} };
fromIni({ profile, clientConfig, clientPlugins: [plugin] });
expect(getDefaultRoleAssumer).toBeCalledWith(clientConfig, [plugin]);
expect(getDefaultRoleAssumerWithWebIdentity).toBeCalledWith(clientConfig, [plugin]);
});
});
10 changes: 7 additions & 3 deletions packages/credential-providers/src/fromIni.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { getDefaultRoleAssumer, getDefaultRoleAssumerWithWebIdentity, STSClientConfig } from "@aws-sdk/client-sts";
import { fromIni as _fromIni, FromIniInit as _FromIniInit } from "@aws-sdk/credential-provider-ini";
import { CredentialProvider } from "@aws-sdk/types";
import { CredentialProvider, Pluggable } from "@aws-sdk/types";

export interface FromIniInit extends _FromIniInit {
clientConfig?: STSClientConfig;
clientPlugins?: Pluggable<any, any>[];
}

/**
Expand Down Expand Up @@ -38,14 +39,17 @@ export interface FromIniInit extends _FromIniInit {
* },
* // Optional. Custom STS client configurations overriding the default ones.
* clientConfig: { region },
* // Optional. Custom STS client middleware plugin to modify the client default behavior.
* // e.g. adding custom headers.
* clientPlugins: [addFooHeadersPlugin],
* }),
* });
* ```
*/
export const fromIni = (init: FromIniInit = {}): CredentialProvider =>
_fromIni({
...init,
roleAssumer: init.roleAssumer ?? getDefaultRoleAssumer(init.clientConfig),
roleAssumer: init.roleAssumer ?? getDefaultRoleAssumer(init.clientConfig, init.clientPlugins),
roleAssumerWithWebIdentity:
init.roleAssumerWithWebIdentity ?? getDefaultRoleAssumerWithWebIdentity(init.clientConfig),
init.roleAssumerWithWebIdentity ?? getDefaultRoleAssumerWithWebIdentity(init.clientConfig, init.clientPlugins),
});
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,14 @@ describe(fromNodeProviderChain.name, () => {
expect(getDefaultRoleAssumerWithWebIdentity).not.toBeCalled();
});

it("should use supplied sts options", () => {
it("should use supplied sts options and plugins", () => {
const profile = "profile";
const clientConfig = {
region: "US_BAR_1",
};
fromNodeProviderChain({ profile, clientConfig });
expect(getDefaultRoleAssumer).toBeCalledWith(clientConfig);
expect(getDefaultRoleAssumerWithWebIdentity).toBeCalledWith(clientConfig);
const plugin = { applyToStack: () => {} };
fromNodeProviderChain({ profile, clientConfig, clientPlugins: [plugin] });
expect(getDefaultRoleAssumer).toBeCalledWith(clientConfig, [plugin]);
expect(getDefaultRoleAssumerWithWebIdentity).toBeCalledWith(clientConfig, [plugin]);
});
});

0 comments on commit 072dea3

Please sign in to comment.