In [1]:
import numpy as np
import torch
import torch.nn.functional as func
import sys
import cv2
from HIFT_core import DeSTR

In [2]:
def SIFT_detect(img, nfeatures=1500, contrastThreshold=0.04):
    """ Compute SIFT feature points. """
    sift = cv2.xfeatures2d.SIFT_create(nfeatures=nfeatures,
                                       contrastThreshold=contrastThreshold)
    keypoints = sift.detect(img, None)
    keypoints = [[k.pt[1], k.pt[0], k.response] for k in keypoints]
    keypoints = np.array(keypoints)
    return keypoints

def keypoints_to_grid(keypoints, img_size):
    """
    Convert a tensor [N, 2] or batched tensor [B, N, 2] of N keypoints into
    a grid in [-1, 1]² that can be used in torch.nn.functional.interpolate.
    """
    n_points = keypoints.size()[-2]
    device = keypoints.device
    grid_points = keypoints.float() * 2. / torch.tensor(
        img_size, dtype=torch.float, device=device) - 1.
    grid_points = grid_points[..., [1, 0]].view(-1, n_points, 1, 2) #B*np*1*2
    return grid_points

def _adapt_weight_names(state_dict):
    """ Adapt the weight names when the training and testing are done
    with a different GPU configuration (with/without DataParallel). """
    train_parallel = list(state_dict.keys())[0][:7] == 'module.'
    test_parallel = torch.cuda.device_count() > 1
    new_state_dict = {}
    if train_parallel and (not test_parallel):
        # Need to remove 'module.' from all the variable names
        for k, v in state_dict.items():
            new_state_dict[k[7:]] = v
    elif test_parallel and (not train_parallel):
        # Need to add 'module.' to all the variable names
        for k, v in state_dict.items():
            new_k = 'module.' + k
            new_state_dict[new_k] = v
    else:  # Nothing to do
        new_state_dict = state_dict
    return new_state_dict

def _match_state_dict(old_state_dict, new_state_dict):
    """ Return a new state dict that has exactly the same entries
            as old_state_dict and that is updated with the values of
            new_state_dict whose entries are shared with old_state_dict.
            This allows loading a pre-trained network. """
    return ({k: new_state_dict[k] if k in new_state_dict else v
             for (k, v) in old_state_dict.items()},
            old_state_dict.keys() == new_state_dict.keys())

def mutual_nn_matching_torch(desc1, desc2, threshold=None):
    if len(desc1) == 0 or len(desc2) == 0:
        return torch.empty((0, 2), dtype=torch.int64), torch.empty((0, 2), dtype=torch.int64)

    device = desc1.device
    similarity = torch.einsum('id, jd->ij', desc1, desc2)

    nn12 = similarity.max(dim=1)[1]
    nn21 = similarity.max(dim=0)[1]
    ids1 = torch.arange(0, similarity.shape[0], device=device)
    mask = (ids1 == nn21[nn12])
    matches = torch.stack([ids1[mask], nn12[mask]]).t()
    scores = similarity.max(dim=1)[0][mask]    
    if threshold:
        mask = scores > threshold
        matches = matches[mask]    
        scores = scores[mask]
    return matches, scores

In [3]:
# load model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
checkpoint_path = '/home/ray/tim_mshift/HIFT/hift.pth'
descriptor = DeSTR()
checkpoint = torch.load(checkpoint_path, map_location='cpu')
adapt_dict = _adapt_weight_names(checkpoint['model_state_dict'])
net_dict = descriptor.state_dict()
updated_state_dict, same_net = _match_state_dict(net_dict, adapt_dict)
descriptor.load_state_dict(updated_state_dict)
descriptor = descriptor.to(device)
if same_net:
    print("Success in loading model!")
descriptor.eval()

Success in loading model!


DeSTR(
  (descriptor): Descriptor(
    (backbone): HTNet(
      (conv): Sequential(
        (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (1): PReLU(num_parameters=1)
        (2): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): DyReLU(
          (fc1): Linear(in_features=64, out_features=16, bias=True)
          (relu): ReLU(inplace=True)
          (fc2): Linear(in_features=16, out_features=256, bias=True)
          (sigmoid): Sigmoid()
        )
        (5): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (6): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (7): DyReLU(
          (fc1): Linear(in_features=128, out_features=32, bias=True)
          (relu): ReLU(inplace=True)
          (fc2): Linear(in_features=32, out_features=512, bias=True)
          (sig

In [4]:
# keypoint extraction
img1 = cv2.imread('/home/ray/tim_mshift/day.jpg',0)
img2 = cv2.imread('/home/ray/tim_mshift/night.jpg',0)
keypoints1 = SIFT_detect(img1, nfeatures=1500, contrastThreshold=0.04)
keypoints2 = SIFT_detect(img2, nfeatures=1500, contrastThreshold=0.04)
grid_points1 = keypoints_to_grid(torch.tensor(keypoints1[:, :2], dtype=torch.float, device=device),img1.shape[:2])
grid_points2 = keypoints_to_grid(torch.tensor(keypoints2[:, :2], dtype=torch.float, device=device),img2.shape[:2])
keypoints1 = keypoints1[:, [1, 0]]
keypoints2 = keypoints2[:, [1, 0]]

[ WARN:0] DEPRECATED: cv.xfeatures2d.SIFT_create() is deprecated due SIFT tranfer to the main repository. https://github.com/opencv/opencv/issues/16736


In [7]:
# descriptor extraction
img1_tensor = torch.tensor(img1[None][None], dtype=torch.float, device=device)/255.
img2_tensor = torch.tensor(img2[None][None], dtype=torch.float, device=device)/255.
with torch.no_grad():
    outputs1 = descriptor.forward(img1_tensor)
    desc1 = func.grid_sample(outputs1, grid_points1).squeeze().transpose(1, 0)
    outputs2 = descriptor.forward(img2_tensor)
    desc2 = func.grid_sample(outputs2, grid_points2).squeeze().transpose(1, 0)
print(desc1.shape, desc2.shape)

torch.Size([1500, 128]) torch.Size([1355, 128])


In [12]:
# mutual nearest neighborhood matching
matches, score = mutual_nn_matching_torch(desc1, desc2)
print(matches.shape)
matches = matches.cpu().numpy()
match1 = keypoints1[matches[:,0]]
match2 = keypoints2[matches[:,1]]

torch.Size([422, 2])


In [13]:
# ransac
inliers = cv2.findHomography(match1, match2, cv2.RANSAC)[1][:,0].astype(bool)
match1 = match1[inliers]
match2 = match2[inliers]
print(match1.shape)

(80, 2)


In [15]:
match1 = match1.astype(np.int32)
match2 = match2.astype(np.int32)
img1 = cv2.imread('/home/ray/tim_mshift/day.jpg')
h, w = img1.shape[:2]
img2 = cv2.imread('/home/ray/tim_mshift/night.jpg')
num = match1.shape[0]
draw_image = np.concatenate([img1, img2], axis=1)
for i in range(num):
    cv2.circle(draw_image, (match1[i, 0], match1[i, 1]), 1, (0, 0, 255), 2)
    cv2.circle(draw_image, (match2[i, 0] + w, match2[i, 1]), 1, (0, 0, 255), 2)
for i in range(num):
    cv2.line(draw_image, (match1[i, 0], match1[i, 1]), (match2[i, 0] + w, match2[i, 1]),(0, 255, 0), 1)
cv2.imwrite('match_pair.jpg', draw_image)

True