# Deploy pytorch model to webservice endpoint

description: deploy pytorch CNN model trained on mnist data to AKS

In [None]:
from azureml.core import Workspace

ws = Workspace.from_config()
ws

In [None]:
import git
from pathlib import Path

# get root of git repo
prefix = Path(git.Repo(".", search_parent_directories=True).working_tree_dir)

# azure ml settings
experiment_name = "pytorch-mnist-mlproject-example"

In [None]:
import mlflow

mlflow.set_tracking_uri(ws.get_mlflow_tracking_uri())
mlflow.set_experiment(experiment_name)

In [None]:
model = None
runs = ws.experiments[experiment_name].get_runs()
run = next(runs)
while run.get_status() != "Completed" or model is None:
    run = next(runs)
    try:
        model = run.register_model(experiment_name, model_path="model")
    except:
        pass

run

In [None]:
from azureml.core.webservice import AksWebservice

aks_config = AksWebservice.deploy_configuration(
    compute_target_name="aks-cpu-deploy",
    cpu_cores=2,
    memory_gb=5,
    tags={"data": "MNIST", "method": "pytorch"},
    description="Predict using webservice",
)

In [None]:
import mlflow.azureml
from random import randint

webservice, azure_model = mlflow.azureml.deploy(
    model_uri=f"runs:/{run.id}/model",
    workspace=ws,
    deployment_config=aks_config,
    service_name="pytorch-mnist-" + str(randint(10000, 99999)),
    model_name="pytorch-mnist-example",
)

In [None]:
import pandas as pd
from random import randint

img = pd.read_csv(
    prefix.joinpath("data", "raw", "mnist", f"{randint(0, 9)}-example.csv")
)
data = {"data": elem for elem in img.to_numpy().reshape(1, 1, -1).tolist()}

In [None]:
%matplotlib inline

import json
import numpy as np
import matplotlib.pyplot as plt

response = webservice.run(json.dumps(data))
response = sorted(response[0].items(), key=lambda x: x[1], reverse=True)

print("Predicted label:", response[0][0])
plt.imshow(np.array(img).reshape(28, 28), cmap="gray")

In [None]:
webservice.delete()