In [6]:
import cv2
import numpy as np
from sklearn.cluster import DBSCAN, OPTICS
from sklearn.cluster import KMeans


# matches we have to get to identify as match on image
MIN_MATCH_COUNT = 5

# supposed to be 1
FLANN_INDEX_KDTREE = 1

# threshold
threshold = 0.9


#------------------------------------------------------------#
def close_1(img, points, h, w):
    """
    checks if this is rectangle
    """
    x_array = [points[i] for i in range(1, 8, 2)]
    x_min, x_max = min(x_array), max(x_array)

    y_array = [points[j] for j in range(0, 7, 2)]
    y_min, y_max = min(y_array), max(y_array)

    img2 = img[x_min:x_max, y_min:y_max]
    h_img2, w_img2 = img2.shape
    return np.isclose(h_img2 / h, w_img2 / w, 0.23)
#------------------------------------------------------------#


#------------------------------------------------------------#
def find_and_polly(img, key_points_query, description_query, query, sift, flann):
    """
    this function takes an image and query then uses sift to find key points
    on img and query, then uses fillPoly to delete already found area
    and gives coordinates of cropped area
    can be done recursively with abusing of res (boolType)
    """
    # finding key points and descriptors on img
    key_points_img, descriptors_img = sift.detectAndCompute(img, None)

    # finding matches with knn on img with desc on query    
    matches = flann.knnMatch(description_query, descriptors_img, k=2)

    # creating a list & add into it if distance satisfies conditions
    dist = []
    for i, (m, n) in enumerate(matches):
        if m.distance < threshold * n.distance:
            dist.append(m)

    # list of distances
    dst_pt = [key_points_img[m.trainIdx].pt for m in dist]

    # labels via DBSCAN fitted on distances
    labels = DBSCAN(eps=100, min_samples = 3).fit_predict(dst_pt)
    #labels = OPTICS(max_eps=100).fit_predict(dst_pt)
    
    # let's create a dict with unique labels
    uniq = {}   
    for pos, a in enumerate(labels):
        if not (a in uniq):
            uniq[a] = 1
        else:
            uniq[a] +=1
    #print(labels)

    # max element
    max_el = max(uniq, key=uniq.get)

    # creating array of DMatch
    d_match_array = []
    for n, x in enumerate(labels): 
        if x == max_el:
            d_match_array.append(dist[n])

    # check if less than our min_match we return just img,
    # else we get location of matched key points in both images
    # use transformation matrix to transform the corners of query
    # to corresponding points in img
    if len(d_match_array) > MIN_MATCH_COUNT:
        src_pts = np.float32([key_points_query[m.queryIdx].pt for m in d_match_array]).reshape(-1, 1, 2)
        dst_pts = np.float32([key_points_img[m.trainIdx].pt for m in d_match_array]).reshape(-1, 1, 2)
        
        M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
        matched_mask = mask.ravel().tolist()
        
        h, w = query.shape[0], query.shape[1]
        pts = np.float32([[0, 0], [0, h-1], [w-1, h-1], [w-1, 0]]).reshape(-1, 1, 2)
        dst = cv2.perspectiveTransform(pts, M)
        
        pts_transformed = np.int32(dst).reshape(8).tolist()
        close_1(img, pts_transformed, h, w)

        if not close_1(img, pts_transformed, h, w):
            return [False, img, [-1, -1, -1, -1]]

        img2 = cv2.fillPoly(img, [np.int32(dst)], 255)    # fillPoly to fill the already found area
        bbox = pts_transformed[:2] + pts_transformed[4:6]   

        return [True, img2, bbox]
    else:
        return [False, img, [-1, -1, -1, -1]]
#----------------------------------------------------------------------------#


#------------------------------------------------------------#
def predict_image_1(img: np.ndarray, query: np.ndarray) -> list: 
    # convert to gray
    img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    query = cv2.cvtColor(query, cv2.COLOR_RGB2GRAY)

    # creating a SIFT detector
    sift = cv2.SIFT_create()

    # finding key points and descriptors on query
    key_points_query, description_query = sift.detectAndCompute(query, None)

    # index_params & search_params
    
    index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
    search_params = dict(checks=40)

    # flann matcher
    flann = cv2.FlannBasedMatcher(index_params, search_params)
    
    # recursively use find_and_polly to find all matches on image and crop them until we don't have matches
    img_true, new_img, bbox = find_and_polly(img, key_points_query, description_query, query, sift, flann)
    list_of_bboxes = []
    while img_true:
        list_of_bboxes.append(bbox)
        img_true, new_img, bbox = find_and_polly(new_img, key_points_query, description_query, query, sift, flann)

    return list_of_bboxes
#------------------------------------------------------------#

In [19]:
import cv2
import numpy as np
from sklearn.cluster import DBSCAN
from sklearn.cluster import KMeans
from collections import Counter

# matches we have to get to identify as match on image
MIN_MATCH_COUNT = 7

FLANN_INDEX_KDTREE = 1
threshold = 0.9


def in_rect(points, h, w):
    x, y = points[1:8:2], points[:7:2]
    x_min, x_max = min(x), max(x)
    y_min, y_max = min(y), max(y)

    h_2, w_2 = x_max - x_min, y_max - y_min
    return np.isclose(h_2 / h, w_2 / w, 0.2)


def search(img, query):
    sift = cv2.SIFT_create()
    key_points_img, descriptors_img = sift.detectAndCompute(img, None)
    key_points_query, description_query = sift.detectAndCompute(query, None)

    index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
    search_params = dict(checks=40)

    flann = cv2.FlannBasedMatcher(index_params, search_params)
    matches = flann.knnMatch(description_query, descriptors_img, k=2)

    good = []
    for i, (m, n) in enumerate(matches):
        if m.distance < threshold * n.distance:
            good.append(m)

    dist = [key_points_img[m.trainIdx].pt for m in good]
    labels = DBSCAN(eps=110, min_samples=3).fit_predict(dist)
    labels_cnt = Counter(labels)
    # print(labels_cnt)
    most_frequent = max(labels_cnt, key=labels_cnt.get)
    cluster = [good[n] for n, label in enumerate(labels) if label == most_frequent]
    if len(cluster) > MIN_MATCH_COUNT:
        src_pts = np.float32([key_points_query[m.queryIdx].pt for m in cluster]).reshape(-1, 1, 2)
        dst_pts = np.float32([key_points_img[m.trainIdx].pt for m in cluster]).reshape(-1, 1, 2)

        M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)

        h, w = query.shape
        pts = np.float32([[0, 0], [0, h - 1], [w - 1, h - 1], [w - 1, 0]]).reshape(-1, 1, 2)
        dst = cv2.perspectiveTransform(pts, M)

        pts_transformed = np.int32(dst).reshape(8).tolist()
        if not in_rect(pts_transformed, h, w):
            return False, img, None

        img2 = cv2.fillPoly(img, [np.int32(dst)], 0)
        box = pts_transformed[:2] + pts_transformed[4:6]

        return True, img2, box
    else:
        return False, img, None


def predict_image(img: np.ndarray, query: np.ndarray) -> list:
    img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    query = cv2.cvtColor(query, cv2.COLOR_RGB2GRAY)
    new_img = img.copy()

    flag, boxes = True, []
    while flag:
        flag, new_img, box = search(new_img, query)
        if box: boxes.append(box)

    return boxes


In [20]:
img_1 = cv2.imread('train/train_0.jpg')
img_2 = cv2.imread('train/template_0_0.jpg')
predict_image(img_1, img_2)

[(0.1887796613905165,
  0.46721563339233396,
  0.08911378648546008,
  0.10624017715454101),
 (0.18798156314425998,
  0.6155833244323731,
  0.07697185940212674,
  0.10519037246704102),
 (0.15862456427680122,
  0.7457643508911133,
  0.09892429775661893,
  0.11660776138305665),
 (0.1805374993218316,
  0.3056957483291626,
  0.08909077114529079,
  0.10990607738494873),
 (0.2789219538370768,
  0.4694417953491211,
  0.08449501461452907,
  0.10391902923583984)]

In [21]:
predict_image_1(img_1, img_2)

[[135, 598, 200, 734],
 [134, 786, 190, 919],
 [129, 391, 194, 535],
 [116, 958, 187, 1112],
 [200, 600, 261, 735],
 [190, 952, 259, 1107],
 [117, 1159, 182, 1303],
 [185, 154, 251, 319],
 [196, 394, 256, 537],
 [194, 1158, 262, 1291]]