# ONNXによる推論処理

In [6]:
import json
import os
from typing import List

import joblib
import numpy as np
import onnxruntime as rt
import torch
from torchvision.models.resnet import resnet50
from PIL import Image

from ml.transformers import PytorchImagePreprocessTransformer, SoftmaxTransformer

In [39]:
json_path: str = "./data/image_net_labels.json"
with open(json_path, "r") as f:
    labels = json.load(f)
    
print(len(labels))

1000


In [9]:
model_directory = "./models/"

onnx_filename = "resnet50.onnx"
onnx_filepath = os.path.join(model_directory, onnx_filename)

preprocess_filename = f"preprocess_transformer.pkl"
preprocess_filepath = os.path.join(model_directory, preprocess_filename)

postprocess_filename = f"softmax_transformer.pkl"
postprocess_filepath = os.path.join(model_directory, postprocess_filename)

print(onnx_filepath)
print(preprocess_filepath)
print(postprocess_filepath)

./models/resnet50.onnx
./models/preprocess_transformer.pkl
./models/softmax_transformer.pkl


In [10]:
preprocess = PytorchImagePreprocessTransformer()

In [12]:
preprocess

In [11]:
postprocess = SoftmaxTransformer()

In [13]:
postprocess

In [16]:
image = Image.open("./data/cat.jpg")
# image

In [18]:
np_image = preprocess.transform(image)
print(np_image.shape)

(1, 3, 224, 224)


In [22]:
sess = rt.InferenceSession(onnx_filepath)

## 推論処理の仕方(sess.runの使い方)

```outputs = sess.run(output_names, input_feed)```

- ```output_names```は，取得したい出力テンソルの名前のリスト

- ```input_feed```は，入力データを含む辞書で、キーは入力テンソルの名前、値は実際の入力データ（通常はNumPy配列）

- ONNXモデルの入力と出力はtensor型として定義されているが、ONNXランタイムが内部でPythonのデータ型とtensorの変換を行うため，推論時にはPythonのデータ型（この場合はNumPy配列）を使用すればよい

In [23]:
inp, out = sess.get_inputs()[0], sess.get_outputs()[0]
print(f"input name='{inp.name}' shape={inp.shape} type={inp.type}")
print(f"output name='{out.name}' shape={out.shape} type={out.type}")
# pred_onx = sess.run([out.name], {inp.name: np_image})

input name='input' shape=['batch_size', 3, 224, 224] type=tensor(float)
output name='output' shape=['batch_size', 1000] type=tensor(float)


In [27]:
pred_onx = sess.run([out.name], {inp.name: np_image})
pred_onx[0].shape

(1, 1000)

In [29]:
prediction = postprocess.transform(np.array(pred_onx))
prediction.shape

(1, 1000)

In [None]:
# print(prediction.shape)
print(labels[np.argmax(prediction[0])])

Siamese cat
