### Symmetric searchable encyption for reverse image search (only index)

#### Bloom filter implementation

In [None]:
import hmac
from base64 import urlsafe_b64encode
from hashlib import sha512
from math import floor, log
from bitarray import bitarray

def expand(func):
    return [bytes('({}-{})'.format(i+1, c), encoding='ascii') for i, x in enumerate(func) for c in range(1, x+1)]

def comparesets(a, b, m, k, pin, repeats=0):
    bfa = BloomFilter(m, k, pin)
    bfb = BloomFilter(m, k, pin)
    bfa.addall(a)
    bfb.addall(b)
    return bfa.jaccard_distance(bfb)

def comparefunctions(f, g, m, k, pin):
    bff = BloomFilter(m, k, pin)
    bfg = BloomFilter(m, k, pin)
    bff.addall(expand(f))
    bfg.addall(expand(g))
    result = 2*bff.union_encoded_estimate(bfg) - bff.encoded_estimate() - bfg.encoded_estimate()
    return result

class BloomFilter(object):
    def __init__(self, m=1024, k=1, pin=b'0'):
        self.m = m
        self.k = k
        self.pin = pin
        self.slice = self.m // self.k
        self.a = bitarray(m)
        self.a.setall(False)

    def __getitem__(self, i):
        return self.a[i]

    def __setitem__(self, i, v):
        self.a[i] = v

    def __str__(self):
        return self.a.to01()

    def __and__(self, other):
        if self.m != other.m or self.k != other.k:
            raise Exception('Operation error',
                            'Length of bloom filters does not match.')
        c = BloomFilter(self.m, self.k, self.pin)
        # for i in range(self.m):
        #	c.a[i] = self.a[i] & other.a[i]
        c.a = self.a & other.a
        return c

    def __or__(self, other):
        if self.m != other.m or self.k != other.k:
            raise Exception('Operation error',
                            'Length of bloom filters does not match.')
        c = BloomFilter(self.m, self.k, self.pin)
        # for i in range(self.m):
        #	c.a[i] = self.a[i] | other.a[i]
        c.a = self.a | other.a
        return c

    def __xor__(self, other):
        if self.m != other.m or self.k != other.k:
            raise Exception('Operation error',
                            'Length of bloom filters does not match.')
        c = BloomFilter(self.m, self.k, self.pin)
        c.a = self.a ^ other.a
        return c

    def __contains__(self, item):
        base = self.h(item)
        added = self.g(item)
        for i in range(1, self.k + 1):
            index = ((base + i * added) % self.slice) + (i - 1) * self.slice
            if self.a[index] == False:
                return False
        return True
    
    def h(self, x):
        return int.from_bytes(sha512(x).digest(), byteorder='big')
    
    def g(self, x):
        return int.from_bytes(hmac.new(pin, x, sha512).digest(), byteorder='big')

    def additem(self, item):
        base = self.h(item)
        added = self.g(item)
        for i in range(1, self.k + 1):
            index = ((base + i * added) % self.slice) + (i - 1) * self.slice
            self.a[index] = True

    def addall(self, items):
        for item in items:
            self.additem(item)

    def to_base64(self):
        return urlsafe_b64encode(self.a.tobytes())

    def ones(self):
        return self.a.count(True)

    def zeros(self):
        return self.a.count(False)

    def or_ones(self, other):
        return (self | other).ones()

    def or_zeros(self, other):
        return (self | other).zeros()

    def and_ones(self, other):
        return (self & other).ones()

    def and_zeros(self, other):
        return (self & other).zeros()

    def xor_ones(self, other):
        return (self ^ other).ones()

    def xor_zeros(self, other):
        return (self ^ other).zeros()

    def encoded_estimate(self):
        return int(floor(- (self.m / self.k) * log(1 - self.ones() / self.m)))

    def intersection_encoded_estimate(self, other):
        return self.encoded_estimate() + other.encoded_estimate() - self.union_encoded_estimate(other)

    def union_encoded_estimate(self, other):
        return int(floor(- (self.m / self.k) * log(1 - self.or_ones(other) / self.m)))

    def jaccard(self, other):
        return self.intersection_encoded_estimate(other) / self.union_encoded_estimate(other)

    def jaccard_distance(self, other):
        return 1.0 - self.intersection_encoded_estimate(other) / self.union_encoded_estimate(other)

    def tanimoto(self, other):
        return self.and_ones(other) / self.or_ones(other)

#### Database implementation

In [None]:
import glob
from PIL import Image
from math import floor
import numpy as np
import itertools
from tqdm.notebook import tqdm

def emd(a, b):
    d = 0
    s = 0
    for i in range(len(a)):
        d = a[i] - b[i] + d
        s = s + abs(d)
    return s

def emd_bloom(f, g):
    d = f.encoded_estimate() + g.encoded_estimate() -2*f.intersection_encoded_estimate(g)
    return d

def dist(a, b):
    s = 0
    for i in range(len(a)):
        s = abs(a[i] - b[i])
    return s

#We are converting to grayscale
def gs_histogram(p, bins, integral):
    a = np.array(Image.open(p).convert('L')).flatten()
    b = np.histogram(a, bins)[0]
    b = np.floor((b / np.sum(b)) * integral)
    return b.astype(np.int16)

def accumulative_histogram(fl):
    return np.cumsum(fl)
    
def feature_list_rescale(fl, factor):
    return [int(floor(x*factor)) for x in fl]

def build_db(icons):
    db = {}
    for icon in icons:
        db[icon] = gs_feature_list(icon)
    return db

def build_db_bloom(icons, bins, m, k, integral, pin):
    db = {}
    for icon in tqdm(icons):
        b = gs_histogram(icon, bins, integral)
        b = accumulative_histogram(b)
        bf = BloomFilter(m,k,pin)
        bf.addall(expand(b))
        db[icon] = bf
    return db

def search(icon, db, comp, distance):
    fl = feature_list(icon)
    return [k for k,v in db.items() if comp(fl,v) <= distance]

def search_db_bloom(icon, bins, m, k, integral, pin, db, distance):
    b = gs_histogram(icon, bins, integral)
    b = accumulative_histogram(b)
    bf = BloomFilter(m,k,pin)
    bf.addall(expand(b))
    
    return [(k,emd_bloom(bf,v)) for k,v in db.items() if emd_bloom(bf,v) <= distance]

Populate database with images from http://www.vision.caltech.edu/Image_Datasets/Caltech101/

In [None]:
db_size = 1024
all_images = glob.glob('./101_ObjectCategories/*/*.jpg')
icons = all_images[:db_size]
bins = 16
m = 8000
k = 4
factor = 100
pin = b'1010'
db = build_db_bloom(icons, bins, m, k, factor, pin)

Compute pairwise distances in the clear and using Bloom filters.

In [None]:
d_clear = np.zeros((db_size,db_size)).astype(np.int16)
d_bloom = np.zeros((db_size,db_size)).astype(np.int16)

clear_histos = []
for i in tqdm(range(len(icons))):
    clear_histos.append(gs_histogram(icons[i], 16, 100))

for i in tqdm(range(len(icons))):
    for j in range(i,len(icons)):
        d_clear[i,j] = emd(clear_histos[i], clear_histos[j])
d_clear = d_clear + d_clear.T - (d_clear*np.eye(db_size))

for i in tqdm(range(len(icons))):
    for j in range(i,len(icons)):
        d_bloom[i,j] = emd_bloom(db[icons[i]],db[icons[j]])
d_bloom = d_bloom + d_bloom.T - (d_bloom*np.eye(db_size))

In [None]:
mae = np.abs(d_clear - d_bloom).mean()
display(mae)

In [None]:
import matplotlib.pyplot as plt

plt.figure(dpi=120)
plt.imshow(d_clear,cmap='viridis')
plt.colorbar()

In [None]:
plt.figure(dpi=120)
plt.imshow(d_bloom,cmap='viridis')
plt.colorbar()

Single database search

In [None]:
query = 0
print(all_images[0])
print(search_db_bloom(all_images[0], bins, m, k, factor, pin, db, 10))

Search for all images to check for false positives

In [None]:
matches = []
for image in tqdm(all_images):
    matches.append(len(search_db_bloom(image, bins, m, k, factor, pin, db, 10)))

In [None]:
found = (np.argwhere(np.array(matches)>1))
print(found)