In [26]:
import io
import torch
import torch.onnx
import onnx
from unet import UNet
import onnxruntime
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
# from utils.dataset import BasicDataset

In [27]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
onnxpath = 'test_classsimple.onnx'
print("torch.__version__:", torch.__version__)
print("onnx.__version__:", onnx.__version__)
print("onnxruntime.__version__:", onnxruntime.__version__)

torch.__version__: 1.4.0
onnx.__version__: 1.10.1
onnxruntime.__version__: 1.9.0


In [28]:
onnx_model = onnx.load(onnxpath)
check = onnx.checker.check_model(onnx_model)
print("check:", check)

check: None


In [29]:
input = torch.ones(1,1,481,481)
print(type(onnx_model))

<class 'onnx.onnx_ml_pb2.ModelProto'>


In [30]:
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

In [31]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    # num_flat_features：计算张量x的总特征量（把每个数字都看出是一个特征，即特征总量）
    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension(批大小维度)
        num_features = 1
        for s in size:
            num_features *= s
        return num_features
net = Net()
net.load_state_dict(torch.load('net_params.pth'))
net.eval()

Net(
  (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [32]:
ort_session = onnxruntime.InferenceSession(onnxpath)
ort_x = torch.ones(1, 4, 481, 481, requires_grad=True)
print(ort_x)

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(ort_x)}
ort_outs = ort_session.run(None, ort_inputs)

torch_out = net(ort_x)
print(torch_out)
print(ort_outs)
# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")

2021-11-29 15:59:46.275300463 [W:onnxruntime:, graph.cc:3391 CleanUnusedInitializers] Removing initializer 'actor.1.num_batches_tracked'. It is not used by any node and should be removed from the model.
2021-11-29 15:59:46.275604459 [W:onnxruntime:, graph.cc:3391 CleanUnusedInitializers] Removing initializer 'actor.10.num_batches_tracked'. It is not used by any node and should be removed from the model.
2021-11-29 15:59:46.275613407 [W:onnxruntime:, graph.cc:3391 CleanUnusedInitializers] Removing initializer 'actor.7.num_batches_tracked'. It is not used by any node and should be removed from the model.
2021-11-29 15:59:46.275617494 [W:onnxruntime:, graph.cc:3391 CleanUnusedInitializers] Removing initializer 'actor.4.num_batches_tracked'. It is not used by any node and should be removed from the model.


tensor([[[[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.]],

         [[1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          [1., 1., 1.,  ..., 1., 1., 1.],
          ...,
          [1., 1., 1.,  ..., 1., 1., 1.],
          

RuntimeError: Given groups=1, weight of size 6 1 5 5, expected input[1, 4, 481, 481] to have 1 channels, but got 4 channels instead