In [None]:
import os
import hashlib
from collections import Counter

import numpy as np
import matplotlib.pyplot as plt
from skimage.util import montage
import cv2
from cv2 import img_hash

from sdcdup.utils import tilebox_corners
from sdcdup.utils import overlap_tag_pairs
from sdcdup.utils import generate_overlap_tag_slices
from sdcdup.utils import boundingbox_corners
from sdcdup.utils import channel_shift
from sdcdup.utils import load_duplicate_truth

from test_friend_circles import SDCImageContainer
from test_friend_circles import load_image_overlap_properties

%matplotlib inline
%reload_ext autoreload
%autoreload 2

RED = (244, 67, 54)  #F44336 
GREEN = (76, 175, 80)  #4CAF50 
LIGHT_BLUE = (3, 169, 244)  #03A9F4

SMALL_SIZE = 10
MEDIUM_SIZE = 12
BIGGER_SIZE = 16
BIGGEST_SIZE = 20
plt.rc('font', size=BIGGEST_SIZE)         # controls default text sizes
plt.rc('axes', titlesize=BIGGEST_SIZE)    # fontsize of the axes title
plt.rc('axes', labelsize=BIGGEST_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGEST_SIZE)  # fontsize of the figure title

montage_rgb = lambda x: np.stack([montage(x[:, :, :, i]) for i in range(x.shape[3])], -1)
montage_pad = lambda x, *args, **kwargs: montage(x, padding_width=10, *args, **kwargs)

train_image_dir = 'data/raw/train_768/'
image_md5hash_grids_file = 'data/interim/image_md5hash_grids.pkl'
image_bm0hash_grids_file = 'data/interim/image_bm0hash_grids.pkl'
image_cm0hash_grids_file = 'data/interim/image_cm0hash_grids.pkl'
image_greycop_grids_file = 'data/interim/image_greycop_grids.pkl'
image_entropy_grids_file = 'data/interim/image_entropy_grids.pkl'
image_issolid_grids_file = 'data/interim/image_issolid_grids.pkl'
image_shipcnt_grids_file = 'data/interim/image_shipcnt_grids.pkl'

overlap_tag_slices = generate_overlap_tag_slices()

bbox_color_map = {'red': RED, 'blue': LIGHT_BLUE, 'green': GREEN}

dtick = 256
n_ticks = 768 // dtick + 1
ticks = [i * dtick for i in range(n_ticks)]

In [None]:
class ImgMod:
    """
    Reads a single image to be modified by hls.
    """

    def __init__(self, filename):
        self.filename = filename
        self.img_id = filename.split('/')[-1]

        self._hls_chan = None
        self._hls_gain = None

        self._parent_bgr = None
        self._parent_hls = None
        self._parent_rgb = None
        self._cv2_hls = None
        self._cv2_bgr = None
        self._cv2_rgb = None

    def channel_shift(self, chan, gain):
        self._hls_chan = chan
        self._hls_gain = gain
        self._cv2_hls = None
        return self.cv2_rgb
    
    def scale(self, minval, maxval):
        m = 255.0 * (maxval - minval)
        res = m * (self.parent_bgr - minval)
        return np.around(res).astype(np.uint8)
    
    @property
    def shape(self):
        return self.parent_bgr.shape
    
    @property
    def parent_bgr(self):
        if self._parent_bgr is None:
            self._parent_bgr = cv2.imread(self.filename)
        return self._parent_bgr

    @property
    def parent_hls(self):
        if self._parent_hls is None:
            self._parent_hls = self.to_hls(self.parent_bgr)
        return self._parent_hls

    @property
    def parent_rgb(self):
        if self._parent_rgb is None:
            self._parent_rgb = self.to_rgb(self.parent_bgr)
        return self._parent_rgb

    @property
    def cv2_hls(self):
        if self._cv2_hls is None:
            if self._hls_gain == None:
                self._cv2_hls = self.parent_hls
            else:
                self._cv2_hls = channel_shift(self.parent_hls, self._hls_chan, self._hls_gain)
        return self._cv2_hls

    @property
    def cv2_bgr(self):
        if self._cv2_bgr is None:
            self._cv2_bgr = self.to_bgr(self.cv2_hls)
        return self._cv2_bgr

    @property
    def cv2_rgb(self):
        if self._cv2_rgb is None:
            self._cv2_rgb = self.to_rgb(self.cv2_bgr)
        return self._cv2_rgb

    def to_hls(self, bgr):
        return cv2.cvtColor(bgr, cv2.COLOR_BGR2HLS_FULL)

    def to_bgr(self, hls):
        return cv2.cvtColor(hls, cv2.COLOR_HLS2BGR_FULL)

    def to_rgb(self, bgr):
        return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)

In [None]:
def plot_image_grid(img_ids, ncols, nrows):
    fig, ax = plt.subplots(nrows, ncols, figsize=(ncols*4, nrows*4))
    for i, img_id in enumerate(img_ids):
        img = cv2.cvtColor(cv2.imread(os.path.join(train_image_dir, img_id)), cv2.COLOR_BGR2RGB)
        ax[i // ncols, i % ncols].imshow(img)
        ax[i // ncols, i % ncols].set_title(img_id)
        ax[i // ncols, i % ncols].set_xticks(ticks)
        ax[i // ncols, i % ncols].set_yticks(ticks)
    plt.tight_layout()

In [None]:
sdcic = SDCImageContainer()
sdcic.preprocess_image_properties(
    image_md5hash_grids_file,
    image_bm0hash_grids_file,
    image_cm0hash_grids_file,
    image_greycop_grids_file,
    image_entropy_grids_file,
    image_issolid_grids_file)
sdcic.preprocess_label_properties(
    image_shipcnt_grids_file)

In [None]:
n_matching_tiles_list = [9, 6, 4, 3, 2, 1]
overlap_image_maps = load_image_overlap_properties(n_matching_tiles_list)
print(len(overlap_image_maps))

In [None]:
dup_truth = load_duplicate_truth()
print(len(dup_truth))

## Check to see how many exact duplicate tiles we have.

In [None]:
dup_tiles = []
dup_tile_hashes = {}
dup_tile_counts = Counter()
dup_image_samples = []
for img_id, tile_hashes in sdcic.tile_md5hash_grids.items():
    c0 = Counter(tile_hashes)
    for tile_hash, c in c0.items():
        if c == 1:
            continue
        tile_indexes = np.where(tile_hashes == tile_hash)[0]
        
        if tile_hash not in dup_tile_hashes:
            dup_image_samples.append((img_id, tile_hash, tile_indexes))
            dup_tile_hashes[tile_hash] = {}

        for idx in tile_indexes:
            if img_id not in dup_tile_hashes[tile_hash]:
                dup_tile_hashes[tile_hash][img_id] = []
            if idx not in dup_tile_hashes[tile_hash][img_id]:
                dup_tile_hashes[tile_hash][img_id].append(idx)
                dup_tile_counts[tile_hash] += 1

dup_hashes = tuple([key for key in dup_tile_hashes])

for img_id, tile_hashes in sdcic.tile_md5hash_grids.items():
    c0 = Counter(tile_hashes)
    for tile_hash, c in c0.items():
        if c > 1:
            continue
        if tile_hash not in dup_hashes:
            continue
        tile_indexes = np.where(tile_hashes == tile_hash)[0]
        
        for idx in tile_indexes:
            if img_id not in dup_tile_hashes[tile_hash]:
                dup_tile_hashes[tile_hash][img_id] = []
            if idx not in dup_tile_hashes[tile_hash][img_id]:
                dup_tile_hashes[tile_hash][img_id].append(idx)
                dup_tile_counts[tile_hash] += 1
                
dup_image_counts = {key: len(val) for key, val in dup_tile_hashes.items()}

dup_tile_counts

In [None]:
bbox_thickness = 5   
ncols = 3
nrows = 2
assert ncols * nrows == len(dup_image_samples)

fig, ax = plt.subplots(nrows, ncols, figsize=(16, 10))
for i, (img_id, tile_hash, tile_indexes) in enumerate(dup_image_samples):
    img = cv2.cvtColor(sdcic.get_img(img_id), cv2.COLOR_BGR2RGB)
        
    for ii, idx in enumerate(tile_indexes):
        if ii == 0:
            tile = sdcic.get_tile(img, idx)
            print(img_id, tile_hash, idx)
            print(tile[0, 0], tile[0, -1])
            print(tile[-1, 0], tile[-1, -1])

    for idx in tile_indexes:
        bbox_pt1, bbox_pt2 = tilebox_corners[idx]
        cv2.rectangle(img, tuple(bbox_pt1), tuple(bbox_pt2), bbox_color_map['red'], bbox_thickness)

    ax[i // ncols, i % ncols].imshow(img)
    ax[i // ncols, i % ncols].set_title(img_id)
    ax[i // ncols, i % ncols].set_xticks(ticks)
    ax[i // ncols, i % ncols].set_yticks(ticks)

plt.tight_layout()

In [None]:
black_tile = np.zeros((256, 256, 3), dtype=np.uint8)
white_tile = black_tile + 255
blue_tile = np.copy(black_tile)
blue_tile[:, :, 0] = 255
red_tile = np.copy(black_tile)
red_tile[:, :, 2] = 255
color_tiles = [black_tile, white_tile, blue_tile, red_tile]
for color_tile in color_tiles:
    print(img_hash.blockMeanHash(color_tile, mode=0)[0])
    print(hashlib.md5(color_tile.tobytes()).hexdigest())

In [None]:
# mostly black images
img_ids = [
    '03ffa7680.jpg', '8d5521663.jpg', '5a70ef013.jpg', '9a2f9d347.jpg', '37a912dca.jpg', 
    '4add7aa1d.jpg', '3db3ef7cc.jpg', '73fec0637.jpg', '7df214d98.jpg', 'c2955cd21.jpg', 
    'de018b2a8.jpg', '8ce769141.jpg', 'fc0e22a0a.jpg', '770c46cd4.jpg', 'd6e432b79.jpg', 
    'd5d1b6fb8.jpg', '0e4d7dd93.jpg', '9ddeed533.jpg', 'addc11de0.jpg', '65418dfe4.jpg', 
    '119d6a3d6.jpg', '1b287c905.jpg', 'b264b0f96.jpg', '996f92939.jpg', 'e5c3b1f59.jpg']
plot_image_grid(img_ids, 5, 5)

In [None]:
def draw_image_pair(
    img1_id,img2_id, 
    img1_overlap_tag="08", 
    shift_brightness=False, 
    plot_shift_wrap=True, 
    plot_scores=False,
    bbox_color=None, 
    save=False):
    
    if not plot_shift_wrap:
        # Just in case...
        plot_scores = False

    if plot_scores:
        scores = overlap_image_maps[(img1_id, img2_id)][img1_overlap_tag]
    imgmod1 = ImgMod(os.path.join(train_image_dir, img1_id))
    imgmod2 = ImgMod(os.path.join(train_image_dir, img2_id))
    
    slice1 = overlap_tag_slices[img1_overlap_tag]
    slice2 = overlap_tag_slices[overlap_tag_pairs[img1_overlap_tag]]

    m12 = np.median(np.vstack([imgmod1.parent_rgb[slice1], imgmod2.parent_rgb[slice2]]), axis=(0, 1), keepdims=True).astype(np.uint8)
    img1_drop = imgmod1.parent_rgb - m12
    img2_drop = imgmod2.parent_rgb - m12
    
    if shift_brightness:
        brightness_level = -100 if np.sum(m12) >= 384 else 100
        img1 = imgmod1.channel_shift('L', brightness_level)
        img2 = imgmod2.channel_shift('L', brightness_level)
    else:
        img1 = imgmod1.parent_rgb
        img2 = imgmod2.parent_rgb
    
    if bbox_color:
        bbox_color = bbox_color_map[bbox_color]
    else:
        bbox_color = LIGHT_BLUE
        if (img1_id, img2_id, img1_overlap_tag) in dup_truth:
            bbox_color = GREEN if dup_truth[(img1_id, img2_id, img1_overlap_tag)] else RED
        
    bbox_thickness = 4
    offset = (bbox_thickness // 2) + 1
    offset_array = np.array([[offset], [-offset]])
    img1_bbox_pt1, img1_bbox_pt2 = boundingbox_corners[img1_overlap_tag] + offset_array
    img2_bbox_pt1, img2_bbox_pt2 = boundingbox_corners[overlap_tag_pairs[img1_overlap_tag]] + offset_array
    
    img1[slice1] = imgmod1.parent_rgb[slice1]
    img2[slice2] = imgmod2.parent_rgb[slice2]
    cv2.rectangle(img1, tuple(img1_bbox_pt1), tuple(img1_bbox_pt2), bbox_color, bbox_thickness)
    cv2.rectangle(img2, tuple(img2_bbox_pt1), tuple(img2_bbox_pt2), bbox_color, bbox_thickness)

    if plot_shift_wrap:
        fig, ((ax00, ax01), (ax10, ax11)) = plt.subplots(2, 2, figsize=(16, 16))    
    else:
        fig, (ax00, ax01) = plt.subplots(1, 2, figsize=(16, 8))
    
    ax00.imshow(img1)
    ax00.set_xticks(ticks)
    ax00.set_yticks(ticks)
    ax00.set_title(f'{img1_id}')

    ax01.imshow(img2)
    ax01.set_xticks(ticks)
    ax01.set_yticks(ticks)
    ax01.set_title(f'{img2_id}')
    
    if plot_shift_wrap:
        img1[slice1] = img1_drop[slice1]
        img2[slice2] = img2_drop[slice2]
        cv2.rectangle(img1, tuple(img1_bbox_pt1), tuple(img1_bbox_pt2), bbox_color, bbox_thickness)
        cv2.rectangle(img2, tuple(img2_bbox_pt1), tuple(img2_bbox_pt2), bbox_color, bbox_thickness)

        ax10.imshow(img1)
        ax10.set_xticks(ticks)
        ax10.set_yticks(ticks)
        if plot_scores:
            ax10.set_title(f'cor: {np.min(scores.cor):7.5f} {np.max(scores.cor):7.5f}')

        ax11.imshow(img2)
        ax11.set_xticks(ticks)
        ax11.set_yticks(ticks)
        if plot_scores:
            ax11.set_title(f'enp: {np.min(scores.enp):5.3f} {np.max(scores.enp):5.3f} {max(scores.pix)}')
    
    if save:
        filename = os.path.join('temp', f"{img1_id}_{img2_id}_{img1_overlap_tag}.jpg")
        if os.path.exists(filename):
            print(f"{filename} already exists.")
        else:
            fig.savefig(filename)
            print(f"{filename} saved.")

def draw_wrap_pair(
    img1_id, img2_id, 
    img1_overlap_tag="08", 
    shift_brightness=False, 
    bbox_color=None, 
    save=False):
    
    imgmod1 = ImgMod(os.path.join(train_image_dir, img1_id))
    imgmod2 = ImgMod(os.path.join(train_image_dir, img2_id))
    
    slice1 = overlap_tag_slices[img1_overlap_tag]
    slice2 = overlap_tag_slices[overlap_tag_pairs[img1_overlap_tag]]

    m12 = np.median(np.vstack([imgmod1.parent_rgb[slice1], imgmod2.parent_rgb[slice2]]), axis=(0, 1), keepdims=True).astype(np.uint8)
    img1_drop = imgmod1.parent_rgb - m12
    img2_drop = imgmod2.parent_rgb - m12
    
    if shift_brightness:
        brightness_level = -100 if np.sum(m12) >= 384 else 100
        img1 = imgmod1.channel_shift('L', brightness_level)
        img2 = imgmod2.channel_shift('L', brightness_level)
    else:
        img1 = imgmod1.parent_rgb
        img2 = imgmod2.parent_rgb
    
    if bbox_color:
        bbox_color = bbox_color_map[bbox_color]
    else:
        bbox_color = LIGHT_BLUE
        if (img1_id, img2_id, img1_overlap_tag) in dup_truth:
            bbox_color = GREEN if dup_truth[(img1_id, img2_id, img1_overlap_tag)] else RED
        
    bbox_thickness = 4
    offset = (bbox_thickness // 2) + 1
    offset_array = np.array([[offset], [-offset]])
    img1_bbox_pt1, img1_bbox_pt2 = boundingbox_corners[img1_overlap_tag] + offset_array
    img2_bbox_pt1, img2_bbox_pt2 = boundingbox_corners[overlap_tag_pairs[img1_overlap_tag]] + offset_array
    
    fig, (ax10, ax11) = plt.subplots(1, 2, figsize=(16, 8))
    
    img1[slice1] = img1_drop[slice1]
    img2[slice2] = img2_drop[slice2]
    cv2.rectangle(img1, tuple(img1_bbox_pt1), tuple(img1_bbox_pt2), bbox_color, bbox_thickness)
    cv2.rectangle(img2, tuple(img2_bbox_pt1), tuple(img2_bbox_pt2), bbox_color, bbox_thickness)

    ax10.imshow(img1)
    ax10.set_xticks(ticks)
    ax10.set_yticks(ticks)
    ax10.set_title(f'{img1_id}')

    ax11.imshow(img2)
    ax11.set_xticks(ticks)
    ax11.set_yticks(ticks)
    ax11.set_title(f'{img2_id}')
    
    if save:
        filename = os.path.join('temp', f"{img1_id}_{img2_id}_{img1_overlap_tag}.jpg")
        if os.path.exists(filename):
            print(f"{filename} already exists.")
        else:
            fig.savefig(filename)
            print(f"{filename} saved.")


In [None]:
draw_image_pair('46b87e21c.jpg', 'f881c203f.jpg', '04', bbox_color='red')

In [None]:
draw_image_pair('46b87e21c.jpg', 'f881c203f.jpg', '07')

In [None]:
draw_image_pair('9b34f2f64.jpg', 'e8b058856.jpg', '05', bbox_color='red')

In [None]:
draw_image_pair('9b34f2f64.jpg', 'e8b058856.jpg', '18')

In [None]:
draw_image_pair('356f4c539.jpg', '6dd7430f6.jpg', '02')

In [None]:
draw_image_pair('b8ce38df4.jpg', 'ddfc36407.jpg')

In [None]:
draw_image_pair('8a0542232.jpg', 'ddfc36407.jpg')

In [None]:
draw_image_pair('8a0542232.jpg', 'ddfc36407.jpg', '68')

In [None]:
draw_image_pair('0ef6cd331.jpg', '2095da0cb.jpg')

In [None]:
draw_image_pair('0efcd3f26.jpg', '89a2baf91.jpg', plot_shift_wrap=False)

In [None]:
draw_image_pair('2c09a2423.jpg', 'b4eba96e8.jpg', plot_shift_wrap=False)

In [None]:
draw_image_pair('2556bfc6c.jpg', '2c09a2423.jpg', '28', plot_scores=True)

In [None]:
draw_image_pair('2556bfc6c.jpg', 'd3474ec95.jpg', plot_shift_wrap=False)

In [None]:
draw_image_pair('2b6c7fd55.jpg', 'c5d9bc753.jpg', '15')

In [None]:
draw_image_pair('536356d11.jpg', '783d9495a.jpg', '15', plot_scores=True)

In [None]:
draw_image_pair('536356d11.jpg', '88c2acaf8.jpg', '15', plot_scores=True)

In [None]:
draw_image_pair('861367193.jpg', 'e6e729afa.jpg', '03')

In [None]:
draw_image_pair('385df9573.jpg', '813a4728e.jpg', '58')

In [None]:
draw_image_pair('03a5fd8d2.jpg', '676f4cfd0.jpg', plot_scores=True)

In [None]:
img_id, pix_thresh = 'd049cb0be.jpg', 5000
img_id, pix_thresh = 'c5d9bc753.jpg', 10000
img_id, pix_thresh = '2556bfc6c.jpg', 5000000
img_id, pix_thresh = 'd3474ec95.jpg', 5000000
pset = set()
overlay_set = set()
for (img1_id, img2_id), overlap_map in overlap_image_maps.items():
    if img_id in (img1_id, img2_id):
        assert len(overlap_map) == 1
        for img1_overlap_tag, scores in overlap_map.items():
#             if np.max(scores.shp) == 0:
#                 break
            if np.max(scores.pix) > pix_thresh:
                break
            print(img1_id, img2_id, img1_overlap_tag)
#             print(scores.pix)
#             print(scores.px0)
#             print(scores.shp)
        else:
            overlay_set.add((img1_id, img2_id, img1_overlap_tag))
            pset.add(img1_id)
            pset.add(img2_id)

# for p in sorted(pset):
#     print(p)

# plot_image_grid(sorted(pset), 4, 4)

In [None]:
for i, overlay in enumerate(sorted(overlay_set)):
    pixel_scores = sdcic.gen_pixel_scores(*overlay)
    if np.max(pixel_scores) < 5000:
        print((overlay[0], overlay[1]))

In [None]:
overlap_group1 = [
    '3e98c83f7.jpg', 
    'b356b1f4a.jpg', 
    'd42dcdc8c.jpg', 
    '861367193.jpg', 
    '536356d11.jpg', 
    '9b82e7a76.jpg',
    '783d9495a.jpg', 
    'cc29cb437.jpg',
    '385df9573.jpg']

overlap_group2 = [
    'd0e99b467.jpg', 
    'f0d46bbd8.jpg', 
    'd049cb0be.jpg', 
    '813a4728e.jpg',
    'a4e6f04a8.jpg', 
    'e80ae5e73.jpg', 
    '88c2acaf8.jpg', 
    '30d3278a2.jpg', 
    'e6e729afa.jpg']


In [None]:
plot_image_grid(overlap_group1, 3, 3)

In [None]:
plot_image_grid(overlap_group2, 3, 3)

In [None]:
image_pairs = [(g1, g2) for g1, g2 in zip(overlap_group1, overlap_group2)]

In [None]:
plot_image_grid(np.array(image_pairs).flatten(), 6, 3)

In [None]:
pixel_scores = []
for img1_id, img2_id in image_pairs:
    pixel_scores.append(sdcic.gen_pixel_scores(img1_id, img2_id, "08"))

print(np.asarray(pixel_scores))

In [None]:
print(np.asarray(pixel_scores).reshape((9, 3, 3)))

In [None]:
for img1_id, img2_id in image_pairs:
    draw_wrap_pair(img1_id, img2_id)