In [None]:
from azureml.core import Workspace

ws = Workspace.from_config()
ws

In [None]:
import git
from pathlib import Path

# get root of git repo
prefix = Path(git.Repo(".", search_parent_directories=True).working_tree_dir)

# project settings
project_uri = prefix.joinpath("mlprojects", "pytorch-mnist")

# azure ml settings
experiment_name = "pytorch-mnist-mlproject-example"
compute_target = "gpu-cluster"

In [None]:
import mlflow

mlflow.set_tracking_uri(ws.get_mlflow_tracking_uri())
mlflow.set_experiment(experiment_name)

In [None]:
backend_config = {"COMPUTE": compute_target, "USE_CONDA": False}

In [None]:
run = mlflow.projects.run(
    uri=str(project_uri), backend="azureml", backend_config=backend_config
)

In [None]:
run

In [None]:
from azureml.core.webservice import AksWebservice

aks_config = AksWebservice.deploy_configuration(
    compute_target_name="aks-cpu-deploy",
    cpu_cores=2,
    memory_gb=5,
    tags={"data": "MNIST", "method": "pytorch"},
    description="Predict using webservice",
)

In [None]:
import mlflow.azureml

webservice, azure_model = mlflow.azureml.deploy(
    model_uri=f"runs:/{run.run_id}/model",
    workspace=ws,
    deployment_config=aks_config,
    service_name="pytorch-mnist-example",
    model_name="pytorch-mnist-example",
)

In [None]:
from torchvision import datasets, transforms
import random
import numpy as np

test_data = datasets.MNIST(
    prefix + "/data/tmp",
    train=False,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
    ),
    download=True,
)


def get_random_image():
    image_idx = random.randint(0, len(test_data))
    image_as_tensor = test_data[image_idx][0]
    return {
        "data": elem
        for elem in image_as_tensor.numpy().reshape(1, 1, -1).tolist()
    }

In [None]:
%matplotlib inline

import json
import matplotlib.pyplot as plt

test_image = get_random_image()

response = webservice.run(json.dumps(test_image))
response = sorted(response[0].items(), key=lambda x: x[1], reverse=True)

print("Predicted label:", response[0][0])
plt.imshow(np.array(test_image["data"]).reshape(28, 28), cmap="gray")

In [None]:
webservice.delete()