In [None]:
import os
import gc
import json
from collections import Counter
from collections import defaultdict

import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

from tqdm import tqdm_notebook
from skimage.util import montage
import cv2

from sdcdup.utils import get_project_root
from sdcdup.utils import idx2ijpair
from sdcdup.utils import overlap_tag_maps
from sdcdup.utils import overlap_tag_pairs
from sdcdup.utils import generate_third_party_overlaps
from sdcdup.utils import generate_tag_pair_lookup
from sdcdup.utils import pad_string
from sdcdup.utils import get_overlap_matches
from sdcdup.utils import merge_overlap_matches
from sdcdup.utils import load_duplicate_truth
from sdcdup.utils import update_duplicate_truth
from sdcdup.utils import update_tile_cliques
from sdcdup.utils import get_tile
from sdcdup.utils import add_tuples
from sdcdup.features import SDCImageContainer
from sdcdup.rebuild_overlap_groups import SDCImage
from sdcdup.rebuild_overlap_groups import PrettyEncoder
from sdcdup.rebuild_overlap_groups import check_overlap
from sdcdup.rebuild_overlap_groups import add_overlap
from sdcdup.rebuild_overlap_groups import get_puzzle_solution

%reload_ext autoreload
%autoreload 2

montage_rgb = lambda x: np.stack([montage(x[:, :, :, i]) for i in range(x.shape[3])], -1)
montage_rgb2 = lambda x, c, r: np.stack([montage(x[:, :, :, i], grid_shape=(c, r)) for i in range(x.shape[3])], -1)

plt.rcParams['savefig.pad_inches'] = 0

project_root = get_project_root()
interim_data_dir = os.path.join(project_root, os.getenv('INTERIM_DATA_DIR'))
results_dir = os.path.join(project_root, 'notebooks', 'figures', 'results')

if not os.path.exists(results_dir):
    os.makedirs(results_dir)
    
third_party_overlaps = generate_third_party_overlaps()
tag_pair_lookup = generate_tag_pair_lookup()

In [None]:
matches_files = [
    'matches_bmh32_0.9_offset.csv',
    'matches_bmh96_0.9_offset.csv', 
    'matches_bmh32_0.8.csv',
    'matches_bmh96_0.8.csv', 
]
sdcic = SDCImageContainer()
sdcic.load_image_metrics(['bmh96'])
sdcic.matches = get_overlap_matches(matches_files)

In [None]:
score_types = ['avg', 'pix', 'dnn']
overlap_image_maps = sdcic.load_image_overlap_properties(matches_files, score_types=score_types)
print(len(overlap_image_maps))

In [None]:
dup_truth = load_duplicate_truth()
print(len(dup_truth))

## Create list of flat hashes. 
(i.e. hashes for tiles where every pixel is the same color)

In [None]:
solid_hashes = set(['b06a8fb9', 'b8e3e4c9', '715bd1bf', '232b4413', '571ea5a6']) # the blues
for img_id, tile_issolid_grid in sdcic.img_metrics['sol'].items():
    idxs = set(np.where(tile_issolid_grid >= 0)[0])
    for idx in idxs:
        if np.all(tile_issolid_grid[idx] >= 0):
            solid_hashes.add(sdcic.img_metrics['md5'][img_id][idx])
print(solid_hashes)

## Get a mapping to all the non_dups.

In [None]:
non_dups = defaultdict(lambda: defaultdict(set))
for (img1_id, img2_id, img12_overlap_tag), is_dup in tqdm_notebook(dup_truth.items()):
    img21_overlap_tag = overlap_tag_pairs[img12_overlap_tag]
    if not is_dup:
        non_dups[img1_id][img12_overlap_tag].add(img2_id)
        non_dups[img2_id][img21_overlap_tag].add(img1_id)
        continue
    if img12_overlap_tag == '08':
        non_dups[img1_id][img12_overlap_tag].add(img2_id)
        non_dups[img2_id][img21_overlap_tag].add(img1_id)
        continue
        
print(len(non_dups))

In [None]:
sort_scores = {}
for k, s in tqdm_notebook(overlap_image_maps.items()):
    score = np.mean(s.avg)/(16*16*3) + np.mean(s.pix)/(256*256*3)
    if score > 30:
        continue
    sort_scores[k] = (score, 9 - len(s.avg))

In [None]:
overlap_groups = {}
is_dup_truth = {}
missing_third_party_matches = set()
G = nx.Graph()
n_non_dups = 0
n_with_solid = 0
n_with_low_probs = 0
n_with_large_diff = 0

for (img1_id, img2_id, img12_overlap_tag), sort_score in tqdm_notebook(sorted(sort_scores.items(), key=lambda x: x[1])):
    
    img21_overlap_tag = overlap_tag_pairs[img12_overlap_tag]

    if (img1_id, img2_id, img12_overlap_tag) in dup_truth:
        if dup_truth[(img1_id, img2_id, img12_overlap_tag)] == 0:
            n_non_dups += 1
            continue

    if img1_id == img2_id:
        continue

    img1_hashes = set(sdcic.img_metrics['md5'][img1_id][overlap_tag_maps[img12_overlap_tag]])
    if len(img1_hashes.difference(solid_hashes)) == 0:
        n_with_solid += 1
        continue
        
    img2_hashes = set(sdcic.img_metrics['md5'][img2_id][overlap_tag_maps[img21_overlap_tag]])
    if len(img2_hashes.difference(solid_hashes)) == 0:
        n_with_solid += 1
        continue

    scores = overlap_image_maps[(img1_id, img2_id, img12_overlap_tag)]
    if np.min(scores.dnn) < 0.8:
        n_with_low_probs += 1
        continue

    if img1_id in overlap_groups:
        if overlap_groups[img1_id].overlaps[img12_overlap_tag]:
            continue

    if img2_id in overlap_groups:
        if overlap_groups[img2_id].overlaps[img21_overlap_tag]:
            continue

    if img1_id not in overlap_groups:
        overlap_groups[img1_id] = SDCImage(img1_id)
    if img2_id not in overlap_groups:
        overlap_groups[img2_id] = SDCImage(img2_id)

    good_overlap, missing_matches = check_overlap(img1_id, img2_id, img12_overlap_tag, overlap_groups, overlap_image_maps, non_dups)
    if len(missing_matches) > 0:
        missing_third_party_matches |= missing_matches
        n_with_large_diff += 1
    if not good_overlap:
        continue

    good_overlap, missing_matches = check_overlap(img2_id, img1_id, img21_overlap_tag, overlap_groups, overlap_image_maps, non_dups)
    if len(missing_matches) > 0:
        missing_third_party_matches |= missing_matches
        n_with_large_diff += 1
    if not good_overlap:
        continue

    add_overlap(img1_id, img2_id, img12_overlap_tag, overlap_groups)
    add_overlap(img2_id, img1_id, img21_overlap_tag, overlap_groups)
    G.add_edge(img1_id, img2_id)

In [None]:
print(f"{n_non_dups:>4}, {n_with_solid:>6}, {n_with_low_probs:>8}", end=',')
print(f"{n_with_large_diff:>8}, {len(overlap_groups):>7}, {len(missing_third_party_matches):>8}")

In [None]:
import pandas as pd
df = pd.DataFrame(sorted(missing_third_party_matches))
df.to_csv(os.path.join(interim_data_dir, f'matches_dnn_{len(missing_third_party_matches)}.csv'), index=False)

In [None]:
neighbor_counts = Counter()
for image_hashes in nx.connected_components(G):
    neighbor_counts[len(image_hashes)] += 1
list(sorted(neighbor_counts.items()))

# Analyze single overlap group

In [None]:
for image_hashes in nx.connected_components(G):
    if len(image_hashes) == 100:
#     if '2df5f0080.jpg' in image_hashes:
        image_hashes0 = sorted(image_hashes)
        print(len(image_hashes0))

In [None]:
j = 0
batch_limit = 100
samples_images = np.empty((batch_limit, 768, 768, 3), dtype=np.float32)
for i, c_img_id in enumerate(image_hashes0[batch_limit*j:batch_limit*(j+1)]):
    c_img = cv2.cvtColor(sdcic.get_img(c_img_id), cv2.COLOR_BGR2RGB)
    samples_images[i] = c_img.astype(np.float32) / 255.0

batch_rgb = montage_rgb(samples_images)

fig, ax = plt.subplots(1, 1, figsize = (16, 16))
ax.imshow(batch_rgb, vmin=0, vmax=1)
plt.axis('off')
plt.show()

In [None]:
# test_group = overlap_groups['9b33aefea.jpg']
test_group = overlap_groups[image_hashes0[10]]
overlaps = set()
overlaps.add(test_group._id)
for overlap_tag, img_id in test_group.overlaps.items():
    if img_id:
        overlaps.add(img_id)
overlaps = sorted(overlaps)
print(test_group.img_id, overlaps)

batch_limit = 25
samples_images = np.empty((len(overlaps), 768, 768, 3), dtype=np.float32)
for i, c_img_id in enumerate(overlaps):
    c_img = cv2.cvtColor(sdcic.get_img(c_img_id), cv2.COLOR_BGR2RGB)
    samples_images[i] = c_img.astype(np.float32) / 255.0

batch_rgb = montage_rgb(samples_images[:batch_limit])

fig, ax = plt.subplots(1, 1, figsize = (16, 16))
ax.imshow(batch_rgb, vmin=0, vmax=1)
plt.axis('off')
plt.show()

In [None]:
topo_dict, maxs = get_puzzle_solution(image_hashes0, overlap_groups)

In [None]:
n_col_blocks = maxs[0] + 3
n_row_blocks = maxs[1] + 3
samples_images = np.ones((n_col_blocks * n_row_blocks, 256, 256, 3), dtype=np.float32)
for c_img_id, pos in topo_dict.items():
    c_img = cv2.cvtColor(sdcic.get_img(c_img_id), cv2.COLOR_BGR2RGB)
    for idx in range(9):
        c_tile = get_tile(c_img, idx)
        ij = idx2ijpair[idx]
        i = (n_row_blocks * (pos[0] + ij[0]) + (pos[1] + ij[1]))
        samples_images[i] = c_tile.astype(np.float32) / 255.0

batch_rgb = montage_rgb2(samples_images, n_col_blocks, n_row_blocks)

fig, ax = plt.subplots(1, 1, figsize = (16, 16))
ax.imshow(batch_rgb, vmin=0, vmax=1)
plt.axis('off')
plt.tight_layout()
plt.show()

# Build overlap groups

### Write overlap groups to json

In [None]:
overlap_groups_dict = {}
for image_hashes in nx.connected_components(G):
    n_images = len(image_hashes)
    image_hashes0 = sorted(image_hashes)
    topo_dict, maxs = get_puzzle_solution(image_hashes0, overlap_groups)
    a = pad_string(str(n_images), 3)
    b = pad_string(str(maxs[0]), 2)
    c = pad_string(str(maxs[1]), 2)
    fname = f'overlap_{a}_{b}_{c}_{image_hashes0[0]}'
    overlap_groups_dict[fname] = topo_dict

rebuild_overlap_groups_filename = os.path.join(project_root, 'rebuild_overlap_groups.json')
with open(rebuild_overlap_groups_filename, 'w') as ofs:
    json.dump(overlap_groups_dict, ofs, cls=PrettyEncoder, indent=2, separators=(', ', ': '), sort_keys=True)

### Save each overlap group as jpg

In [None]:
for image_hashes in nx.connected_components(G):
    n_images = len(image_hashes)
    if n_images < 25:
        continue
    image_hashes0 = sorted(image_hashes)
    topo_dict, maxs = get_puzzle_solution(image_hashes0, overlap_groups)
    a = pad_string(str(n_images), 3)
    b = pad_string(str(maxs[0]), 2)
    c = pad_string(str(maxs[1]), 2)

    fname = f'overlap_{a}_{b}_{c}_{image_hashes0[0]}'
    filename = os.path.join(results_dir, fname)
    if os.path.exists(filename):
        continue
    print(fname)
    
    n_col_blocks = maxs[0] + 3
    n_row_blocks = maxs[1] + 3
    samples_images = np.ones((n_col_blocks * n_row_blocks, 256, 256, 3), dtype=np.float32)
    for c_img_id, pos in topo_dict.items():
        c_img = cv2.cvtColor(sdcic.get_img(c_img_id), cv2.COLOR_BGR2RGB)
        for idx in range(9):
            c_tile = get_tile(c_img, idx)
            ij = idx2ijpair[idx]
            i = (n_row_blocks * (pos[0] + ij[0]) + (pos[1] + ij[1]))
            samples_images[i] = c_tile.astype(np.float32) / 255.0
    batch_rgb = montage_rgb2(samples_images, n_col_blocks, n_row_blocks)

    fig, ax = plt.subplots(1, 1, figsize = (16, 16))
    ax.imshow(batch_rgb, vmin=0, vmax=1)
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(filename)
    plt.cla()
    plt.clf()
    plt.close(fig)
    gc.collect()