In [1]:
import wget

In [8]:
url = 'https://github.com/DataTalksClub/machine-learning-zoomcamp/releases/download/dl-models/clothing_classifier_mobilenet_v2_latest.onnx'
wget.download(url)

100% [..........................................................................] 9814018 / 9814018

'clothing_classifier_mobilenet_v2_latest.onnx'

In [4]:
import numpy as np

from keras_image_helper import create_preprocessor

In [5]:
import numpy as np
def preprocess_pytorch(X):
    # X: shape (1, 299, 299, 3), dtype=float32, values in [0, 255]
    X = X / 255.0

    mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1)
    std = np.array([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1)

    # Convert NHWC → NCHW
    # from (batch, height, width, channels) → (batch, channels, height, width)
    X = X.transpose(0, 3, 1, 2)

    # Normalize
    X = (X - mean) / std

    return X.astype(np.float32)


preprocessor = create_preprocessor(preprocess_pytorch, target_size=(224, 224))

In [6]:
url = 'http://bit.ly/mlbookcamp-pants'
X = preprocessor.from_url(url)

In [9]:
import onnxruntime as ort

onnx_model_path = "clothing_classifier_mobilenet_v2_latest.onnx"
session = ort.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])

In [10]:
inputs = session.get_inputs()
outputs = session.get_outputs()

input_name = inputs[0].name
output_name = outputs[0].name

In [11]:
result = session.run([output_name], {input_name: X})

In [12]:
result

[array([[ 0.14996365, -1.7628496 , -3.336878  , -1.7682592 ,  5.3368697 ,
         -1.0429132 , -0.30727357,  0.7430814 , -1.7061669 , -3.9114783 ]],
       dtype=float32)]

In [13]:
predictions = result[0][0].tolist()

classes = [
    'dress',
    'hat',
    'longsleeve',
    'outwear',
    'pants',
    'shirt',
    'shoes',
    'shorts',
    'skirt',
    't-shirt'
]

dict(zip(classes, predictions))


{'dress': 0.1499636471271515,
 'hat': -1.7628495693206787,
 'longsleeve': -3.3368780612945557,
 'outwear': -1.7682591676712036,
 'pants': 5.336869716644287,
 'shirt': -1.0429131984710693,
 'shoes': -0.3072735667228699,
 'shorts': 0.7430813908576965,
 'skirt': -1.7061668634414673,
 't-shirt': -3.911478281021118}