In [None]:
"""
機械学習モデルの訓練および、tar.gzからのモデル抽出を行なう。モデルの構造やデータ変換のロジックはfashion_mnist.pyにある。
"""

import os
import pickle
import sagemaker
import keras

from keras.datasets import fashion_mnist
from sagemaker.tensorflow import TensorFlow

os.makedirs('data', exist_ok=True)
fout = open('data/fashion_mnist.pickle', 'wb')
pickle.dump(fashion_mnist.load_data(), fout)

In [None]:
session = sagemaker.Session()
bucket_name = session.default_bucket()
inputs = session.upload_data(path='data', bucket=bucket_name, key_prefix='data')

In [None]:
role = sagemaker.get_execution_role()
tf_estimator = TensorFlow(
	entry_point='fashion_mnist.py',
	role=role,
	train_instance_count=1,
	train_instance_type='ml.m5.xlarge',
	framework_version='1.12.0',
	py_version='py3',
	script_mode=True)

tf_estimator.fit(inputs)

In [None]:
import tarfile
import boto3

job_name = tf_estimator.latest_training_job.name

s3 = boto3.resource('s3')
bucket = s3.Bucket(bucket_name)

bucket.download_file(os.path.join(job_name, 'output/output.tar.gz'), 'output.tar.gz')
tarfile.open('output.tar.gz', 'r:gz').extractall()
model = keras.models.load_model('model.h5')

In [None]:
model