In [None]:
import os

# environment setup
with open(".env", "r") as key_file:
    keys = list(key_file)

for item in keys:
    variable, value = item.split("=")[0], "=".join(item.split("=")[1:])
    os.environ[variable] = value.replace("\n", "")

In [None]:
# Import libraries
import matplotlib.pyplot as plt
import numpy as np
import sagemaker
import sagemaker, subprocess, boto3

from PIL import Image
from datetime import datetime
from sagemaker import s3, get_execution_role
from sagemaker.pytorch import PyTorchModel, PyTorchPredictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

# model 1: SAM

In [None]:
# Get the right credentials, role and client for SageMaker
sm_client = boto3.client(service_name="sagemaker", region_name="ap-southeast-1")
role = os.environ["SAGEMAKER_ROLE"]
print(f'Role: {role}')

INSTANCE_TYPE_SAM = 'ml.g4dn.xlarge'
INSTANCE_TYPE_INPAINTING = 'ml.g4dn.2xlarge'

bashCommand = "tar -cpzf  code.tar.gz code/"
process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE)
output, error = process.communicate()

s3_client = boto3.client('s3')
s3_resource = boto3.resource('s3')
sts = boto3.client('sts')
AWS_ACCOUNT_ID = sts.get_caller_identity()["Account"]
REGION = s3_client.meta.region_name

bucket = 'inpainting-test-s3'
response = s3_client.list_buckets()
for bucket in response['Buckets']:
    if 'inpainting-test-s3' in bucket["Name"]:
        bucket = bucket["Name"]
        break
print(f'Bucket: {bucket}')

sess = sagemaker.Session(default_bucket=bucket.split('s3://')[-1])

In [None]:
SAM_ENDPOINT_NAME = 'sam-pytorch-' + str(datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S-%f'))
%store SAM_ENDPOINT_NAME

prefix_sam = "SAM/demo-custom-endpoint"

model_data_sam = s3.S3Uploader.upload("code.tar.gz", f's3://{bucket}/{prefix_sam}')
print(f'Model Data: {model_data_sam}')

In [None]:
model_sam = PyTorchModel(entry_point='inference_sam.py',
                     model_data=model_data_sam, 
                     framework_version='1.12', 
                     py_version='py38',
                     role=role,
                     env={'TS_MAX_RESPONSE_SIZE':'2000000000', 'SAGEMAKER_MODEL_SERVER_TIMEOUT' : '300'},
                     sagemaker_session=sess,
                     name='model-'+SAM_ENDPOINT_NAME)

print(f'SAM Endpoint Name: {SAM_ENDPOINT_NAME}')

predictor_sam = model_sam.deploy(initial_instance_count=1, 
                         instance_type=INSTANCE_TYPE_SAM,
                         deserializers=JSONDeserializer(),
                         endpoint_name=SAM_ENDPOINT_NAME)

# Model 2: Inpainting

In [None]:
INPAINTING_ENDPOINT_NAME = 'inpainting-pytorch-' + str(datetime.utcnow().strftime('%Y-%m-%d-%H-%M-%S-%f'))
%store INPAINTING_ENDPOINT_NAME

prefix_inpainting = "InPainting/demo-custom-endpoint"

model_data_inpainting = s3.S3Uploader.upload("code.tar.gz", f"s3://{bucket}/{prefix_inpainting}")
print(f'Model Data: {model_data_inpainting}')

model_inpainting = PyTorchModel(entry_point='inference_inpainting.py',
                     model_data=model_data_inpainting, 
                     framework_version='1.12', 
                     py_version='py38',
                     role=role,
                     env={'TS_MAX_RESPONSE_SIZE':'2000000000', 'SAGEMAKER_MODEL_SERVER_TIMEOUT' : '300'},
                     sagemaker_session=sess,
                     name='model-'+INPAINTING_ENDPOINT_NAME)

print(f'InPainting Endpoint Name: {INPAINTING_ENDPOINT_NAME}')

predictor_inpainting = model_inpainting.deploy(initial_instance_count=1, 
                         instance_type=INSTANCE_TYPE_INPAINTING,
                         serializer=JSONSerializer(),
                         deserializers=JSONDeserializer(),
                         endpoint_name=INPAINTING_ENDPOINT_NAME,
                        #  volume_size=128
                         )

# Inference

In [None]:
sm_client.list_endpoints()

In [None]:
SAM_ENDPOINT_NAME = "inpainting-pytorch-2023-12-04-12-04-31-281536"

print(f'SAM Endpoint Name: {SAM_ENDPOINT_NAME}')

raw_image = Image.open("images/speaker.png").convert("RGB")

predictor_sam = PyTorchPredictor(endpoint_name=SAM_ENDPOINT_NAME,
                             deserializer=JSONDeserializer())

output_array = predictor_sam.predict(raw_image, initial_args={'Accept': 'application/json'})

mask_image = Image.fromarray(np.array(output_array).astype(np.uint8))

# save the image using PIL Image
mask_image.save('images/speaker_mask.png')

# We are going to plot the outputs
plot_images = [raw_image, mask_image]
titles = ['Original Product Image', 'Mask']
fig, ax = plt.subplots(1,len(plot_images), dpi = 200)
for k1, img in enumerate(plot_images):
    ax[k1].imshow(img); ax[k1].axis('off')
    ax[k1].set_title(titles[k1], fontsize=6)

In [None]:
raw_image = Image.open("images/speaker.png").convert("RGB")
mask_image = Image.open('images/speaker_mask.png').convert('RGB')
prompt_fr = "apple, books"
prompt_bg = "table"
negative_prompt = "longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, letters" 

inputs = {}
inputs["image"] = np.array(raw_image)
inputs["mask"] = np.array(mask_image)
inputs["prompt_fr"] = prompt_fr
inputs["prompt_bg"] = prompt_bg
inputs["negative_prompt"] = negative_prompt

predictor_inpainting = PyTorchPredictor(endpoint_name=INPAINTING_ENDPOINT_NAME,
                             serializer=JSONSerializer(),
                             deserializer=JSONDeserializer())


output_array = predictor_inpainting.predict(inputs, initial_args={'Accept': 'application/json'})

gai_mask = Image.fromarray(np.array(output_array[2]).astype(np.uint8))
gai_background = Image.fromarray(np.array(output_array[1]).astype(np.uint8))
gai_image = Image.fromarray(np.array(output_array[0]).astype(np.uint8))
post_image = Image.fromarray(np.array(output_array[3]).astype(np.uint8))

# We are going to plot the outputs
plot_images = [gai_mask, gai_background, gai_image, post_image]
titles = ['Refined Mask', 'Generated Background', 'Generated Product Image', 'Post Process Image']
fig, ax = plt.subplots(1,len(plot_images), dpi = 200)
for k1, img in enumerate(plot_images):
    ax[k1].imshow(img); ax[k1].axis('off')
    ax[k1].set_title(titles[k1], fontsize=5)


# save the generated image using PIL Image
post_image.save('images/speaker_generated.png')

# Cleanup

In [None]:
sm_client.list_endpoints()

In [None]:
SAM_ENDPOINT_NAME = "sam-pytorch-2023-12-04-11-56-31-587907"
response = sm_client.describe_endpoint_config(EndpointConfigName=SAM_ENDPOINT_NAME)
print(response)
endpoint_config_name = response['EndpointConfigName']

# Delete Endpoint
sm_client.delete_endpoint(EndpointName=SAM_ENDPOINT_NAME)

# Delete Endpoint Configuration
sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)

# Delete Model
for prod_var in response['ProductionVariants']:
    model_name = prod_var['ModelName']
    sm_client.delete_model(ModelName=model_name) 