In [1]:
import matplotlib.pyplot as plt
import pandas as pd
import os
import cv2  
from tqdm import tqdm
from collections import Counter
from PIL import Image

In [2]:
def create_des(path):
    sift = cv2.SIFT_create()
    des_dict = {}
    for filename in tqdm(sorted(os.listdir(path))):
        if filename.startswith("."):
            continue
        gray = cv2.imread(os.path.join(path, filename),cv2.IMREAD_GRAYSCALE)
        kp, des = sift.detectAndCompute(gray,None)
        des_dict[filename] = des
    return des_dict

In [4]:
train_des = create_des("train")
test_des = create_des("test")

100%|███████████████████████████████████████| 7501/7501 [11:49<00:00, 10.58it/s]
100%|███████████████████████████████████████| 1201/1201 [01:40<00:00, 11.95it/s]


In [5]:
def create_flann(des):
    index_params = dict(algorithm = 1, trees = 5)
    search_params = dict(checks=50)
    flann = cv2.FlannBasedMatcher(index_params,search_params)
    for d in tqdm(des.values()):
        flann.add([d])
    flann.train()
    return flann

In [6]:
FLANN = create_flann(train_des)

100%|███████████████████████████████████| 7500/7500 [00:00<00:00, 168150.61it/s]


In [12]:
def sift_compare(path, flann, test_des, train_des):
    sift_result = {}
    train_image_list = list(train_des)
    for filename in tqdm(sorted(os.listdir(path))):
        if filename.startswith("."):
            continue
        try:
            matches = flann.knnMatch(test_des[filename], k=2)
        except:
            sift_result[filename] = []
            continue
        
        matches_counter = Counter()
        for m, n in matches:
            if m.distance < 0.55*n.distance:
                matches_counter[train_image_list[m.imgIdx]] += 1
        
        best_match = matches_counter.most_common(3)
        sift_result[filename] = best_match
        
    return sift_result
        

In [13]:
sift_result = sift_compare("test", FLANN, test_des, train_des)

100%|███████████████████████████████████████| 1201/1201 [13:40<00:00,  1.46it/s]


In [None]:
# check how the top 3 matching image looks like
# for test only
for img_name in sift_result.keys():
    plt.subplots(figsize=(10, 10)) 
    print(img_name)
    img = Image.open(os.path.join("test/"+img_name))
    plt.subplot(1, 4, 1)
    plt.imshow(img)
    plt.axis("off")
    predict = sift_result.get(img_name)
    for i, (predict_img,_) in enumerate(predict):
        img = Image.open(os.path.join("train/"+predict_img))
        plt.subplot(1, 4, i+2)
        plt.imshow(img)
        plt.axis("off")
    plt.show()

In [14]:
cnn_result = pd.read_csv("output_cnn.csv", index_col="id")
train_coord_csv = pd.read_csv("train.csv", index_col="id")
all_test_coords = []
for img_name in sift_result.keys():
    current = []
    current.append(img_name[:-4])
    corresponding_image_list = sift_result.get(img_name)
    if len(corresponding_image_list) == 0:
        x, y = cnn_result.loc[img_name[:-4]]
    else:
        # uses the best match image
        best_image = corresponding_image_list[0][0]
        x, y = train_coord_csv.loc[best_image[:-4]]
    current.append(x)
    current.append(y)
    all_test_coords.append(current)
        

In [15]:
df = pd.DataFrame(all_test_coords, columns = ["id", 'x', 'y'])
df.to_csv("output_sift.csv", index=False)