In [None]:
!pip install tensorflow==2.4.1 -U

In [None]:
import tensorflow as tf
from sagemaker.tensorflow import TensorFlow
from tensorflow.keras.datasets import mnist
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras.models import load_model
import sagemaker

In [None]:
print(tf.__version__)

In [None]:
TRAIN_X_PATH = './train_X.npy'
TEST_X_PATH = './test_X.npy'
TRAIN_Y_PATH = './train_y.npy'
TEST_Y_PATH = './test_y.npy'
(train_X, train_y), (test_X, test_y) = mnist.load_data()
train_X = (train_X-127.5)/127.5
test_X = (test_X-127.5)/127.5
# channel last
train_X = train_X.reshape((train_X.shape[0],train_X.shape[1],train_X.shape[2],1))
test_X = test_X.reshape((test_X.shape[0],test_X.shape[1],test_X.shape[2],1))
# one-hot
train_y = np.eye(10)[train_y]
test_y = np.eye(10)[test_y]
np.save(TRAIN_X_PATH,train_X)
np.save(TEST_X_PATH,test_X)
np.save(TRAIN_Y_PATH,train_y)
np.save(TEST_Y_PATH,test_y)

In [None]:
train_X.shape,train_y.shape,test_X.shape,test_y.shape

In [None]:
role = sagemaker.get_execution_role()
sess = sagemaker.session.Session()
train_X_uri = sess.upload_data(path=TRAIN_X_PATH, bucket=sess.default_bucket(), key_prefix='sagemaker/mnist')
train_y_uri = sess.upload_data(path=TRAIN_Y_PATH, bucket=sess.default_bucket(), key_prefix='sagemaker/mnist')
test_X_uri = sess.upload_data(path=TEST_X_PATH, bucket=sess.default_bucket(), key_prefix='sagemaker/mnist')
test_y_uri = sess.upload_data(path=TEST_Y_PATH, bucket=sess.default_bucket(), key_prefix='sagemaker/mnist')
print(train_X_uri)
print(train_y_uri)
print(test_X_uri)
print(test_y_uri)

## DCGAN Train

In [None]:
estimator = TensorFlow(
    entry_point='./src/dcgan_train.py',
    role=role,
    instance_count=1,
    instance_type='ml.p3.2xlarge',
    framework_version='2.4.1',
    py_version='py37',
    hyperparameters={
        'sagemaker_s3_output':'s3://sagemaker-ap-northeast-1-155580384669/mnist_dcgan_intermediate',
        'epochs':5,
    }
)

In [None]:
%%time
estimator.fit({
    'train': train_X_uri,
})

In [None]:
generator_model_uri = estimator.latest_training_job.describe()['ModelArtifacts']['S3ModelArtifacts']
print(generator_model_uri)

## Classifier Train

In [None]:
estimator = TensorFlow(
    entry_point='./src/classifier_train.py',
    role=role,
    instance_count=1,
    instance_type='ml.p3.2xlarge',
    framework_version='2.4.1',
    py_version='py37',
    hyperparameters={
        'epochs':10,
    }
)

In [None]:
%%time

print(train_X_uri[:-11])

estimator.fit({
    'train': train_X_uri[:-11],
})

In [None]:
classifier_model_uri = estimator.latest_training_job.describe()['ModelArtifacts']['S3ModelArtifacts']
print(classifier_model_uri)

## Check the Model

In [None]:
!aws s3 cp {generator_model_uri} .
!mkdir -p ./src/ggv2/components/artifacts/com.example.Publisher/1.0.0
!tar zxvf model.tar.gz -C ./src/ggv2/components/artifacts/com.example.Publisher/1.0.0
!rm model.tar.gz 
!aws s3 cp {classifier_model_uri} .
!mkdir -p ./src/ggv2/components/artifacts/com.example.Subscriber/1.0.0
!mkdir -p ./src/ggv2/components/artifacts/com.example.Subscriber/1.0.1
!tar zxvf model.tar.gz -C ./src/ggv2/components/artifacts/com.example.Subscriber/1.0.0
!mv ./src/ggv2/components/artifacts/com.example.Subscriber/1.0.0/2.h5 ./src/ggv2/components/artifacts/com.example.Subscriber/1.0.1/2.h5
!rm model.tar.gz

In [None]:
generator = tf.keras.models.load_model('./src/ggv2/components/artifacts/com.example.Publisher/1.0.0/1.h5')
classifier1 = tf.keras.models.load_model('./src/ggv2/components/artifacts/com.example.Subscriber/1.0.0/1.h5')
classifier2 = tf.keras.models.load_model('./src/ggv2/components/artifacts/com.example.Subscriber/1.0.1/2.h5')

In [None]:
generator.summary()

In [None]:
classifier.summary()

In [None]:
pred_X = np.random.uniform(-1, 1, (1,7,7,1))
pred_y = generator.predict(pred_X)
from matplotlib import pyplot as plt
plt.imshow(pred_y[0,:,:,0],'gray')

In [None]:
np.argmax(classifier.predict(pred_y))