In [87]:
import cv2
import numpy as np
from skimage import io
from skimage.transform import rotate, resize, rescale
from skimage.util import img_as_bool, img_as_float, pad, invert, img_as_ubyte
from skimage.morphology import square, closing, dilation
from skimage.color import gray2rgb, rgb2gray
import time
from multiprocessing.dummy import Pool as ThreadPool 
import sys
import re



def crop_to_object(img):
    img = img_as_ubyte(closing(img, square(5)))
    ret,thresh = cv2.threshold(img,20,255,0)
    _, lc, _ = cv2.findContours(thresh,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
    assert len(lc) == 1
    rect = cv2.minAreaRect(lc[0])
    img = rotate(img, rect[2], resize = True)
    img = img_as_ubyte(img)
    ret,thresh = cv2.threshold(img,20,255,0)
    thresh = img_as_ubyte(closing(thresh, square(5)))
    _, lc, _ = cv2.findContours(thresh,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
    assert len(lc) == 1
    rect = cv2.minAreaRect(lc[0])
    box = cv2.boxPoints(rect)
    box = np.int0(box)
    min_y, max_y = min(box[:,1]), max(box[:,1])
    min_x, max_x = min(box[:,0]), max(box[:,0])
    return img[min_y:max_y,min_x:max_x]

def get_contours(img):
    
    img = pad(img,(5,5),'constant', constant_values=(0, 0))
    img = img_as_ubyte(closing(img, square(7)))
    ret,thresh = cv2.threshold(img,10,255,0)
    _, img_c, _ = cv2.findContours(thresh,cv2.RETR_TREE,cv2.CHAIN_APPROX_SIMPLE)
    if len(img_c) > 1:
        print("more than one contour")
        if len(img_c) == 2:
            if len(img_c[0]) > len(img_c[1]):
                return img_c[0]
            else:
                return img_c[1]
        assert len(img_c) <= 2
        #img_d = np.copy(thresh)    
        #img_d = gray2rgb(img_d)
        #cv2.drawContours(img_d, img_c, 0, (0,255,0), 3)
        #io.imshow(img_d)
    return img_c[0]


def rotate_basis(img): #return ok ud
    if img.shape[0]> img.shape[1]:
        img = np.rot90(img).copy()
    img = resize(img,(100,150))
    h,w = img.shape
    upper_z = np.count_nonzero(img[0:int(h/2)])
    bottom_z = np.count_nonzero(img[int(h/2):h])
    if bottom_z < upper_z:
        return np.rot90(img, 2).copy(), img.copy()
    else:
        return img.copy(), np.rot90(img,2).copy()

    
def calc_shape_context_distance(img_1, img_2):
    sd = cv2.createShapeContextDistanceExtractor()
    try:
        d2 = sd.computeDistance(img_1,img_2)
        return d2
    except Exception as e:
        print(e)
        return 100


    
    
def get_img_contours(img_path):
    if img_path not in imgs_rotations:
        img = io.imread(img_path)
        img = crop_to_object(img)
        ok, ud = rotate_basis(img)
        imgs_rotations[img_path] = {"ok":get_contours(ok), "ud":get_contours(ud)}
    return imgs_rotations[img_path]
    
def score_pair(img1_path, img2_path):
    if (img2_path,img1_path) in pair_scores:
        return pair_scores[(img2_path,img1_path)]
    start = time.time()
    l = get_img_contours(img1_path)["ok"]
    r = get_img_contours(img2_path)["ud"]
    score = calc_shape_context_distance(l,r) * -1
    pair_scores[(img1_path,img2_path)] = score
    return score


def score_similarity(pair):
    return (pair,score_pair(pair[0],pair[1]))

def rank_pairs(pairs):
    scores = pool.map(score_similarity, pairs)
    return sorted(scores, key=lambda x: -x[1])[0:5]

pair_scores = {}
imgs_rotations = {}
imgs_objects = {}
pool = ThreadPool(8) 

path = "data/set0/"#sys.argv[1]
N = 6 #sys.argv[2]
img_paths = list(map(lambda x: path + "/" + str(x) + ".png", np.arange(0,N)))
for img_path in img_paths:
    candidates = set(img_paths)
    candidates.remove(img_path)
    possible_pairs = list(itertools.product([img_path],candidates))
    rank = np.array(rank_pairs(possible_pairs))
    rank_f = list(map(lambda x: re.search('(\d+).png',x[1]).group(1), rank[:,0]))
    print(' '.join(rank_f))




  .format(dtypeobj_in, dtypeobj_out))
  warn("The default mode, 'constant', will be changed to 'reflect' in "


5 1 3 4 2
2 3 4 0 5
1 4 3 0 5
4 1 2 5 0
3 1 2 5 0
0 3 4 1 2


In [None]:
img = io.imread("data/set0/3.png")
img = crop_to_object(img)
ok, ud = rotate_basis(img)
print(ok.shape)
io.imshow(ok)

In [None]:
io.imshow(ud)

In [88]:
from skimage import io, img_as_float
import itertools
import numpy as np
import random
from skimage.morphology import square,  dilation
import cv2
from multiprocessing.dummy import Pool as ThreadPool 



    
def score_similarity(pair):
    return (pair,score_pair(pair[0],pair[1]))

def load_correct_pairs(set_path):
    result = []
    file = open(set_path + "correct.txt", mode="r")
    lines = file.readlines()
    for idx, line in enumerate(lines):
        result.append((set_path+ str(idx) + ".png", set_path+ line.strip() + ".png"))
    return np.array(result)


def rank_pairs(pairs):
    scores = pool.map(score_similarity, pairs)
    return sorted(scores, key=lambda x: -x[1])[0:5]


def score_rank(rank, correct_pair):
    for idx, elem in enumerate(rank):
        if np.all(elem[0] == correct_pair):
            return 1/(idx+1)
    return 0


def evaluate_set(set_path):
    print("########### {} ###########".format(set_path))
    correct_pairs = load_correct_pairs(set_path)
    score_sum = 0
    for idx, correct_pair in enumerate(correct_pairs):
        queried = correct_pair[0]
        candidates = set(correct_pairs[:,0])
        candidates.remove(queried)
        possible_pairs = list(itertools.product([queried],candidates))
        rank = rank_pairs(possible_pairs)
        score_for_rank = score_rank(rank,correct_pair)
        score_sum += score_for_rank
        #print("{}/{} score = {} matched = {} correct = {}".format(idx+1, len(correct_pairs), score_for_rank,rank[0],  correct_pair))    
    print("Total scores for set = {}/{}".format(score_sum, len(correct_pairs)))
        
sets_to_evaluate = ["data/set0/","data/set1/","data/set2/","data/set3/","data/set4/","data/set5/","data/set6/","data/set7/","data/set8/"]
#sets_to_evaluate = ["data/set0/"]
pool = ThreadPool(8) 
result = ""
for set_to_eval in sets_to_evaluate:
    start = time.time()
    evaluate_set(set_to_eval)
    print("{} {}[s]".format(set_to_eval, time.time() - start))







########### data/set0/ ###########


  .format(dtypeobj_in, dtypeobj_out))
  warn("The default mode, 'constant', will be changed to 'reflect' in "


Total scores for set = 6.0/6
data/set0/ 0.36753034591674805[s]
########### data/set1/ ###########
Total scores for set = 20.0/20
data/set1/ 2.313107490539551[s]
########### data/set2/ ###########


KeyboardInterrupt: 