In [None]:
import sagemaker

print(sagemaker.__version__)

sess = sagemaker.Session()
bucket = sess.default_bucket()  
prefix = 'pascalvoc-segmentation'

In [None]:
s3_train_data = 's3://{}/{}/input/train'.format(bucket, prefix)
s3_validation_data = 's3://{}/{}/input/validation'.format(bucket, prefix)
s3_train_annotation_data = 's3://{}/{}/input/train_annotation'.format(bucket, prefix)
s3_validation_annotation_data = 's3://{}/{}/input/validation_annotation'.format(bucket, prefix)
s3_output_location = 's3://{}/{}/output'.format(bucket, prefix)

print(s3_train_data)
print(s3_validation_data)
print(s3_train_annotation_data)
print(s3_validation_annotation_data)

In [None]:
from sagemaker.image_uris import retrieve

region = sess.boto_session.region_name    
container = retrieve('semantic-segmentation', region)
print(container)

In [None]:
from sagemaker import get_execution_role
 
role = get_execution_role()

seg = sagemaker.estimator.Estimator(
    container,
    sagemaker.get_execution_role(),
    instance_count = 1, 
    instance_type = 'ml.p3.2xlarge',
    output_path = s3_output_location)

In [None]:
seg.set_hyperparameters(
    backbone='resnet-50', 
    algorithm='fcn',              
    use_pretrained_model=True, 
    num_classes=21,
    epochs=30,
    num_training_samples=1464) 

In [None]:
from sagemaker import TrainingInput

train_data = TrainingInput(
    s3_train_data, distribution='FullyReplicated', 
    content_type='image/jpeg')

validation_data = TrainingInput(
    s3_validation_data,
    content_type='image/jpeg')

train_annotation = TrainingInput(
    s3_train_annotation_data,
    content_type='image/png')

validation_annotation = TrainingInput(
    s3_validation_annotation_data,
    content_type='image/png')

data_channels = {'train': train_data, 
                 'validation': validation_data,
                 'train_annotation': train_annotation, 
                 'validation_annotation':validation_annotation}

In [None]:
seg.fit(inputs=data_channels)

In [None]:
seg_predictor = seg.deploy(
    initial_instance_count=1,
    instance_type='ml.c5.2xlarge')

In [None]:
!wget -O test.jpg https://upload.wikimedia.org/wikipedia/commons/e/ea/SilverMorgan.jpg
filename = 'test.jpg'

Let's convert the image to bytearray before we supply it to our endpoint.

In [None]:
import matplotlib.pyplot as plt
import PIL

im = PIL.Image.open(filename)
im.save(filename, "JPEG")

%matplotlib inline
plt.imshow(im)
plt.axis('off')

In [None]:
import boto3, json
import numpy as np

runtime = boto3.Session().client(service_name='runtime.sagemaker')

with open(filename, 'rb') as f:
    payload = f.read()
    payload = bytearray(payload)

response = runtime.invoke_endpoint(EndpointName=seg_predictor.endpoint_name, 
                                   ContentType='image/jpeg', 
                                   Accept='image/png', 
                                   Body=payload)

result = response['Body'].read()

Let's display the segmentation mask.

In [None]:
import PIL
from PIL import Image
import numpy as np
import io

num_classes = 21
mask = np.array(Image.open(io.BytesIO(result)))
plt.imshow(mask, vmin=0, vmax=num_classes-1, cmap='gray_r')
plt.show()

In [None]:
response = runtime.invoke_endpoint(EndpointName=seg_predictor.endpoint_name, 
                                   ContentType='image/jpeg', 
                                   Accept='application/x-protobuf', 
                                   Body=payload)

result = response['Body'].read()

results_file = 'results.rec'
with open(results_file, 'wb') as f:
    f.write(result)

In [None]:
%%sh
pip install mxnet

In [None]:
from sagemaker.amazon.record_pb2 import Record
import mxnet as mx

rec = Record()
recordio = mx.recordio.MXRecordIO(results_file, 'r')
protobuf = rec.ParseFromString(recordio.read())

values = list(rec.features["target"].float32_tensor.values)
shape = list(rec.features["shape"].int32_tensor.values)

In [None]:
print(shape)
print(len(values))
mask = np.reshape(np.array(values), shape)

In [None]:
pixel_probs = mask[0,:,0,0]
print(pixel_probs)
print(np.argmax(pixel_probs))

In [None]:
print(mask.shape)

In [None]:
seg_predictor.delete_endpoint()