In [76]:
import torch
import torch.nn as nn
import torch.nn.functional as F
# !pip3 install onnx
import onnx
# !pip3 install onnx-simplifier
import onnxsim
import numpy as np

In [3]:
class SuperPointNet(torch.nn.Module):
  """ Pytorch definition of SuperPoint Network. """
  def __init__(self):
    super(SuperPointNet, self).__init__()
    self.relu = torch.nn.ReLU(inplace=True)
    self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
    c1, c2, c3, c4, c5, d1 = 64, 64, 128, 128, 256, 256
    # Shared Encoder.
    self.conv1a = torch.nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
    self.conv1b = torch.nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
    self.conv2a = torch.nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
    self.conv2b = torch.nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
    self.conv3a = torch.nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
    self.conv3b = torch.nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
    self.conv4a = torch.nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
    self.conv4b = torch.nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
    # Detector Head.
    self.convPa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
    self.convPb = torch.nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
    # Descriptor Head.
    self.convDa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
    self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=1, stride=1, padding=0)

  def forward(self, x):
    """ Forward pass that jointly computes unprocessed point and descriptor
    tensors.
    Input
      x: Image pytorch tensor shaped N x 1 x H x W.
    Output
      semi: Output point pytorch tensor shaped N x 65 x H/8 x W/8.
      desc: Output descriptor pytorch tensor shaped N x 256 x H/8 x W/8.
    """
    # Shared Encoder.
    x = self.relu(self.conv1a(x))
    x = self.relu(self.conv1b(x))
    x = self.pool(x)
    x = self.relu(self.conv2a(x))
    x = self.relu(self.conv2b(x))
    x = self.pool(x)
    x = self.relu(self.conv3a(x))
    x = self.relu(self.conv3b(x))
    x = self.pool(x)
    x = self.relu(self.conv4a(x))
    x = self.relu(self.conv4b(x))
    # Detector Head.
    cPa = self.relu(self.convPa(x))
    semi = self.convPb(cPa)
    # Descriptor Head.
    cDa = self.relu(self.convDa(x))
    desc = self.convDb(cDa)
    dn = torch.norm(desc, p=2, dim=1) # Compute the norm.
    desc = desc.div(torch.unsqueeze(dn, 1)) # Divide by norm to normalize.
    return semi, desc

In [4]:
model = SuperPointNet()
num_params = sum([np.prod(p.size()) for p in model.parameters()])
num_params

1300865

In [63]:
class double_conv(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()

        def conv_dw(inp, oup, stride):
            return nn.Sequential(
                nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
                nn.BatchNorm2d(inp),
                nn.ReLU(inplace=True),

                nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
                nn.BatchNorm2d(oup),
                nn.ReLU(inplace=True),
            )

        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

        # Mobilenet v1
        # self.conv = nn.Sequential(
        #     conv_dw(in_ch, out_ch, 1),
        #     nn.BatchNorm2d(out_ch),
        #     nn.ReLU(inplace=True),
        #     conv_dw(out_ch, out_ch, 1),
        #     nn.BatchNorm2d(out_ch),
        #     nn.ReLU(inplace=True)
        # )


    def forward(self, x):
        x = self.conv(x)
        return x


class inconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(inconv, self).__init__()
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x


class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x


In [64]:
class SuperPointNet_gauss2(torch.nn.Module):
    """ Pytorch definition of SuperPoint Network. """
    def __init__(self, subpixel_channel=1):
        super(SuperPointNet_gauss2, self).__init__()
        c1, c2, c3, c4, c5, d1 = 64, 64, 128, 128, 256, 256
        det_h = 65
        self.inc = inconv(1, c1)
        self.down1 = down(c1, c2)
        self.down2 = down(c2, c3)
        self.down3 = down(c3, c4)
        self.relu = torch.nn.ReLU(inplace=True)

        # Detector Head.
        self.convPa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
        self.bnPa = nn.BatchNorm2d(c5)
        self.convPb = torch.nn.Conv2d(c5, det_h, kernel_size=1, stride=1, padding=0)
        self.bnPb = nn.BatchNorm2d(det_h)
        # Descriptor Head.
        self.convDa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
        self.bnDa = nn.BatchNorm2d(c5)
        self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=1, stride=1, padding=0)
        self.bnDb = nn.BatchNorm2d(d1)
        self.output = None

    def forward(self, x):
        """ Forward pass that jointly computes unprocessed point and descriptor
        tensors.
        Input
          x: Image pytorch tensor shaped N x 1 x patch_size x patch_size.
        Output
          semi: Output point pytorch tensor shaped N x 65 x H/8 x W/8.
          desc: Output descriptor pytorch tensor shaped N x 256 x H/8 x W/8.
        """
        # Let's stick to this version: first BN, then relu
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)

        # Detector Head.
        cPa = self.relu(self.bnPa(self.convPa(x4)))
        semi = self.bnPb(self.convPb(cPa))
        # Descriptor Head.
        cDa = self.relu(self.bnDa(self.convDa(x4)))
        desc = self.bnDb(self.convDb(cDa))

        dn = torch.norm(desc, p=2, dim=1) # Compute the norm.
        desc = desc.div(torch.unsqueeze(dn, 1)) # Divide by norm to normalize.
        output = {'semi': semi, 'desc': desc}
        self.output = output

        return output


In [65]:
model = SuperPointNet_gauss2()
num_params = sum([np.prod(p.size()) for p in model.parameters()])
num_params

1304067

*MobileNet v2*

In [8]:
class InvertedResidual(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio, use_batch_norm=True, onnx_compatible=False):
        super(InvertedResidual, self).__init__()
        ReLU = nn.ReLU if onnx_compatible else nn.ReLU6

        self.stride = stride
        assert stride in [1, 2]

        hidden_dim = round(inp * expand_ratio)
        self.use_res_connect = self.stride == 1 and inp == oup

        if expand_ratio == 1:
            if use_batch_norm:
                self.conv = nn.Sequential(
                    # dw
                    nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                    nn.BatchNorm2d(hidden_dim),
                    ReLU(inplace=True),
                    # pw-linear
                    nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                    nn.BatchNorm2d(oup),
                )
            else:
                self.conv = nn.Sequential(
                    # dw
                    nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                    ReLU(inplace=True),
                    # pw-linear
                    nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                )
        else:
            if use_batch_norm:
                self.conv = nn.Sequential(
                    # pw
                    nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                    nn.BatchNorm2d(hidden_dim),
                    ReLU(inplace=True),
                    # dw
                    nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                    nn.BatchNorm2d(hidden_dim),
                    ReLU(inplace=True),
                    # pw-linear
                    nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                    nn.BatchNorm2d(oup),
                )
            else:
                self.conv = nn.Sequential(
                    # pw
                    nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                    ReLU(inplace=True),
                    # dw
                    nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                    ReLU(inplace=True),
                    # pw-linear
                    nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                )

    def forward(self, x):
        if self.use_res_connect:
            return x + self.conv(x)
        else:
            return self.conv(x)



class double_conv_mobilenet(nn.Module):
    '''(conv => BN => ReLU) * 2'''
    def __init__(self, in_ch, out_ch, expand_ratio):
        super(double_conv_mobilenet, self).__init__()

        # Mobilenet v2
        self.conv = nn.Sequential(
            InvertedResidual(in_ch, out_ch, 1, expand_ratio=expand_ratio, use_batch_norm=True, onnx_compatible=True),
            nn.ReLU(inplace=True),
            InvertedResidual(out_ch, out_ch, 1, expand_ratio=expand_ratio, use_batch_norm=True, onnx_compatible=True),
            nn.ReLU(inplace=True)
        )


    def forward(self, x):
        x = self.conv(x)
        return x


In [9]:
class SuperPointNet_mobilenet(torch.nn.Module):
    """ Pytorch definition of SuperPoint Network. """
    def __init__(self, subpixel_channel=1):
        super(SuperPointNet_mobilenet, self).__init__()
        c1, c2, c3, c4, c5, d1 = 64, 64, 128, 128, 256, 256
        det_h = 65
        self.inc = double_conv_mobilenet(1, c1, expand_ratio=1)
        self.down1 = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv_mobilenet(c1, c2, expand_ratio=2)
        )
        self.down2 = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv_mobilenet(c2, c3, expand_ratio=2)
        )
        self.down3 = nn.Sequential(
            nn.MaxPool2d(2),
            double_conv_mobilenet(c3, c4, expand_ratio=2)
        )

        self.relu = torch.nn.ReLU(inplace=True)
        # Detector Head.
        self.convPa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
        self.bnPa = nn.BatchNorm2d(c5)
        self.convPb = torch.nn.Conv2d(c5, det_h, kernel_size=1, stride=1, padding=0)
        self.bnPb = nn.BatchNorm2d(det_h)
        # Descriptor Head.
        self.convDa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
        self.bnDa = nn.BatchNorm2d(c5)
        self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=1, stride=1, padding=0)
        self.bnDb = nn.BatchNorm2d(d1)
        self.output = None

    def forward(self, x):
        """ Forward pass that jointly computes unprocessed point and descriptor
        tensors.
        Input
          x: Image pytorch tensor shaped N x 1 x patch_size x patch_size.
        Output
          semi: Output point pytorch tensor shaped N x 65 x H/8 x W/8.
          desc: Output descriptor pytorch tensor shaped N x 256 x H/8 x W/8.
        """
        # Let's stick to this version: first BN, then relu
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)

        # Detector Head.
        cPa = self.relu(self.bnPa(self.convPa(x4)))
        semi = self.bnPb(self.convPb(cPa))
        # Descriptor Head.
        cDa = self.relu(self.bnDa(self.convDa(x4)))
        desc = self.bnDb(self.convDb(cDa))

        dn = torch.norm(desc, p=2, dim=1) # Compute the norm.
        desc = desc.div(torch.unsqueeze(dn, 1)) # Divide by norm to normalize.
        output = {'semi': semi, 'desc': desc}
        self.output = output

        return output

In [10]:
model = SuperPointNet_mobilenet()
num_params = sum([np.prod(p.size()) for p in model.parameters()])
num_params

949838

SqueezeNet


In [11]:
class Fire(nn.Module):

    def __init__(self, inplanes, squeeze_planes,
                 expand1x1_planes, expand3x3_planes):
        super(Fire, self).__init__()
        self.inplanes = inplanes
        self.squeeze = nn.Conv2d(inplanes, squeeze_planes, kernel_size=1)
        self.squeeze_bn = nn.BatchNorm2d(squeeze_planes)
        self.squeeze_activation = nn.ReLU(inplace=True)
        self.expand1x1 = nn.Conv2d(squeeze_planes, expand1x1_planes,
                                   kernel_size=1)
        self.expand1x1_bn = nn.BatchNorm2d(expand1x1_planes)
        self.expand1x1_activation = nn.ReLU(inplace=True)
        self.expand3x3 = nn.Conv2d(squeeze_planes, expand3x3_planes,
                                   kernel_size=3, padding=1)
        self.expand3x3_bn = nn.BatchNorm2d(expand3x3_planes)
        self.expand3x3_activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.squeeze_activation(self.squeeze_bn(self.squeeze(x)))
        return torch.cat([
            self.expand1x1_activation(self.expand1x1_bn(self.expand1x1(x))),
            self.expand3x3_activation(self.expand3x3_bn(self.expand3x3(x)))
        ], 1)

In [12]:
class SuperPointNet_squeezenet(torch.nn.Module):
    """ Pytorch definition of SuperPoint Network. """
    def __init__(self, subpixel_channel=1):
        super(SuperPointNet_squeezenet, self).__init__()
        c1, c2, c3, c4, c5, d1 = 64, 64, 128, 128, 256, 256
        det_h = 65
        self.inc = nn.Sequential(
            nn.Conv2d(1, c1, 3, padding=1),
            nn.BatchNorm2d(c1),
            nn.ReLU(inplace=True),
            nn.Conv2d(c1, c1, 3, padding=1),
            nn.BatchNorm2d(c1),
            nn.ReLU(inplace=True)
        )
        self.down1 = nn.Sequential(
            nn.MaxPool2d(2),
            Fire(c1, 16, int(c2/2), int(c2/2)),
            Fire(c2, 16, int(c2/2), int(c2/2)),
        )
        self.down2 = nn.Sequential(
            nn.MaxPool2d(2),
            Fire(c2, 32, int(c3/2), int(c3/2)),
            Fire(c3, 32, int(c3/2), int(c3/2)),
        )
        self.down3 = nn.Sequential(
            nn.MaxPool2d(2),
            Fire(c3, 48, int(c4/2), int(c4/2)),
            Fire(c4, 48, int(c4/2), int(c4/2)),
        )

        self.relu = torch.nn.ReLU(inplace=True)
        # Detector Head.
        self.convPa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
        self.bnPa = nn.BatchNorm2d(c5)
        self.convPb = torch.nn.Conv2d(c5, det_h, kernel_size=1, stride=1, padding=0)
        self.bnPb = nn.BatchNorm2d(det_h)
        # Descriptor Head.
        self.convDa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
        self.bnDa = nn.BatchNorm2d(c5)
        self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=1, stride=1, padding=0)
        self.bnDb = nn.BatchNorm2d(d1)
        self.output = None

    def forward(self, x):
        """ Forward pass that jointly computes unprocessed point and descriptor
        tensors.
        Input
          x: Image pytorch tensor shaped N x 1 x patch_size x patch_size.
        Output
          semi: Output point pytorch tensor shaped N x 65 x H/8 x W/8.
          desc: Output descriptor pytorch tensor shaped N x 256 x H/8 x W/8.
        """
        # Let's stick to this version: first BN, then relu
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)

        # Detector Head.
        cPa = self.relu(self.bnPa(self.convPa(x4)))
        semi = self.bnPb(self.convPb(cPa))
        # Descriptor Head.
        cDa = self.relu(self.bnDa(self.convDa(x4)))
        desc = self.bnDb(self.convDb(cDa))

        dn = torch.norm(desc, p=2, dim=1) # Compute the norm.
        desc = desc.div(torch.unsqueeze(dn, 1)) # Divide by norm to normalize.
        output = {'semi': semi, 'desc': desc}
        self.output = output

        return output

In [13]:
model = SuperPointNet_squeezenet()
num_params = sum([np.prod(p.size()) for p in model.parameters()])
num_params

847939

Resnet18

In [77]:
from models.resnet import resnet18

# from models.SubpixelNet import SubpixelNet
class SuperPointNet_resnet18(torch.nn.Module):
    """ Pytorch definition of SuperPoint Network. """
    def __init__(self, subpixel_channel=1):
        super(SuperPointNet_resnet18, self).__init__()
        c1, c2, c3, c4, c5, d1 = 64, 64, 128, 128, 256, 256
        det_h = 65
        self.feature = resnet18()
        # self.inc = inconv(1, c1)
        # self.down1 = down(c1, c2)
        # self.down2 = down(c2, c3)
        # self.down3 = down(c3, c4)
        # self.down4 = down(c4, 512)
        # self.up1 = up(c4+c3, c2)
        # self.up2 = up(c2+c2, c1)
        # self.up3 = up(c1+c1, c1)
        # self.outc = outconv(c1, subpixel_channel)
        self.relu = torch.nn.ReLU(inplace=True)
        # self.outc = outconv(64, n_classes)
        # Detector Head.
        self.convPa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
        self.bnPa = nn.BatchNorm2d(c5)
        self.convPb = torch.nn.Conv2d(c5, det_h, kernel_size=1, stride=1, padding=0)
        self.bnPb = nn.BatchNorm2d(det_h)
        # Descriptor Head.
        self.convDa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
        self.bnDa = nn.BatchNorm2d(c5)
        self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=1, stride=1, padding=0)
        self.bnDb = nn.BatchNorm2d(d1)
        self.output = None

    def forward(self, x):
        """ Forward pass that jointly computes unprocessed point and descriptor
        tensors.
        Input
          x: Image pytorch tensor shaped N x 1 x patch_size x patch_size.
        Output
          semi: Output point pytorch tensor shaped N x 65 x H/8 x W/8.
          desc: Output descriptor pytorch tensor shaped N x 256 x H/8 x W/8.
        """
        # Let's stick to this version: first BN, then relu
        # x1 = self.inc(x)
        # x2 = self.down1(x1)
        # x3 = self.down2(x2)
        # x4 = self.down3(x3)
        x4 = self.feature(x)

        # Detector Head.
        cPa = self.relu(self.bnPa(self.convPa(x4)))
        semi = self.bnPb(self.convPb(cPa))
        # Descriptor Head.
        cDa = self.relu(self.bnDa(self.convDa(x4)))
        desc = self.bnDb(self.convDb(cDa))

        dn = torch.norm(desc, p=2, dim=1) # Compute the norm.
        desc = desc.div(torch.unsqueeze(dn, 1)) # Divide by norm to normalize.
        output = {'semi': semi, 'desc': desc}
        self.output = output

        return output


In [78]:
model = SuperPointNet_resnet18()
num_params = sum([np.prod(p.size()) for p in model.parameters()])
num_params

1942147

Export to onnx


In [42]:
net1 = SuperPointNet()
state_dict1 = torch.load('./pretrained/superpoint_v1.pth')
net1.load_state_dict(state_dict1)

<All keys matched successfully>

In [66]:
net = SuperPointNet_gauss2()
checkpoint = torch.load('./logs/superpoint_coco_gauss/checkpoints/superPointNet_150000_checkpoint.pth.tar', map_location=torch.device('cpu'))
net.load_state_dict(checkpoint['model_state_dict'])
net.eval()

SuperPointNet_gauss2(
  (inc): inconv(
    (conv): double_conv(
      (conv): Sequential(
        (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
  )
  (down1): down(
    (mpconv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): double_conv(
        (conv): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4)

In [41]:
net = SuperPointNet_mobilenet()
checkpoint = torch.load('./logs/superpoint_coco_mobilenet_v2/checkpoints/superPointNet_150000_checkpoint.pth.tar', map_location=torch.device('cpu'))
net.load_state_dict(checkpoint['model_state_dict'])
net.eval()

SuperPointNet_mobilenet(
  (inc): double_conv_mobilenet(
    (conv): Sequential(
      (0): InvertedResidual(
        (conv): Sequential(
          (0): Conv2d(1, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(1, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): ReLU(inplace=True)
      (2): InvertedResidual(
        (conv): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, a

In [52]:
net = SuperPointNet_squeezenet()
checkpoint = torch.load('./logs/superpoint_coco_squeezenet/checkpoints/superPointNet_150000_checkpoint.pth.tar', map_location=torch.device('cpu'))
net.load_state_dict(checkpoint['model_state_dict'])
net.eval()

SuperPointNet_squeezenet(
  (inc): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
  )
  (down1): Sequential(
    (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (1): Fire(
      (squeeze): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1))
      (squeeze_bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (squeeze_activation): ReLU(inplace=True)
      (expand1x1): Conv2d(16, 32, kernel_size=(1, 1), stride=(1, 1))
      (expand1x1_bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (expand1x1_activation): ReLU(inplace=True)
      (exp

In [79]:
net = SuperPointNet_resnet18()
checkpoint = torch.load('./logs/superpoint_coco_resnet18/checkpoints/superPointNet_150000_checkpoint.pth.tar', map_location=torch.device('cpu'))
net.load_state_dict(checkpoint['model_state_dict'])
net.eval()

SuperPointNet_resnet18(
  (feature): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [81]:
x = torch.ones(1, 1, 120, 392, requires_grad=True)
torch_out = net(x)
torch_out['semi'][0,0,0]

tensor([7.1143, 7.1408, 7.4010, 7.3743, 7.1654, 7.0392, 6.9852, 6.8115, 6.7343,
        6.6291, 6.5956, 6.5610, 6.5692, 6.5692, 6.5692, 6.5692, 6.5692, 6.5692,
        6.5692, 6.5692, 6.5692, 6.5692, 6.5692, 6.5692, 6.5692, 6.5692, 6.5692,
        6.5692, 6.5692, 6.5692, 6.5692, 6.5692, 6.5692, 6.5692, 6.5692, 6.5692,
        6.5692, 6.5692, 6.5541, 6.5457, 6.5599, 6.6784, 6.8644, 7.1354, 7.3488,
        7.6130, 7.8650, 7.4608, 6.3813], grad_fn=<SelectBackward0>)

In [82]:
torch.onnx.export(net,               # model being run
                  x,                         # model input (or a tuple for multiple inputs)
                  "sp_resnet18_b1.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=11,          # the ONNX version to export the model to
                  training = torch.onnx.TrainingMode.EVAL,
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output_det', 'output_desc'],
                  dynamic_axes={
                      "input":{
                          2: "Height",
                          3: "Width"
                      }
                  }
)

In [83]:

# Checks
model_onnx = onnx.load("sp_resnet18_b1.onnx")  # load onnx model
onnx.checker.check_model(model_onnx)  # check onnx model
# print(onnx.helper.printable_graph(model_onnx.graph))  # print

In [84]:
model_onnx_simplified, check = onnxsim.simplify(
    model_onnx,
    dynamic_input_shape=True,
    input_shapes={'input': [1,1,120,392]})
assert check, 'assert check failed'
onnx.save(model_onnx_simplified, "sp_resnet18_b1_simplified.onnx")

# model_onnx, check = onnxsim.simplify(
#     model_onnx,
#     dynamic_input_shape=False)
# assert check, 'assert check failed'
# onnx.save(model_onnx, "superpoint_pretrained_b2.onnx")

In [26]:
import cv2
cv2.__version__

'4.5.4'

In [85]:
# superpoint = cv2.dnn.readNetFromONNX("./onnx/sp-gauss-sparse-loss/superpoint_gauss_sparse_loss_simplified.onnx")
# superpoint = cv2.dnn.readNetFromONNX("./onnx/sp-gauss-sparse-loss/superpoint_pretrained_b1.onnx")
superpoint = cv2.dnn.readNetFromONNX("./sp_resnet18_b1_simplified.onnx")


In [37]:
img = np.ones((120, 392), dtype=np.float32)
# img = cv2.imread("assets/icl_snippet/250.png",cv2.IMREAD_GRAYSCALE)
# img = cv2.imread("image_00/data/0000000010.png", cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (392, 120)).astype(np.float32)
blob = cv2.dnn.blobFromImage(img, size=(392, 120))

In [38]:
img[:5, :5]

array([[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]], dtype=float32)

In [86]:
superpoint.setInput(blob)
result = superpoint.forward(superpoint.getUnconnectedOutLayersNames())

In [87]:
result[0][0,0,0]

array([7.114337 , 7.1408205, 7.401045 , 7.374346 , 7.1654186, 7.0391836,
       6.985178 , 6.8115234, 6.734315 , 6.6291046, 6.5956297, 6.5609856,
       6.5691776, 6.5691776, 6.5691776, 6.5691776, 6.5691776, 6.5691776,
       6.5691776, 6.5691776, 6.5691776, 6.5691776, 6.5691776, 6.5691776,
       6.5691776, 6.5691776, 6.5691776, 6.5691776, 6.5691776, 6.5691776,
       6.5691776, 6.5691776, 6.5691776, 6.5691776, 6.5691776, 6.5691776,
       6.5691776, 6.5691776, 6.554138 , 6.545664 , 6.559906 , 6.678446 ,
       6.8643665, 7.1353664, 7.3487816, 7.6130466, 7.8649883, 7.4607544,
       6.381258 ], dtype=float32)

In [32]:
img_tensor = torch.Tensor(img).unsqueeze(0).unsqueeze(0)
img_tensor.shape

torch.Size([1, 1, 120, 392])

In [88]:
pytorch_result = net(img_tensor)

In [61]:
pytorch_result['desc'][0,0,0]

tensor([-0.1762, -0.0966, -0.0629, -0.0491, -0.0465, -0.0619, -0.0619, -0.0619,
        -0.0619, -0.0619, -0.0619, -0.0619, -0.0619, -0.0619, -0.0619, -0.0619,
        -0.0619, -0.0619, -0.0619, -0.0619, -0.0619, -0.0619, -0.0619, -0.0619,
        -0.0619, -0.0619, -0.0619, -0.0619, -0.0619, -0.0619, -0.0619, -0.0619,
        -0.0619, -0.0619, -0.0619, -0.0619, -0.0619, -0.0619, -0.0619, -0.0619,
        -0.0619, -0.0619, -0.0619, -0.0619, -0.0706, -0.0902, -0.1196, -0.1152,
        -0.0064], grad_fn=<SelectBackward0>)

In [89]:
pytorch_result['semi'][0,0,0]

tensor([7.1143, 7.1408, 7.4010, 7.3743, 7.1654, 7.0392, 6.9852, 6.8115, 6.7343,
        6.6291, 6.5956, 6.5610, 6.5692, 6.5692, 6.5692, 6.5692, 6.5692, 6.5692,
        6.5692, 6.5692, 6.5692, 6.5692, 6.5692, 6.5692, 6.5692, 6.5692, 6.5692,
        6.5692, 6.5692, 6.5692, 6.5692, 6.5692, 6.5692, 6.5692, 6.5692, 6.5692,
        6.5692, 6.5692, 6.5541, 6.5457, 6.5599, 6.6784, 6.8644, 7.1354, 7.3488,
        7.6130, 7.8650, 7.4608, 6.3813], grad_fn=<SelectBackward0>)