/
create-transform-job.ts
290 lines (257 loc) · 9.75 KB
/
create-transform-job.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
import * as ec2 from '@aws-cdk/aws-ec2';
import * as iam from '@aws-cdk/aws-iam';
import * as sfn from '@aws-cdk/aws-stepfunctions';
import { Size, Stack, Token } from '@aws-cdk/core';
import { Construct } from 'constructs';
import { integrationResourceArn, validatePatternSupported } from '../private/task-utils';
import { BatchStrategy, ModelClientOptions, S3DataType, TransformInput, TransformOutput, TransformResources } from './base-types';
import { renderEnvironment, renderTags } from './private/utils';
/**
* Properties for creating an Amazon SageMaker transform job task
*
*/
export interface SageMakerCreateTransformJobProps extends sfn.TaskStateBaseProps {
/**
* Transform Job Name.
*/
readonly transformJobName: string;
/**
* Role for the Transform Job.
*
* @default - A role is created with `AmazonSageMakerFullAccess` managed policy
*/
readonly role?: iam.IRole;
/**
* Number of records to include in a mini-batch for an HTTP inference request.
*
* @default - No batch strategy
*/
readonly batchStrategy?: BatchStrategy;
/**
* Environment variables to set in the Docker container.
*
* @default - No environment variables
*/
readonly environment?: { [key: string]: string };
/**
* Maximum number of parallel requests that can be sent to each instance in a transform job.
*
* @default - Amazon SageMaker checks the optional execution-parameters to determine the settings for your chosen algorithm.
* If the execution-parameters endpoint is not enabled, the default value is 1.
*/
readonly maxConcurrentTransforms?: number;
/**
* Maximum allowed size of the payload, in MB.
*
* @default 6
*/
readonly maxPayload?: Size;
/**
* Name of the model that you want to use for the transform job.
*/
readonly modelName: string;
/**
* Configures the timeout and maximum number of retries for processing a transform job invocation.
*
* @default - 0 retries and 60 seconds of timeout
*/
readonly modelClientOptions?: ModelClientOptions;
/**
* Tags to be applied to the train job.
*
* @default - No tags
*/
readonly tags?: { [key: string]: string };
/**
* Dataset to be transformed and the Amazon S3 location where it is stored.
*/
readonly transformInput: TransformInput;
/**
* S3 location where you want Amazon SageMaker to save the results from the transform job.
*/
readonly transformOutput: TransformOutput;
/**
* ML compute instances for the transform job.
*
* @default - 1 instance of type M4.XLarge
*/
readonly transformResources?: TransformResources;
}
/**
* Class representing the SageMaker Create Transform Job task.
*
*/
export class SageMakerCreateTransformJob extends sfn.TaskStateBase {
private static readonly SUPPORTED_INTEGRATION_PATTERNS: sfn.IntegrationPattern[] = [
sfn.IntegrationPattern.REQUEST_RESPONSE,
sfn.IntegrationPattern.RUN_JOB,
];
protected readonly taskPolicies?: iam.PolicyStatement[];
protected readonly taskMetrics?: sfn.TaskMetricsConfig;
/**
* Dataset to be transformed and the Amazon S3 location where it is stored.
*/
private readonly transformInput: TransformInput;
/**
* ML compute instances for the transform job.
*/
private readonly transformResources: TransformResources;
private readonly integrationPattern: sfn.IntegrationPattern;
private _role?: iam.IRole;
constructor(scope: Construct, id: string, private readonly props: SageMakerCreateTransformJobProps) {
super(scope, id, props);
this.integrationPattern = props.integrationPattern || sfn.IntegrationPattern.REQUEST_RESPONSE;
validatePatternSupported(this.integrationPattern, SageMakerCreateTransformJob.SUPPORTED_INTEGRATION_PATTERNS);
// set the sagemaker role or create new one
if (props.role) {
this._role = props.role;
}
// set the S3 Data type of the input data config objects to be 'S3Prefix' if not defined
this.transformInput = props.transformInput.transformDataSource.s3DataSource.s3DataType
? props.transformInput
: Object.assign({}, props.transformInput, {
transformDataSource: { s3DataSource: { ...props.transformInput.transformDataSource.s3DataSource, s3DataType: S3DataType.S3_PREFIX } },
});
// set the default value for the transform resources
this.transformResources = props.transformResources || {
instanceCount: 1,
instanceType: ec2.InstanceType.of(ec2.InstanceClass.M4, ec2.InstanceSize.XLARGE),
};
this.taskPolicies = this.makePolicyStatements();
}
/**
* @internal
*/
protected _renderTask(): any {
return {
Resource: integrationResourceArn('sagemaker', 'createTransformJob', this.integrationPattern),
Parameters: sfn.FieldUtils.renderObject(this.renderParameters()),
};
}
/**
* The execution role for the Sagemaker transform job.
*
* Only available after task has been added to a state machine.
*/
public get role(): iam.IRole {
if (this._role === undefined) {
throw new Error('role not available yet--use the object in a Task first');
}
return this._role;
}
private renderParameters(): { [key: string]: any } {
return {
...(this.props.batchStrategy ? { BatchStrategy: this.props.batchStrategy } : {}),
...renderEnvironment(this.props.environment),
...(this.props.maxConcurrentTransforms ? { MaxConcurrentTransforms: this.props.maxConcurrentTransforms } : {}),
...(this.props.maxPayload ? { MaxPayloadInMB: this.props.maxPayload.toMebibytes() } : {}),
...this.props.modelClientOptions ? this.renderModelClientOptions(this.props.modelClientOptions) : {},
ModelName: this.props.modelName,
...renderTags(this.props.tags),
...this.renderTransformInput(this.transformInput),
TransformJobName: this.props.transformJobName,
...this.renderTransformOutput(this.props.transformOutput),
...this.renderTransformResources(this.transformResources),
};
}
private renderModelClientOptions(options: ModelClientOptions): { [key: string]: any } {
const retries = options.invocationsMaxRetries;
if (!Token.isUnresolved(retries) && retries? (retries < 0 || retries > 3): false) {
throw new Error(`invocationsMaxRetries should be between 0 and 3. Received: ${retries}.`);
}
const timeout = options.invocationsTimeout?.toSeconds();
if (!Token.isUnresolved(timeout) && timeout? (timeout < 1 || timeout > 3600): false) {
throw new Error(`invocationsTimeout should be between 1 and 3600 seconds. Received: ${timeout}.`);
}
return {
ModelClientConfig: {
InvocationsMaxRetries: retries ?? 0,
InvocationsTimeoutInSeconds: timeout ?? 60,
},
};
}
private renderTransformInput(input: TransformInput): { [key: string]: any } {
return {
TransformInput: {
...(input.compressionType ? { CompressionType: input.compressionType } : {}),
...(input.contentType ? { ContentType: input.contentType } : {}),
DataSource: {
S3DataSource: {
S3Uri: input.transformDataSource.s3DataSource.s3Uri,
S3DataType: input.transformDataSource.s3DataSource.s3DataType,
},
},
...(input.splitType ? { SplitType: input.splitType } : {}),
},
};
}
private renderTransformOutput(output: TransformOutput): { [key: string]: any } {
return {
TransformOutput: {
S3OutputPath: output.s3OutputPath,
...(output.encryptionKey ? { KmsKeyId: output.encryptionKey.keyArn } : {}),
...(output.accept ? { Accept: output.accept } : {}),
...(output.assembleWith ? { AssembleWith: output.assembleWith } : {}),
},
};
}
private renderTransformResources(resources: TransformResources): { [key: string]: any } {
return {
TransformResources: {
InstanceCount: resources.instanceCount,
InstanceType: sfn.JsonPath.isEncodedJsonPath(resources.instanceType.toString())
? resources.instanceType.toString() : `ml.${resources.instanceType}`,
...(resources.volumeEncryptionKey ? { VolumeKmsKeyId: resources.volumeEncryptionKey.keyArn } : {}),
},
};
}
private makePolicyStatements(): iam.PolicyStatement[] {
const stack = Stack.of(this);
// create new role if doesn't exist
if (this._role === undefined) {
this._role = new iam.Role(this, 'SagemakerTransformRole', {
assumedBy: new iam.ServicePrincipal('sagemaker.amazonaws.com'),
managedPolicies: [iam.ManagedPolicy.fromAwsManagedPolicyName('AmazonSageMakerFullAccess')],
});
}
// https://docs.aws.amazon.com/step-functions/latest/dg/sagemaker-iam.html
const policyStatements = [
new iam.PolicyStatement({
actions: ['sagemaker:CreateTransformJob', 'sagemaker:DescribeTransformJob', 'sagemaker:StopTransformJob'],
resources: [
stack.formatArn({
service: 'sagemaker',
resource: 'transform-job',
resourceName: '*',
}),
],
}),
new iam.PolicyStatement({
actions: ['sagemaker:ListTags'],
resources: ['*'],
}),
new iam.PolicyStatement({
actions: ['iam:PassRole'],
resources: [this.role.roleArn],
conditions: {
StringEquals: { 'iam:PassedToService': 'sagemaker.amazonaws.com' },
},
}),
];
if (this.integrationPattern === sfn.IntegrationPattern.RUN_JOB) {
policyStatements.push(
new iam.PolicyStatement({
actions: ['events:PutTargets', 'events:PutRule', 'events:DescribeRule'],
resources: [
stack.formatArn({
service: 'events',
resource: 'rule',
resourceName: 'StepFunctionsGetEventsForSageMakerTransformJobsRule',
}),
],
}),
);
}
return policyStatements;
}
}