Skip to content

Commit

Permalink
feat(stepfunctions-tasks): allow BedrockInvokeModel to use JsonPath t…
Browse files Browse the repository at this point in the history
…o specify input/output S3 URIs
  • Loading branch information
shikha372 committed Jun 6, 2024
1 parent 623cedb commit d560c33
Show file tree
Hide file tree
Showing 11 changed files with 255 additions and 18 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,25 @@
]
]
}
},
{
"Action": [
"s3:GetObject",
"s3:PutObject"
],
"Effect": "Allow",
"Resource": {
"Fn::Join": [
"",
[
"arn:",
{
"Ref": "AWS::Partition"
},
":s3:::*"
]
]
}
}
],
"Version": "2012-10-17"
Expand Down Expand Up @@ -72,7 +91,19 @@
{
"Ref": "AWS::Region"
},
"::foundation-model/amazon.titan-text-express-v1\",\"Body\":{\"inputText\":\"Generate a list of five first names.\",\"textGenerationConfig\":{\"maxTokenCount\":100,\"temperature\":1}}}},\"Prompt2\":{\"End\":true,\"Type\":\"Task\",\"ResultPath\":\"$\",\"ResultSelector\":{\"names.$\":\"$.Body.results[0].outputText\"},\"Resource\":\"arn:",
"::foundation-model/amazon.titan-text-express-v1\",\"Body\":{\"inputText\":\"Generate a list of five first names.\",\"textGenerationConfig\":{\"maxTokenCount\":100,\"temperature\":1}}}},\"Prompt2\":{\"Next\":\"Prompt3\",\"Type\":\"Task\",\"ResultPath\":\"$\",\"ResultSelector\":{\"names.$\":\"$.Body.results[0].outputText\"},\"Resource\":\"arn:",
{
"Ref": "AWS::Partition"
},
":states:::bedrock:invokeModel\",\"Parameters\":{\"ModelId\":\"arn:",
{
"Ref": "AWS::Partition"
},
":bedrock:",
{
"Ref": "AWS::Region"
},
"::foundation-model/amazon.titan-text-express-v1\",\"Body\":{\"inputText.$\":\"States.Format('Alphabetize this list of first names:\\n{}', $.names)\",\"textGenerationConfig\":{\"maxTokenCount\":100,\"temperature\":1}}}},\"Prompt3\":{\"End\":true,\"Type\":\"Task\",\"InputPath\":\"$.names\",\"OutputPath\":\"$.names\",\"Resource\":\"arn:",
{
"Ref": "AWS::Partition"
},
Expand All @@ -84,7 +115,7 @@
{
"Ref": "AWS::Region"
},
"::foundation-model/amazon.titan-text-express-v1\",\"Body\":{\"inputText.$\":\"States.Format('Alphabetize this list of first names:\\n{}', $.names)\",\"textGenerationConfig\":{\"maxTokenCount\":100,\"temperature\":1}}}}},\"TimeoutSeconds\":30}"
"::foundation-model/amazon.titan-text-express-v1\",\"Input\":{\"S3Uri.$\":\"$.names\"},\"Output\":{\"S3Uri.$\":\"$.names\"}}}},\"TimeoutSeconds\":30}"
]
]
},
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,13 @@ const prompt2 = new BedrockInvokeModel(stack, 'Prompt2', {
resultPath: '$',
});

const chain = sfn.Chain.start(prompt1).next(prompt2);
const prompt3 = new BedrockInvokeModel(stack, 'Prompt3', {
model,
inputPath: sfn.JsonPath.stringAt('$.names'),
outputPath: sfn.JsonPath.stringAt('$.names'),
});

const chain = sfn.Chain.start(prompt1).next(prompt2).next(prompt3);

new sfn.StateMachine(stack, 'StateMachine', {
definitionBody: sfn.DefinitionBody.fromChainable(chain),
Expand Down
21 changes: 21 additions & 0 deletions packages/aws-cdk-lib/aws-stepfunctions-tasks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,27 @@ const task = new tasks.BedrockInvokeModel(this, 'Prompt Model', {
names: sfn.JsonPath.stringAt('$.Body.results[0].outputText'),
},
});
```
### Using Input Path

Provide S3 URI as an input or output path to invoke a model

```ts

import * as bedrock from 'aws-cdk-lib/aws-bedrock';

const model = bedrock.FoundationModel.fromFoundationModelId(
this,
'Model',
bedrock.FoundationModelIdentifier.AMAZON_TITAN_TEXT_G1_EXPRESS_V1,
);

const task = new tasks.BedrockInvokeModel(this, 'Prompt Model', {
model,
inputPath: sfn.JsonPath.stringAt('$.prompt'),
outputPath: sfn.JsonPath.stringAt('$.prompt'),
});

```

## CodeBuild
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,14 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {

constructor(scope: Construct, id: string, private readonly props: BedrockInvokeModelProps) {
super(scope, id, props);

this.integrationPattern = props.integrationPattern ?? sfn.IntegrationPattern.REQUEST_RESPONSE;

validatePatternSupported(this.integrationPattern, BedrockInvokeModel.SUPPORTED_INTEGRATION_PATTERNS);

const isBodySpecified = props.body !== undefined;
const isInputSpecified = props.input !== undefined && props.input.s3Location !== undefined;
//Either specific props.input with bucket name and object key or input s3 path
const isInputSpecified = (props.input !== undefined && props.input.s3Location !== undefined) || (props.inputPath !== undefined);

if (isBodySpecified && isInputSpecified) {
throw new Error('Either `body` or `input` must be specified, but not both.');
Expand All @@ -155,7 +157,21 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {
}),
];

if (this.props.input !== undefined && this.props.input.s3Location !== undefined) {
if (this.props.inputPath !== undefined) {
policyStatements.push(
new iam.PolicyStatement({
actions: ['s3:GetObject'],
resources: [
Stack.of(this).formatArn({
region: '',
account: '',
service: 's3',
resource: '*',
}),
],
}),
);
} else if (this.props.input !== undefined && this.props.input.s3Location !== undefined) {
policyStatements.push(
new iam.PolicyStatement({
actions: ['s3:GetObject'],
Expand All @@ -172,7 +188,21 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {
);
}

if (this.props.output !== undefined && this.props.output.s3Location !== undefined) {
if (this.props.outputPath !== undefined) {
policyStatements.push(
new iam.PolicyStatement({
actions: ['s3:PutObject'],
resources: [
Stack.of(this).formatArn({
region: '',
account: '',
service: 's3',
resource: '*',
}),
],
}),
);
} else if (this.props.output !== undefined && this.props.output.s3Location !== undefined) {
policyStatements.push(
new iam.PolicyStatement({
actions: ['s3:PutObject'],
Expand Down Expand Up @@ -207,11 +237,13 @@ export class BedrockInvokeModel extends sfn.TaskStateBase {
Body: this.props.body?.value,
Input: this.props.input?.s3Location ? {
S3Uri: `s3://${this.props.input.s3Location.bucketName}/${this.props.input.s3Location.objectKey}`,
} : undefined,
} : this.props.inputPath ? { S3Uri: this.props.inputPath } : undefined,
Output: this.props.output?.s3Location ? {
S3Uri: `s3://${this.props.output.s3Location.bucketName}/${this.props.output.s3Location.objectKey}`,
} : undefined,
} : this.props.outputPath ? { S3Uri: this.props.outputPath }: undefined,
}),
};
}

}

Loading

0 comments on commit d560c33

Please sign in to comment.