Skip to content

Commit d8fcb50

Browse files
mmcclean-awsmergify[bot]
authored andcommitted
fix(aws-stepfunctions): refactor sagemaker tasks and fix default role issue (#3014)
* fix(aws-stepfunctions) refactor and fix default role issue * fix(aws-stepfunctions) removed console log statements and fixed s3 prefix error * fix(aws-stepfunctions) removed construct from contructor for sagemaker tasks. Changed ISubnet[] to SubnetSelection in props * fix(aws-stepfunctions) renamed cdk core package reference * Update tests
1 parent c020efa commit d8fcb50

File tree

7 files changed

+180
-99
lines changed

7 files changed

+180
-99
lines changed

packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-task-base-types.ts

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -197,20 +197,15 @@ export interface ResourceConfig {
197197
* @experimental
198198
*/
199199
export interface VpcConfig {
200-
/**
201-
* VPC security groups.
202-
*/
203-
readonly securityGroups: ec2.ISecurityGroup[];
204-
205200
/**
206201
* VPC id
207202
*/
208-
readonly vpc: ec2.Vpc;
203+
readonly vpc: ec2.IVpc;
209204

210205
/**
211206
* VPC subnets.
212207
*/
213-
readonly subnets: ec2.ISubnet[];
208+
readonly subnets?: ec2.SubnetSelection;
214209
}
215210

216211
/**

packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-train-task.ts

Lines changed: 84 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import ec2 = require('@aws-cdk/aws-ec2');
22
import iam = require('@aws-cdk/aws-iam');
33
import sfn = require('@aws-cdk/aws-stepfunctions');
4-
import { Construct, Duration, Stack } from '@aws-cdk/core';
4+
import { Duration, Lazy, Stack } from '@aws-cdk/core';
55
import { resourceArnSuffix } from './resource-arn-suffix';
66
import { AlgorithmSpecification, Channel, InputMode, OutputDataConfig, ResourceConfig,
77
S3DataType, StoppingCondition, VpcConfig, } from './sagemaker-task-base-types';
@@ -53,7 +53,7 @@ export interface SagemakerTrainTaskProps {
5353
/**
5454
* Tags to be applied to the train job.
5555
*/
56-
readonly tags?: {[key: string]: any};
56+
readonly tags?: {[key: string]: string};
5757

5858
/**
5959
* Identifies the Amazon S3 location where you want Amazon SageMaker to save the results of model training.
@@ -88,15 +88,6 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
8888
*/
8989
public readonly connections: ec2.Connections = new ec2.Connections();
9090

91-
/**
92-
* The execution role for the Sagemaker training job.
93-
*
94-
* @default new role for Amazon SageMaker to assume is automatically created.
95-
*/
96-
public readonly role: iam.IRole;
97-
98-
public readonly grantPrincipal: iam.IPrincipal;
99-
10091
/**
10192
* The Algorithm Specification
10293
*/
@@ -117,9 +108,15 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
117108
*/
118109
private readonly stoppingCondition: StoppingCondition;
119110

111+
private readonly vpc: ec2.IVpc;
112+
private securityGroup: ec2.ISecurityGroup;
113+
private readonly securityGroups: ec2.ISecurityGroup[] = [];
114+
private readonly subnets: string[];
120115
private readonly integrationPattern: sfn.ServiceIntegrationPattern;
116+
private _role?: iam.IRole;
117+
private _grantPrincipal?: iam.IPrincipal;
121118

122-
constructor(scope: Construct, private readonly props: SagemakerTrainTaskProps) {
119+
constructor(private readonly props: SagemakerTrainTaskProps) {
123120
this.integrationPattern = props.integrationPattern || sfn.ServiceIntegrationPattern.FIRE_AND_FORGET;
124121

125122
const supportedPatterns = [
@@ -143,8 +140,66 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
143140
maxRuntime: Duration.hours(1)
144141
};
145142

143+
// check that either algorithm name or image is defined
144+
if ((!props.algorithmSpecification.algorithmName) && (!props.algorithmSpecification.trainingImage)) {
145+
throw new Error("Must define either an algorithm name or training image URI in the algorithm specification");
146+
}
147+
148+
// set the input mode to 'File' if not defined
149+
this.algorithmSpecification = ( props.algorithmSpecification.trainingInputMode ) ?
150+
( props.algorithmSpecification ) :
151+
( { ...props.algorithmSpecification, trainingInputMode: InputMode.FILE } );
152+
153+
// set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined
154+
this.inputDataConfig = props.inputDataConfig.map(config => {
155+
if (!config.dataSource.s3DataSource.s3DataType) {
156+
return Object.assign({}, config, { dataSource: { s3DataSource:
157+
{ ...config.dataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } } });
158+
} else {
159+
return config;
160+
}
161+
});
162+
163+
// add the security groups to the connections object
164+
if (props.vpcConfig) {
165+
this.vpc = props.vpcConfig.vpc;
166+
this.subnets = (props.vpcConfig.subnets) ?
167+
(this.vpc.selectSubnets(props.vpcConfig.subnets).subnetIds) : this.vpc.selectSubnets().subnetIds;
168+
}
169+
}
170+
171+
/**
172+
* The execution role for the Sagemaker training job.
173+
*
174+
* Only available after task has been added to a state machine.
175+
*/
176+
public get role(): iam.IRole {
177+
if (this._role === undefined) {
178+
throw new Error(`role not available yet--use the object in a Task first`);
179+
}
180+
return this._role;
181+
}
182+
183+
public get grantPrincipal(): iam.IPrincipal {
184+
if (this._grantPrincipal === undefined) {
185+
throw new Error(`Principal not available yet--use the object in a Task first`);
186+
}
187+
return this._grantPrincipal;
188+
}
189+
190+
/**
191+
* Add the security group to all instances via the launch configuration
192+
* security groups array.
193+
*
194+
* @param securityGroup: The security group to add
195+
*/
196+
public addSecurityGroup(securityGroup: ec2.ISecurityGroup): void {
197+
this.securityGroups.push(securityGroup);
198+
}
199+
200+
public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig {
146201
// set the sagemaker role or create new one
147-
this.grantPrincipal = this.role = props.role || new iam.Role(scope, 'SagemakerRole', {
202+
this._grantPrincipal = this._role = this.props.role || new iam.Role(task, 'SagemakerRole', {
148203
assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'),
149204
inlinePolicies: {
150205
CreateTrainingJob: new iam.PolicyDocument({
@@ -157,7 +212,7 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
157212
'logs:CreateLogGroup',
158213
'logs:DescribeLogStreams',
159214
'ecr:GetAuthorizationToken',
160-
...props.vpcConfig
215+
...this.props.vpcConfig
161216
? [
162217
'ec2:CreateNetworkInterface',
163218
'ec2:CreateNetworkInterfacePermission',
@@ -178,36 +233,23 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
178233
}
179234
});
180235

181-
if (props.outputDataConfig.encryptionKey) {
182-
props.outputDataConfig.encryptionKey.grantEncrypt(this.role);
236+
if (this.props.outputDataConfig.encryptionKey) {
237+
this.props.outputDataConfig.encryptionKey.grantEncrypt(this._role);
183238
}
184239

185-
if (props.resourceConfig && props.resourceConfig.volumeEncryptionKey) {
186-
props.resourceConfig.volumeEncryptionKey.grant(this.role, 'kms:CreateGrant');
240+
if (this.props.resourceConfig && this.props.resourceConfig.volumeEncryptionKey) {
241+
this.props.resourceConfig.volumeEncryptionKey.grant(this._role, 'kms:CreateGrant');
187242
}
188243

189-
// set the input mode to 'File' if not defined
190-
this.algorithmSpecification = ( props.algorithmSpecification.trainingInputMode ) ?
191-
( props.algorithmSpecification ) :
192-
( { ...props.algorithmSpecification, trainingInputMode: InputMode.FILE } );
193-
194-
// set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined
195-
this.inputDataConfig = props.inputDataConfig.map(config => {
196-
if (!config.dataSource.s3DataSource.s3DataType) {
197-
return Object.assign({}, config, { dataSource: { s3DataSource:
198-
{ ...config.dataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } } });
199-
} else {
200-
return config;
201-
}
202-
});
203-
204-
// add the security groups to the connections object
205-
if (this.props.vpcConfig) {
206-
this.props.vpcConfig.securityGroups.forEach(sg => this.connections.addSecurityGroup(sg));
244+
// create a security group if not defined
245+
if (this.vpc && this.securityGroup === undefined) {
246+
this.securityGroup = new ec2.SecurityGroup(task, 'TrainJobSecurityGroup', {
247+
vpc: this.vpc
248+
});
249+
this.connections.addSecurityGroup(this.securityGroup);
250+
this.securityGroups.push(this.securityGroup);
207251
}
208-
}
209252

210-
public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig {
211253
return {
212254
resourceArn: 'arn:aws:states:::sagemaker:createTrainingJob' + resourceArnSuffix.get(this.integrationPattern),
213255
parameters: this.renderParameters(),
@@ -218,7 +260,7 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
218260
private renderParameters(): {[key: string]: any} {
219261
return {
220262
TrainingJobName: this.props.trainingJobName,
221-
RoleArn: this.role.roleArn,
263+
RoleArn: this._role!.roleArn,
222264
...(this.renderAlgorithmSpecification(this.algorithmSpecification)),
223265
...(this.renderInputDataConfig(this.inputDataConfig)),
224266
...(this.renderOutputDataConfig(this.props.outputDataConfig)),
@@ -303,8 +345,8 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
303345

304346
private renderVpcConfig(config: VpcConfig | undefined): {[key: string]: any} {
305347
return (config) ? { VpcConfig: {
306-
SecurityGroupIds: config.securityGroups.map(sg => ( sg.securityGroupId )),
307-
Subnets: config.subnets.map(subnet => ( subnet.subnetId )),
348+
SecurityGroupIds: Lazy.listValue({ produce: () => (this.securityGroups.map(sg => (sg.securityGroupId))) }),
349+
Subnets: this.subnets,
308350
}} : {};
309351
}
310352

@@ -330,7 +372,7 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
330372
}),
331373
new iam.PolicyStatement({
332374
actions: ['iam:PassRole'],
333-
resources: [this.role.roleArn],
375+
resources: [this._role!.roleArn],
334376
conditions: {
335377
StringEquals: { "iam:PassedToService": "sagemaker.amazonaws.com" }
336378
}

packages/@aws-cdk/aws-stepfunctions-tasks/lib/sagemaker-transform-task.ts

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import ec2 = require('@aws-cdk/aws-ec2');
22
import iam = require('@aws-cdk/aws-iam');
33
import sfn = require('@aws-cdk/aws-stepfunctions');
4-
import { Construct, Stack } from '@aws-cdk/core';
4+
import { Stack } from '@aws-cdk/core';
55
import { resourceArnSuffix } from './resource-arn-suffix';
66
import { BatchStrategy, S3DataType, TransformInput, TransformOutput, TransformResources } from './sagemaker-task-base-types';
77

@@ -37,7 +37,7 @@ export interface SagemakerTransformProps {
3737
/**
3838
* Environment variables to set in the Docker container.
3939
*/
40-
readonly environment?: {[key: string]: any};
40+
readonly environment?: {[key: string]: string};
4141

4242
/**
4343
* Maximum number of parallel requests that can be sent to each instance in a transform job.
@@ -57,7 +57,7 @@ export interface SagemakerTransformProps {
5757
/**
5858
* Tags to be applied to the train job.
5959
*/
60-
readonly tags?: {[key: string]: any};
60+
readonly tags?: {[key: string]: string};
6161

6262
/**
6363
* Dataset to be transformed and the Amazon S3 location where it is stored.
@@ -82,13 +82,6 @@ export interface SagemakerTransformProps {
8282
*/
8383
export class SagemakerTransformTask implements sfn.IStepFunctionsTask {
8484

85-
/**
86-
* The execution role for the Sagemaker training job.
87-
*
88-
* @default new role for Amazon SageMaker to assume is automatically created.
89-
*/
90-
public readonly role: iam.IRole;
91-
9285
/**
9386
* Dataset to be transformed and the Amazon S3 location where it is stored.
9487
*/
@@ -98,10 +91,10 @@ export class SagemakerTransformTask implements sfn.IStepFunctionsTask {
9891
* ML compute instances for the transform job.
9992
*/
10093
private readonly transformResources: TransformResources;
101-
10294
private readonly integrationPattern: sfn.ServiceIntegrationPattern;
95+
private _role?: iam.IRole;
10396

104-
constructor(scope: Construct, private readonly props: SagemakerTransformProps) {
97+
constructor(private readonly props: SagemakerTransformProps) {
10598
this.integrationPattern = props.integrationPattern || sfn.ServiceIntegrationPattern.FIRE_AND_FORGET;
10699

107100
const supportedPatterns = [
@@ -114,12 +107,9 @@ export class SagemakerTransformTask implements sfn.IStepFunctionsTask {
114107
}
115108

116109
// set the sagemaker role or create new one
117-
this.role = props.role || new iam.Role(scope, 'SagemakerRole', {
118-
assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'),
119-
managedPolicies: [
120-
iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess')
121-
]
122-
});
110+
if (props.role) {
111+
this._role = props.role;
112+
}
123113

124114
// set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined
125115
this.transformInput = (props.transformInput.transformDataSource.s3DataSource.s3DataType) ? (props.transformInput) :
@@ -140,13 +130,35 @@ export class SagemakerTransformTask implements sfn.IStepFunctionsTask {
140130
}
141131

142132
public bind(task: sfn.Task): sfn.StepFunctionsTaskConfig {
133+
// create new role if doesn't exist
134+
if (this._role === undefined) {
135+
this._role = new iam.Role(task, 'SagemakerTransformRole', {
136+
assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'),
137+
managedPolicies: [
138+
iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess')
139+
]
140+
});
141+
}
142+
143143
return {
144144
resourceArn: 'arn:aws:states:::sagemaker:createTransformJob' + resourceArnSuffix.get(this.integrationPattern),
145145
parameters: this.renderParameters(),
146146
policyStatements: this.makePolicyStatements(task),
147147
};
148148
}
149149

150+
/**
151+
* The execution role for the Sagemaker training job.
152+
*
153+
* Only available after task has been added to a state machine.
154+
*/
155+
public get role(): iam.IRole {
156+
if (this._role === undefined) {
157+
throw new Error(`role not available yet--use the object in a Task first`);
158+
}
159+
return this._role;
160+
}
161+
150162
private renderParameters(): {[key: string]: any} {
151163
return {
152164
...(this.props.batchStrategy) ? { BatchStrategy: this.props.batchStrategy } : {},

0 commit comments

Comments
 (0)