In [None]:
import os
import operator
from collections import Counter
from collections import namedtuple
from collections import defaultdict

import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook
from skimage.util import montage
import cv2
import torch

from sdcdup.utils import overlap_tag_pairs
from sdcdup.utils import overlap_tag_maps
from sdcdup.utils import generate_overlap_tag_slices
from sdcdup.utils import generate_tag_pair_lookup
from sdcdup.utils import channel_shift
from sdcdup.utils import load_duplicate_truth
from sdcdup.utils import update_duplicate_truth
from sdcdup.utils import update_tile_cliques

from sdcdup.features.image_features import SDCImageContainer
from sdcdup.features.image_features import load_image_overlap_properties
from sdcdup.models.dupnet import load_checkpoint

%matplotlib inline
%reload_ext autoreload
%autoreload 2

EPS = np.finfo(np.float32).eps

RED = (244, 67, 54) #F44336 
GREEN = (76, 175, 80) #4CAF50 
LIGHT_BLUE = (3, 169, 244) #03A9F4

SMALL_SIZE = 10
MEDIUM_SIZE = 12
BIGGER_SIZE = 16
BIGGEST_SIZE = 20
plt.rc('font', size=BIGGEST_SIZE)         # controls default text sizes
plt.rc('axes', titlesize=BIGGEST_SIZE)    # fontsize of the axes title
plt.rc('axes', labelsize=BIGGEST_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGEST_SIZE)  # fontsize of the figure title

montage_rgb = lambda x: np.stack([montage(x[:, :, :, i]) for i in range(x.shape[3])], -1)
montage_pad = lambda x, *args, **kwargs: montage(x, padding_width=10, *args, **kwargs)
zeros_mask = np.zeros((256*3, 256*3, 1), dtype=np.float32)

train_image_dir = 'data/raw/train_768/'
image_md5hash_grids_file = 'data/interim/image_md5hash_grids.pkl'
image_bm0hash_grids_file = 'data/interim/image_bm0hash_grids.pkl'
image_cm0hash_grids_file = 'data/interim/image_cm0hash_grids.pkl'
image_greycop_grids_file = 'data/interim/image_greycop_grids.pkl'
image_entropy_grids_file = 'data/interim/image_entropy_grids.pkl'
image_issolid_grids_file = 'data/interim/image_issolid_grids.pkl'
image_shipcnt_grids_file = 'data/interim/image_shipcnt_grids.pkl'

overlap_tag_slices = generate_overlap_tag_slices()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class ImgMod:
    """
    Reads a single image to be modified by hls.
    """

    def __init__(self, filename):
        self.filename = filename
        self.img_id = filename.split('/')[-1]

        self._hls_chan = None
        self._hls_gain = None

        self._parent_bgr = None
        self._parent_hls = None
        self._parent_rgb = None
        self._cv2_hls = None
        self._cv2_bgr = None
        self._cv2_rgb = None

    def brightness_shift(self, chan, gain):
        self._hls_chan = chan
        self._hls_gain = gain
        self._cv2_hls = None
        return self.cv2_rgb
    
    def scale(self, minval, maxval):
        m = 255.0 * (maxval - minval)
        res = m * (self.parent_bgr - minval)
        return np.around(res).astype(np.uint8)
    
    @property
    def shape(self):
        return self.parent_bgr.shape
    
    @property
    def parent_bgr(self):
        if self._parent_bgr is None:
            self._parent_bgr = cv2.imread(self.filename)
        return self._parent_bgr

    @property
    def parent_hls(self):
        if self._parent_hls is None:
            self._parent_hls = self.to_hls(self.parent_bgr)
        return self._parent_hls

    @property
    def parent_rgb(self):
        if self._parent_rgb is None:
            self._parent_rgb = self.to_rgb(self.parent_bgr)
        return self._parent_rgb

    @property
    def cv2_hls(self):
        if self._cv2_hls is None:
            if self._hls_gain is None:
                self._cv2_hls = self.parent_hls
            else:
                self._cv2_hls = channel_shift(self.parent_hls, self._hls_chan, self._hls_gain)
        return self._cv2_hls

    @property
    def cv2_bgr(self):
        if self._cv2_bgr is None:
            self._cv2_bgr = self.to_bgr(self.cv2_hls)
        return self._cv2_bgr

    @property
    def cv2_rgb(self):
        if self._cv2_rgb is None:
            self._cv2_rgb = self.to_rgb(self.cv2_bgr)
        return self._cv2_rgb

    def to_hls(self, bgr):
        return cv2.cvtColor(bgr, cv2.COLOR_BGR2HLS_FULL)

    def to_bgr(self, hls):
        return cv2.cvtColor(hls, cv2.COLOR_HLS2BGR_FULL)

    def to_rgb(self, bgr):
        return cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)

In [None]:
sdcic = SDCImageContainer()
sdcic.preprocess_image_properties(
    image_md5hash_grids_file,
    image_bm0hash_grids_file,
    image_cm0hash_grids_file,
    image_greycop_grids_file,
    image_entropy_grids_file,
    image_issolid_grids_file)
sdcic.preprocess_label_properties(
    image_shipcnt_grids_file)

In [None]:
dup_truth = load_duplicate_truth()
print(len(dup_truth))

In [None]:
# score_types = ['bmh', 'cmh', 'con', 'hom', 'eng', 'cor', 'epy', 'enp', 'pix', 'px0', 'shp']
n_matching_tiles_list = [9, 6, 4, 3, 2, 1]
overlap_image_maps = load_image_overlap_properties(n_matching_tiles_list)
print(len(overlap_image_maps))

In [None]:
from torch.utils import data

img_overlap_index_maps = generate_tag_pair_lookup()
TilePairs = namedtuple('TilePairs', 'img1_id img2_id img1_overlap_tag overlap_idx idx1 idx2')

def get_img(img_id):
    return cv2.imread(os.path.join(train_image_dir, img_id))
    

class Dataset(data.Dataset):
    
    """Characterizes a dataset for PyTorch"""
    def __init__(self, tile_pairs, 
                 image_transform=None,
                 in_shape=(6, 256, 256), 
                 out_shape=(1,)):

        """Initialization"""
        self.sz = 256
        self.tile_pairs = tile_pairs
        self.image_transform = image_transform
        self.ij = ((0, 0), (0, 1), (0, 2),
                   (1, 0), (1, 1), (1, 2),
                   (2, 0), (2, 1), (2, 2))
        
        self.in_shape = in_shape
        self.out_shape = out_shape
        
    def __len__(self):
        """Denotes the total number of samples"""
        return len(self.tile_pairs)

    def __getitem__(self, index):
        """Generates one sample of data"""
        tp = self.tile_pairs[index]
        
        img1 = get_img(tp.img1_id)
        img2 = get_img(tp.img2_id)
        
        tile1 = cv2.cvtColor(self.get_tile(img1, *self.ij[tp.idx1]), cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
        tile2 = cv2.cvtColor(self.get_tile(img2, *self.ij[tp.idx2]), cv2.COLOR_BGR2RGB).astype(np.float32) / 255.
        
        X = np.dstack([tile1, tile2])
        X = X.transpose((2, 0, 1))
        X = torch.from_numpy(X)
        return X
    
    def get_tile(self, img, i, j):
        return img[i * self.sz:(i + 1) * self.sz, j * self.sz:(j + 1) * self.sz, :]

In [None]:
def preprocess(x):
    return x.view(-1, 6, 256, 256).to(device)

class WrappedDataLoader:
    def __init__(self, dl, func):
        self.dl = dl
        self.func = func

    def __len__(self):
        return len(self.dl)

    def __iter__(self):
        batches = iter(self.dl)
        for b in batches:
            yield (self.func(b))

In [None]:
tile_pairs = []
for (img1_id, img2_id), overlap_maps in tqdm_notebook(overlap_image_maps.items()):
    for img1_overlap_tag in overlap_maps:
#         if (img1_id, img2_id, img1_overlap_tag) in dup_truth:
#             continue
        for overlap_idx, (idx1, idx2) in enumerate(img_overlap_index_maps[img1_overlap_tag]):
            tile_pairs.append(TilePairs(img1_id, img2_id, img1_overlap_tag, overlap_idx, idx1, idx2))
print(len(tile_pairs))

In [None]:
test_ds = Dataset(tile_pairs)
test_dl = data.DataLoader(test_ds, batch_size=256, num_workers=20)
test_dl = WrappedDataLoader(test_dl, preprocess)
print(len(test_dl))

In [None]:
model = load_checkpoint('models/dup_model.best.pth')
model.cuda()
model.to(device)

In [None]:
model.eval()
with torch.no_grad():
    yprobs0 = [model(xb) for xb in tqdm_notebook(test_dl)]
    yprobs = np.vstack([l.cpu() for l in yprobs0]).reshape(-1)
    print(len(yprobs0), yprobs.shape)

In [None]:
yprobs_c = np.where(np.abs(yprobs - 0.5) < 0.472)[0]
print(yprobs_c.shape)

In [None]:
is_weak_pred = False
weak_preds = []
overlap_cnn_tile_scores = {}
for ii, (tp, yprob) in enumerate(zip(tile_pairs, yprobs)):
    if ii in yprobs_c:
        is_weak_pred = True
    if (tp.img1_id, tp.img2_id) not in overlap_cnn_tile_scores:
        overlap_cnn_tile_scores[(tp.img1_id, tp.img2_id)] = {}
    if tp.img1_overlap_tag not in overlap_cnn_tile_scores[(tp.img1_id, tp.img2_id)]:
        n_overlapping_tiles = len(img_overlap_index_maps[tp.img1_overlap_tag])
        cnn_scores = [None] * n_overlapping_tiles
        overlap_cnn_tile_scores[(tp.img1_id, tp.img2_id)][tp.img1_overlap_tag] = cnn_scores
    overlap_cnn_tile_scores[(tp.img1_id, tp.img2_id)][tp.img1_overlap_tag][tp.overlap_idx] = yprob
    if tp.overlap_idx == n_overlapping_tiles - 1 and is_weak_pred:
        weak_preds.append((tp.img1_id, tp.img2_id, tp.img1_overlap_tag))
        is_weak_pred = False

len(weak_preds)

## Find overlaps with ships

In [None]:
untested_image_pairs_with_ship_masks = []
for (img1_id, img2_id), overlap_maps in tqdm_notebook(overlap_image_maps.items()):
    # TODO: Find out which remaining tile pairs have masks but aren't in dup_truth.

    mask1 = sdcic.tile_shipcnt_grids[img1_id]
    mask2 = sdcic.tile_shipcnt_grids[img2_id]
    
    has_mask1 = np.sum(mask1) > 0
    has_mask2 = np.sum(mask2) > 0

    if not (has_mask1 and has_mask2):
        continue

    for img1_overlap_tag in overlap_maps:
        if (img1_id, img2_id, img1_overlap_tag) in dup_truth:
            continue
        untested_image_pairs_with_ship_masks.append((img1_id, img2_id))
        break

len(overlap_image_maps), len(untested_image_pairs_with_ship_masks)

In [None]:
untested_overlaps_with_ship_masks = []
for (img1_id, img2_id), overlap_maps in tqdm_notebook(overlap_image_maps.items()):
    # TODO: Find out which remaining tile pairs have masks but aren't in dup_truth.

    mask1 = sdcic.tile_shipcnt_grids[img1_id]
    mask2 = sdcic.tile_shipcnt_grids[img2_id]
    
    has_mask1 = np.sum(mask1) > 0
    has_mask2 = np.sum(mask2) > 0

    if not (has_mask1 and has_mask2):
        continue

    for img1_overlap_tag in overlap_maps:
        if (img1_id, img2_id, img1_overlap_tag) in dup_truth:
            continue

        mask1_slice_total = np.sum(mask1[overlap_tag_maps[img1_overlap_tag]])
        mask2_slice_total = np.sum(mask2[overlap_tag_maps[overlap_tag_pairs[img1_overlap_tag]]])

        if mask1_slice_total + mask2_slice_total < 1:
            continue

        untested_overlaps_with_ship_masks.append((img1_id, img2_id, img1_overlap_tag))

len(overlap_image_maps), len(untested_overlaps_with_ship_masks)

## Find overlapping images using hashlib
Update: The values between two supposedly exact 256x256 crops are not exact (See below).

In [None]:

md5hash_dict = defaultdict(list)
img_ids = os.listdir(train_image_dir)

for img_id in tqdm_notebook(img_ids):
    for h in sdcic.tile_md5hash_grids[img_id]:
        md5hash_dict[h].append(img_id)

dup_counts_dict = defaultdict(int)
for key, dups in md5hash_dict.items():
    dup_counts_dict[len(dups)] += 1

sorted_bin_sizes = sorted(dup_counts_dict.items())
print('n images with k duplicates')
print('(k, n)')
sorted_bin_sizes

In [None]:
batch_size = 9
skip = 365
ii = 0
jj = 0
batch_limit = 9
samples_images = np.empty((batch_limit, 768, 768, 3), dtype=np.float32)

for hash_id, dups in md5hash_dict.items():
    ii += 1
    if len(dups) == batch_size:
        dups0 = list(set(dups))
        img_id = dups0[0]
        idx = sdcic.tile_md5hash_grids[img_id].tolist().index(hash_id)
        print(hash_id, len(dups), ii, sdcic.tile_entropy_grids[img_id][idx])
        if jj == min(dup_counts_dict[len(dups)], skip):
            break
        jj += 1

for i, c_img_id in enumerate(dups0[:batch_limit]):
    c_img = cv2.cvtColor(sdcic.get_img(c_img_id), cv2.COLOR_BGR2RGB)
    samples_images[i] = c_img.astype(np.float32) / 255.0

batch_rgb = montage_rgb(samples_images)
print(samples_images.shape)
print(batch_rgb.shape, batch_rgb.dtype)

fig, ax = plt.subplots(1, 1, figsize = (16, 16))
ax.imshow(batch_rgb, vmin=0, vmax=1)
plt.axis('off')
plt.show()

## Find overlapping images using cv2.img_hash

In [None]:
bm0hash_dict = defaultdict(list)
img_ids = os.listdir(train_image_dir)

for img_id in tqdm_notebook(img_ids):
    for h in sdcic.tile_bm0hash_grids[img_id]:
        bm0hash_dict[tuple(h)].append(img_id)  # hex

dup_counts_dict = defaultdict(int)
for key, dups in bm0hash_dict.items():
    dup_counts_dict[len(dups)] += 1

sorted_bin_sizes = sorted(dup_counts_dict.items())
print('n images with k duplicates')
print('(k, n)')
sorted_bin_sizes

In [None]:
batch_size = 18
skip = 5
ii = 0
jj = 0
batch_limit = 9
samples_images = np.empty((batch_limit, 768, 768, 3), dtype=np.float32)

for hash_id, dups in bm0hash_dict.items():
    ii += 1
    if len(dups) == batch_size:
        dups0 = list(set(dups))
        img_id = dups0[0]
        idx = np.where(np.all(sdcic.tile_bm0hash_grids[img_id] == np.asarray(hash_id), axis=1))[0]
        print(hash_id, len(dups), ii, sdcic.tile_entropy_grids[img_id][idx])
        if jj == min(dup_counts_dict[len(dups)], skip):
            break
        jj += 1

for i, c_img_id in enumerate(dups0[:batch_limit]):
    c_img = cv2.cvtColor(sdcic.get_img(c_img_id), cv2.COLOR_BGR2RGB)
    samples_images[i] = c_img.astype(np.float32) / 255.0

batch_rgb = montage_rgb(samples_images)
print(samples_images.shape)
print(batch_rgb.shape, batch_rgb.dtype)

fig, ax = plt.subplots(1, 1, figsize = (16, 16))
ax.imshow(batch_rgb, vmin=0, vmax=1)
plt.axis('off')
# plt.savefig(os.path.join('models', BASE_MODEL, f"{train_meta_filebase}_{score_str}_batch_{BATCH_NUM}.jpg"))
plt.show()

## Here we explore dup detection using image gradients and cross entropy 

In [None]:
def get_channel_entropy(ctr, img_size=1769472):  # 768x768x3
    ctr_norm = {k: v / img_size for k, v in sorted(ctr.items())}
    ctr_entropy = {k: -v * np.log(v) for k, v in ctr_norm.items()}
    entropy = np.sum([k * v for k, v in ctr_entropy.items()])
    return entropy

def get_entropy(img_id):
    img = cv2.imread(os.path.join(train_image_dir, img_id))
    img_grad = np.gradient(img.astype(np.int), axis=(0, 1))
    entropy_list = []
    for channel_grad in img_grad:
        ctr = Counter(np.abs(channel_grad).flatten())
        entropy_list.append(get_channel_entropy(ctr, img.size))
    return np.array(entropy_list)

def get_entropy1(img_id):
    img = cv2.imread(os.path.join(train_image_dir, img_id))
    img_grad = np.gradient(img.astype(np.int), 0.5, axis=(0, 1))
    entropy_list = []
    for channel_grad in img_grad:
        ctr = Counter(np.abs(channel_grad).astype(np.uint8).flatten())
        entropy_list.append(ctr)
    return entropy_list

def get_entropy2(img1_id, img2_id):
    entropy1_list = get_entropy1(img1_id)
    entropy2_list = get_entropy1(img2_id)
    entropy_list = []
    for ctr1, ctr2 in zip(entropy1_list, entropy2_list):
        ctr = (ctr1 - ctr2) + (ctr2 - ctr1)
        entropy_list.append(get_channel_entropy(ctr))
    return np.array(entropy_list)

In [None]:
score_lim0 = 0
score_lim1 = 1
for (img1_id, img2_id), overlap_maps in tqdm_notebook(overlap_image_maps.items()):
    if img1_id > img2_id:
        # sanity check
        raise ValueError(f'img1_id ({img1_id}) should be lexicographically smaller than img2_id ({img2_id})')
    for img1_overlap_tag, scores in overlap_maps.items():
        if (img1_id, img2_id, img1_overlap_tag) not in dup_truth:
            continue
        
        is_dup = dup_truth[(img1_id, img2_id, img1_overlap_tag)]

        if is_dup == 0 and np.max(scores.enp) > score_lim0:
            score_lim0 = np.max(scores.enp)
            print_score = True
        elif is_dup == 1 and np.max(scores.enp) < score_lim1:
            score_lim1 = np.max(scores.enp)
            print_score = True
        else:
            print_score = False

        if print_score:
            img1_entropy_vec = get_entropy(img1_id)
            img2_entropy_vec = get_entropy(img2_id)
            img1_entropy_vec_norm = np.linalg.norm(img1_entropy_vec)
            img2_entropy_vec_norm = np.linalg.norm(img2_entropy_vec)
            n_vec = np.max([img1_entropy_vec_norm, img2_entropy_vec_norm])
            img1_scaled_vec = img1_entropy_vec / n_vec
            img2_scaled_vec = img2_entropy_vec / n_vec
            grad_score = 1.0 - np.linalg.norm(img1_scaled_vec - img2_scaled_vec)

            entropy2 = get_entropy2(img1_id, img2_id)
            entropy2_norm = np.linalg.norm(entropy2)
            
            print('')
            print(f'{is_dup}, {min(scores.bmh):7.5f}, {min(scores.cmh):7.5f}, {grad_score:7.5f}, {entropy2_norm}')
            print(img1_id, img1_entropy_vec, f'{img1_entropy_vec_norm}')
            print(img2_id, img2_entropy_vec, f'{img2_entropy_vec_norm}')
            print(get_entropy(img1_id))
            print(get_entropy(img2_id))
            print(entropy2)
            print(np.max(scores.enp))


In [None]:
img1_id = '691d5afc2.jpg'
img2_id = '56417e7af.jpg'

In [None]:
img1_entropy_vec = get_entropy(img1_id)
img2_entropy_vec = get_entropy(img2_id)
img1_entropy_vec_norm = np.linalg.norm(img1_entropy_vec)
img2_entropy_vec_norm = np.linalg.norm(img2_entropy_vec)
n_vec = np.max([img1_entropy_vec_norm, img1_entropy_vec_norm])
img1_scaled_vec = img1_entropy_vec / n_vec
img2_scaled_vec = img2_entropy_vec / n_vec
print('')
print(img1_id, img1_entropy_vec, f'{img1_entropy_vec_norm}')
print(img2_id, img2_entropy_vec, f'{img1_entropy_vec_norm}')
print(f'{np.linalg.norm(img1_scaled_vec - img2_scaled_vec)}')

In [None]:
df = pd.read_csv('data/processed/dup_blacklist_6.csv', sep=', ')
for idx, row in df.iterrows():
    print(idx)
    img1_entropy_vec = get_entropy(row['ImageId1'])
    img1_entropy_vec_u = img1_entropy_vec / np.linalg.norm(img1_entropy_vec)
    print(row['ImageId1'], img1_entropy_vec)
    img2_entropy_vec = get_entropy(row['ImageId2'])
    img2_entropy_vec_u = img2_entropy_vec / np.linalg.norm(img2_entropy_vec)
    print(row['ImageId2'], img2_entropy_vec)
    print(np.dot(img1_entropy_vec_u, img2_entropy_vec_u), np.linalg.norm(img1_entropy_vec - img2_entropy_vec))

## search for reasonable thresholds

In [None]:
bmh_scores = defaultdict()
cmh_scores = defaultdict()
pix_scores = defaultdict(int)

for (img1_id, img2_id), overlap_maps in tqdm_notebook(overlap_image_maps.items()):
    if img1_id > img2_id:
        # sanity check
        raise ValueError(f'img1_id ({img1_id}) should be lexicographically smaller than img2_id ({img2_id})')
    for img1_overlap_tag, scores in overlap_maps.items():
        if (img1_id, img2_id, img1_overlap_tag) in dup_truth:
            continue

        for i in range(len(scores.bmh)):
            idx = (img1_id, img2_id, img1_overlap_tag, i)
            bmh_scores[idx] = scores.bmh[i]
            cmh_scores[idx] = scores.cmh[i]
            pix_scores[idx] = scores.pix[i]

overlap_scores_df = pd.DataFrame()
overlap_scores_df['bmh'] = pd.Series(bmh_scores)
overlap_scores_df['cmh'] = pd.Series(cmh_scores)
overlap_scores_df['pix'] = pd.Series(pix_scores)

overlap_scores_df.describe(percentiles=[.01, .05, .1, .25, .5, .75, .90, .95, .99])

In [None]:
bmh_arr = []
cmh_arr = []
con_arr = []
hom_arr = []
eng_arr = []
cor_arr = []
epy_arr = []
enp_arr = []
pix_arr = []
px0_arr = []
shp_arr = []

for (img1_id, img2_id), overlap_maps in tqdm_notebook(overlap_image_maps.items()):
    for img1_overlap_tag, scores in overlap_maps.items():
        if (img1_id, img2_id, img1_overlap_tag) in dup_truth:
            continue

        bmh_arr += list(scores.bmh)
        cmh_arr += list(scores.cmh)
        con_arr += list(scores.con)
        hom_arr += list(scores.hom)
        eng_arr += list(scores.eng)
        cor_arr += list(scores.cor)
        epy_arr += list(scores.epy)
        enp_arr += list(scores.enp)
        pix_arr += list(scores.pix)
        px0_arr += list(scores.px0)
        shp_arr += list(scores.shp)

In [None]:
overlap_limits_df = pd.DataFrame()
overlap_limits_df['bmh'] = pd.Series(bmh_arr)
overlap_limits_df['cmh'] = pd.Series(cmh_arr)
overlap_limits_df['con'] = pd.Series(con_arr)
overlap_limits_df['hom'] = pd.Series(hom_arr)
overlap_limits_df['eng'] = pd.Series(eng_arr)
overlap_limits_df['cor'] = pd.Series(cor_arr)
overlap_limits_df['epy'] = pd.Series(epy_arr)
overlap_limits_df['enp'] = pd.Series(enp_arr)
overlap_limits_df['pix'] = pd.Series(pix_arr)
overlap_limits_df['px0'] = pd.Series(px0_arr)
overlap_limits_df['shp'] = pd.Series(shp_arr)

In [None]:
overlap_limits_df.describe(percentiles=[.001, .01, .02, .05, .1, .25, .5, .75, .9, .95, .98, .99, 0.999])

In [None]:
overlap_limits_df.describe(percentiles=[.1, .25, .5, .75, .9])

In [None]:
#  |-----|--------------|-----|
# min  lower          upper  max

metric_tags = ['bmh', 'cmh', 'con', 'hom', 'eng', 'cor', 'epy', 'enp', 'pix', 'px0', 'shp']
Overlap_Scores_Lower_Limit = namedtuple('overlap_scores_lower_limit', metric_tags)
Overlap_Scores_Upper_Limit = namedtuple('overlap_scores_upper_limit', metric_tags)

osl_lower = Overlap_Scores_Lower_Limit(0., 0., 1e-5, 18e-6, 8e-6, 2e-6, 2e-6, 0.9995, 141, 0, 0)
osl_upper = Overlap_Scores_Upper_Limit(1., 1., 8e-5, 1.5e-4, 1e-4, 2e-5, 2e-5, 0.99993, 1859, 1e7, 1e7)

## Filter

In [None]:
Overlap_Idx_Scores = namedtuple('overlap_idx_scores', [
    'idx', 
    'bmh_min', 'cmh_min', 'con_min', 'hom_min', 'eng_min', 'cor_min', 'epy_min', 'enp_min', 'pix_min', 'px0_min', 'shp_min', 
    'bmh_max', 'cmh_max', 'con_max', 'hom_max', 'eng_max', 'cor_max', 'epy_max', 'enp_max', 'pix_max', 'px0_max', 'shp_max'])

bmh_min = 0
cmh_min = 0
con_min = 0
hom_min = 0
eng_min = 0
cor_min = 0
epy_min = 0
enp_min = 0
pix_min = 0
px0_min = 0
shp_min = 0

bmh_max = 1
cmh_max = 1
con_max = 1
hom_max = 1
eng_max = 1
cor_max = 1
epy_max = 1
enp_max = 1
pix_max = 256*256*3*255
px0_max = 256*256
shp_max = 256*256

bmh_min_hits = 0
cmh_min_hits = 0
con_min_hits = 0
hom_min_hits = 0
eng_min_hits = 0
cor_min_hits = 0
epy_min_hits = 0
enp_min_hits = 0
pix_min_hits = 0
px0_min_hits = 0
shp_min_hits = 0

bmh_max_hits = 0
cmh_max_hits = 0
con_max_hits = 0
hom_max_hits = 0
eng_max_hits = 0
cor_max_hits = 0
epy_max_hits = 0
enp_max_hits = 0
pix_max_hits = 0
px0_max_hits = 0
shp_max_hits = 0

flat_score_good = 0
flat_score_bad = 0
print_first_good = True
print_first_bad = True
n_not_dups = 0

overlap_candidates = []
for (img1_id, img2_id), overlap_maps in tqdm_notebook(overlap_image_maps.items()):
    for img1_overlap_tag, scores in overlap_maps.items():
        if (img1_id, img2_id, img1_overlap_tag) in dup_truth:
            continue

        constraint_hits = 0
        
        bmh_min = np.min(scores.bmh)
        if bmh_min < osl_lower.bmh:
            bmh_min_hits += 1
            constraint_hits += 1
            
        cmh_min = np.min(scores.cmh)
        if cmh_min < osl_lower.cmh:
            cmh_min_hits += 1
            constraint_hits += 1
            
        con_min = np.min(scores.con)
        if con_min < osl_lower.con:
            con_min_hits += 1
            constraint_hits += 1
            
        hom_min = np.min(scores.hom)
        if hom_min < osl_lower.hom:
            hom_min_hits += 1
            constraint_hits += 1
            
        eng_min = np.min(scores.eng)
        if eng_min < osl_lower.eng:
            eng_min_hits += 1
            constraint_hits += 1
            
        cor_min = np.min(scores.cor)
        if cor_min < osl_lower.cor:
            cor_min_hits += 1
            constraint_hits += 1
            
        epy_min = np.min(scores.epy)
        if epy_min < osl_lower.epy:
            epy_min_hits += 1
            constraint_hits += 1
            
        enp_min = np.min(scores.enp)
        if enp_min < osl_lower.enp:
            enp_min_hits += 1
            constraint_hits += 1
            
        pix_min = np.min(scores.pix)
        if pix_min < osl_lower.pix:
            pix_min_hits += 1
            constraint_hits += 1

        px0_min = np.min(scores.px0)
        if px0_min < osl_lower.px0:
            px0_min_hits += 1
            constraint_hits += 1

        shp_min = np.min(scores.shp)
        if shp_min < osl_lower.shp:
            shp_min_hits += 1
            constraint_hits += 1

            
        bmh_max = np.max(scores.bmh)
        if bmh_max > osl_upper.bmh:
            bmh_max_hits += 1
            constraint_hits += 1
            
        cmh_max = np.max(scores.cmh)
        if cmh_max > osl_upper.cmh:
            cmh_max_hits += 1
            constraint_hits += 1
            
        con_max = np.max(scores.con)
        if con_max > osl_upper.con:
            con_max_hits += 1
            constraint_hits += 1

        hom_max = np.max(scores.hom)
        if hom_max > osl_upper.hom:
            hom_max_hits += 1
            constraint_hits += 1
            
        eng_max = np.max(scores.eng)
        if eng_max > osl_upper.eng:
            eng_max_hits += 1
            constraint_hits += 1

        cor_max = np.max(scores.cor)
        if cor_max > osl_upper.cor:
            cor_max_hits += 1
            constraint_hits += 1

        epy_max = np.max(scores.epy)
        if epy_max > osl_upper.epy:
            epy_max_hits += 1
            constraint_hits += 1

        enp_max = np.max(scores.enp)
        if enp_max > osl_upper.enp:
            enp_max_hits += 1
            constraint_hits += 1

        pix_max = np.max(scores.pix)
        if pix_max > osl_upper.pix:
            pix_max_hits += 1
            constraint_hits += 1

        px0_max = np.max(scores.px0)
        if px0_max > osl_upper.px0:
            px0_max_hits += 1
            constraint_hits += 1

        shp_max = np.max(scores.shp)
        if shp_max > osl_upper.shp:
            shp_max_hits += 1
            constraint_hits += 1

#         if constraint_hits < 0:
#             continue
            
        idx = (img1_id, img2_id, img1_overlap_tag)
        overlap_scores = Overlap_Idx_Scores(
            idx, 
            bmh_min, cmh_min, con_min, hom_min, eng_min, cor_min, epy_min, enp_min, pix_min, px0_min, shp_min, 
            bmh_max, cmh_max, con_max, hom_max, eng_max, cor_max, epy_max, enp_max, pix_max, px0_max, shp_max)
        overlap_candidates.append(overlap_scores)

In [None]:
print(len(overlap_candidates))
print(bmh_min_hits, cmh_min_hits, con_min_hits, hom_min_hits, eng_min_hits, cor_min_hits, epy_min_hits, enp_min_hits, pix_min_hits)
print(bmh_max_hits, cmh_max_hits, con_max_hits, hom_max_hits, eng_max_hits, cor_max_hits, epy_max_hits, enp_max_hits, pix_max_hits)

In [None]:
print(len(dup_truth), n_not_dups, flat_score_good, flat_score_bad)

# Use duplicate_truth.txt and image_md5hash_grids.pkl to find untested duplicate and non-duplicate tiles.

## Create list of flat hashes. 
(i.e. hashes for tiles where every pixel is the same color)

In [None]:
img_overlap_index_maps = generate_tag_pair_lookup()

solid_hashes = set()
for img_id, tile_issolid_grid in sdcic.tile_issolid_grids.items():
    idxs = set(np.where(tile_issolid_grid >= 0)[0])
    for idx in idxs:
        if np.all(tile_issolid_grid[idx] >= 0):
            solid_hashes.add(sdcic.tile_md5hash_grids[img_id][idx])

print(solid_hashes)

### Using dicts

In [None]:
tile_hash_dup_dict = defaultdict(set)
tile_hash_dif_dict = defaultdict(set)

for (img1_id, img2_id, img1_overlap_tag), is_dup in dup_truth.items():
    
    for idx1, idx2 in img_overlap_index_maps[img1_overlap_tag]:
        
        tile1_hash = sdcic.tile_md5hash_grids[img1_id][idx1]
        tile2_hash = sdcic.tile_md5hash_grids[img2_id][idx2]
        
        if is_dup:

            if tile1_hash in solid_hashes or tile2_hash in solid_hashes:
                continue

            tile_hash_dup_dict[tile1_hash].add(tile1_hash)
            tile_hash_dup_dict[tile2_hash].add(tile2_hash)
            tile_hash_dup_dict[tile1_hash].add(tile2_hash)
            tile_hash_dup_dict[tile2_hash].add(tile1_hash)
        
        else:
            if tile1_hash == tile2_hash:
                continue

            tile_hash_dif_dict[tile1_hash].add(tile2_hash)
            tile_hash_dif_dict[tile2_hash].add(tile1_hash)
            
print(len(tile_hash_dup_dict), len(tile_hash_dif_dict))

# Sanity check: hashes cannot be simultaneously "a dup" and "not a dup" of tile1_hash
for tile1_hash in tile_hash_dup_dict:
    if len(tile_hash_dup_dict[tile1_hash].intersection(tile_hash_dif_dict[tile1_hash])) != 0:
        print(tile1_hash, tile_hash_dup_dict[tile1_hash], tile_hash_dif_dict[tile1_hash])
    assert len(tile_hash_dup_dict[tile1_hash].intersection(tile_hash_dif_dict[tile1_hash])) == 0
    
# Sanity check: If B and C are dups of A, then make sure C not in tile_hash_dif_dict[B]
for tile1_hash, tile1_dups in tile_hash_dup_dict.items():
    for tile1_dup1 in sorted(tile1_dups):
        for tile1_dup2 in sorted(tile1_dups):
            if tile1_dup1 in tile_hash_dif_dict[tile1_dup2]:
                print(tile1_hash, tile1_dup1, tile_hash_dif_dict[tile1_dup2])
            assert tile1_dup1 not in tile_hash_dif_dict[tile1_dup2]

# Now we should be able to form cliques: (i.e. If A == B and B == C, then A == C)
for tile1_hash, tile1_dups in tile_hash_dup_dict.items():
    for tile1_dup1 in sorted(tile1_dups):
        for tile1_dup2 in sorted(tile1_dups):
            if tile1_dup1 <= tile1_dup2:
                continue
            tile_hash_dup_dict[tile1_dup1].add(tile1_dup2)
            tile_hash_dup_dict[tile1_dup2].add(tile1_dup1)

neighbor_counts = Counter()
for tile1_hash, tile1_dups in tile_hash_dup_dict.items():
    neighbor_counts[len(tile1_dups)] += 1
list(sorted(neighbor_counts.items()))

In [None]:
auto_overlap_labels_0 = {}

for candidate in overlap_candidates:
    img1_id, img2_id, img1_overlap_tag = candidate.idx
    for idx1, idx2 in img_overlap_index_maps[img1_overlap_tag]:
        
        tile1_hash = sdcic.tile_md5hash_grids[img1_id][idx1]
        tile2_hash = sdcic.tile_md5hash_grids[img2_id][idx2]
        
        if tile1_hash in tile_hash_dif_dict[tile2_hash]:
            assert tile2_hash in tile_hash_dif_dict[tile1_hash]
            auto_overlap_labels_0[(img1_id, img2_id, img1_overlap_tag)] = 0
            break

print(len(auto_overlap_labels_0))

In [None]:
auto_overlap_labels_1 = {}

for candidate in overlap_candidates:
    img1_id, img2_id, img1_overlap_tag = candidate.idx
    for idx1, idx2 in img_overlap_index_maps[img1_overlap_tag]:
        
        tile1_hash = sdcic.tile_md5hash_grids[img1_id][idx1]
        tile2_hash = sdcic.tile_md5hash_grids[img2_id][idx2]
        
        if tile1_hash in tile_hash_dup_dict[tile2_hash]:
            assert tile2_hash in tile_hash_dup_dict[tile1_hash]
            continue
        else:
            break
    else:
        auto_overlap_labels_1[(img1_id, img2_id, img1_overlap_tag)] = 1

print(len(auto_overlap_labels_1))

### Using cliques via networkx

In [None]:
tile_hash_dup_cliques = nx.Graph()
tile_hash_dif_cliques = nx.Graph()

for (img1_id, img2_id, img1_overlap_tag), is_dup in dup_truth.items():
    for idx1, idx2 in img_overlap_index_maps[img1_overlap_tag]:
        tile1_hash = sdcic.tile_md5hash_grids[img1_id][idx1]
        tile2_hash = sdcic.tile_md5hash_grids[img2_id][idx2]
        if is_dup:
            if tile1_hash in solid_hashes or tile2_hash in solid_hashes:
                continue
            update_tile_cliques(tile_hash_dup_cliques, tile1_hash, tile2_hash)
        else:
            if tile1_hash == tile2_hash:
                continue
            tile_hash_dif_cliques.add_edge(tile1_hash, tile2_hash)

print(tile_hash_dup_cliques.number_of_nodes(), tile_hash_dif_cliques.number_of_nodes())

neighbor_counts = Counter()
for tile_hashes in nx.connected_components(tile_hash_dup_cliques):
    neighbor_counts[len(tile_hashes)] += 1
list(sorted(neighbor_counts.items()))

#### Separately

In [None]:
auto_overlap_labels_0 = {}

for candidate in overlap_candidates:
    img1_id, img2_id, img1_overlap_tag = candidate.idx
    for idx1, idx2 in img_overlap_index_maps[img1_overlap_tag]:
        
        tile1_hash = sdcic.tile_md5hash_grids[img1_id][idx1]
        tile2_hash = sdcic.tile_md5hash_grids[img2_id][idx2]
        
        if tile1_hash in tile_hash_dif_cliques and tile2_hash in set(nx.neighbors(tile_hash_dif_cliques, tile1_hash)):
            auto_overlap_labels_0[(img1_id, img2_id, img1_overlap_tag)] = 0
            break

print(len(auto_overlap_labels_0))

In [None]:
auto_overlap_labels_1 = {}

for candidate in overlap_candidates:
    img1_id, img2_id, img1_overlap_tag = candidate.idx
    for idx1, idx2 in img_overlap_index_maps[img1_overlap_tag]:
        
        tile1_hash = sdcic.tile_md5hash_grids[img1_id][idx1]
        tile2_hash = sdcic.tile_md5hash_grids[img2_id][idx2]
        
        if tile1_hash in tile_hash_dup_cliques and tile2_hash in set(nx.neighbors(tile_hash_dup_cliques, tile1_hash)):
            continue
        else:
            break
    else:
        auto_overlap_labels_1[(img1_id, img2_id, img1_overlap_tag)] = 1

print(len(auto_overlap_labels_1))

In [None]:
auto_overlap_labels = {}
for key in auto_overlap_labels_0:
    assert key not in auto_overlap_labels_1
auto_overlap_labels.update(auto_overlap_labels_0)
auto_overlap_labels.update(auto_overlap_labels_1)
print(len(auto_overlap_labels))

#### Combined

In [None]:
auto_overlap_labels = {}

for candidate in overlap_candidates:
    img1_id, img2_id, img1_overlap_tag = candidate.idx
    if (img1_id, img2_id, img1_overlap_tag) in auto_overlap_labels:
        continue
    is_dup = 1
    for idx1, idx2 in img_overlap_index_maps[img1_overlap_tag]:
        
        tile1_hash = sdcic.tile_md5hash_grids[img1_id][idx1]
        tile2_hash = sdcic.tile_md5hash_grids[img2_id][idx2]
        
        if tile1_hash in tile_hash_dif_cliques and tile2_hash in set(nx.neighbors(tile_hash_dif_cliques, tile1_hash)):
            is_dup = 0
            break
        elif tile1_hash in tile_hash_dup_cliques and tile2_hash in set(nx.neighbors(tile_hash_dup_cliques, tile1_hash)):
            continue
        else:
            is_dup = -1

    if is_dup == -1:
        continue
    
    auto_overlap_labels[(img1_id, img2_id, img1_overlap_tag)] = is_dup

print(len(auto_overlap_labels))

In [None]:
dup_truth = update_duplicate_truth(auto_overlap_labels, auto=True)
len(dup_truth)

## Check the performance of DupNet

In [None]:
ytrue = []
tile_pairs = []
for (img1_id, img2_id, img1_overlap_tag), is_dup in tqdm_notebook(dup_truth.items()):
    for overlap_idx, (idx1, idx2) in enumerate(img_overlap_index_maps[img1_overlap_tag]):
        tile_pairs.append(TilePairs(img1_id, img2_id, img1_overlap_tag, overlap_idx, idx1, idx2))
        ytrue.append(is_dup)
print(len(tile_pairs), sum(ytrue), len(ytrue))

In [None]:
test_ds = Dataset(tile_pairs)
test_dl = data.DataLoader(test_ds, batch_size=256, num_workers=12)
test_dl = WrappedDataLoader(test_dl, preprocess)
print(len(test_dl))

In [None]:
model = load_checkpoint('models/dup_model.best.pth')
model.cuda()
model.to(device)

In [None]:
model.eval()
with torch.no_grad():
    yprobs0 = [model(xb) for xb in tqdm_notebook(test_dl)]
    yprobs = np.vstack([l.cpu() for l in yprobs0]).reshape(-1)
    print(len(yprobs0), yprobs.shape, min(yprobs), max(yprobs))

In [None]:
overlap_cnn_tile_scores = {}
for tp, yprob in zip(tile_pairs, yprobs):
    
    if (tp.img1_id, tp.img2_id) not in overlap_cnn_tile_scores:
        overlap_cnn_tile_scores[(tp.img1_id, tp.img2_id)] = {}
    
    if tp.img1_overlap_tag not in overlap_cnn_tile_scores[(tp.img1_id, tp.img2_id)]:
        n_overlapping_tiles = len(img_overlap_index_maps[tp.img1_overlap_tag])
        cnn_scores = np.zeros(n_overlapping_tiles)
        overlap_cnn_tile_scores[(tp.img1_id, tp.img2_id)][tp.img1_overlap_tag] = cnn_scores
    
    overlap_cnn_tile_scores[(tp.img1_id, tp.img2_id)][tp.img1_overlap_tag][tp.overlap_idx] = yprob

In [None]:
DNN_Stats = namedtuple('dnn_stats', ['yprob', 'ypred', 'ytrue', 'loss', 'yconf', 'pix'])

dup_dict = {}
for (img1_id, img2_id, img1_overlap_tag), ytrue in tqdm_notebook(dup_truth.items()):
    assert img1_id < img2_id

    if (img1_id, img2_id, img1_overlap_tag) in dup_dict:
        continue
    if (img1_id, img2_id) not in overlap_image_maps:
        continue
    if img1_overlap_tag not in overlap_image_maps[(img1_id, img2_id)]:
        continue
    scores = overlap_image_maps[(img1_id, img2_id)][img1_overlap_tag]
    if len(scores.pix) < 2:
        continue
    pix = max(scores.pix)
#     if (img1_id, img2_id) not in overlap_cnn_tile_scores:
#         continue
#     if img1_overlap_tag not in overlap_cnn_tile_scores[(img1_id, img2_id)]:
#         continue

    dcnn_scores_raw = overlap_cnn_tile_scores[(img1_id, img2_id)][img1_overlap_tag]
    dcnn_conf_raw = np.abs((dcnn_scores_raw - 0.5) * 2) # confidence? (1: very, 0: not at all)
    yconf = np.min(dcnn_conf_raw)
    yprob = np.min(dcnn_scores_raw)
    ypred = (yprob > 0.5) * 1
    assert ypred <= 1
    
    if ytrue:
        bce = - ytrue * np.log(yprob)
    else:
        bce = - (1 - ytrue) * np.log(1 - yprob)
    
    dup_dict[(img1_id, img2_id, img1_overlap_tag)] = DNN_Stats(yprob, ypred, ytrue, bce, yconf, pix)

In [None]:
DNN_Stats2 = namedtuple('dnn_stats', ['key', 'yprob', 'ypred', 'ytrue', 'loss', 'yconf', 'pix'])
dup_dict_flat = []
for keys, dnns in tqdm_notebook(dup_dict.items()):
    dup_dict_flat.append(DNN_Stats2(keys, dnns.yprob, dnns.ypred, dnns.ytrue, dnns.loss, dnns.yconf, dnns.pix))

In [None]:
n_correct = 0
id_tags = []
for dnns in tqdm_notebook(sorted(dup_dict_flat, key=operator.attrgetter('loss'), reverse=True)):

    # Skip the ones the dnn got correct.
    if dnns.ypred == dnns.ytrue:
        n_correct += 1
        continue
        
#     if dnns.key[2] != '08':
#         continue
#     if not dnns.ytrue:
#         continue
#     if (dnns.key[0], dnns.key[1]) not in overlap_image_maps:
#         continue

    if dnns.loss == np.nan:
        print('nan ', dnns)
        id_tags.append(dnns.key)
        continue
    if dnns.loss == np.inf:
        print('+inf', dnns)
        id_tags.append(dnns.key)
        continue
    if dnns.loss == -np.inf:
        print('-inf', dnns)
        id_tags.append(dnns.key)
        continue
        
#     Skip the ones the dnn was certain about.
#     if dnns.yprob < 0.01 or dnns.yprob > 0.99:
#         continue

    id_tags.append(dnns.key)
len(id_tags), n_correct

In [None]:
tags_counter = Counter()
for img1_id, img2_id, img1_overlap_tag in id_tags:
    tags_counter[img1_id] += 1
    tags_counter[img2_id] += 1
print(len(tags_counter))

for k, v in sorted(tags_counter.items(), key=operator.itemgetter(0), reverse=False):
    print(k, v)

In [None]:
aa = 0
n_samples = 10
use_median_shift = True

test_files = id_tags[aa * n_samples: (aa + 1) * n_samples]#[::-1]
for f in test_files:
    print(f, '{:.5} {} {} {:.5} {:.5} {}'.format(*dup_dict[f]))

dtick = 256
n_ticks = 768 // dtick + 1
ticks = [i * dtick for i in range(n_ticks)]

fig, m_axs = plt.subplots(n_samples, 2, figsize = (12, 6 * n_samples))
for ii, (img1_id, img2_id, img1_overlap_tag) in enumerate(test_files):
    
    scores = overlap_image_maps[(img1_id, img2_id)][img1_overlap_tag]
    
    (ax1, ax2) = m_axs[ii]
    yprob, ypred, is_dup, loss, yconf, pix = dup_dict[(img1_id, img2_id, img1_overlap_tag)]
    
    imgmod1 = ImgMod(os.path.join(train_image_dir, img1_id))
    imgmod2 = ImgMod(os.path.join(train_image_dir, img2_id))

    slice1 = overlap_tag_slices[img1_overlap_tag]
    slice2 = overlap_tag_slices[overlap_tag_pairs[img1_overlap_tag]]

    m12 = np.median(np.vstack([imgmod1.parent_rgb[slice1], imgmod2.parent_rgb[slice2]]), axis=(0, 1), keepdims=True).astype(np.uint8)
    
    brightness_level = -100 if np.sum(m12) >= 384 else 100
    img1 = imgmod1.brightness_shift('L', brightness_level)
    img2 = imgmod2.brightness_shift('L', brightness_level)
    
    if use_median_shift:
        img1_drop = imgmod1.parent_rgb - m12
        img2_drop = imgmod2.parent_rgb - m12
    else:        
        img1_drop = imgmod1.parent_rgb
        img2_drop = imgmod2.parent_rgb
    
    img1[slice1] = img1_drop[slice1]
    img2[slice2] = img2_drop[slice2]

    ax1.imshow(img1)
    ax1.set_title(f'{img1_id} {yprob:6.4} ({is_dup})')
    ax1.set_xticks(ticks)
    ax1.set_yticks(ticks)

    ax2.imshow(img2)
    ax2.set_title(f'{img2_id} {loss:4.2f} {max(scores.pix)}')
    ax2.set_xticks(ticks)
    ax2.set_yticks(ticks)

plt.tight_layout()
# fig.savefig(os.path.join('temp', BASE_MODEL, f"{train_meta_filebase}_{score_str}_batch_{BATCH_NUM}_row_{aa+1}.jpg"))