In [30]:
%%writefile processing.py
import boto3
import json
import re
import base64
import random
import string
from datetime import datetime
import argparse
from pyspark.sql import SparkSession

def generate_inference_id(length):
    chars = string.ascii_letters + string.digits
    inference_id = ''.join(random.choice(chars) for _ in range(length))
    return inference_id

def encode_image(bucket_name, key):
    s3 = boto3.client('s3', region_name=AWS_REGION)  
    body = s3.get_object(Bucket=bucket_name, Key=key)['Body'].read()
    return base64.b64encode(body).decode("utf-8")    
  
def convert_image_to_payload(prompt, encoded_image):
    inputs = {"prompt": prompt, "image": encoded_image}   
    return json.dumps(inputs)

def upload_payload_to_s3(bucket_name, key, body):
    s3 = boto3.client('s3', region_name=AWS_REGION)  
    s3.put_object(Bucket=bucket_name, Key=key, Body=body)
    return key

def process_image_and_save_to_s3(prompt, file_s3_path, input_s3_prefix):   
    # s3 = boto3.client('s3')  
    # Use a regex to parse out the bucket name and key
    match = re.match("s3://(.+?)/(.+)", file_s3_path)
    bucket_name = match.group(1)
    key = match.group(2)
    # Split the key on '/' to get the file name
    path_parts = key.split("/")
    file_name = path_parts[-1]    
    input_name = file_name.replace('.jpg', '.json') 
    input_key = input_s3_prefix + '/' + input_name  
    encoded_body = convert_image_to_payload(prompt, encode_image(bucket_name, key))
    upload_payload_to_s3(bucket_name, input_key, encoded_body)    
    return f's3://{bucket_name}/{input_key}'

def invoke_endpoint(endpoint_name, input_location):
    sagemaker_client = boto3.client('sagemaker-runtime', region_name=AWS_REGION)
    response = sagemaker_client.invoke_endpoint_async(
        EndpointName = endpoint_name,
        InputLocation = input_location,
        InferenceId = generate_inference_id(40)     
    )
    request_time = datetime.strptime(response['ResponseMetadata']['HTTPHeaders']['date'], "%a, %d %b %Y %H:%M:%S %Z").strftime("%Y-%m-%dT%H:%M:%S.000Z")
    inference_id = response['InferenceId']
    output_location = response['OutputLocation']
    return inference_id, output_location, request_time

def ddb_registration(table_name, prompt, inference_id, endpoint_name, input_location, output_location, request_time):
    dynamodb = boto3.resource('dynamodb', region_name=AWS_REGION)
    table = dynamodb.Table(table_name)
    item = {
        'inference_id': inference_id,
        'prompt': prompt,
        'endpoint_name' : endpoint_name,
        'input_location' : input_location, 
        'output_location' : output_location, 
        'request_time' : request_time
        
    }
    table.put_item(Item=item)
    
def execute(file_s3_path):
    prompt = PROMPT
    input_s3_prefix = INPUT_S3_PREFIX
    endpoint_name = ENDPOINT_NAME
    table_name = TABLE_NAME   
    input_location = process_image_and_save_to_s3(prompt, file_s3_path, input_s3_prefix)
    inference_id, output_location, request_time = invoke_endpoint(endpoint_name, input_location)
    ddb_registration(table_name, prompt, inference_id, endpoint_name, input_location, output_location, request_time)
    
if __name__ == "__main__":    
    
    parser = argparse.ArgumentParser(description="app configuration")
    parser.add_argument("--prompt", type=str, help="prompt for the images")
    parser.add_argument("--endpoint_name", type=str, help="async endpoint name")
    parser.add_argument("--input_s3_prefix", type=str, help="the prefix of the s3 input for invocation")
    parser.add_argument("--s3_path", type=str, help="the s3 path to the raw images")
    parser.add_argument("--table_name", type=str, help="DDB table")
    parser.add_argument("--aws_region", type=str, help="AWS region")
     
    args, _ = parser.parse_known_args()
    print("Received arguments {}".format(args))
    
    PROMPT =  args.prompt
    ENDPOINT_NAME = args.endpoint_name
    INPUT_S3_PREFIX = args.input_s3_prefix
    S3_PATH = args.s3_path
    TABLE_NAME = args.table_name 
    AWS_REGION= args.aws_region
    
    spark = SparkSession.builder.appName('PySparkApp').getOrCreate()

    file_list = spark.sparkContext.wholeTextFiles(S3_PATH).map(lambda x: x[0])
    file_list.foreach(execute)

    spark.stop()




Overwriting processing.py


In [None]:
import sagemaker
sagemaker_session = sagemaker.Session()
bucket = sagemaker_session.default_bucket()
role = sagemaker.get_execution_role()
from time import gmtime, strftime

from sagemaker.spark.processing import PySparkProcessor
timestamp_prefix = strftime("%Y-%m-%d-%H-%M-%S", gmtime())
prefix = "sagemaker/spark-preprocess-demo/{}".format(timestamp_prefix)


PROMPT = 'Question: what can I see in this photo? Answer:'
ENDPOINT_NAME = 'endpoint-blip2-flan-t5-xl-2023-11-10-23-41-58-817'
INPUT_S3_PREFIX = 'blip2/spark/testrun2' 
S3_PATH = 's3://<BUCKET>/blip2/images/'
TABLE_NAME = 'sparktable' 
AWS_REGION = 'us-west-2'

# Run the processing job
spark_processor = PySparkProcessor(
    base_job_name="sm-spark",
    framework_version="3.1",
    role=role,
    instance_count=2,
    instance_type="ml.m5.xlarge",
    max_runtime_in_seconds=1200,
)

spark_processor.run(
    submit_app="processing.py",
    arguments=[
        "--prompt",
        PROMPT,
        "--endpoint_name",
        ENDPOINT_NAME,
        "--input_s3_prefix",
        INPUT_S3_PREFIX,
        "--s3_path",
        S3_PATH,
        "--table_name",
        TABLE_NAME,
        "--aws_region",
        AWS_REGION
    ],
    spark_event_logs_s3_uri="s3://{}/{}/spark_event_logs".format(bucket, prefix),
    logs=True,
)

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /root/.config/sagemaker/config.yaml


INFO:sagemaker:Creating processing-job with name sm-spark-2023-11-11-02-40-45-056


...................