In [None]:
import os
from collections import defaultdict

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook
from skimage.util import montage
import cv2

from sdcdup.utils import get_project_root
from sdcdup.features.image_features import SDCImageContainer

%load_ext dotenv
%dotenv
%matplotlib inline
%reload_ext autoreload
%autoreload 2

RED = (244, 67, 54)  #F44336 
GREEN = (76, 175, 80)  #4CAF50 
LIGHT_BLUE = (3, 169, 244)  #03A9F4

SMALL_SIZE = 10
MEDIUM_SIZE = 12
BIGGER_SIZE = 16
BIGGEST_SIZE = 20
plt.rc('font', size=BIGGEST_SIZE)         # controls default text sizes
plt.rc('axes', titlesize=BIGGEST_SIZE)    # fontsize of the axes title
plt.rc('axes', labelsize=BIGGEST_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGEST_SIZE)  # fontsize of the figure title

montage_rgb = lambda x: np.stack([montage(x[:, :, :, i]) for i in range(x.shape[3])], -1)

project_root = get_project_root()
train_image_dir = os.path.join(project_root, os.getenv('RAW_DATA_DIR'), 'train_768')

In [None]:
sdcic = SDCImageContainer()
sdcic.preprocess_image_properties()

## Find overlapping images using hashlib
Update: The values between two supposedly exact 256x256 crops are not always exact (See below).

In [None]:

md5hash_dict = defaultdict(list)
img_ids = os.listdir(train_image_dir)

for img_id in tqdm_notebook(img_ids):
    for h in sdcic.tile_md5hash_grids[img_id]:
        md5hash_dict[h].append(img_id)

dup_counts_dict = defaultdict(int)
for key, dups in md5hash_dict.items():
    dup_counts_dict[len(dups)] += 1

sorted_bin_sizes = sorted(dup_counts_dict.items())
print('n images with k duplicates')
print('(k, n)')
sorted_bin_sizes

In [None]:
batch_size = 9
skip = 365
ii = 0
jj = 0
batch_limit = 9
samples_images = np.empty((batch_limit, 768, 768, 3), dtype=np.float32)

for hash_id, dups in md5hash_dict.items():
    ii += 1
    if len(dups) == batch_size:
        dups0 = list(set(dups))
        img_id = dups0[0]
        idx = sdcic.tile_md5hash_grids[img_id].tolist().index(hash_id)
        print(hash_id, len(dups), ii)
        if jj == min(dup_counts_dict[len(dups)], skip):
            break
        jj += 1

for i, c_img_id in enumerate(dups0[:batch_limit]):
    c_img = cv2.cvtColor(sdcic.get_img(c_img_id), cv2.COLOR_BGR2RGB)
    samples_images[i] = c_img.astype(np.float32) / 255.0

batch_rgb = montage_rgb(samples_images)
print(samples_images.shape)
print(batch_rgb.shape, batch_rgb.dtype)

fig, ax = plt.subplots(1, 1, figsize = (16, 16))
ax.imshow(batch_rgb, vmin=0, vmax=1)
plt.axis('off')
plt.show()

## Find overlapping images using cv2.img_hash

In [None]:
bm0hash_dict = defaultdict(list)
img_ids = os.listdir(train_image_dir)

for img_id in tqdm_notebook(img_ids):
    for h in sdcic.tile_bm0hash_grids[img_id]:
        bm0hash_dict[tuple(h)].append(img_id)  # hex

dup_counts_dict = defaultdict(int)
for key, dups in bm0hash_dict.items():
    dup_counts_dict[len(dups)] += 1

sorted_bin_sizes = sorted(dup_counts_dict.items())
print('n images with k duplicates')
print('(k, n)')
sorted_bin_sizes

In [None]:
batch_size = 18
skip = 5
ii = 0
jj = 0
batch_limit = 9
samples_images = np.empty((batch_limit, 768, 768, 3), dtype=np.float32)

for hash_id, dups in bm0hash_dict.items():
    ii += 1
    if len(dups) == batch_size:
        dups0 = list(set(dups))
        img_id = dups0[0]
        idx = np.where(np.all(sdcic.tile_bm0hash_grids[img_id] == np.asarray(hash_id), axis=1))[0]
        print(hash_id, len(dups), ii)
        if jj == min(dup_counts_dict[len(dups)], skip):
            break
        jj += 1

for i, c_img_id in enumerate(dups0[:batch_limit]):
    c_img = cv2.cvtColor(sdcic.get_img(c_img_id), cv2.COLOR_BGR2RGB)
    samples_images[i] = c_img.astype(np.float32) / 255.0

batch_rgb = montage_rgb(samples_images)
print(samples_images.shape)
print(batch_rgb.shape, batch_rgb.dtype)

fig, ax = plt.subplots(1, 1, figsize = (16, 16))
ax.imshow(batch_rgb, vmin=0, vmax=1)
plt.axis('off')
# plt.savefig(os.path.join(project_root, 'models', BASE_MODEL, f"{train_meta_filebase}_{score_str}_batch_{BATCH_NUM}.jpg"))
plt.show()