In [8]:
import onnxruntime
import torch

In [9]:
from utils.constants import SAVED_MODELS_DIR, EXAMPLE_IMAGES_DIR, ENCODER_MODEL_NAME, CLASSIFIER_MODEL_NAME, EXAMPLE_IMAGES_NAMES

In [10]:
# Load the ONNX model using onnxruntime
encoder_model_path = SAVED_MODELS_DIR + ENCODER_MODEL_NAME
onnx_session = onnxruntime.InferenceSession(encoder_model_path)

In [13]:
from model.inference import preprocess_image

img_path = EXAMPLE_IMAGES_DIR + EXAMPLE_IMAGES_NAMES[3]
print(img_path)
img = preprocess_image(img_path)

# Ensure that the shape of img_tensor is correct: (1, 224, 224, 3)
# This is necessary because ONNX models expect the input in this shape
print(f"Image tensor shape: {img.shape}")

../data/examples/example_4.jpg
Image tensor shape: (1, 224, 224, 3)


In [14]:
# Run inference using ONNX runtime
# Input name is usually 'input', but it can vary; check the model to confirm.
input_name = onnx_session.get_inputs()[0].name
output_name = onnx_session.get_outputs()[0].name  # Typically the first output is the embedding
print(input_name)
print(output_name)

args_0
dense


In [15]:
# Run the model (get embeddings)
onnx_output = onnx_session.run([output_name], {input_name: img})

# The ONNX model's output (embedding)
embedding = onnx_output[0]  # This will be a numpy array

# Convert the embedding to a PyTorch tensor for further use
embedding_tensor = torch.tensor(embedding).float()

# Now you can use `embedding_tensor` as input to your downstream PyTorch model
print("Embedding shape:", embedding_tensor.shape)

Embedding shape: torch.Size([1, 256])


In [16]:
from model.model_definition import CurrencyClassifier
from utils.helpers import get_currency_from_label

classifier_model_path = SAVED_MODELS_DIR + CLASSIFIER_MODEL_NAME

model = CurrencyClassifier(256, 17)
model.load_state_dict(torch.load(classifier_model_path))
model.eval()

with torch.no_grad():
    output = model(embedding_tensor)
    predicted_class = torch.argmax(output, dim=1)

pred_label = predicted_class.item().__int__()

pred = get_currency_from_label(pred_label)

print("Predicted class:", pred)

Predicted class: USD
