In [1]:
import os
import zipfile
import numpy as np
from imageio import imread
from scipy.spatial.distance import cdist


class Test:
    def __init__(self, zip_name):
        self.__Z = zipfile.ZipFile(zip_name)
        self.cat_dict = self.get_cat()

    def get_cat(self):
        cat_dict = {}
        for name in self.__Z.NameToInfo:
            if name.count('/') == 1 and not name.endswith('/'):
                run_name, _ = name.split('/')
                pair_names = self.__Z.read(name).splitlines()
                pair_names = [
                    pair_name.decode().split(' ') for pair_name in pair_names
                ]
                test_files, train_files = np.asanyarray(pair_names).T
                cat_dict[run_name] = {'test': test_files, 'train': train_files}
        return cat_dict

    def LoadImgAsPoints(self, fn):
        '''
        Load image file and return coordinates of 'inked' pixels in the binary image
        '''
        I = imread(self.__Z.read(fn))
        I = np.array(I, dtype=bool)
        I = np.logical_not(I)
        (row, col) = I.nonzero()
        D = np.array([row, col]).T.astype(float)
        D = D - D.mean(axis=0)
        return D

    def ModHausdorffDistance(self, itemA, itemB):
        '''
        M.-P. Dubuisson, A. K. Jain (1994). A modified hausdorff distance for object matching.
        International Conference on Pattern Recognition, pp. 566-568.
        计算 similarity between two images

        Input
        ====
        itemA : [n x 2] coordinates of "inked" pixels
        itemB : [m x 2] coordinates of "inked" pixels

        '''
        D = cdist(itemA, itemB)
        mindist_A = D.min(axis=1)
        mindist_B = D.min(axis=0)
        mean_A = np.mean(mindist_A)
        mean_B = np.mean(mindist_B)
        return max(mean_A, mean_B)

    def classification_run(self, cat, ftype='cost'):
        '''
        Compute error rate for one run of one-shot classification

        Input
        ===============
        cat : contains images for a run of one-shot classification
        ftype  : 'cost' if small values from f_cost mean more similar, or 'score' if large values are more similar

        Output
        =============
        perror : percent errors (0 to 100% error)
        '''
        assert ((ftype == 'cost') | (ftype == 'score'))
        train_files = self.cat_dict[cat]['train']
        answers_files = train_files.copy()
        test_files = self.cat_dict[cat]['test']
        ntrain = len(train_files)
        ntest = len(test_files)

        # load the images (and, if needed, extract features)
        train_items = [self.LoadImgAsPoints(f) for f in train_files]
        test_items = [self.LoadImgAsPoints(f) for f in test_files]
        # compute cost matrix
        costM = np.zeros((ntest, ntrain), float)
        for i in range(ntest):
            for c in range(ntrain):
                costM[i, c] = self.ModHausdorffDistance(
                    test_items[i], train_items[c])
        if ftype == 'cost':
            YHAT = np.argmin(costM, axis=1)
        elif ftype == 'score':
            YHAT = np.argmax(costM, axis=1)
        else:
            assert False

        # compute the error rate
        correct = 0.0
        for i in range(ntest):
            if train_files[YHAT[i]] == answers_files[i]:
                correct += 1.0
        pcorrect = 100 * correct / ntest
        perror = 100 - pcorrect
        return perror

    def test(self, ftype='cost'):
        '''
        M.-P. Dubuisson, A. K. Jain (1994). A modified hausdorff distance for object matching.
        International Conference on Pattern Recognition, pp. 566-568.

        Results
        =====
        38.8 percent errors
        '''
        print('One-shot classification demo with Modified Hausdorff Distance')
        perror = []
        for r, cat in enumerate(self.cat_dict):
            perror.append(self.classification_run(cat, ftype))

            print(cat + " (error " + str(perror[r]) + "%)")
        total = np.mean(perror)
        print(" average error " + str(total) + "%")

In [2]:
zip_name = './python/one-shot-classification/all_runs.zip'
ftype = 'cost'

T = Test(zip_name)
T.test(ftype)

One-shot classification demo with Modified Hausdorff Distance
run01 (error 45.0%)
run02 (error 35.0%)
run03 (error 40.0%)
run04 (error 25.0%)
run05 (error 30.0%)
run06 (error 15.0%)
run07 (error 60.0%)
run08 (error 35.0%)
run09 (error 40.0%)
run10 (error 55.0%)
run11 (error 15.0%)
run12 (error 70.0%)
run13 (error 65.0%)
run14 (error 35.0%)
run15 (error 15.0%)
run16 (error 25.0%)
run17 (error 30.0%)
run18 (error 40.0%)
run19 (error 70.0%)
run20 (error 30.0%)
 average error 38.75%
