In [None]:
import json
import os
from multiprocessing import Pool

import cairosvg
import imagehash
from PIL import Image
from tqdm.contrib.concurrent import process_map


def get_hash(icon, png_size=128, hash_size=16):
    # ensure png exists
    if not os.path.exists(icon['png_path']):
        cairosvg.svg2png(url=icon['path'],
                         write_to=icon['png_path'],
                         negate_colors=True,
                         output_width=png_size,
                         output_height=png_size)

    icon['hash'] = imagehash.phash(Image.open(icon['png_path']),
                                   hash_size=hash_size)
    return icon

def load_icons(base_path, sets=None):
    icons = []

    svg_base = os.path.join(base_path, 'svg')
    png_base = os.path.join(base_path, 'png')

    for set_id in os.listdir(svg_base) if sets is None else sets:
        for icon_name in os.listdir(os.path.join(svg_base, set_id)):
            icons.append({
                "path":
                os.path.join(svg_base, set_id, icon_name),
                "png_path":
                os.path.join(png_base, set_id,
                             icon_name.replace('.svg', '.png')),
                "name":
                icon_name.replace('.svg', '')
            })
        # ensure png dir exists
        os.makedirs(os.path.join(png_base, set_id), exist_ok=True)
    return icons

class JSONEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, imagehash.ImageHash):
            return str(obj)
        return json.JSONEncoder.default(self, obj)


if __name__ == '__main__':
    # filenames = get_icon_paths('./tmp/svg', sets=['go', 'tb', 'ai', 'fa'])
    icons = load_icons('./tmp')

    results = process_map(get_hash, icons, chunksize=100)

    with open('tmp/results.json', 'w') as f:
        # encoder
        json.dump(results, f, ensure_ascii=False, cls=JSONEncoder)

In [None]:
results = json.load(open('tmp/results.json'))

In [None]:
target = results[results.index(
    next(x for x in results if x['name'] == 'GoAlert'))]


In [None]:
for r in results:
    r['restored_hash'] = imagehash.hex_to_hash(r['hash'])

In [None]:
res = results.copy()

res.sort(key=lambda x: x['restored_hash'] - target['restored_hash'])


In [None]:
[x['name'] for x in res[:10]]

In [None]:
# find equal hashes
hashes = {}
for result in results:
    if result['hash'] in hashes:
        hashes[result['hash']].append(result['path'])
    else:
        hashes[result['hash']] = [result['path']]

# find clusters
clusters = []
for hash_, filenames in hashes.items():
    if len(filenames) > 1:
        clusters.append(filenames)

In [23]:
# display each cluster in a row
from IPython.display import display, HTML

def path_to_filename(path):
    return os.path.basename(path).replace('.svg', '')

def display_cluster(paths, title):
    img = lambda path: f'<div style="display:flex; flex-direction:column; align-items:center; justify-content:center;"><img src="{path.replace("svg", "png")}" style="height: 64px; width: 64px" title="{path_to_filename(path)}">{path_to_filename(path)}</div>'

    tmpl = f'<h2>{title}</h2><div style="display: flex; flex-direction: row; gap: 8px">{"".join([img(p) for p in paths])}</div>'
    display(HTML(tmpl))

# for i, cluster in enumerate(clusters):
#     display_cluster(cluster, title=f'Cluster {i} ({len(cluster)} icons)')

In [25]:
display_cluster([x['png_path'] for x in res[:10]], title='Top 10')

In [None]:
sorted_results = sorted(results, key=lambda x: x['hash'])

for i in range(0, len(sorted_results[:50]), 10):
    display_cluster([r['path'] for r in sorted_results[i:i+10]], title=f'Cluster {i} ({len(cluster)} icons)')

In [None]:
def search_similar(hash, hashes, threshold=5):
    similar = []
    for h, paths in hashes.items():
        if hash == h:
            continue
        if hash - h <= threshold:
            similar.append(paths)
    return similar