Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(credential-providers): source accountId from credential providers #6019

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
41 changes: 37 additions & 4 deletions clients/client-sts/src/defaultStsRoleAssumers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,31 @@ export type RoleAssumer = (

const ASSUME_ROLE_DEFAULT_REGION = "us-east-1";

interface AssumedRoleUser {
/**
* The ARN of the temporary security credentials that are returned from the AssumeRole action.
*/
Arn?: string;

/**
* A unique identifier that contains the role ID and the role session name of the role that is being assumed.
*/
AssumedRoleId?: string;
}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added and left the interface here instead of a new file, since the overhead of copying a new file to clients seems more than having a single interface present here. if we find ourselves adding many STS interfaces, we should consider refactoring a little.

/**
* @internal
*/
const getAccountIdFromAssumedRoleUser = (assumedRoleUser?: AssumedRoleUser) => {
if (typeof assumedRoleUser?.Arn === "string") {
const arnComponents = assumedRoleUser.Arn.split(":");
if (arnComponents.length > 4 && arnComponents[4] !== "") {
return arnComponents[4];
}
}
return undefined;
};

/**
* @internal
*
Expand Down Expand Up @@ -84,17 +109,21 @@ export const getDefaultRoleAssumer = (
logger: logger as any,
});
}
const { Credentials } = await stsClient.send(new AssumeRoleCommand(params));
const { Credentials, AssumedRoleUser } = await stsClient.send(new AssumeRoleCommand(params));
if (!Credentials || !Credentials.AccessKeyId || !Credentials.SecretAccessKey) {
throw new Error(`Invalid response from STS.assumeRole call with role ${params.RoleArn}`);
}

const accountId = getAccountIdFromAssumedRoleUser(AssumedRoleUser);

return {
accessKeyId: Credentials.AccessKeyId,
secretAccessKey: Credentials.SecretAccessKey,
sessionToken: Credentials.SessionToken,
expiration: Credentials.Expiration,
// TODO(credentialScope): access normally when shape is updated.
credentialScope: (Credentials as any).CredentialScope,
...((Credentials as any).CredentialScope && { credentialScope: (Credentials as any).CredentialScope }),
...(accountId && { accountId }),
};
};
};
Expand Down Expand Up @@ -134,17 +163,21 @@ export const getDefaultRoleAssumerWithWebIdentity = (
logger: logger as any,
});
}
const { Credentials } = await stsClient.send(new AssumeRoleWithWebIdentityCommand(params));
const { Credentials, AssumedRoleUser } = await stsClient.send(new AssumeRoleWithWebIdentityCommand(params));
if (!Credentials || !Credentials.AccessKeyId || !Credentials.SecretAccessKey) {
throw new Error(`Invalid response from STS.assumeRoleWithWebIdentity call with role ${params.RoleArn}`);
}

const accountId = getAccountIdFromAssumedRoleUser(AssumedRoleUser);

return {
accessKeyId: Credentials.AccessKeyId,
secretAccessKey: Credentials.SecretAccessKey,
sessionToken: Credentials.SessionToken,
expiration: Credentials.Expiration,
// TODO(credentialScope): access normally when shape is updated.
credentialScope: (Credentials as any).CredentialScope,
...((Credentials as any).CredentialScope && { credentialScope: (Credentials as any).CredentialScope }),
...(accountId && { accountId }),
};
};
};
Expand Down
26 changes: 26 additions & 0 deletions clients/client-sts/test/defaultRoleAssumers.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,17 @@ describe("getDefaultRoleAssumer", () => {
);
});

it("should return accountId in the credentials", async () => {
const roleAssumer = getDefaultRoleAssumer();
const params: AssumeRoleCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
};
const sourceCred = { accessKeyId: "key", secretAccessKey: "secrete" };
const assumedRole = await roleAssumer(sourceCred, params);
expect(assumedRole.accountId).toEqual("123");
});

it("should use the STS client config", async () => {
const logger = console;
const region = "some-region";
Expand Down Expand Up @@ -169,6 +180,10 @@ describe("getDefaultRoleAssumer", () => {
describe("getDefaultRoleAssumerWithWebIdentity", () => {
const assumeRoleResponse = `<Response xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<AssumeRoleWithWebIdentityResult>
<AssumedRoleUser>
<AssumedRoleId>AROAZOX2IL27GNRBJHWC2:session</AssumedRoleId>
<Arn>arn:aws:sts::123456789012:assumed-role/assume-role-test/session</Arn>
</AssumedRoleUser>
<Credentials>
<AccessKeyId>key</AccessKeyId>
<SecretAccessKey>secrete</SecretAccessKey>
Expand Down Expand Up @@ -209,6 +224,17 @@ describe("getDefaultRoleAssumerWithWebIdentity", () => {
});
});

it("should return accountId in the credentials", async () => {
const roleAssumerWithWebIdentity = getDefaultRoleAssumerWithWebIdentity();
const params: AssumeRoleWithWebIdentityCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
WebIdentityToken: "token",
};
const assumedRole = await roleAssumerWithWebIdentity(params);
expect(assumedRole.accountId).toEqual("123456789012");
});

it("should use the STS client middleware", async () => {
const customMiddlewareFunction = jest.fn();
const roleAssumerWithWebIdentity = getDefaultRoleAssumerWithWebIdentity({}, [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,17 @@ describe("getDefaultRoleAssumer", () => {
);
});

it("should return accountId in the credentials", async () => {
const roleAssumer = getDefaultRoleAssumer();
const params: AssumeRoleCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
};
const sourceCred = { accessKeyId: "key", secretAccessKey: "secrete" };
const assumedRole = await roleAssumer(sourceCred, params);
expect(assumedRole.accountId).toEqual("123");
});
siddsriv marked this conversation as resolved.
Show resolved Hide resolved

it("should use the STS client config", async () => {
const logger = console;
const region = "some-region";
Expand Down Expand Up @@ -167,6 +178,10 @@ describe("getDefaultRoleAssumer", () => {
describe("getDefaultRoleAssumerWithWebIdentity", () => {
const assumeRoleResponse = `<Response xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
<AssumeRoleWithWebIdentityResult>
<AssumedRoleUser>
<AssumedRoleId>AROAZOX2IL27GNRBJHWC2:session</AssumedRoleId>
<Arn>arn:aws:sts::123456789012:assumed-role/assume-role-test/session</Arn>
</AssumedRoleUser>
<Credentials>
<AccessKeyId>key</AccessKeyId>
<SecretAccessKey>secrete</SecretAccessKey>
Expand Down Expand Up @@ -207,6 +222,17 @@ describe("getDefaultRoleAssumerWithWebIdentity", () => {
});
});

it("should return accountId in the credentials", async () => {
const roleAssumerWithWebIdentity = getDefaultRoleAssumerWithWebIdentity();
const params: AssumeRoleWithWebIdentityCommandInput = {
RoleArn: "arn:aws:foo",
RoleSessionName: "session",
WebIdentityToken: "token",
};
const assumedRole = await roleAssumerWithWebIdentity(params);
expect(assumedRole.accountId).toEqual("123456789012");
});

it("should use the STS client middleware", async () => {
const customMiddlewareFunction = jest.fn();
const roleAssumerWithWebIdentity = getDefaultRoleAssumerWithWebIdentity({}, [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,31 @@ export type RoleAssumer = (

const ASSUME_ROLE_DEFAULT_REGION = "us-east-1";

interface AssumedRoleUser {
/**
* The ARN of the temporary security credentials that are returned from the AssumeRole action.
*/
Arn?: string;

/**
* A unique identifier that contains the role ID and the role session name of the role that is being assumed.
*/
AssumedRoleId?: string;
}

/**
* @internal
*/
const getAccountIdFromAssumedRoleUser = (assumedRoleUser?: AssumedRoleUser) => {
if (typeof assumedRoleUser?.Arn === "string") {
const arnComponents = assumedRoleUser.Arn.split(":");
if (arnComponents.length > 4 && arnComponents[4] !== "") {
return arnComponents[4];
}
}
return undefined;
};

/**
* @internal
*
Expand Down Expand Up @@ -81,17 +106,21 @@ export const getDefaultRoleAssumer = (
logger: logger as any,
});
}
const { Credentials } = await stsClient.send(new AssumeRoleCommand(params));
const { Credentials, AssumedRoleUser } = await stsClient.send(new AssumeRoleCommand(params));
if (!Credentials || !Credentials.AccessKeyId || !Credentials.SecretAccessKey) {
throw new Error(`Invalid response from STS.assumeRole call with role ${params.RoleArn}`);
}

const accountId = getAccountIdFromAssumedRoleUser(AssumedRoleUser);

return {
accessKeyId: Credentials.AccessKeyId,
secretAccessKey: Credentials.SecretAccessKey,
sessionToken: Credentials.SessionToken,
expiration: Credentials.Expiration,
// TODO(credentialScope): access normally when shape is updated.
credentialScope: (Credentials as any).CredentialScope,
...((Credentials as any).CredentialScope && { credentialScope: (Credentials as any).CredentialScope }),
...(accountId && { accountId }),
};
};
};
Expand Down Expand Up @@ -131,17 +160,21 @@ export const getDefaultRoleAssumerWithWebIdentity = (
logger: logger as any,
});
}
const { Credentials } = await stsClient.send(new AssumeRoleWithWebIdentityCommand(params));
const { Credentials, AssumedRoleUser } = await stsClient.send(new AssumeRoleWithWebIdentityCommand(params));
if (!Credentials || !Credentials.AccessKeyId || !Credentials.SecretAccessKey) {
throw new Error(`Invalid response from STS.assumeRoleWithWebIdentity call with role ${params.RoleArn}`);
}

const accountId = getAccountIdFromAssumedRoleUser(AssumedRoleUser);

return {
accessKeyId: Credentials.AccessKeyId,
secretAccessKey: Credentials.SecretAccessKey,
sessionToken: Credentials.SessionToken,
expiration: Credentials.Expiration,
// TODO(credentialScope): access normally when shape is updated.
credentialScope: (Credentials as any).CredentialScope,
...((Credentials as any).CredentialScope && { credentialScope: (Credentials as any).CredentialScope }),
...(accountId && { accountId }),
};
};
};
Expand Down
20 changes: 18 additions & 2 deletions packages/credential-provider-env/src/fromEnv.spec.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import { CredentialsProviderError } from "@smithy/property-provider";

import { ENV_EXPIRATION, ENV_KEY, ENV_SECRET, ENV_SESSION, fromEnv } from "./fromEnv";
import { ENV_ACCOUNT_ID, ENV_EXPIRATION, ENV_KEY, ENV_SECRET, ENV_SESSION, fromEnv } from "./fromEnv";

describe(fromEnv.name, () => {
const ORIGINAL_ENV = process.env;
const mockAccessKeyId = "mockAccessKeyId";
const mockSecretAccessKey = "mockSecretAccessKey";
const mockSessionToken = "mockSessionToken";
const mockExpiration = new Date().toISOString();
const mockAccountId = "123456789012";

beforeEach(() => {
process.env = {
Expand All @@ -16,6 +17,7 @@ describe(fromEnv.name, () => {
[ENV_SECRET]: mockSecretAccessKey,
[ENV_SESSION]: mockSessionToken,
[ENV_EXPIRATION]: mockExpiration,
[ENV_ACCOUNT_ID]: mockAccountId,
};
});

Expand All @@ -30,19 +32,33 @@ describe(fromEnv.name, () => {
secretAccessKey: mockSecretAccessKey,
sessionToken: mockSessionToken,
expiration: new Date(mockExpiration),
accountId: mockAccountId,
});
});

it("can create credentials without a session token or expiration", async () => {
it("can create credentials without a session token, accountId, or expiration", async () => {
delete process.env[ENV_SESSION];
delete process.env[ENV_EXPIRATION];
delete process.env[ENV_ACCOUNT_ID];
const receivedCreds = await fromEnv()();
expect(receivedCreds).toStrictEqual({
accessKeyId: mockAccessKeyId,
secretAccessKey: mockSecretAccessKey,
});
});

it("should include accountId when it is provided in environment variables", async () => {
process.env[ENV_ACCOUNT_ID] = mockAccountId;
const receivedCreds = await fromEnv()();
expect(receivedCreds).toHaveProperty("accountId", mockAccountId);
});

it("should not include accountId when it is not provided in environment variables", async () => {
delete process.env[ENV_ACCOUNT_ID]; // Ensure accountId is not set
const receivedCreds = await fromEnv()();
expect(receivedCreds).not.toHaveProperty("accountId");
});

it.each([ENV_KEY, ENV_SECRET])("throws if env['%s'] is not found", async (key) => {
delete process.env[key];
const expectedError = new CredentialsProviderError("Unable to find environment variable credentials.");
Expand Down
6 changes: 6 additions & 0 deletions packages/credential-provider-env/src/fromEnv.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ export const ENV_EXPIRATION = "AWS_CREDENTIAL_EXPIRATION";
* @internal
*/
export const ENV_CREDENTIAL_SCOPE = "AWS_CREDENTIAL_SCOPE";
/**
* @internal
*/
export const ENV_ACCOUNT_ID = "AWS_ACCOUNT_ID";

/**
* @internal
Expand All @@ -41,6 +45,7 @@ export const fromEnv =
const sessionToken: string | undefined = process.env[ENV_SESSION];
const expiry: string | undefined = process.env[ENV_EXPIRATION];
const credentialScope: string | undefined = process.env[ENV_CREDENTIAL_SCOPE];
const accountId: string | undefined = process.env[ENV_ACCOUNT_ID];

if (accessKeyId && secretAccessKey) {
return {
Expand All @@ -49,6 +54,7 @@ export const fromEnv =
...(sessionToken && { sessionToken }),
...(expiry && { expiration: new Date(expiry) }),
...(credentialScope && { credentialScope }),
...(accountId && { accountId }),
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ const getMockStaticCredsProfile = () => ({
aws_secret_access_key: "mock_aws_secret_access_key",
aws_session_token: "mock_aws_session_token",
aws_credential_scope: "mock_aws_credential_scope",
aws_account_id: "mock_aws_account_id",
});

describe(isStaticCredsProfile.name, () => {
Expand Down Expand Up @@ -32,6 +33,12 @@ describe(isStaticCredsProfile.name, () => {
});
});

it.each(["aws_account_id"])("value at '%s' is not of type string | undefined", (key) => {
[true, null, 1, NaN, {}].forEach((value) => {
expect(isStaticCredsProfile({ ...getMockStaticCredsProfile(), [key]: value })).toEqual(false);
});
});

it("returns true for StaticCredentialsProfile", () => {
expect(isStaticCredsProfile(getMockStaticCredsProfile())).toEqual(true);
});
Expand All @@ -46,6 +53,7 @@ describe(resolveStaticCredentials.name, () => {
secretAccessKey: mockProfile.aws_secret_access_key,
sessionToken: mockProfile.aws_session_token,
credentialScope: mockProfile.aws_credential_scope,
accountId: mockProfile.aws_account_id,
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export interface StaticCredsProfile extends Profile {
aws_secret_access_key: string;
aws_session_token?: string;
aws_credential_scope?: string;
aws_account_id?: string;
}

/**
Expand All @@ -20,7 +21,8 @@ export const isStaticCredsProfile = (arg: any): arg is StaticCredsProfile =>
typeof arg === "object" &&
typeof arg.aws_access_key_id === "string" &&
typeof arg.aws_secret_access_key === "string" &&
["undefined", "string"].indexOf(typeof arg.aws_session_token) > -1;
["undefined", "string"].indexOf(typeof arg.aws_session_token) > -1 &&
["undefined", "string"].indexOf(typeof arg.aws_account_id) > -1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there doesn't seem to be any point to this check, or the aws_session_token one

Copy link
Contributor Author

@siddsriv siddsriv Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

 ● isStaticCredsProfile › value at 'aws_session_token' is not of type string | undefined
  expect(received).toEqual(expected) // deep equality

    Expected: false
    Received: true

      30 |   it.each(["aws_session_token"])("value at '%s' is not of type string | undefined", (key) => {
      31 |     [true, null, 1, NaN, {}].forEach((value) => {
    > 32 |       expect(isStaticCredsProfile({ ...getMockStaticCredsProfile(), [key]: value })).toEqual(false);
         |                                                                                      ^
      33 |     });
      34 |   });
      35 |

      at src/resolveStaticCredentials.spec.ts:32:86
          at Array.forEach (<anonymous>)
      at src/resolveStaticCredentials.spec.ts:31:30

unit test seems to fail when i remove these checks (for both session token and accountId). we'll have to change the unit tests for these as well.


/**
* @internal
Expand All @@ -34,6 +36,7 @@ export const resolveStaticCredentials = (
accessKeyId: profile.aws_access_key_id,
secretAccessKey: profile.aws_secret_access_key,
sessionToken: profile.aws_session_token,
credentialScope: profile.aws_credential_scope,
...(profile.aws_credential_scope && { credentialScope: profile.aws_credential_scope }),
...(profile.aws_account_id && { accountId: profile.aws_account_id }),
});
};