In [None]:
import boto3
import sagemaker

session = sagemaker.Session()
bucket = session.default_bucket()

### Define channels

In [None]:
# Fully replicated, File Mode

from sagemaker.session import s3_input

prefix = 'imagenet'
s3_train_path = 's3://{}/{}/input/train/'.format(bucket, prefix)
s3_val_path   = 's3://{}/{}/input/validation/'.format(bucket, prefix)
s3_output     = 's3://{}/{}/output/'.format(bucket, prefix)

train_data = s3_input(s3_train_path,
                      distribution='FullyReplicated', 
                      content_type='application/x-recordio',
                      s3_data_type='S3Prefix',
                      input_mode='File')

validation_data = s3_input(s3_val_path,
                           distribution='FullyReplicated', 
                           content_type='application/x-recordio', 
                           s3_data_type='S3Prefix',
                           input_mode='File')

In [None]:
# Fully replicated, Pipe Mode

prefix = 'imagenet-split'
s3_train_path = 's3://{}/{}/input/training/'.format(bucket, prefix)
s3_val_path   = 's3://{}/{}/input/validation/'.format(bucket, prefix)
s3_output     = 's3://{}/{}/output/'.format(bucket, prefix)

from sagemaker.session import s3_input, ShuffleConfig

train_data = s3_input(s3_train_path, 
                      distribution='FullyReplicated',
                      shuffle_config=ShuffleConfig(59),
                      content_type='application/x-recordio',
                      s3_data_type='S3Prefix',
                      input_mode='Pipe')

validation_data = s3_input(s3_val_path, 
                           distribution='FullyReplicated',
                           content_type='application/x-recordio', 
                           s3_data_type='S3Prefix',
                           input_mode='Pipe')

In [None]:
print(s3_train_path)
print(s3_val_path)
print(s3_output)

In [None]:
s3_channels = {'train': train_data, 'validation': validation_data}

### Get the name of the image classification algorithm in our region

In [None]:
from sagemaker.amazon.amazon_estimator import get_image_uri

region_name = boto3.Session().region_name
container = get_image_uri(region_name, "image-classification", "latest")
print(container)

### Configure the training job

In [None]:
role = sagemaker.get_execution_role()

ic = sagemaker.estimator.Estimator(container,
                                   role, 
                                   train_instance_count=8, 
                                   train_instance_type='ml.p3dn.24xlarge',
                                   output_path=s3_output)

### Set algorithm parameters

In [None]:
ic.set_hyperparameters(num_layers=50,                 # Train a Resnet-50 model
                       use_pretrained_model=0,        # Train from scratch
                       num_classes=1000,              # 1000 ImageNet classes
                       num_training_samples=1281167,  # Number of training samples
                       mini_batch_size=2816,
                       learning_rate=0.5,
                       lr_scheduler_factor=0.5,
                       lr_scheduler_step='30,60,90,120,150,180',
                       kv_store='dist_sync',
                       augmentation_type='crop',
                       early_stopping=True,
                       early_stopping_patience=30,
                       top_k=3,
                       epochs=200)                     

### Train the model

In [None]:
ic.fit(inputs=s3_channels)

### Deploy the model

In [None]:
ic_predictor = ic.deploy(initial_instance_count=1,
                         instance_type='ml.t2.medium')

### Download a test image

In [None]:
# Dog
!wget -O /tmp/test.jpg http://www.vision.caltech.edu/Image_Datasets/Caltech256/images/056.dog/056_0010.jpg
file_name = '/tmp/test.jpg'
from IPython.display import Image
Image(file_name)

### Predict test image

In [None]:
# Load test image from file
with open(file_name, 'rb') as f:
    payload = f.read()
    payload = bytearray(payload)

# Set content type
ic_predictor.content_type = 'application/x-image'

# Predict image and print result
result = ic_predictor.predict(payload)
print(result)

### Delete endpoint

In [None]:
ic_predictor.delete_endpoint()