In [12]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # resize
    transforms.ToTensor(),          # сonvert to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize
])

def preprocess_image(image_path):
    image = Image.open(image_path)
    image_tensor = transform(image)  # transformations
    image_tensor = image_tensor.unsqueeze(0)
    return image_tensor

def load_model(onnx_model_path):
    return ort.InferenceSession(onnx_model_path)


def predict(model, image_tensor):
    # image tensor -> numpy array
    input_name = model.get_inputs()[0].name
    inputs = {input_name: image_tensor.numpy()}
    outputs = model.run(None, inputs)
    return outputs

if __name__ == "__main__":
    import torchvision.transforms as transforms
    from PIL import Image
    import numpy as np
    import onnxruntime as ort

    image_path = "viol.jfif" # THE IMAGE


    # onnx_model_path = "/content/drive/MyDrive/checkpoints/resnet50_3e4_10_secondTry2__epoch_20_accuracy_test_97.4265.onnx"
    onnx_model_path = "resnet50_3e4_10_secondTry2__epoch_20_accuracy_test_97.4265.onnx"

    # preprocess
    image_tensor = preprocess_image(image_path)

    # load the model
    model = load_model(onnx_model_path)

    # prediction
    outputs = predict(model, image_tensor)

    # raw predictions
    logits = outputs[0]
    # print("Inference logits:", logits)
    # probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
    # print("Class probabilities:", probabilities)


    predicted_class = np.argmax(logits)
    # print(f"Predicted class: {predicted_class}")

    class_labels = [  # Define your class labels
        "annual mallow", "asian virginsbower", "barbados lily", "bull thistle", "buttercup",
        "california poppies", "calla lily", "canna lily", "coltsfoot", "common columbine",
        "common cornflag", "common daisy", "common dandelion", "common primroses", "corn poppy",
        "desert rose", "fritillaries", "garden petunia", "passionflower", "peruvian lily",
        "scarlet beebalm", "sunflower", "tea roses", "tiger lily", "violets", "wallflowers",
        "water lilies"
    ]
    if predicted_class < len(class_labels):
        print(f"Predicted class label: {class_labels[predicted_class]}")
    else:
        print("Predicted class index is out of range!")

Predicted class label: violets
