## Keypoint Detection
---
There are 10 images given in the ***images_corrected*** folder. Your task is to detect ***200*** keypoints for each of them using the SIFT detector.

Below is the tutorial to follow for generating SIFT keypoints:

SIFT: https://docs.opencv.org/3.3.0/da/df5/tutorial_py_sift_intro.html

Let's take a look at these images first!

---

In [32]:
# load packages
import cv2
import numpy as np
import os
import torch
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

%matplotlib inline

In [None]:
img_dir = "images_corrected"
sifts = {}
if os.path.exists(img_dir):
    if os.listdir(img_dir) is []:
        print("No images!")
        exit(0)
    num_img = len(os.listdir(img_dir))
    for img in os.listdir(img_dir):
        if not img.endswith("jpg"):
            continue
        image_dir = os.path.join(img_dir, img)
        image = cv2.imread(image_dir)
        gray= cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)

        sift = cv2.xfeatures2d.SIFT_create(nfeatures=200)
        kp = sift.detect(gray,None)
        sifts[img] = torch.Tensor(np.array([x.pt for x in kp][:200]))
        img=cv2.drawKeypoints(gray,kp,image,flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS)
        
        plt.imshow(img)
        plt.show()
else:
    print("image folder not exists!")
    exit(0)

In [58]:
sorted_sifts = {k: v for k, v in sorted(sifts.items(), key=lambda x: int(x[0].split('.')[0]))}
total_sifts = torch.stack(tuple(sorted_sifts.values()), dim=0)
print(total_sifts.shape)
torch.save(total_sifts, 'SIFT.pth')



torch.Size([10, 200, 2])


### The Keypoints:
Upon running SIFT on the 10 images, for each image, the 200 highest-response SIFT keypoints should be kept. These keypoints should be saved as a torch tensor of size (10 x 200 x 2) in a file called "SIFT.pth". Please ensure that the tensor stores the 10 images in order by image name (i.e., 1-10.jpg) and that the 200 keypoints are sorted from highest response to lowest response.

Note that the detected keypoints are represented as x and y coordinates in the image. For example, 10 keypoints of image ***3.jpg*** are:

In [4]:
#keypoints = list([(10.0, 10),  (16.0, 15.5), (15, 16), (1585, 16), (15, 1024), (100, 106), (150, 160), (715, 716), (315, 916), (815, 640)])
keypoints = [(1040.0224609375, 300.8042907714844), (399.89947509765625, 235.52102661132812), (1011.2779541015625, 283.0083923339844), (950.94677734375, 291.58880615234375), (333.86993408203125, 234.14942932128906), (949.2930297851562, 488.50238037109375), (1006.3855590820312, 481.1703796386719), (933.950927734375, 306.55450439453125), (1006.9017333984375, 275.14898681640625), (1007.01953125, 288.50469970703125)]

Then we can obtain the patches with these keypoints using the following helper function:


In [59]:
def getPatches(kps, img, size=32, num=500):
    res = torch.zeros(num, 1, size, size)
    if type(img) is np.ndarray:
        img = torch.from_numpy(img)
    h, w = img.shape      # note: for image, the x direction is the verticle, y-direction is the horizontal...
    for i in range(num):
        cx, cy = kps[i]
        cx, cy = int(cx), int(cy)
        dd = int(size/2)
        xmin, xmax = max(0, cx - dd), min(w, cx + dd ) 
        ymin, ymax = max(0, cy - dd), min(h, cy + dd ) 
        
        xmin_res, xmax_res = dd - min(dd,cx), dd + min(dd, w - cx)
        ymin_res, ymax_res = dd - min(dd,cy), dd + min(dd, h - cy)

        cropped_img = img[ymin: ymax, xmin: xmax]
        ch, cw = cropped_img.shape
        res[i, 0, ymin_res: ymin_res+ch, xmin_res: xmin_res+cw] =  cropped_img
        
    return res




### Let's plot these patches

In [65]:
img_dir = "images_corrected"
all_patches = []
if os.path.exists(img_dir):
    if os.listdir(img_dir) is []:
        print("No images!")
        exit(0)
    num_img = len(os.listdir(img_dir))
    for img in os.listdir(img_dir):
        if not img.endswith("jpg"):
            continue
        image_dir = os.path.join(img_dir, img)
        image = cv2.imread(image_dir)
        gray= cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)
        patches = getPatches(sorted_sifts[img], gray,size=32, num=200)
        for patch in patches:
            im = patch[0].numpy()
            # plt.imshow(im)
            # plt.show()
        print(patches.shape)
        all_patches.append(patches)

all_patches = torch.stack(all_patches, dim=0)
print(all_patches.shape)
output_dir = "SIFT_patches.pth"         # modify it to SIFT_patches.pth
torch.save(all_patches, output_dir)


torch.Size([200, 1, 32, 32])
torch.Size([200, 1, 32, 32])
torch.Size([200, 1, 32, 32])
torch.Size([200, 1, 32, 32])
torch.Size([200, 1, 32, 32])
torch.Size([200, 1, 32, 32])
torch.Size([200, 1, 32, 32])
torch.Size([200, 1, 32, 32])
torch.Size([200, 1, 32, 32])
torch.Size([200, 1, 32, 32])
torch.Size([10, 200, 1, 32, 32])


In [None]:
img = cv2.imread('images_corrected/3.jpg')
gray= cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
# gray = img[:, :, 0]
patches = getPatches(keypoints, gray,size=32, num=10)
for patch in patches:
    im = patch[0].numpy()
    plt.imshow(im)
    plt.show()


### Save the patches with PyTorch
You will extract 32x32 image patches for each of the 200 keypoints in each image. Theses patches will be saved as a torch tensor of size (10 x 200 x 1 x 32 x 32) in a file called "SIFT_patches.pth". Here 10 refers to the number of images (they should be stored sorted by image name, e.g, 1-10.jpg), 200 corresponds to the number of keypoints, and 1x32x32 corresponds to the gray scale image patch around a keypoint.

An example of how to save such a tensor is shown below using the example extracted patches, where the tensor returned by ***getPatches()*** is the one that you should store in a list:

In [7]:
all_patches = []
all_patches.append(patches)
all_patches = torch.stack(all_patches, dim=0)
output_dir = "patches.pth"         # modify it to SIFT_patches.pth
torch.save(all_patches, output_dir)

### Test with your saved patches

In [66]:
test_patches = torch.load(output_dir)
print(type(test_patches))
print(test_patches.shape)
# your tensor for each should have size of [10, 200, 1, 32, 32]; where 10 means 10 images (in the order 1-10), 200 means 200 keypoints.

<class 'torch.Tensor'>
torch.Size([10, 200, 1, 32, 32])
