# Distributed Tensorflow 2 with MultiWorkerMirroredStrategy
In this tutorial, you will train a PyTorch model on the [CIFAR10](http://www.cs.toronto.edu/~kriz/cifar.html) dataset using distributed training with Tensorflow 2 `MultiWorkerMirroredStrategy` module across a Azure Stack Hub CPU Kubernetes cluster.

## Prerequisites
* If you are using an Azure Machine Learning Notebook VM, you are all set. Otherwise, go through the [AZML-SDK-INSTALL](https://docs.microsoft.com/en-us/python/api/overview/azure/ml/install?view=azure-ml-py)  to install the Azure Machine Learning Python SDK and create an Azure ML `Workspace`.

In [None]:
from azureml.core import Dataset, Environment, Experiment, Workspace

from azureml.widgets import RunDetails
import os
import requests
import tempfile

## Initialize workspace

Initialize a [Workspace](https://docs.microsoft.com/azure/machine-learning/service/concept-azure-machine-learning-architecture#workspace) object from the existing workspace you created in the Prerequisites step. `Workspace.from_config()` creates a workspace object from the details stored in `config.json`.

In [None]:
ws = Workspace.from_config()

## Prepare dataset

Here we download cifar10 dataset from [cifar10-data](https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz). The downloaded data is then registered as dataset in a data store of the workspace. 

To set up data store using an azure stack hub storage account, please refer to [Train_azure_arc](https://microsoft-my.sharepoint.com/:w:/p/penorouzi/EZ_GNID8J35LsUnZZAU82GQBNND8oE1cQYgYmB3AGG5Qww). 

To register the dataset manually, please refer to this [video](https://msit.microsoftstream.com/video/51f7a3ff-0400-b9eb-2703-f1eb38bc6232)

In [None]:
import tempfile
import os
import requests
from azureml.core import Dataset

from azureml.exceptions import UserErrorException
dataset_name = 'CIFAR-10'
try:
    ds = Dataset.get_by_name(ws, name=dataset_name)
except UserErrorException as e:#dataset not registered

    path = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'

    with tempfile.TemporaryDirectory() as tmpdir:
        os.makedirs(os.path.join(tmpdir, 'cifar-10'))

        data = requests.get(path, allow_redirects=True).content
        with open(os.path.join(tmpdir, 'cifar-10', path.split('/')[-1]), 'wb') as f:
            f.write(data)
    
        datastore_name = "<my_data_store>"
        ds = Dataset.File.upload_directory(tmpdir, ws.datastores.get(datastore_name), overwrite=True)
        ds.register(ws, dataset_name, 'CIFAR-10 images from https://www.cs.toronto.edu/~kriz/cifar.html')
        

## Create or attach existing ArcKubernetesCompute
You will need to create a [compute target](https://docs.microsoft.com/azure/machine-learning/service/concept-azure-machine-learning-architecture#compute-target) for training your model. In this tutorial, ArcKubernetesCompute  ArcKubernetesCompute for our remote training compute resource. Make sure azureml-contrib is installed by following  [AZML-SDK-INSTALL](https://docs.microsoft.com/en-us/python/api/overview/azure/ml/install?view=azure-ml-py) 

In [None]:
from azureml.contrib.core.compute.arckubernetescompute import ArcKubernetesCompute

resource_id = "<resource_id>"
attach_config = ArcKubernetesCompute.attach_configuration(resource_id = resource_id)

attach_name = "arc_attach"
attach_result = ArcKubernetesCompute.attach(ws ,attach_name, attach_config)

attach_result.wait_for_completion(show_output=True)

print(attach_result)



### Configure the training job and Submit a run
 
 Use TensorflowConfiguration to set number of worker and number of parameter server to use.

In [None]:
from azureml.core import ScriptRunConfig, Run
from azureml.core.runconfig import TensorflowConfiguration

compute_target = ws.compute_targets[attach_name]

env = Environment.from_dockerfile(
    name='tf_2.4',
    dockerfile='./env/Dockerfile.gpu',
    conda_specification='./env/tf-24-env.yaml')

experiment_name = 'dist-tf2-on-aks-arc'
experiment = Experiment(workspace=ws, name=experiment_name)

worker_count= 3
src = ScriptRunConfig(source_directory='./scripts',
                      script='train.py',
                      arguments=[
                          '--dataset-path', ws.datasets[dataset_name].as_mount(),
                          '--epochs', 1,#80
                          '--global-batch-size', 256,
                          '--batches-per-epoch', 256,
                          '--alpha-init', 0.005,
                      ],
                      compute_target=compute_target,
                      environment=env,
                      distributed_job_config=TensorflowConfiguration(worker_count=worker_count, parameter_server_count=1))#configuring AML TF config

rs_config = src.run_config.amlk8scompute.resource_configuration
rs_config.gpu_count = 0
rs_config.cpu_count = worker_count - rs_config.gpu_count
rs_config.memory_request_in_gb = 6

run = experiment.submit(config=src)

In [None]:
run.wait_for_completion(show_output=True) # this provides a verbose log