Copyright (c) Microsoft Corporation. All rights reserved.

Licensed under the MIT License.

# Distributed Tensorflow 2 with MultiWorkerMirroredStrategy
In this tutorial, you will train a PyTorch model on the [CIFAR10](http://www.cs.toronto.edu/~kriz/cifar.html) dataset using distributed training with Tensorflow 2 `MultiWorkerMirroredStrategy` module across a Azure Stack Hub CPU Kubernetes cluster.

## Prerequisites

Please see Prerequisite part of [this notebook](../pipeline/nyc-taxi-data-regression-model-building.ipynb)

In [None]:
from azureml.core import Dataset, Environment, Experiment, Workspace

from azureml.widgets import RunDetails
import os
import requests
import tempfile

## Initialize workspace

Initialize a [Workspace](https://docs.microsoft.com/azure/machine-learning/service/concept-azure-machine-learning-architecture#workspace) object from the existing workspace you created in the Prerequisites step. `Workspace.from_config()` creates a workspace object from the details stored in `config.json`. 

If you haven't done already please go to `config.json` file and fill in your workspace information.

In [None]:
ws = Workspace.from_config()
print('Workspace name: ' + ws.name, 
      'Azure region: ' + ws.location, 
      'Subscription id: ' + ws.subscription_id, 
      'Resource group: ' + ws.resource_group, sep='\n')

## Prepare dataset

Here we download cifar10 dataset from [cifar10-data](https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz). The downloaded data is then registered as dataset in a data store of the workspace. 

To set up datastore using an azure stack hub storage account, please refer to [Train_azure_arc](https://github.com/Azure/AML-Kubernetes/blob/master/docs/ASH/Train-AzureArc.md#create-and-configure-azure-stack-hubs-storage-account). 

To register the dataset manually, please refer to this [video](https://msit.microsoftstream.com/video/51f7a3ff-0400-b9eb-2703-f1eb38bc6232)


In [None]:
import tempfile
import os
import requests
from azureml.core import Dataset

from azureml.exceptions import UserErrorException
dataset_name = 'CIFAR-10'
try:
    ds = Dataset.get_by_name(ws, name=dataset_name)
except UserErrorException as e:#dataset not registered

    path = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'

    with tempfile.TemporaryDirectory() as tmpdir:
        os.makedirs(os.path.join(tmpdir, 'cifar-10'))

        data = requests.get(path, allow_redirects=True).content
        with open(os.path.join(tmpdir, 'cifar-10', path.split('/')[-1]), 'wb') as f:
            f.write(data)
    
        datastore_name = "<my_data_store>"
        ds = Dataset.File.upload_directory(tmpdir, ws.datastores.get(datastore_name), overwrite=True)
        ds.register(ws, dataset_name, 'CIFAR-10 images from https://www.cs.toronto.edu/~kriz/cifar.html')
        

## Create or attach existing ArcKubernetesCompute

The attaching code here depends  python package azureml-contrib-k8s which current is in private preview. Install private preview branch of AzureML SDK by running following command (private preview):

<pre>
pip install --disable-pip-version-check --extra-index-url https://azuremlsdktestpypi.azureedge.net/azureml-contrib-k8s-preview/D58E86006C65 azureml-contrib-k8s
</pre>

In [None]:
from azureml.contrib.core.compute.arckubernetescompute import ArcKubernetesCompute

resource_id = "/subscriptions/6b736da6-3246-44dd-a0b8-b5e95484633d/resourceGroups/AML-stack-val/providers/Microsoft.Kubernetes/connectedClusters/kub-orlando-Test"

attach_config = ArcKubernetesCompute.attach_configuration(
    resource_id= resource_id,
)

try:
    attach_name = "peymanarc"
    arcK_target_result = ArcKubernetesCompute.attach(ws, attach_name, attach_config)
    arcK_target_result.wait_for_completion(show_output=True)
    print('arc attach  success')
except ComputeTargetException as e:
    print(e)
    print('arc attach  failed')

attach_name = "nc6"#"ds3v2" #
compute_target = ws.compute_targets[attach_name]



### Configure the training job and Submit a run
 
 Use TensorflowConfiguration to set number of worker and number of parameter server to use.

In [None]:
from azureml.core import ScriptRunConfig, Run
from azureml.core.runconfig import TensorflowConfiguration

compute_target = ws.compute_targets[attach_name]

env = Environment.from_dockerfile(
    name='tf_2.4',
    dockerfile='./env/Dockerfile.gpu',
    conda_specification='./env/tf-24-env.yaml')

experiment_name = 'dist-tf2-on-aks-arc'
experiment = Experiment(workspace=ws, name=experiment_name)

worker_count= 2
src = ScriptRunConfig(source_directory='./scripts',
                      script='train.py',
                      arguments=[
                          '--dataset-path', ws.datasets[dataset_name].as_mount(),
                          '--epochs', 1,#80
                          '--global-batch-size', 256,
                          '--batches-per-epoch', 256,
                          '--alpha-init', 0.005,
                      ],
                      compute_target=compute_target,
                      environment=env,
                      distributed_job_config=TensorflowConfiguration(worker_count=worker_count, parameter_server_count=1))#configuring AML TF config

rs_config = src.run_config.amlk8scompute.resource_configuration
rs_config.gpu_count = 1
rs_config.cpu_count = worker_count - rs_config.gpu_count
rs_config.memory_request_in_gb = 6

run = experiment.submit(config=src)

In [None]:
run.wait_for_completion(show_output=True) # this provides a verbose log

In [None]:
#  the model is saved at path "outputs/001"
# register the model
model = run.register_model(model_name='cifar10tf', model_path='outputs/001')

The machine learning model named "cifar10tf" should be registered in your AML workspace.

## Test Registered Model

To test the trained model, you can create (or use existing) a AKS cluster for serving the model using AML deployment

In [None]:
from azureml.core import Environment, Workspace, Model, ComputeTarget
from azureml.core.compute import AksCompute
from azureml.core.model import InferenceConfig
from azureml.core.webservice import Webservice, AksWebservice
from azureml.core.compute_target import ComputeTargetException
import numpy as np
import json

In [None]:
ws = Workspace.from_config()

# Choose a name for your AKS cluster
aks_name = 'aks-service-2'

# Verify that cluster does not exist already
try:
    aks_target = ComputeTarget(workspace=ws, name=aks_name)
    is_new_compute  = False
    print('Found existing cluster, use it.')
except ComputeTargetException:
    # Use the default configuration (can also provide parameters to customize)
    prov_config = AksCompute.provisioning_configuration()

    # Create the cluster
    aks_target = ComputeTarget.create(workspace = ws,
                                    name = aks_name,
                                    provisioning_configuration = prov_config)
    is_new_compute  = True

if aks_target.get_status() != "Succeeded":
    aks_target.wait_for_completion(show_output=True)

### Deploy the model

In [None]:
env = Environment.from_conda_specification(name="tf_2.4", file_path="./env/tf-24-env.yaml")
inference_config = InferenceConfig(entry_script='score.py', environment=env)
deploy_config = AksWebservice.deploy_configuration()

model = ws.models["cifar10tf"]
service_name = 'cifartfservice1'
service = Model.deploy(workspace=ws,
                       name=service_name,
                       models=[model],
                       inference_config=inference_config,
                       deployment_config=deploy_config,
                       deployment_target=aks_target,
                       overwrite=True)

service.wait_for_deployment(show_output=True)

### Test with inputs

In [None]:
with open("cifar10_test_input.json", "r") as fp:
    inputs_json = json.load(fp)
inputs = json.dumps(inputs_json)
resp = service.run(inputs)
predicts = resp["predictions"]
predicts_str = json.dumps(predicts)

print(predicts_str)

### Delete the newly created cluster

Note: This is important if you wish to avoid the cost of this cluster

In [None]:
if is_new_compute:
    aks_target.delete()
    service.delete()

## Next Steps

1. Learn how to [download model then upload to Azure Storage blobs](../AML-model-download-upload.ipynb)
2. Learn how to [inference using KFServing with model in Azure Storage Blobs](https://aka.ms/kfas)