In [None]:
import os
from collections import Counter
from collections import namedtuple
from collections import defaultdict

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook
from sklearn import tree
import graphviz
from cv2 import img_hash

from sdcdup.utils import overlap_tag_pairs
from sdcdup.utils import overlap_tag_maps
from sdcdup.utils import generate_overlap_tag_slices
from sdcdup.utils import generate_tag_pair_lookup
from sdcdup.utils import get_project_root
from sdcdup.utils import fuzzy_compare
from sdcdup.utils import bce_loss
from sdcdup.utils import get_hamming_distance
from sdcdup.utils import load_duplicate_truth
from sdcdup.utils import get_tile
from sdcdup.utils import ImgMod
from sdcdup.features import SDCImageContainer
from sdcdup.visualization import get_ticks
from sdcdup.visualization import subtract_channel_average
from sdcdup.visualization import draw_overlap_bbox
from sdcdup.visualization import show_image
from sdcdup.visualization import ChannelShift

%load_ext dotenv
%dotenv
%matplotlib inline
%reload_ext autoreload
%autoreload 2

RED = (244, 67, 54)  #F44336
GREEN = (76, 175, 80)  #4CAF50

SMALL_SIZE = 10
MEDIUM_SIZE = 12
BIGGER_SIZE = 16
BIGGEST_SIZE = 20
plt.rc('font', size=BIGGEST_SIZE)         # controls default text sizes
plt.rc('axes', titlesize=BIGGEST_SIZE)    # fontsize of the axes title
plt.rc('axes', labelsize=BIGGEST_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGEST_SIZE)  # fontsize of the figure title

project_root = get_project_root()
models_dir = os.path.join(project_root, 'models')
results_dir = os.path.join(project_root, 'notebooks', 'figures')
train_image_dir = os.path.join(project_root, os.getenv('RAW_DATA_DIR'), 'train_768')
tag_pair_lookup = generate_tag_pair_lookup()
overlap_tag_slices = generate_overlap_tag_slices()
ticks = get_ticks()

In [None]:
dup_truth = load_duplicate_truth()
print(len(dup_truth))

In [None]:
sdcic = SDCImageContainer()
sdcic.matches = list(dup_truth)

In [None]:
matches_files = []
# score_types = ['bmh32', 'bmh96', 'hst', 'avg', 'pix', 'dnn']
score_types = ['dnn']
overlap_image_maps = sdcic.load_image_overlap_properties(matches_files, score_types=score_types)
print(len(overlap_image_maps))

## Evaluate all the ground truth examples on the model.

In [None]:
DNN_Stats = namedtuple('dnn_stats', ['yprob', 'ypred', 'loss'])

dup_dict = {}
for (img1_id, img2_id, img1_overlap_tag), ytrue in tqdm_notebook(dup_truth.items()):
    yprob = np.min(overlap_image_maps[(img1_id, img2_id, img1_overlap_tag)].dnn)
    ypred = (yprob > 0.5) * 1
    loss = bce_loss(ytrue, yprob)
    dup_dict[(img1_id, img2_id, img1_overlap_tag)] = DNN_Stats(yprob, ypred, loss)

## Try out a decision tree classifier for dup_truth

In [None]:
L = []
X = []
Y = []
for key, scores in dup_dict.items():
    L.append(key)
    X.append([
        scores.ypred,
        scores.loss,
#         min(overlap_image_maps[key].bmh32),
#         min(overlap_image_maps[key].bmh96),
#         max(overlap_image_maps[key].hst),
#         max(overlap_image_maps[key].avg),
#         max(overlap_image_maps[key].pix), 
    ])
    Y.append(dup_truth[key])

L = np.array(L)
X = np.array(X)  # X = [[0, 0], [1, 1]]
Y = np.array(Y)  # Y = [0, 1]

print(len(X), len(Y), sum(Y))

In [None]:
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, Y)
dot_data = tree.export_graphviz(
    clf, 
    feature_names=[
        'ypred',
        'loss',
#         'bmh32', 
#         'bmh96', 
#         'hst', 
#         'avg', 
#         'pix', 
    ], 
    filled=True, 
    rounded=True, 
    special_characters=True, 
    leaves_parallel=True) 

graph = graphviz.Source(dot_data) 
graph

In [None]:
graph.render(f'decision_tree_{len(X)}', directory=results_dir, cleanup=True, format='png')

## decision tree analysis (Optional)

In [None]:
all_nodes = clf.apply(X)

nodes = np.where(all_nodes == 6)
np.argmin(X[nodes]), np.min(X[nodes]), np.argmax(X[nodes]), np.max(X[nodes])

In [None]:
idx = 0
print(L[nodes][idx], Y[nodes][idx], X[nodes][idx])
print(overlap_image_maps[tuple(L[nodes][idx])])

In [None]:
tricky_examples = [
    ['e28669903.jpg', 'ed2998ef7.jpg', '08', 1],  # 9
    ['66482462b.jpg', 'e2497099c.jpg', '08', 1],  # 9
    ['73fec0637.jpg', '8b0219c19.jpg', '08', 0],  # 9
    ['00ce2c1c0.jpg', '68ef625ba.jpg', '18', 1],  # 6
    ['01178499a.jpg', '7a7a0034a.jpg', '05', 1],  # 6
    ['1ebdf2f08.jpg', 'b1bfb768c.jpg', '05', 1],  # 6 [91.          0.99781223]
    ['d4f0aaa70.jpg', 'd84d4a78a.jpg', '05', 0],  # 6 [5.95230000e+04 9.98578088e-01] 
    ['012d8cca1.jpg', 'bc45cee87.jpg', '07', 1],  # 6
    ['2323bf875.jpg', 'b5da61fce.jpg', '07', 1],  # 6 [2.05663500e+06 9.98277186e-01]
    ['7f2be2b0a.jpg', '84dcdc7af.jpg', '07', 0],  # 6
    ['089858a56.jpg', '903a8b121.jpg', '38', 1],  # 6
    ['468bf9178.jpg', '6090b3a8b.jpg', '38', 1],  # 6 [1.30900000e+03 9.97640283e-01]
    ['d843fc5ca.jpg', 'e805070df.jpg', '38', 1],  # 6
    ['000194a2d.jpg', '384765ab2.jpg', '38', 1],  # 6
    ['0ef6cd331.jpg', 'e6a6f80cd.jpg', '38', 0],  # 6 [1.72270000e+04 9.98394555e-01]
    ['0a33ce967.jpg', '3964f0cee.jpg', '04', 1],  # 4
    ['d164aea52.jpg', 'fded6e12d.jpg', '04', 1],  # 4
    ['c3193fb05.jpg', 'cc68e7818.jpg', '15', 0],  # 4 [2.16300000e+04 9.98311792e-01]
    ['331987f64.jpg', '4869b48b6.jpg', '15', 0],  # 4
    ['0318fc519.jpg', 'b7feb225a.jpg', '37', 1],  # 4
    ['7234a3a53.jpg', 'dc6534704.jpg', '37', 1],  # 4
    ['de6fb187d.jpg', 'ea6dc23b7.jpg', '37', 1],  # 4 [223.           0.99544613]
    ['cd3c59923.jpg', 'efdd03319.jpg', '37', 0],  # 4 [6.70246000e+05 9.99894307e-01] 
    ['0c279107f.jpg', '3b1314d5d.jpg', '37', 0],  # 4
    ['42f02a4a4.jpg', '7d31648ff.jpg', '48', 0],  # 4
    ['204906e27.jpg', '892a69b4b.jpg', '02', 1],  # 3 [6.31644000e+05 9.97614902e-01]
    ['813c8ec35.jpg', 'caa94ffc3.jpg', '06', 0],  # 3 [1.76759000e+05 9.99834742e-01]
    ['0256ef90d.jpg', '46da51931.jpg', '06', 0],  # 3 [3.70260000e+05 9.99319673e-01]
    ['0ee790381.jpg', 'ac87bcee5.jpg', '06', 0],  # 3
    ['2f6c0deaa.jpg', 'e44a4f5b0.jpg', '28', 1],  # 3 [24.          0.99509307]
    ['0ef6cd331.jpg', '813c8ec35.jpg', '28', 0],  # 3 [1.79442000e+05 9.98195859e-01]
    ['4c56d2f00.jpg', 'dcd94e973.jpg', '68', 1],  # 3 [6.31635000e+05 9.97534103e-01]
    ['b645cd49b.jpg', 'f2e554691.jpg', '68', 1],  # 3 [3.76847000e+05 9.96659721e-01]
    ['b998c7415.jpg', 'd4d26f700.jpg', '68', 1],  # 3 [3.76847000e+05 9.96680501e-01]
    ['0ef6cd331.jpg', '3a9e579aa.jpg', '68', 0],  # 3 [1.62810000e+04 9.98394555e-01]
    ['a61b3e245.jpg', 'd84d4a78a.jpg', '68', 0],  # 3 [2.59134100e+06 9.99175738e-01]
    ['2095da0cb.jpg', '45b1a4561.jpg', '68', 0],  # 3
]

In [None]:
def get_tile_scores(tile1, tile2):
    score = fuzzy_compare(tile1, tile2)
    bmh1 = img_hash.blockMeanHash(tile1)
    bmh2 = img_hash.blockMeanHash(tile2)
    score_hamm = get_hamming_distance(bmh1, bmh2, normalize=True, as_score=True)
    cmh1 = img_hash.colorMomentHash(tile1)
    cmh2 = img_hash.colorMomentHash(tile2)
    score_norm = np.linalg.norm(cmh1 - cmh2)
    score_expnorm = np.exp(-score_norm)
    return score, score_hamm, score_norm, score_expnorm

def plot_image_pair(img1_id, img2_id, img1_overlap_tag, is_dup, draw_boxes=True):
    
    imgmod1 = ImgMod(os.path.join(train_image_dir, img1_id))
    imgmod2 = ImgMod(os.path.join(train_image_dir, img2_id))
    
    img1 = imgmod1.parent_rgb
    img2 = imgmod2.parent_rgb
    
    subtract_channel_average(img1, img2, img1_overlap_tag, shift=ChannelShift('median', True))
    img2_overlap_tag = overlap_tag_pairs[img1_overlap_tag]
    
    img1_overlap_map = overlap_tag_maps[img1_overlap_tag]
    img2_overlap_map = overlap_tag_maps[img2_overlap_tag]
    for idx1, idx2 in zip(img1_overlap_map, img2_overlap_map):
        
        tile1 = get_tile(imgmod1.parent_rgb, idx1)
        tile2 = get_tile(imgmod2.parent_rgb, idx2)
        score0, score0_hamm, score0_norm, score0_expnorm = get_tile_scores(tile1, tile2)
        
        tile1_drop = get_tile(img1, idx1)
        tile2_drop = get_tile(img2, idx2)
        score1, score1_hamm, score1_norm, score1_expnorm = get_tile_scores(tile1_drop, tile2_drop)
        
        m12_tile = np.median(np.vstack([tile1, tile2]), axis=(0, 1), keepdims=True).astype(np.uint8)
        tile1_drop = tile1 - m12_tile
        tile2_drop = tile2 - m12_tile        
        score2, score2_hamm, score2_norm, score2_expnorm = get_tile_scores(tile1_drop, tile2_drop)
        
        print(f'tile {idx1} / tile {idx2}')        
        print(f'{score0:10.8f}, {score0_hamm:10.8f}, {score0_norm:10.8f}, {score0_expnorm:10.8f}')
        print(f'{score1:10.8f}, {score1_hamm:10.8f}, {score1_norm:10.8f}, {score1_expnorm:10.8f}')
        print(f'{score2:10.8f}, {score2_hamm:10.8f}, {score2_norm:10.8f}, {score2_expnorm:10.8f}', m12_tile)
    
    if draw_boxes:
        bbox_thickness = 4
        bbox_color = GREEN if is_dup else RED
        draw_overlap_bbox(img1, img1_overlap_tag, bbox_thickness, bbox_color)
        draw_overlap_bbox(img2, img2_overlap_tag, bbox_thickness, bbox_color)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
    
    show_image(ax1, img1, img1_id, ticks)
    show_image(ax2, img2, img2_id, ticks)

In [None]:
plot_image_pair(*tricky_examples[22], draw_boxes=False)

In [None]:
plot_image_pair(*L[nodes][idx], Y[nodes][idx], draw_boxes=False)