In [None]:
import mlflow
import mlflow.azureml

import azureml.core
from azureml.core import Workspace


print("SDK version:", azureml.core.VERSION)
print("MLflow version:", mlflow.version.VERSION)

In [None]:
ws = Workspace.from_config()
ws.get_details()

In [None]:
# azure ml settings
compute_target = 'gpu-cluster'

In [None]:
mlflow.set_tracking_uri(ws.get_mlflow_tracking_uri())

In [None]:
experiment_name = "pytorch-mlflow-projects"
mlflow.set_experiment(experiment_name)

In [None]:
backend_config = {"COMPUTE": "gpu-cluster", "USE_CONDA": False}

In [None]:
run = mlflow.projects.run(uri=".", 
                            backend = "azureml",
                            backend_config = backend_config)

## Deploy model

In [None]:
from azureml.core.webservice import AciWebservice, Webservice

model_path = "model"

aci_config = AciWebservice.deploy_configuration(cpu_cores=2, 
                                                memory_gb=5, 
                                                tags={"data": "MNIST",  "method" : "pytorch"}, 
                                                description="Predict using webservice")

webservice, azure_model = mlflow.azureml.deploy(model_uri='runs:/{}/{}'.format(run.run_id, model_path),
                                                      workspace=ws,
                                                      deployment_config=aci_config,
                                                      service_name="pytorch-mnist",
                                                      model_name="pytorch_mnist")

In [None]:
from torchvision import datasets, transforms
import random
import numpy as np

test_data = datasets.MNIST('../../data', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))]), download=True)


def get_random_image():
    image_idx = random.randint(0,len(test_data))
    image_as_tensor = test_data[image_idx][0]
    return {"data": elem for elem in image_as_tensor.numpy().reshape(1,1,-1).tolist()}

In [None]:
%matplotlib inline

import json
import matplotlib.pyplot as plt

test_image = get_random_image()

response = webservice.run(json.dumps(test_image))

response = sorted(response[0].items(), key = lambda x: x[1], reverse = True)


print("Predicted label:", response[0][0])
plt.imshow(np.array(test_image["data"]).reshape(28,28), cmap = "gray")

## Clean up
You can delete the ACI deployment with a delete API call.

In [None]:
webservice.delete()