In [None]:
############################################################################
## Copyright 2021 Hewlett Packard Enterprise Development LP
## Licensed under the Apache License, Version 2.0 (the "License"); you may
## not use this file except in compliance with the License. You may obtain
## a copy of the License at
##
##    http://www.apache.org/licenses/LICENSE-2.0
##
## Unless required by applicable law or agreed to in writing, software
## distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
## WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
## License for the specific language governing permissions and limitations
## under the License.
############################################################################


In [None]:
import os
import json

# if FG_URL is set 
is_in_interactive_mode = os.environ['FG_URL'] != ""

if is_in_interactive_mode:
   print("running in interactive mode")
   dataDir = os.path.join( os.environ['FG_ANALYSIS_DIR'] , "data" )
   resultDir = os.path.join(os.environ['FG_ANALYSIS_DIR'], "result")
else:
   print("running in swarm mode")
   dataDir = "/data"
   resultDir = "/result"



analysisInfoJson = "analysis_info.json"
assert os.path.isfile(analysisInfoJson), "the 'analysis_info.json' is missing"

with open(analysisInfoJson, 'r') as f:
   analysisInfo = json.load(f)

domainProps = analysisInfo.get("domain")
minimum_peers = domainProps.get("minimum-peers", 2)
max_epochs = domainProps.get("convergence-max-epochs", 2)
sync_interval = domainProps.get("convergence-sync-interval", 100)
batch_size = domainProps.get("convergence-batch-size", 32)

print("")
print(f"data directory   = {dataDir}")
print(f"result directory = {resultDir}")
print(f"minimum_peers    = {minimum_peers}")
print(f"max_epochs       = {max_epochs}")
print(f"sync_interval    = {sync_interval}")
print(f"batch_size       = {batch_size}")

In [None]:
import tensorflow as tf
import numpy as np
import time
import datetime
import os


In [None]:
if not is_in_interactive_mode:
    from swarmlearning.tf import SwarmCallback


In [None]:
def load_data(dataDir):
    """Loads the MNIST dataset.
    # Arguments
        dataDir: path where to find the mnist.npz file
    # Returns
        Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
    """
    path = os.path.join(dataDir,'mnist.npz') 

    with np.load(path, allow_pickle=True) as f:
        x_train, y_train = f['x_train'], f['y_train']
        x_test, y_test = f['x_test'], f['y_test']
    return (x_train, y_train), (x_test, y_test)


def main():
  model_name = 'mnist_tf'

  (x_train, y_train),(x_test, y_test) = load_data(dataDir)
  x_train, x_test = x_train / 255.0, x_test / 255.0

  model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
  ])

  model.compile(optimizer='adam',
                loss='sparse_categorical_crossentropy',
                metrics=['accuracy'])


  
  if not is_in_interactive_mode:
    # Create Swarm callback
    swarmCallback = SwarmCallback(sync_interval=sync_interval,
                                    min_peers=minimum_peers,
                                    val_data=(x_test, y_test),
                                    val_batch_size=batch_size,
                                    model_name=model_name)
    callbacks = [swarmCallback]
  else:
    callbacks = []

  model.fit(x_train, y_train, 
            batch_size=batch_size,
            epochs=max_epochs,
            verbose=1,            
            callbacks=callbacks)

  # Save model and weights
  model_path = os.path.join(resultDir, model_name)
  model.save(model_path)
  print('Saved the trained model!')



In [None]:
main()