## Keypoint Detection Solution
---

In [19]:
# load packages
import cv2
import numpy as np
import os
import torch
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline

In [20]:
# Initiate SURF detector
surf = cv2.xfeatures2d.SURF_create(30)

In [21]:
def getPatches(kps, img, size=32, num=500):
    res = torch.zeros(num, 1, size, size)
    if type(img) is np.ndarray:
        img = torch.from_numpy(img)
    h, w = img.shape      # note: for image, the x direction is the verticle, y-direction is the horizontal...
    for i in range(num):
        cx, cy = kps[i]
        cx, cy = int(cx), int(cy)
        dd = int(size/2)
        xmin, xmax = max(0, cx - dd), min(w, cx + dd ) - 1
        ymin, ymax = max(0, cy - dd), min(h, cy + dd ) - 1 
        
        xmin_res, xmax_res = dd - min(dd,cx), dd + min(dd, w - cx)-1
        ymin_res, ymax_res = dd - min(dd,cy), dd + min(dd, h - cy)-1

        res[i, 0, ymin_res: ymax_res, xmin_res: xmax_res] = img[ymin: ymax, xmin: xmax]
    return res

In [22]:
img_dir = "../image_retrieval/images"     # for images
# img_dir = "../image_retrieval/query"      # for query
kps_num = 30
patch_size = 32
img_num = 140      # for images
# img_num = 35       # for query
res = torch.zeros(img_num, kps_num, 2)
patches = torch.zeros(img_num, kps_num, 1, patch_size, patch_size)
if os.path.exists(img_dir):
    if os.listdir(img_dir) is []:
        print("No images!")
        exit(0)
    num_img = len(os.listdir(img_dir))
    idx = 0
    
    files = os.listdir(img_dir)
    for file in files:
        if not file.endswith("JPG"):
            files.remove(file)
    sorted_files = sorted(files, key=lambda x: int(x.split('.')[0]))                # for images
#     sorted_files = sorted(files, key=lambda x: int(x.split('.')[0].split('q')[1]))    # for query
    print(sorted_files)
    for img in sorted_files:
        if not img.endswith("JPG"):
            continue
        image_dir = os.path.join(img_dir, img)
        image = cv2.imread(image_dir)
        img= cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)
        
        ## find the keypoints and descriptors with SURF
        kps, des = surf.detectAndCompute(img, None)
#         if len(kps) < 30:
#             idx += 1
#             continue
        keypoints_img = [kps[a].pt for a in range(kps_num)] 
        res[idx] = torch.FloatTensor(keypoints_img)
        
        ## extract patches
        patches[idx] = getPatches(keypoints_img, img, size=patch_size, num=kps_num)
        idx += 1
        
        ## plot keypoints on each image
#         img2 = cv2.drawKeypoints(img, kps, None, color=(255,0,0), flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
#         plt.imshow(img2)
#         plt.show()
else:
    print("image folder not exists!")
    exit(0)

['1.JPG', '2.JPG', '3.JPG', '4.JPG', '5.JPG', '6.JPG', '7.JPG', '8.JPG', '9.JPG', '10.JPG', '11.JPG', '12.JPG', '13.JPG', '14.JPG', '15.JPG', '16.JPG', '17.JPG', '18.JPG', '19.JPG', '20.JPG', '21.JPG', '22.JPG', '23.JPG', '24.JPG', '25.JPG', '26.JPG', '27.JPG', '28.JPG', '29.JPG', '30.JPG', '31.JPG', '32.JPG', '33.JPG', '34.JPG', '35.JPG', '36.JPG', '37.JPG', '38.JPG', '39.JPG', '40.JPG', '41.JPG', '42.JPG', '43.JPG', '44.JPG', '45.JPG', '46.JPG', '47.JPG', '48.JPG', '49.JPG', '50.JPG', '51.JPG', '52.JPG', '53.JPG', '54.JPG', '55.JPG', '56.JPG', '57.JPG', '58.JPG', '59.JPG', '60.JPG', '61.JPG', '62.JPG', '63.JPG', '64.JPG', '65.JPG', '66.JPG', '67.JPG', '68.JPG', '69.JPG', '70.JPG', '71.JPG', '72.JPG', '73.JPG', '74.JPG', '75.JPG', '76.JPG', '77.JPG', '78.JPG', '79.JPG', '80.JPG', '81.JPG', '82.JPG', '83.JPG', '84.JPG', '85.JPG', '86.JPG', '87.JPG', '88.JPG', '89.JPG', '90.JPG', '91.JPG', '92.JPG', '93.JPG', '94.JPG', '95.JPG', '96.JPG', '97.JPG', '98.JPG', '99.JPG', '100.JPG', '101.JP

In [23]:
print(res.shape)
print(patches.shape)

torch.Size([140, 30, 2])
torch.Size([140, 30, 1, 32, 32])


In [24]:
## save tensors
output_dir_kps = "keypoints.pt"
output_dir_patches = "patches.pt"
torch.save(res, output_dir_kps)
torch.save(patches, output_dir_patches)