Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

# Distributed Tensorflow 2 with MultiWorkerMirroredStrategy
In this tutorial, you will train a PyTorch model on the [CIFAR10](http://www.cs.toronto.edu/~kriz/cifar.html) dataset using distributed training with Tensorflow 2 `MultiWorkerMirroredStrategy` module across a Azure Stack Hub CPU Kubernetes cluster.

## Prerequisite

*     A Kubernetes cluster deployed on Azure Stack Hub, connected to Azure through ARC.
     
   For details on how to deploy kubernetes cluster on Azure Stack Hub and enabling ARC connection to Azure, please follow [this guide](https://github.com/Azure/AML-Kubernetes/blob/master/docs/ASH/AML-ARC-Compute.md)
  

*     Datastore setup in Azure Machine Learning workspace backed up by Azure Stack Hub storage account.

   [This document](https://github.com/Azure/AML-Kubernetes/blob/master/docs/ASH/Train-AzureArc.md) is a detailed guide on how to create Azure Machine Learning workspace, create a  Azure Stack Hub Storage account, and setup datastore in AML workspace backed by ASH storage account.


*      Last but not least, you need to be able to run a Notebook. 

   If you are using an Azure Machine Learning Notebook VM, you are all set. Otherwise, make sure you go through the configuration Notebook located at [here](https://github.com/Azure/MachineLearningNotebooks) first if you haven't. This sets you up with a working config file that has information on your workspace, subscription id, etc.

In [3]:
from azureml.core import Dataset, Environment, Experiment, Workspace
import os
import requests

## Initialize workspace

Initialize a [Workspace](https://docs.microsoft.com/azure/machine-learning/service/concept-azure-machine-learning-architecture#workspace) object from the existing workspace you created in the Prerequisites step. `Workspace.from_config()` creates a workspace object from the details stored in `config.json`. 

If you haven't done already please go to `config.json` file and fill in your workspace information.

In [4]:
ws = Workspace.from_config()
print('Workspace name: ' + ws.name, 
      'Azure region: ' + ws.location, 
      'Subscription id: ' + ws.subscription_id, 
      'Resource group: ' + ws.resource_group, sep='\n')

If you run your code in unattended mode, i.e., where you can't give a user input, then we recommend to use ServicePrincipalAuthentication or MsiAuthentication.
Please refer to aka.ms/aml-notebook-auth for different authentication mechanisms in azureml-sdk.


Workspace name: sl-ash2-mal
Azure region: eastus
Subscription id: 6b736da6-3246-44dd-a0b8-b5e95484633d
Resource group: sl-ash2


## Prepare dataset

You may download cifar10 dataset from [cifar10-data](https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz). Create folder "cifar10-data" under working directory of this notebook, then  copy "cifar-10-python.tar.gz" to folder "cifar10-data". The following cell will upload "cifar-10-python.tar.gz" to datastore of the workspace, and finally registered as dataset in the workspace. 

Upload and dataset registration take about 3 mins.

To set up datastore using an azure stack hub storage account, please refer to [Train_azure_arc](https://github.com/Azure/AML-Kubernetes/blob/master/docs/ASH/Train-AzureArc.md). To register the dataset manually, please refer to this [video](https://msit.microsoftstream.com/video/51f7a3ff-0400-b9eb-2703-f1eb38bc6232)


In [5]:
import os
from azureml.core import Datastore, Dataset

dataset_name = 'cifar10'
datastore_name = "ashstore"

if dataset_name not  in ws.datasets:
    datastore =  Datastore.get(ws, datastore_name)
    
    src_dir, target_path = 'cifar10-data', 'cifar10-data-ash' #assuming cifar-10-python.tar.gz is in folder cifar10-data
    
    # upload data from local to AML datastore:
    datastore.upload(src_dir, target_path)

    # register data uploaded as AML dataset:
    datastore_paths = [(datastore, target_path)]
    cifar_ds = Dataset.File.from_files(path=datastore_paths)
    cifar_ds.register(ws, dataset_name, "CIFAR-10 images from https://www.cs.toronto.edu/~kriz/cifar.html")
        
dataset_ash = ws.datasets[dataset_name]

## Create or attach existing ArcKubernetesCompute

The attaching code here depends  python package azureml-contrib-k8s which current is in private preview. Install private preview branch of AzureML SDK by running following command (private preview):

<pre>
pip install --disable-pip-version-check --extra-index-url https://azuremlsdktestpypi.azureedge.net/azureml-contrib-k8s-preview/D58E86006C65 azureml-contrib-k8s
</pre>


Attaching ASH cluster the first time may take 7 minutes. It will be much faster after first attachment.

In [6]:
from azureml.contrib.core.compute.arckubernetescompute import ArcKubernetesCompute
from azureml.core import ComputeTarget

resource_id = "/subscriptions/6b736da6-3246-44dd-a0b8-b5e95484633d/resourceGroups/sl-ash2/providers/Microsoft.Kubernetes/connectedClusters/sl-d2-o-arc"

attach_config = ArcKubernetesCompute.attach_configuration(
    resource_id= resource_id,
)

try:
    attach_name = "sl-d2-o-arc"
    arcK_target_result = ArcKubernetesCompute.attach(ws, attach_name, attach_config)
    arcK_target_result.wait_for_completion(show_output=True)
    print('arc attach  success')
except ComputeTargetException as e:
    print(e)
    print('arc attach  failed')

attach_name = "d13"
arcK_target = ComputeTarget(ws, attach_name)

SucceededProvisioning operation finished, operation "Succeeded"
arc attach  success


## Configure the training job and Submit a run

### Create an experiement

In [7]:
experiment_name = 'dist-tf2-on-aks-arc'
experiment = Experiment(workspace=ws, name=experiment_name)

### Create an environment

In [8]:
env = Environment.from_dockerfile(
    name='tf_2.4',
    dockerfile='tf-script/Dockerfile.gpu',
    conda_specification='tf-script/tf-24-env.yaml')

### Configure the training job

Use TensorflowConfiguration to set number of worker and number of parameter server to use.
With worker_count= 3, training for one epoch may take 21 mins with vm size comparable to Standard_DS3_v2


In [10]:
from azureml.core import ScriptRunConfig, Run
from azureml.core.runconfig import TensorflowConfiguration

worker_count= 3
src = ScriptRunConfig(source_directory='tf-script',
                      script='train.py',
                      arguments=[
                          '--dataset-path', dataset_ash.as_mount(),
                          '--epochs', 2,#80
                          '--global-batch-size', 256,
                          '--batches-per-epoch', 256,
                          '--alpha-init', 0.005,
                      ],
                      compute_target=arcK_target,
                      environment=env,
                      distributed_job_config=TensorflowConfiguration(worker_count=worker_count, parameter_server_count=1))#configuring AML TF config

rs_config = src.run_config.amlk8scompute.resource_configuration
rs_config.gpu_count = 0
rs_config.cpu_count = worker_count - rs_config.gpu_count
rs_config.memory_request_in_gb = 6



### Submit job
Run your experiment by submitting your ScriptRunConfig object. Note that this call is asynchronous.

In [11]:
run = experiment.submit(config=src)
run.wait_for_completion(show_output=True) # this provides a verbose log

RunId: dist-tf2-on-aks-arc_1613702443_c5157ce0
Web View: https://ml.azure.com/experiments/dist-tf2-on-aks-arc/runs/dist-tf2-on-aks-arc_1613702443_c5157ce0?wsid=/subscriptions/6b736da6-3246-44dd-a0b8-b5e95484633d/resourcegroups/sl-ash2/workspaces/sl-ash2-mal

Streaming azureml-logs/55_azureml-execution-tvmps_16073b0ab0f5023d82c74e2a2fcec4a47b34a8622966f122479d5082ad253048_d.txt

2021-02-19T02:41:04Z Starting output-watcher...
2021-02-19T02:41:04Z IsDedicatedCompute == True, won't poll for Low Pri Preemption
Login Succeeded
Using default tag: latest
latest: Pulling from azureml/azureml_2a916f4728727b3b5b18dbee7e16a3e5
Digest: sha256:68ffd55d5df309a0bda6ab9bb4fe2c3a1739a44786f2902f8f3bd785a61b3e4c
Status: Image is up to date for viennaglobal.azurecr.io/azureml/azureml_2a916f4728727b3b5b18dbee7e16a3e5:latest
viennaglobal.azurecr.io/azureml/azureml_2a916f4728727b3b5b18dbee7e16a3e5:latest
2021-02-19T02:41:06Z Check if container dist-tf2-on-aks-arc_1613702443_c5157ce0 already exist exited wit

batch end 96
batch begin 97
batch end 97
batch begin 98
batch end 98
batch begin 99
batch end 99
batch begin 100
batch end 100
batch begin 101
batch end 101
batch begin 102
batch end 102
batch begin 103
batch end 103
batch begin 104
batch end 104
batch begin 105
batch end 105
batch begin 106
batch end 106
batch begin 107
batch end 107
batch begin 108
batch end 108
batch begin 109
batch end 109
batch begin 110
batch end 110
batch begin 111
batch end 111
batch begin 112
batch end 112
batch begin 113
batch end 113
batch begin 114
batch end 114
batch begin 115
batch end 115
batch begin 116
batch end 116
batch begin 117
batch end 117
batch begin 118
batch end 118
batch begin 119
batch end 119
batch begin 120
batch end 120
batch begin 121
batch end 121
batch begin 122
batch end 122
batch begin 123
batch end 123
batch begin 124
batch end 124
batch begin 125
batch end 125
batch begin 126
batch end 126
batch begin 127
batch end 127
batch begin 128
batch end 128
batch begin 129
batch end 129
bat

batch end 186
batch begin 187
batch end 187
batch begin 188
batch end 188
batch begin 189
batch end 189
batch begin 190
batch end 190
batch begin 191
batch end 191
batch begin 192
batch end 192
batch begin 193
batch end 193
batch begin 194
batch end 194
batch begin 195
batch end 195
batch begin 196
batch end 196
batch begin 197
batch end 197
batch begin 198
batch end 198
batch begin 199
batch end 199
batch begin 200
batch end 200
batch begin 201
batch end 201
batch begin 202
batch end 202
batch begin 203
batch end 203
batch begin 204
batch end 204
batch begin 205
batch end 205
batch begin 206
batch end 206
batch begin 207
batch end 207
batch begin 208
batch end 208
batch begin 209
batch end 209
batch begin 210
batch end 210
batch begin 211
batch end 211
batch begin 212
batch end 212
batch begin 213
batch end 213
batch begin 214
batch end 214
batch begin 215
batch end 215
batch begin 216
batch end 216
batch begin 217
batch end 217
batch begin 218
batch end 218
batch begin 219
batch end 

{'runId': 'dist-tf2-on-aks-arc_1613702443_c5157ce0',
 'target': 'd13',
 'status': 'Completed',
 'startTimeUtc': '2021-02-19T02:40:59.963486Z',
 'endTimeUtc': '2021-02-19T02:56:56.49202Z',
 'properties': {'_azureml.ComputeTargetType': 'amlcompute',
  'ContentSnapshotId': '569d190c-18bd-463f-9c25-4a366de60d75',
  'azureml.git.repository_uri': 'git@github.com:lisongshan007/AML-Kubernetes.git',
  'mlflow.source.git.repoURL': 'git@github.com:lisongshan007/AML-Kubernetes.git',
  'azureml.git.branch': 'master',
  'mlflow.source.git.branch': 'master',
  'azureml.git.commit': '9af78d232734dace4df94ed28a5e6888ed3905e4',
  'mlflow.source.git.commit': '9af78d232734dace4df94ed28a5e6888ed3905e4',
  'azureml.git.dirty': 'False',
  'ProcessInfoFile': 'azureml-logs/process_info.json',
  'ProcessStatusFile': 'azureml-logs/process_status.json'},
 'inputDatasets': [{'dataset': {'id': '24d86baf-41f9-40e1-ae0d-5ae8b4c9cdfc'}, 'consumptionDetails': {'type': 'RunInput', 'inputName': 'input__358e1d54', 'mechan

### Register the model

In [12]:
#  the model is saved at path "outputs/001"
registered_model_name = 'cifar10tf'
model = run.register_model(model_name=registered_model_name, model_path='outputs/001')

The machine learning model named "cifar10tf" should be registered in your AML workspace.

## Test Registered Model

To test the trained model, you can create (or use existing) a AKS cluster for serving the model using AML deployment

In [13]:
from azureml.core import Environment, Workspace, Model, ComputeTarget
from azureml.core.compute import AksCompute
from azureml.core.model import InferenceConfig
from azureml.core.webservice import Webservice, AksWebservice
from azureml.core.compute_target import ComputeTargetException
import numpy as np
import json

### Provision the AKS Cluster

This is a one time setup. You can reuse this cluster for multiple deployments after it has been created. If you delete the cluster or the resource group that contains it, then you would have to recreate it. It may take 5 mins to create a new AKS cluster.

In [15]:
ws = Workspace.from_config()

# Choose a name for your AKS cluster
aks_name = 'aks-service-2'

# Verify that cluster does not exist already
try:
    aks_target = ComputeTarget(workspace=ws, name=aks_name)
    is_new_compute  = False
    print('Found existing cluster, use it.')
except ComputeTargetException:
    # Use the default configuration (can also provide parameters to customize)
    prov_config = AksCompute.provisioning_configuration()

    # Create the cluster
    aks_target = ComputeTarget.create(workspace = ws, 
                                    name = aks_name, 
                                    provisioning_configuration = prov_config)
    is_new_compute  = True
    
print("using compute target: ", aks_target.name)

Found existing cluster, use it.
using compute target:  aks-service-2


### Deploy the model

In [16]:
env = Environment.from_conda_specification(name="tf_2.4", file_path="tf-script/tf-24-env.yaml")
inference_config = InferenceConfig(entry_script='score_tf.py', environment=env)
deploy_config = AksWebservice.deploy_configuration()

#registered_model_name = "cifar10tf_80"
model = ws.models[registered_model_name]
service_name = 'cifartfservice1'
service = Model.deploy(workspace=ws,
                       name=service_name,
                       models=[model],
                       inference_config=inference_config,
                       deployment_config=deploy_config,
                       deployment_target=aks_target,
                       overwrite=True)

service.wait_for_deployment(show_output=True)

Tips: You can try get_logs(): https://aka.ms/debugimage#dockerlog or local deployment: https://aka.ms/debugimage#debug-locally to debug if deployment takes longer than 10 minutes.
Running.........................................................................................
Succeeded
AKS service creation operation finished, operation "Succeeded"


### Test with inputs

For testing purpose, first five images from test batch are extracted. Let's take a look these pictures: 

A sample input has been included in cifar10_test_input.json for model testing. The outputs are probabilities  for each class. the ten labels.

In [17]:
from PIL import Image

img_files = ["test_img_0_cat.jpg", "test_img_1_ship.jpg","test_img_2_ship.jpg","test_img_3_plane.jpg","test_img_4_frog.jpg"]

for img_file in img_files:
    image = Image.open("test_imgs/{}".format(img_file))
    image.show()

After some data process, these five images are converted to json as input for the trained model. The outputs are probabilities  for each class per image.

In [18]:
with open("cifar10_test_input_tf.json", "r") as fp:
    inputs_json = json.load(fp)
inputs = json.dumps(inputs_json)
resp = service.run(inputs)
predicts = resp["predictions"]
print(predicts)

[[0.010631756857037544, 0.05487808212637901, 0.08862863481044769, 0.17179299890995026, 0.02946522831916809, 0.10520273447036743, 0.3316628038883209, 0.005262687802314758, 0.10701419413089752, 0.09546088427305222], [0.004351824056357145, 0.9802891612052917, 3.8142397897900082e-06, 1.57823469635332e-05, 5.642276391881751e-06, 6.758522204108885e-07, 2.414366463199258e-05, 6.766516889911145e-06, 0.008264408446848392, 0.007037722039967775], [0.09286026656627655, 0.3799227774143219, 0.00039583438774570823, 0.0008613724494352937, 0.0007697520195506513, 0.0001935697946464643, 0.001001340220682323, 0.0006281483802013099, 0.43021687865257263, 0.09315010905265808], [0.3268948495388031, 0.11965004354715347, 0.00022240476391743869, 0.000488616235088557, 0.0002872534969355911, 1.9691684428835288e-05, 0.00026881249505095184, 7.29254461475648e-05, 0.5433928370475769, 0.00870265532284975], [5.594985486823134e-05, 0.0007109963917173445, 0.0064894272945821285, 0.013348071835935116, 0.08071228116750717, 0

Then you can easily get the predictions of labels:

In [19]:
import numpy as np
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')
np_predicts = np.array(predicts)
pred_indexes = np.argmax(np_predicts, 1)

predict_labels = [classes[i] for i in pred_indexes]
print(predict_labels)

['frog', 'car', 'ship', 'ship', 'frog']


### Delete the newly created cluster

Note: This is important if you wish to avoid the cost of this cluster

In [20]:
if is_new_compute:
    aks_target.delete()
    service.delete()

## Next Steps

1. Learn how to [download model then upload to Azure Storage blobs](../AML-model-download-upload.ipynb)
2. Learn how to [inference using KFServing with model in Azure Storage Blobs](https://aka.ms/kfas)