1
1
import ec2 = require( '@aws-cdk/aws-ec2' ) ;
2
2
import iam = require( '@aws-cdk/aws-iam' ) ;
3
3
import sfn = require( '@aws-cdk/aws-stepfunctions' ) ;
4
- import { Construct , Duration , Stack } from '@aws-cdk/core' ;
4
+ import { Duration , Lazy , Stack } from '@aws-cdk/core' ;
5
5
import { resourceArnSuffix } from './resource-arn-suffix' ;
6
6
import { AlgorithmSpecification , Channel , InputMode , OutputDataConfig , ResourceConfig ,
7
7
S3DataType , StoppingCondition , VpcConfig , } from './sagemaker-task-base-types' ;
@@ -53,7 +53,7 @@ export interface SagemakerTrainTaskProps {
53
53
/**
54
54
* Tags to be applied to the train job.
55
55
*/
56
- readonly tags ?: { [ key : string ] : any } ;
56
+ readonly tags ?: { [ key : string ] : string } ;
57
57
58
58
/**
59
59
* Identifies the Amazon S3 location where you want Amazon SageMaker to save the results of model training.
@@ -88,15 +88,6 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
88
88
*/
89
89
public readonly connections : ec2 . Connections = new ec2 . Connections ( ) ;
90
90
91
- /**
92
- * The execution role for the Sagemaker training job.
93
- *
94
- * @default new role for Amazon SageMaker to assume is automatically created.
95
- */
96
- public readonly role : iam . IRole ;
97
-
98
- public readonly grantPrincipal : iam . IPrincipal ;
99
-
100
91
/**
101
92
* The Algorithm Specification
102
93
*/
@@ -117,9 +108,15 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
117
108
*/
118
109
private readonly stoppingCondition : StoppingCondition ;
119
110
111
+ private readonly vpc : ec2 . IVpc ;
112
+ private securityGroup : ec2 . ISecurityGroup ;
113
+ private readonly securityGroups : ec2 . ISecurityGroup [ ] = [ ] ;
114
+ private readonly subnets : string [ ] ;
120
115
private readonly integrationPattern : sfn . ServiceIntegrationPattern ;
116
+ private _role ?: iam . IRole ;
117
+ private _grantPrincipal ?: iam . IPrincipal ;
121
118
122
- constructor ( scope : Construct , private readonly props : SagemakerTrainTaskProps ) {
119
+ constructor ( private readonly props : SagemakerTrainTaskProps ) {
123
120
this . integrationPattern = props . integrationPattern || sfn . ServiceIntegrationPattern . FIRE_AND_FORGET ;
124
121
125
122
const supportedPatterns = [
@@ -143,8 +140,66 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
143
140
maxRuntime : Duration . hours ( 1 )
144
141
} ;
145
142
143
+ // check that either algorithm name or image is defined
144
+ if ( ( ! props . algorithmSpecification . algorithmName ) && ( ! props . algorithmSpecification . trainingImage ) ) {
145
+ throw new Error ( "Must define either an algorithm name or training image URI in the algorithm specification" ) ;
146
+ }
147
+
148
+ // set the input mode to 'File' if not defined
149
+ this . algorithmSpecification = ( props . algorithmSpecification . trainingInputMode ) ?
150
+ ( props . algorithmSpecification ) :
151
+ ( { ...props . algorithmSpecification , trainingInputMode : InputMode . FILE } ) ;
152
+
153
+ // set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined
154
+ this . inputDataConfig = props . inputDataConfig . map ( config => {
155
+ if ( ! config . dataSource . s3DataSource . s3DataType ) {
156
+ return Object . assign ( { } , config , { dataSource : { s3DataSource :
157
+ { ...config . dataSource . s3DataSource , s3DataType : S3DataType . S3_PREFIX } } } ) ;
158
+ } else {
159
+ return config ;
160
+ }
161
+ } ) ;
162
+
163
+ // add the security groups to the connections object
164
+ if ( props . vpcConfig ) {
165
+ this . vpc = props . vpcConfig . vpc ;
166
+ this . subnets = ( props . vpcConfig . subnets ) ?
167
+ ( this . vpc . selectSubnets ( props . vpcConfig . subnets ) . subnetIds ) : this . vpc . selectSubnets ( ) . subnetIds ;
168
+ }
169
+ }
170
+
171
+ /**
172
+ * The execution role for the Sagemaker training job.
173
+ *
174
+ * Only available after task has been added to a state machine.
175
+ */
176
+ public get role ( ) : iam . IRole {
177
+ if ( this . _role === undefined ) {
178
+ throw new Error ( `role not available yet--use the object in a Task first` ) ;
179
+ }
180
+ return this . _role ;
181
+ }
182
+
183
+ public get grantPrincipal ( ) : iam . IPrincipal {
184
+ if ( this . _grantPrincipal === undefined ) {
185
+ throw new Error ( `Principal not available yet--use the object in a Task first` ) ;
186
+ }
187
+ return this . _grantPrincipal ;
188
+ }
189
+
190
+ /**
191
+ * Add the security group to all instances via the launch configuration
192
+ * security groups array.
193
+ *
194
+ * @param securityGroup: The security group to add
195
+ */
196
+ public addSecurityGroup ( securityGroup : ec2 . ISecurityGroup ) : void {
197
+ this . securityGroups . push ( securityGroup ) ;
198
+ }
199
+
200
+ public bind ( task : sfn . Task ) : sfn . StepFunctionsTaskConfig {
146
201
// set the sagemaker role or create new one
147
- this . grantPrincipal = this . role = props . role || new iam . Role ( scope , 'SagemakerRole' , {
202
+ this . _grantPrincipal = this . _role = this . props . role || new iam . Role ( task , 'SagemakerRole' , {
148
203
assumedBy : new iam . ServicePrincipal ( 'sagemaker.amazonaws.com' ) ,
149
204
inlinePolicies : {
150
205
CreateTrainingJob : new iam . PolicyDocument ( {
@@ -157,7 +212,7 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
157
212
'logs:CreateLogGroup' ,
158
213
'logs:DescribeLogStreams' ,
159
214
'ecr:GetAuthorizationToken' ,
160
- ...props . vpcConfig
215
+ ...this . props . vpcConfig
161
216
? [
162
217
'ec2:CreateNetworkInterface' ,
163
218
'ec2:CreateNetworkInterfacePermission' ,
@@ -178,36 +233,23 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
178
233
}
179
234
} ) ;
180
235
181
- if ( props . outputDataConfig . encryptionKey ) {
182
- props . outputDataConfig . encryptionKey . grantEncrypt ( this . role ) ;
236
+ if ( this . props . outputDataConfig . encryptionKey ) {
237
+ this . props . outputDataConfig . encryptionKey . grantEncrypt ( this . _role ) ;
183
238
}
184
239
185
- if ( props . resourceConfig && props . resourceConfig . volumeEncryptionKey ) {
186
- props . resourceConfig . volumeEncryptionKey . grant ( this . role , 'kms:CreateGrant' ) ;
240
+ if ( this . props . resourceConfig && this . props . resourceConfig . volumeEncryptionKey ) {
241
+ this . props . resourceConfig . volumeEncryptionKey . grant ( this . _role , 'kms:CreateGrant' ) ;
187
242
}
188
243
189
- // set the input mode to 'File' if not defined
190
- this . algorithmSpecification = ( props . algorithmSpecification . trainingInputMode ) ?
191
- ( props . algorithmSpecification ) :
192
- ( { ...props . algorithmSpecification , trainingInputMode : InputMode . FILE } ) ;
193
-
194
- // set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined
195
- this . inputDataConfig = props . inputDataConfig . map ( config => {
196
- if ( ! config . dataSource . s3DataSource . s3DataType ) {
197
- return Object . assign ( { } , config , { dataSource : { s3DataSource :
198
- { ...config . dataSource . s3DataSource , s3DataType : S3DataType . S3_PREFIX } } } ) ;
199
- } else {
200
- return config ;
201
- }
202
- } ) ;
203
-
204
- // add the security groups to the connections object
205
- if ( this . props . vpcConfig ) {
206
- this . props . vpcConfig . securityGroups . forEach ( sg => this . connections . addSecurityGroup ( sg ) ) ;
244
+ // create a security group if not defined
245
+ if ( this . vpc && this . securityGroup === undefined ) {
246
+ this . securityGroup = new ec2 . SecurityGroup ( task , 'TrainJobSecurityGroup' , {
247
+ vpc : this . vpc
248
+ } ) ;
249
+ this . connections . addSecurityGroup ( this . securityGroup ) ;
250
+ this . securityGroups . push ( this . securityGroup ) ;
207
251
}
208
- }
209
252
210
- public bind ( task : sfn . Task ) : sfn . StepFunctionsTaskConfig {
211
253
return {
212
254
resourceArn : 'arn:aws:states:::sagemaker:createTrainingJob' + resourceArnSuffix . get ( this . integrationPattern ) ,
213
255
parameters : this . renderParameters ( ) ,
@@ -218,7 +260,7 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
218
260
private renderParameters ( ) : { [ key : string ] : any } {
219
261
return {
220
262
TrainingJobName : this . props . trainingJobName ,
221
- RoleArn : this . role . roleArn ,
263
+ RoleArn : this . _role ! . roleArn ,
222
264
...( this . renderAlgorithmSpecification ( this . algorithmSpecification ) ) ,
223
265
...( this . renderInputDataConfig ( this . inputDataConfig ) ) ,
224
266
...( this . renderOutputDataConfig ( this . props . outputDataConfig ) ) ,
@@ -303,8 +345,8 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
303
345
304
346
private renderVpcConfig ( config : VpcConfig | undefined ) : { [ key : string ] : any } {
305
347
return ( config ) ? { VpcConfig : {
306
- SecurityGroupIds : config . securityGroups . map ( sg => ( sg . securityGroupId ) ) ,
307
- Subnets : config . subnets . map ( subnet => ( subnet . subnetId ) ) ,
348
+ SecurityGroupIds : Lazy . listValue ( { produce : ( ) => ( this . securityGroups . map ( sg => ( sg . securityGroupId ) ) ) } ) ,
349
+ Subnets : this . subnets ,
308
350
} } : { } ;
309
351
}
310
352
@@ -330,7 +372,7 @@ export class SagemakerTrainTask implements iam.IGrantable, ec2.IConnectable, sfn
330
372
} ) ,
331
373
new iam . PolicyStatement ( {
332
374
actions : [ 'iam:PassRole' ] ,
333
- resources : [ this . role . roleArn ] ,
375
+ resources : [ this . _role ! . roleArn ] ,
334
376
conditions : {
335
377
StringEquals : { "iam:PassedToService" : "sagemaker.amazonaws.com" }
336
378
}
0 commit comments