In [2]:
import onnxruntime as ort
model_path = "CSOL-Utilities-resnet18-800x600.onnx"
session = ort.InferenceSession(model_path)

In [3]:
from torchvision import transforms
width = 800
height = 600
transform = transforms.Compose([
    transforms.Pad(padding=20, fill=0, padding_mode='constant'),
    transforms.Resize((height, width)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [4]:
from PIL import Image
import numpy as np
import torch

def predict(image_path):
    image = Image.open(image_path).convert("RGB")
    input_tensor = transform(image).unsqueeze(0)
    input_numpy = input_tensor.numpy()
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    outputs = session.run([output_name], {input_name: input_numpy})
    output_tensor = torch.tensor(outputs[0])
    labels = [
        "大厅界面",
        "等候室界面",
        "加载界面",
        "场景界面",
        "结算界面"
    ]
    # softmax
    output_tensor = torch.nn.functional.softmax(output_tensor, dim=1)
    for i, label in enumerate(labels):
        print(f"{label}: {output_tensor[0, i].item():.4f}")
    predicted_index = torch.argmax(output_tensor, dim=1).item()
    predicted_value = output_tensor[0, predicted_index].item()
    if (predicted_value < 0.95):
        print("Result: 未知")
    else:
        print(f"Result: {labels[int(predicted_index)]}")

In [5]:
p = "C:/Users/Silver/BaiduSyncdisk/Pictures/Screen Shots/2025/2025-09-08_21-48-45_497.png"
predict(p)

大厅界面: 0.0001
等候室界面: 0.0000
加载界面: 0.9999
场景界面: 0.0000
结算界面: 0.0000
Result: 加载界面
