In [5]:
# -*-coding:utf8-*-
import torch
import torch.nn.functional as F
import numpy as np
# from utils.tensor_op import pixel_shuffle
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

def pixel_shuffle(tensor, scale_factor):
    """
    Implementation of pixel shuffle using numpy

    Parameters:
    -----------
    tensor: input tensor, shape is [N, C, H, W]
    scale_factor: scale factor to up-sample tensor

    Returns:
    --------
    tensor: tensor after pixel shuffle, shape is [N, C/(r*r), r*H, r*W],
        where r refers to scale factor
    """
    num, ch, height, width = tensor.shape
    assert ch % (scale_factor * scale_factor) == 0

    new_ch = ch // (scale_factor * scale_factor)
    new_height = height * scale_factor
    new_width = width * scale_factor

    tensor = tensor.reshape(
        [num, new_ch, scale_factor, scale_factor, height, width])
    # new axis: [num, new_ch, height, scale_factor, width, scale_factor]
    tensor = tensor.permute(0, 1, 4, 2, 5, 3)
    tensor = tensor.reshape(num, new_ch, new_height, new_width)
    return tensor

class SuperPointNet(torch.nn.Module):
    """
    The magicleap definition of SuperPoint Network.
    Mainly for debug or export homography adaptations
    """
    def __init__(self, input_channel=1, grid_size=8):
        super(SuperPointNet, self).__init__()

        self.grid_size = grid_size

        self.relu = torch.nn.ReLU(inplace=True)
        self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
        c1, c2, c3, c4, c5, d1 = 64, 64, 128, 128, 256, 256
        # Shared Encoder.
        self.conv1a = torch.nn.Conv2d(input_channel, c1, kernel_size=3, stride=1, padding=1)
        self.conv1b = torch.nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
        self.conv2a = torch.nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
        self.conv2b = torch.nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
        self.conv3a = torch.nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
        self.conv3b = torch.nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
        self.conv4a = torch.nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
        self.conv4b = torch.nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
        # Detector Head.
        self.convPa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
        self.convPb = torch.nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
        # Descriptor Head.
        self.convDa = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
        self.convDb = torch.nn.Conv2d(c5, d1, kernel_size=1, stride=1, padding=0)
        #
        self.softmax = torch.nn.Softmax(dim=1)

    def forward(self, x):
        """ Forward pass that jointly computes unprocessed point and descriptor
        tensors.
        Input
          x: Image pytorch tensor shaped N x 1 x H x W.
        Output
          semi: Output point pytorch tensor shaped N x 65 x H/8 x W/8.
          desc: Output descriptor pytorch tensor shaped N x 256 x H/8 x W/8.
        """
        if isinstance(x, dict):
            x = x['img']

        # Shared Encoder.
        x = self.relu(self.conv1a(x))
        x = self.relu(self.conv1b(x))
        x = self.pool(x)
        x = self.relu(self.conv2a(x))
        x = self.relu(self.conv2b(x))
        x = self.pool(x)
        x = self.relu(self.conv3a(x))
        x = self.relu(self.conv3b(x))
        x = self.pool(x)
        x = self.relu(self.conv4a(x))
        x = self.relu(self.conv4b(x))
        # Detector Head.
        cPa = self.relu(self.convPa(x))
        semi = self.convPb(cPa)
        #
        prob = self.softmax(semi)
        prob = prob[:, :-1, :, :]  # remove dustbin,[B,64,H,W]
        # Reshape to get full resolution heatmap.
        prob = pixel_shuffle(prob, self.grid_size)  # [B,1,H*8,W*8]
        prob = prob.squeeze(dim=1)#[B,H,W]

        # Descriptor Head, useless for export image key points
        cDa = self.relu(self.convDa(x))
        out = self.convDb(cDa)
        dn = torch.norm(out, p=2, dim=1)  # Compute the norm.
        desc_raw = out.div(torch.unsqueeze(dn, 1))  # Divide by norm to normalize.
        ##
        # # interpolation
        desc = F.interpolate(desc_raw, scale_factor=self.grid_size, mode='bilinear', align_corners=False)
        desc = F.normalize(desc, p=2, dim=1)  # normalize by channel

        prob = {'logits':semi, 'prob':prob}
        desc = {'desc_raw':desc_raw, 'desc':desc}
        return prob, desc




In [7]:
model = SuperPointNet()
model.load_state_dict(torch.load('./superpoint_v1.pth'))
print('Done')

Done


In [101]:
import cv2

img_path = './data/images/COCO_train2014_000000000009.jpg'

img = cv2.imread(img_path, 0)
print(img.shape)

(240, 320)


In [114]:
x = torch.tensor(img, dtype=torch.float32)
x = x.unsqueeze(0)
x = x.unsqueeze(0)

torch.Size([1, 1, 240, 320])
tensor([[[[ 27.,  27.,  27.,  ..., 170., 163., 162.],
          [ 24.,  26.,  27.,  ..., 174., 165., 165.],
          [ 24.,  26.,  28.,  ..., 171., 170., 168.],
          ...,
          [  5.,   6.,  16.,  ..., 101.,  69.,  30.],
          [  3.,   3.,  12.,  ...,  32.,  19.,  12.],
          [  2.,   4.,   7.,  ...,  11.,  10.,   4.]]]])


In [115]:
out = model(x)

In [119]:
print(out[0]['logits'].shape, out[0]['prob'].shape)
print(out[1]['desc_raw'].shape, out[1]['desc'].shape)

torch.Size([1, 65, 30, 40]) torch.Size([1, 240, 320])
torch.Size([1, 256, 30, 40]) torch.Size([1, 256, 240, 320])


In [129]:
test = out[1]['desc'][0]
print(test.shape)

torch.Size([256, 240, 320])


In [130]:
test2 = test.reshape([240,320,256])

In [146]:
t = torch.tensor([[[1,2, 3, 3],[4, 5, 6, 6],[0,0,0,0]],
                 [[7,8,9, 9],[10,11,12, 12],[0,0,0,0]]])
print(t)

tensor([[[ 1,  2,  3,  3],
         [ 4,  5,  6,  6],
         [ 0,  0,  0,  0]],

        [[ 7,  8,  9,  9],
         [10, 11, 12, 12],
         [ 0,  0,  0,  0]]])


In [147]:
t.shape

torch.Size([2, 3, 4])

In [148]:
t[0][0]

tensor([1, 2, 3, 3])

In [149]:
t = t.reshape(3,4,2)
print(t[0][0])

tensor([1, 2])


In [150]:
print(t)

tensor([[[ 1,  2],
         [ 3,  3],
         [ 4,  5],
         [ 6,  6]],

        [[ 0,  0],
         [ 0,  0],
         [ 7,  8],
         [ 9,  9]],

        [[10, 11],
         [12, 12],
         [ 0,  0],
         [ 0,  0]]])
