## Keypoint Detection Solution
---

In [1]:
# load packages
import cv2
import numpy as np
import os
import torch
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline

In [2]:
# Initiate SIFT detector
sift = cv2.xfeatures2d.SIFT_create(200)

In [3]:
def getPatches(kps, img, size=32, num=500):
    res = torch.zeros(num, 1, size, size)
    if type(img) is np.ndarray:
        img = torch.from_numpy(img)
    h, w = img.shape      # note: for image, the x direction is the verticle, y-direction is the horizontal...
    for i in range(num):
        cx, cy = kps[i]
        cx, cy = int(cx), int(cy)
        dd = int(size/2)
        xmin, xmax = max(0, cx - dd), min(w, cx + dd ) - 1
        ymin, ymax = max(0, cy - dd), min(h, cy + dd ) - 1 
        
        xmin_res, xmax_res = dd - min(dd,cx), dd + min(dd, w - cx)-1
        ymin_res, ymax_res = dd - min(dd,cy), dd + min(dd, h - cy)-1

        res[i, 0, ymin_res: ymax_res, xmin_res: xmax_res] = img[ymin: ymax, xmin: xmax]
    return res

In [4]:
img_dir = "img"
kps_num = 200
patch_size = 32
res = torch.zeros(5, kps_num, 2)
patches = torch.zeros(5, kps_num, 1, patch_size, patch_size)
if os.path.exists(img_dir):
    if os.listdir(img_dir) is []:
        print("No images!")
        exit(0)
    num_img = len(os.listdir(img_dir))
    idx = 0
    for img in os.listdir(img_dir):
        if not img.endswith("jpg"):
            continue
        image_dir = os.path.join(img_dir, img)
        image = cv2.imread(image_dir)
        img= cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)
        
        ## find the keypoints and descriptors with SIFT
        kps, des = sift.detectAndCompute(img, None)
        keypoints_img = [kps[a].pt for a in range(kps_num)] 
        res[idx] = torch.FloatTensor(keypoints_img)
        
        ## extract patches
        patches[idx] = getPatches(keypoints_img, img, size=patch_size, num=kps_num)
        idx += 1
        
        ## plot keypoints on each image
        # img2 = cv2.drawKeypoints(img, kps, None, color=(255,0,0), flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
        # plt.imshow(img2)
        # plt.show()
else:
    print("image folder not exists!")
    exit(0)

In [5]:
print(res.shape)
print(patches.shape)

torch.Size([5, 200, 2])
torch.Size([5, 200, 1, 32, 32])


In [6]:
## save tensors
output_dir_kps = "keypoints.pt"
output_dir_patches = "patches.pt"
torch.save(res, output_dir_kps)
torch.save(patches, output_dir_patches)