In [141]:
import io
import torch
import torch.onnx
import onnx
from unet import UNet
import onnxruntime
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
# from utils.dataset import BasicDataset

In [142]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
pthfile = 'ckpt.test.pth'
onnxpath = 'test_conv_pool.onnx'
print("torch.__version__:", torch.__version__)
print("onnx.__version__:", onnx.__version__)
print("onnxruntime.__version__:", onnxruntime.__version__)

torch.__version__: 1.4.0
onnx.__version__: 1.10.1
onnxruntime.__version__: 1.9.0


In [143]:
onnx_model = onnx.load(onnxpath)
check = onnx.checker.check_model(onnx_model)
print("check:", check)

check: None


In [144]:
def crop_map(h, crop_size):
    
    bs, ch, H, W = h.size()
    bs = torch.tensor(1)
    ch = torch.tensor(4)
    H = torch.tensor(481)
    W = torch.tensor(481)
    print(h.size())
    
    map_tmp = torch.zeros(1, 4, H+crop_size, W+crop_size)
    map_c1 = torch.ones(H+crop_size, W+crop_size)
    map_tmp[0][0] = map_c1
   
    for i in range(ch):
        map_tmp[0][i][crop_size//2:crop_size//2+H, crop_size//2:crop_size//2+W] = h[0][i]
    
    x_pos = torch.nonzero(map_tmp[0][3]==1)[:][:,0][0]
    y_pos = torch.nonzero(map_tmp[0][3]==1)[:][:,1][0]
    
    # print(map_tmp)
    print("x_pos:{}, y_pos:{}".format(x_pos,y_pos))
    # print(map_tmp.size())
    # print(torch.nonzero(map_tmp[0][3]==1))
    
    output = torch.randn(1, 4, crop_size, crop_size)
    
    for i in range(ch):
        if crop_size%2 == 0:
            output[0][i] = map_tmp[0][i][x_pos-crop_size//2:x_pos+crop_size//2, y_pos-crop_size//2:y_pos+crop_size//2]
        else:
            output[0][i] = map_tmp[0][i][x_pos-crop_size//2:x_pos+crop_size//2+1, y_pos-crop_size//2:y_pos+crop_size//2+1]
            
    return output

class Flatten(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.view(x.shape[0], -1)

class Global_Actor(nn.Module):
    def __init__(self, G):
        super().__init__()
        self.G = G
        self.actor = nn.Sequential(  # (8, G, G)
            nn.Conv2d(8, 8, 3, padding=1),  # (8, G, G)
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.Conv2d(8, 4, 3, padding=1),  # (4, G, G)
            nn.BatchNorm2d(4),
            nn.ReLU(),
            nn.Conv2d(4, 4, 5, padding=2),  # (4, G, G)
            nn.BatchNorm2d(4),
            nn.ReLU(),
            nn.Conv2d(4, 2, 5, padding=2),  # (2, G, G)
            nn.BatchNorm2d(2),
            nn.ReLU(),
            nn.Conv2d(2, 1, 5, padding=2),  # (1, G, G)
            # nn.BatchNorm2d(1),
            Flatten(),  # (G*G, )
            # nn.Sigmoid(),  # added for non-negative
        )
        
        
    def _get_h12(self, inputs): # inputs needs to be a tensor, i.e., original inputs["map_at_t"], (bs, 4, M, M), channel 3 means one-hot pose, channel 0~1 means global map
        # x = inputs["pose_in_map_at_t"]  # (bs,2)
        # map_at_t (4, m, m)
        # x = torch.nonzero(inputs[0][3]==1)
        h = inputs

        h_1 = crop_map(h, self.G)
        print(h_1.size())
        print(torch.nonzero(h_1[0][3]==1))
        h_2 = F.max_pool2d(h, (2, 2))
# adaptive_max_pool2d
        h_12 = torch.cat([h_1, h_2], dim=1)

        return h_12

    def forward(self, inputs):
        x1 = self._get_h12(inputs)
        x2 = self.actor(x1)
        return x2

In [145]:
inputs = torch.randn(1, 4, 481, 481)

map3 = torch.zeros(1, 481, 481)
map3[0][20][20] = 1.
print(map3)
print(torch.nonzero(map3==1))
inputs[0][3] = map3
print(torch.nonzero(inputs[0][3]==1))

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])
tensor([[ 0, 20, 20]])
tensor([[20, 20]])


In [146]:
# eva ONNX model and torch model with calc same value
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

In [147]:
rl_net = Global_Actor(240)
rl_net.actor.load_state_dict(torch.load(pthfile, map_location='cpu'),strict=False)
rl_net.eval()

Global_Actor(
  (actor): Sequential(
    (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv2d(4, 4, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (7): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Conv2d(4, 2, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (10): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
    (12): Conv2d(2, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (13): Flatten()
  )
)

In [148]:
x = inputs
print(x.size())
torch_out = rl_net(x)
print("pth model output:", torch_out)
print(torch_out.shape)

torch.Size([1, 4, 481, 481])
torch.Size([1, 4, 481, 481])
x_pos:140, y_pos:140
torch.Size([1, 4, 240, 240])
tensor([[120, 120]])
pth model output: tensor([[-0.0282, -0.0373, -0.0381,  ..., -0.0281, -0.0339, -0.0382]],
       grad_fn=<ViewBackward>)
torch.Size([1, 57600])


In [150]:
ort_session = onnxruntime.InferenceSession(onnxpath)

# compute ONNX Runtime output prediction
ort_outs = ort_session.run(None, {ort_session.get_inputs()[0].name: x.cpu().numpy().astype(np.float32)})

# compare ONNX Runtime and PyTorch results

print("onnx model output:", ort_outs)
print('tor_out: ', torch_out.shape)

# np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
# print("Exported model has been tested with ONNXRuntime, and the result looks good!")

onnx model output: [array([[ 0.00421813,  0.0104795 ,  0.02085524, ..., -0.00355027,
        -0.01217862, -0.01496383]], dtype=float32)]
tor_out:  torch.Size([1, 57600])
