# Import a TensorFlow model to SageMaker

In [None]:
!pip -q install keras matplotlib

In [None]:
!pip -q install tensorflow==2.3.1

In [None]:
import os
import keras
import numpy as np
from keras.datasets import fashion_mnist

(x_train, y_train), (x_val, y_val) = fashion_mnist.load_data()

os.makedirs('./fashion-mnist', exist_ok = True)

np.savez('./fashion-mnist/training', image=x_train, label=y_train)
np.savez('./fashion-mnist/validation', image=x_val, label=y_val)

### Train a model locally

In [None]:
gpu_count = 1
batch_size = 128
epochs = 10

In [None]:
import tensorflow as tf
import numpy as np

from model import FMNISTModel

print("TensorFlow version", tf.__version__)

# Load data set
x_train = np.load('fashion-mnist/training.npz')['image']
y_train = np.load('fashion-mnist/training.npz')['label']
x_val  = np.load('fashion-mnist/validation.npz')['image']
y_val  = np.load('fashion-mnist/validation.npz')['label']

# Add extra dimension for channel: (28,28) --> (28, 28, 1)
x_train = x_train[..., tf.newaxis]
x_val   = x_val[..., tf.newaxis]

# Prepare training and validation iterators
#  - define batch size
#  - normalize pixel values to [0,1]
#  - one-hot encode labels
preprocess = lambda x, y: (tf.divide(tf.cast(x, tf.float32), 255.0), 
                           tf.reshape(tf.one_hot(y, 10), (-1, 10)))

if (gpu_count > 1):
    batch_size *= gpu_count
    
train = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size)
train = train.map(preprocess)
train = train.repeat()

val = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size)
val = val.map(preprocess)
val = val.repeat()

# Build model
model = FMNISTModel()

model.compile(optimizer='adam',
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# Train model
train_steps = x_train.shape[0] / batch_size
val_steps   = x_val.shape[0] / batch_size

model.fit(train, 
          epochs=epochs, 
          steps_per_epoch=train_steps, 
          validation_data=val, 
          validation_steps=val_steps)

# save model for Tensorflow Serving
model.save('byo-tensorflow/1')
   

### Package model for SageMaker

In [None]:
import sagemaker

sess = sagemaker.Session()
bucket = sess.default_bucket()
prefix = 'byo-tensorflow'

print(bucket)

In [None]:
%%sh -s $prefix
cd $1
tar cvfz model-tf.tar.gz 1

In [None]:
model_path = sess.upload_data(path=prefix+'/model-tf.tar.gz', key_prefix=prefix)

print(model_path)

### Deploy model on SageMaker

In [None]:
from sagemaker.tensorflow.model import TensorFlowModel

tf_model = TensorFlowModel(
    model_data=model_path,
    framework_version='2.3.1',
    role=sagemaker.get_execution_role())

In [None]:
from time import strftime,gmtime

tf_endpoint_name = 'tf-{}-{}'.format(prefix, strftime("%Y-%m-%d-%H-%M-%S", gmtime()))

tf_predictor = tf_model.deploy(
    endpoint_name=tf_endpoint_name,
    initial_instance_count=1, 
    instance_type='ml.t2.medium')

print(tf_endpoint_name)

### Predict with model

In [None]:
%matplotlib inline
import random
import matplotlib.pyplot as plt

num_samples = 5
indices = random.sample(range(x_val.shape[0] - 1), num_samples)
images = x_val[indices]/255
labels = y_val[indices]

for i in range(num_samples):
    plt.subplot(1,num_samples,i+1)
    plt.imshow(images[i].reshape(28, 28), cmap='gray')
    plt.title(labels[i])
    plt.axis('off')

payload = images.reshape(num_samples, 28, 28, 1)

In [None]:
response = tf_predictor.predict(payload)

prediction = np.array(response['predictions'])
predicted_label = prediction.argmax(axis=1)
print('Predicted labels are: {}'.format(predicted_label))

### Create a predictor for an existing endpoint

In [None]:
from sagemaker.tensorflow.model import TensorFlowPredictor

another_predictor = TensorFlowPredictor(
    endpoint_name=tf_endpoint_name,
    serializer=sagemaker.serializers.JSONSerializer()
)

In [None]:
response = another_predictor.predict(payload)

prediction = np.array(response['predictions'])
predicted_label = prediction.argmax(axis=1)
print('Predicted labels are: {}'.format(predicted_label))

### Updating an endpoint

In [None]:
another_predictor.update_endpoint(
    initial_instance_count=2,
    instance_type='ml.t2.medium')

In [None]:
tf_predictor.delete_endpoint()