In [None]:
import os
import time
import json
import pickle
import hashlib
from collections import defaultdict
from collections import Counter

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook
from skimage.util import montage
import cv2
from cv2 import img_hash

from sdcdup.utils import get_project_root
from sdcdup.utils import overlap_tag_maps
from sdcdup.utils import overlap_tag_pairs
from sdcdup.utils import generate_pair_tag_lookup
from sdcdup.utils import get_hamming_distance
from sdcdup.utils import get_hamming_distance_array
from sdcdup.features import SDCImageContainer

%load_ext dotenv
%dotenv
%matplotlib inline
%reload_ext autoreload
%autoreload 2

SMALL_SIZE = 10
MEDIUM_SIZE = 12
BIGGER_SIZE = 16
BIGGEST_SIZE = 20
plt.rc('font', size=BIGGEST_SIZE)         # controls default text sizes
plt.rc('axes', titlesize=BIGGEST_SIZE)    # fontsize of the axes title
plt.rc('axes', labelsize=BIGGEST_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGEST_SIZE)  # fontsize of the figure title

montage_rgb = lambda x: np.stack([montage(x[:, :, :, i]) for i in range(x.shape[3])], -1)

project_root = get_project_root()
train_image_dir = os.path.join(project_root, os.getenv('RAW_DATA_DIR'), 'train_768')
interim_data_dir = os.path.join(project_root, os.getenv('INTERIM_DATA_DIR'))
pair_tag_lookup = generate_pair_tag_lookup()
ticks = get_ticks()

matches_white = {
    'bmh32': tuple(np.ones(32, dtype='uint8') * 255),
    'bmh96': tuple(np.ones(96, dtype='uint8') * 255)
}

In [None]:
matches_metric = 'bmh32'
matches_threshold = 0.9

sdcic = SDCImageContainer()
sdcic.load_image_metrics(['md5', 'bmh32', 'bmh96'])
img_ids = os.listdir(train_image_dir)

## All have the same blockMeanHash, but they each have different md5 hashes.

In [None]:
black_tile = np.zeros((256, 256, 3), dtype=np.uint8)
white_tile = black_tile + 255
blue_tile = np.copy(black_tile)
blue_tile[:, :, 0] = 255
red_tile = np.copy(black_tile)
red_tile[:, :, 2] = 255
color_tiles = [black_tile, white_tile, blue_tile, red_tile]
for color_tile in color_tiles:
    print(hashlib.md5(color_tile.tobytes()).hexdigest())
    print(img_hash.blockMeanHash(color_tile, mode=0)[0])
    # ...for each color channel
    hash0 = img_hash.blockMeanHash(color_tile[..., 0], mode=0)
    hash1 = img_hash.blockMeanHash(color_tile[..., 1], mode=0)
    hash2 = img_hash.blockMeanHash(color_tile[..., 2], mode=0)
    print(np.hstack([hash0, hash1, hash2])[0])

## Find overlapping images with hashlib.md5
Update: The values between two supposedly exact 256x256 crops are not always exact (See below).

In [None]:
md5hash_dict = defaultdict(list)
for img_id in tqdm_notebook(img_ids):
    for h in sdcic.img_metrics['md5'][img_id]:
        md5hash_dict[h].append(img_id)

In [None]:
dup_counts_dict = defaultdict(int)
for key, dups in md5hash_dict.items():
    dup_counts_dict[len(dups)] += 1

sorted_bin_sizes = sorted(dup_counts_dict.items())
print('n images with k duplicates')
print('(k, n)')
sorted_bin_sizes

In [None]:
batch_size = 9
skip = 365
ii = 0
jj = 0

for hash_id, dups in md5hash_dict.items():
    ii += 1
    if len(dups) == batch_size:
        dups0 = list(set(dups))
        img_id = dups0[0]
        idx = sdcic.img_metrics['md5'][img_id].tolist().index(hash_id)
        print(hash_id, len(dups), ii)
        if jj == min(dup_counts_dict[len(dups)], skip):
            break
        jj += 1

batch_limit = 9
samples_images = np.empty((batch_limit, 768, 768, 3), dtype=np.float32)
for i, c_img_id in enumerate(dups0[:batch_limit]):
    c_img = cv2.cvtColor(sdcic.get_img(c_img_id), cv2.COLOR_BGR2RGB)
    samples_images[i] = c_img.astype(np.float32) / 255.0

batch_rgb = montage_rgb(samples_images)

fig, ax = plt.subplots(1, 1, figsize = (16, 16))
ax.imshow(batch_rgb, vmin=0, vmax=1)
plt.axis('off')
plt.show()

## Find overlapping images with cv2.blockMeanHash 
(Using only exact first matches)

In [None]:
# TODO: Use filter for all overlaps here?
# img_ids = filter_duplicates(img_ids)

bm0hash_dict = defaultdict(set)
for img_id in tqdm_notebook(img_ids):
    for h in sdcic.img_metrics[matches_metric][img_id]:
        bm0hash_dict[tuple(h)].add(img_id)  # hex

bm0hash_dict.pop(matches_white[matches_metric])

sorted_hash_dict = {}
for key, dups in sorted(bm0hash_dict.items(), key=lambda x: len(x[1]), reverse=True):
    if len(dups) > 1:
        sorted_hash_dict[key] = sorted(dups)

In [None]:
dup_counts_dict = defaultdict(int)
for key, dups in bm0hash_dict.items():
    dup_counts_dict[len(dups)] += 1

sorted_bin_sizes = sorted(dup_counts_dict.items())
print('n images with k duplicates')
print('(k, n)')
sorted_bin_sizes

In [None]:
def generate_matches(sorted_hash_dict, sdcic, matches_metric, matches_threshold):

    test_matches = set()
    for hash_id, img_list in tqdm_notebook(sorted_hash_dict.items()):

        hamming_lookup = {img_id: get_hamming_distance_array(sdcic.img_metrics[matches_metric][img_id], np.asarray(hash_id)[None, :], normalize=True, as_score=True) for img_id in img_list}
        
        temp_matches = set()
        for img1_id in img_list:
            tiles1 = [idx for idx, bmhd in enumerate(hamming_lookup[img1_id]) if bmhd >= matches_threshold]
            for img2_id in img_list:
                if img2_id <= img1_id:
                    continue
                tiles2 = [idx for idx, bmhd in enumerate(hamming_lookup[img2_id]) if bmhd >= matches_threshold]

                # create a set of valid overlap_tags based on matching image tiles.
                overlap_tags = set()
                for t1 in tiles1:
                    for t2 in tiles2:
                        overlap_tags.add(pair_tag_lookup.get((t1, t2)))

                for img1_overlap_tag in overlap_tags:
                    temp_matches.add((img1_id, img2_id, img1_overlap_tag))

        test_matches.update(temp_matches)
        
    return test_matches

def generate_matches2(test_matches, sdcic, matches_metric, matches_threshold):
    new_matches = set()
    for match in tqdm_notebook(sorted(test_matches)):
        bmh_scores = sdcic.overlap_scores_config[matches_metric]['func'](*match)
        if min(bmh_scores) < matches_threshold:
            continue
        new_matches.add(tuple(match))
    
    return new_matches

In [None]:
test_matches = generate_matches(sorted_hash_dict, sdcic, matches_metric, matches_threshold)
print(len(test_matches))

In [None]:
matches = generate_matches2(test_matches, sdcic, matches_metric, matches_threshold)
print(len(matches))

In [None]:
matches_file = f'matches_{matches_metric}_{matches_threshold}.csv'
full_matches_file = os.path.join(interim_data_dir, matches_file)
df = pd.DataFrame(sorted(matches))
df.to_csv(full_matches_file, index=False)

In [None]:
batch_size = 18
skip = 5
ii = 0
jj = 0

for hash_id, dups in bm0hash_dict.items():
    ii += 1
    if len(dups) == batch_size:
        dups0 = list(set(dups))
        img_id = dups0[0]
        idx = np.where(np.all(sdcic.img_metrics[matches_metric][img_id] == np.asarray(hash_id), axis=1))[0]
        print(hash_id, len(dups), ii)
        if jj == min(dup_counts_dict[len(dups)], skip):
            break
        jj += 1

batch_limit = 9
samples_images = np.empty((batch_limit, 768, 768, 3), dtype=np.float32)
for i, c_img_id in enumerate(dups0[:batch_limit]):
    c_img = cv2.cvtColor(sdcic.get_img(c_img_id), cv2.COLOR_BGR2RGB)
    samples_images[i] = c_img.astype(np.float32) / 255.0

batch_rgb = montage_rgb(samples_images)

fig, ax = plt.subplots(1, 1, figsize = (16, 16))
ax.imshow(batch_rgb, vmin=0, vmax=1)
plt.axis('off')
plt.show()