In [None]:
import sagemaker

print(sagemaker.__version__)

sess = sagemaker.Session()
bucket = sess.default_bucket()  
prefix = 'pascalvoc-segmentation'

In [None]:
s3_train_data = 's3://{}/{}/input/train'.format(bucket, prefix)
s3_validation_data = 's3://{}/{}/input/validation'.format(bucket, prefix)
s3_train_annotation_data = 's3://{}/{}/input/train_annotation'.format(bucket, prefix)
s3_validation_annotation_data = 's3://{}/{}/input/validation_annotation'.format(bucket, prefix)
s3_output_location = 's3://{}/{}/output'.format(bucket, prefix)

print(s3_train_data)
print(s3_validation_data)
print(s3_train_annotation_data)
print(s3_validation_annotation_data)

In [None]:
import boto3
from sagemaker import image_uris

region = boto3.Session().region_name    
container = image_uris.retrieve('semantic-segmentation', region)
print(container)

In [None]:
from sagemaker import get_execution_role
 
role = get_execution_role()

seg = sagemaker.estimator.Estimator(container,
                                    role, 
                                    sagemaker_session = sess,
                                    instance_count = 1, 
                                    instance_type = 'ml.p3.8xlarge',
                                    output_path = s3_output_location)

In [None]:
seg.set_hyperparameters(backbone='resnet-50', 
                             algorithm='fcn',              
                             use_pretrained_model=True, 
                             num_classes=21,
                             epochs=30,
                             num_training_samples=1464) 

Now that the hyperparameters are setup, let us prepare the handshake between our data channels and the algorithm. To do this, we need to create the `sagemaker.session.s3_input` objects from our data channels. These objects are then put in a simple dictionary, which the algorithm uses to train.

In [None]:
train_data = sagemaker.TrainingInput(s3_train_data, distribution='FullyReplicated', 
                                        content_type='image/jpeg', s3_data_type='S3Prefix')

validation_data = sagemaker.TrainingInput(s3_validation_data, distribution='FullyReplicated', 
                                        content_type='image/jpeg', s3_data_type='S3Prefix')

train_annotation = sagemaker.TrainingInput(s3_train_annotation_data, distribution='FullyReplicated', 
                                        content_type='image/png', s3_data_type='S3Prefix')

validation_annotation = sagemaker.TrainingInput(s3_validation_annotation_data, distribution='FullyReplicated', 
                                        content_type='image/png', s3_data_type='S3Prefix')

data_channels = {'train': train_data, 
                 'validation': validation_data,
                 'train_annotation': train_annotation, 
                 'validation_annotation':validation_annotation}

In [None]:
seg.fit(inputs=data_channels)

In [None]:
seg_predictor = seg.deploy(initial_instance_count=1, instance_type='ml.c5.2xlarge')

In [None]:
!wget -O test.jpg https://upload.wikimedia.org/wikipedia/commons/e/ea/SilverMorgan.jpg
filename = 'test.jpg'

Let's convert the image to bytearray before we supply it to our endpoint.

In [None]:
import matplotlib.pyplot as plt
import PIL

im = PIL.Image.open(filename)
im.save(filename, "JPEG")

%matplotlib inline
plt.imshow(im)
plt.axis('off')

In [None]:
seg_predictor.content_type = 'image/jpeg'
seg_predictor.accept = 'image/png'

with open(filename, 'rb') as image:
    img = image.read()
    img = bytearray(img)

response = seg_predictor.predict(img)

Let's display the segmentation mask.

In [None]:
import PIL
from PIL import Image
import numpy as np
import io

num_classes = 21
mask = np.array(Image.open(io.BytesIO(response)))
plt.imshow(mask, vmin=0, vmax=num_classes-1, cmap='gray_r')
plt.show()

In [None]:
seg_predictor.content_type = 'image/jpeg'
seg_predictor.accept = 'application/x-protobuf'
response = seg_predictor.predict(img)

results_file = 'results.rec'
with open(results_file, 'wb') as f:
    f.write(response)

In [None]:
from sagemaker.amazon.record_pb2 import Record
import mxnet as mx

rec = Record()
recordio = mx.recordio.MXRecordIO(results_file, 'r')
protobuf = rec.ParseFromString(recordio.read())

values = list(rec.features["target"].float32_tensor.values)
shape = list(rec.features["shape"].int32_tensor.values)

In [None]:
print(shape)
print(len(values))
mask = np.reshape(np.array(values), shape)

In [None]:
pixel_probs = mask[0,:,0,0]
print(pixel_probs)
print(np.argmax(pixel_probs))

In [None]:
print(mask.shape)

In [None]:
seg_predictor.delete_endpoint()