In [35]:
import argparse
import glob
import numpy as np
import os
import time

import cv2
import torch
class SuperPointNet(torch.nn.Module):
    """ Pytorch definition of SuperPoint Network. """
    def __init__(self):
        super(SuperPointNet, self).__init__()
        self.relu = torch.nn.ReLU(inplace=True)
        self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        c1, c2, c3, c4, c5, d1 = 64, 64, 128, 128, 256, 256
        # Shared Encoder.
        self.conv1a = torch.nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
        self.conv1b = torch.nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
        self.conv2a = torch.nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
        self.conv2b = torch.nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
        self.conv3a = torch.nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
        self.conv3b = torch.nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
        self.conv4a = torch.nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
        self.conv4b = torch.nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
        # Detector Head.
        self.convPa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
        self.convPb = torch.nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
        # Descriptor Head.
        self.convDa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
        self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=1, stride=1, padding=0)

    def forward(self, x):
        """ Forward pass that jointly computes unprocessed point and descriptor
        tensors.
        Input
          x: Image pytorch tensor shaped N x 1 x H x W.
        Output
          semi: Output point pytorch tensor shaped N x 65 x H/8 x W/8.
          desc: Output descriptor pytorch tensor shaped N x 256 x H/8 x W/8.
        """
        # Shared Encoder.
        x = self.relu(self.conv1a(x))
        x = self.relu(self.conv1b(x))
        x = self.pool(x)
        x = self.relu(self.conv2a(x))
        x = self.relu(self.conv2b(x))
        x = self.pool(x)
        x = self.relu(self.conv3a(x))
        x = self.relu(self.conv3b(x))
        x = self.pool(x)
        x = self.relu(self.conv4a(x))
        x = self.relu(self.conv4b(x))
        # Detector Head.
        cPa = self.relu(self.convPa(x))
        semi = self.convPb(cPa)
        # Descriptor Head.
        cDa = self.relu(self.convDa(x))
        desc = self.convDb(cDa)
        dn = torch.norm(desc, p=2, dim=1) # Compute the norm.
        desc = desc.div(torch.unsqueeze(dn, 1)) # Divide by norm to normalize.
        
#         semi = torch::softmax(semi, 1);
#         semi = semi.slice(1, 0, 64);
#         semi = semi.permute({0, 2, 3, 1});  // [B, H/8, W/8, 64]
        semi = torch.softmax(semi, 1)
        semi = semi.narrow(1, 0, 64)
        semi = semi.permute((0, 2, 3, 1))

#         int Hc = semi.size(1);
#         int Wc = semi.size(2);
#         semi = semi.contiguous().view({-1, Hc, Wc, 8, 8});
#         semi = semi.permute({0, 1, 3, 2, 4});
#         semi = semi.contiguous().view({-1, Hc * 8, Wc * 8});  // [B, H, W]
        Hc = semi.size(1)
        Wc = semi.size(2)
        semi = semi.contiguous().view((-1, Hc, Wc, 8, 8))
        semi = semi.permute((0, 1, 3, 2, 4))
        semi = semi.contiguous().view((-1, Hc * 8, Wc * 8)) 
        
        print(semi.size(), desc.size())
        return semi, desc


In [None]:
model = SuperPointNet()
model.load_state_dict(torch.load("superpoint_v1.pth"))
model.eval()
x = torch.randn(1, 1, 208, 400, requires_grad=True)
model(x)

In [37]:
# Input to the model


# Export the model
torch.onnx.export(model,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "superpoint_v1.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=11,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['image'],   # the model's input names
                  output_names = ['semi', "desc"]) # the model's output names


torch.Size([1, 208, 400]) torch.Size([1, 256, 26, 50])


