In [97]:
import io
import torch
import torch.onnx
import onnx
from unet import UNet
import onnxruntime
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
# from utils.dataset import BasicDataset

In [98]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
pthfile = 'ckpt.test.pth'
onnxpath = 'test_conv_pool.onnx'
print("torch.__version__:", torch.__version__)
print("onnx.__version__:", onnx.__version__)
print("onnxruntime.__version__:", onnxruntime.__version__)

torch.__version__: 1.4.0
onnx.__version__: 1.10.1
onnxruntime.__version__: 1.9.0


In [99]:
onnx_model = onnx.load(onnxpath)
check = onnx.checker.check_model(onnx_model)
print("check:", check)

check: None


In [105]:
def crop_map(h, x, crop_size, mode="bilinear"):
    """
    Crops a tensor h centered around location x with size crop_size

    Inputs:
        h - (bs, F, H, W)
        x - (bs, 2) --- (x, y) locations
        crop_size - scalar integer

    Conventions for x:
        The origin is at the top-left, X is rightward, and Y is downward.
    """

    bs, _, H, W = h.size()
    Hby2 = (H - 1) / 2 if H % 2 == 1 else H // 2
    Wby2 = (W - 1) / 2 if W % 2 == 1 else W // 2
    start = -(crop_size - 1) / 2 if crop_size % 2 == 1 else -(crop_size // 2)
    end = start + crop_size - 1
    x_grid = (
        torch.arange(start, end + 1, step=1)
        .unsqueeze(0)
        .expand(crop_size, -1)
        .contiguous()
        .float()
    )
    y_grid = (
        torch.arange(start, end + 1, step=1)
        .unsqueeze(1)
        .expand(-1, crop_size)
        .contiguous()
        .float()
    )
    center_grid = torch.stack([x_grid, y_grid], dim=2).to(
        h.device
    )  # (crop_size, crop_size, 2)

    x_pos = x[:, 0] - Wby2  # (bs, )
    y_pos = x[:, 1] - Hby2  # (bs, )

    crop_grid = center_grid.unsqueeze(0).expand(
        bs, -1, -1, -1
    )  # (bs, crop_size, crop_size, 2)
    crop_grid = crop_grid.contiguous()

    # Convert the grid to (-1, 1) range
    crop_grid[:, :, :, 0] = (
        crop_grid[:, :, :, 0] + x_pos.unsqueeze(1).unsqueeze(2)
    ) / Wby2
    crop_grid[:, :, :, 1] = (
        crop_grid[:, :, :, 1] + y_pos.unsqueeze(1).unsqueeze(2)
    ) / Hby2

    h_cropped = F.grid_sample(h, crop_grid, mode=mode)

    return h_cropped

class Flatten(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.view(x.shape[0], -1)

class Global_Actor(nn.Module):
    def __init__(self, G):
        super().__init__()
        self.G = G
        self.actor = nn.Sequential(  # (8, G, G)
            nn.Conv2d(8, 8, 3, padding=1),  # (8, G, G)
            nn.BatchNorm2d(8),
            nn.ReLU(),
            nn.Conv2d(8, 4, 3, padding=1),  # (4, G, G)
            nn.BatchNorm2d(4),
            nn.ReLU(),
            nn.Conv2d(4, 4, 5, padding=2),  # (4, G, G)
            nn.BatchNorm2d(4),
            nn.ReLU(),
            nn.Conv2d(4, 2, 5, padding=2),  # (2, G, G)
            nn.BatchNorm2d(2),
            nn.ReLU(),
            nn.Conv2d(2, 1, 5, padding=2),  # (1, G, G)
            # nn.BatchNorm2d(1),
            Flatten(),  # (G*G, )
            # nn.Sigmoid(),  # added for non-negative
        )
        
        
    def _get_h12(self, inputs): # inputs needs to be a tensor, i.e., original inputs["map_at_t"], (bs, 4, M, M), channel 3 means one-hot pose, channel 0~1 means global map
        # x = inputs["pose_in_map_at_t"]  # (bs,2)
        # map_at_t (4, m, m)
        x = torch.nonzero(inputs[0][3]==1)
        h = inputs

        h_1 = crop_map(h, x[:, :2], self.G)
        h_2 = F.adaptive_max_pool2d(h, (self.G, self.G))

        h_12 = torch.cat([h_1, h_2], dim=1)

        return h_12

    def forward(self, inputs):
        x1 = self._get_h12(inputs)
        x2 = self.actor(x1)
        return x2

In [106]:
inputs = torch.randn(1, 4, 481, 481)

map3 = torch.zeros(1, 481, 481)
map3[0][20][20] = 1.
print(map3)
print(torch.nonzero(map3==1))
inputs[0][3] = map3
print(torch.nonzero(inputs[0][3]==1))

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])
tensor([[ 0, 20, 20]])
tensor([[20, 20]])


In [107]:
# eva ONNX model and torch model with calc same value
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

In [108]:
rl_net = Global_Actor(240)
rl_net.actor.load_state_dict(torch.load(pthfile, map_location='cpu'),strict=False)
rl_net.eval()

Global_Actor(
  (actor): Sequential(
    (0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Conv2d(8, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Conv2d(4, 4, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (7): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
    (9): Conv2d(4, 2, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (10): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU()
    (12): Conv2d(2, 1, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (13): Flatten()
  )
)

In [109]:
x = inputs
print(x.size())
torch_out = rl_net(x)
print("pth model output:", torch_out)

torch.Size([1, 4, 481, 481])
pth model output: tensor([[0.0042, 0.0042, 0.0042,  ..., 0.0042, 0.0042, 0.0042]],
       grad_fn=<ViewBackward>)


In [96]:
ort_session = onnxruntime.InferenceSession(onnxpath)

# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name:to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)

# compare ONNX Runtime and PyTorch results

print("onnx model output:", ort_outs)
print('tor_out: ', torch_out.shape)

np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNXRuntime, and the result looks good!")

InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Got invalid dimensions for input: input for the following indices
 index: 1 Got: 4 Expected: 8
 index: 2 Got: 480 Expected: 240
 index: 3 Got: 480 Expected: 240
 Please fix either the inputs or the model.

In [36]:
a = torch.randn(2,3,4)
print(a.size())
a = torch.unsqueeze(a, 1)
print(a.size())

torch.Size([2, 3, 4])
torch.Size([2, 1, 3, 4])
