In [None]:
import os
import operator
from collections import Counter
from collections import namedtuple

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook
import cv2
import torch
from torch.utils import data

from sdcdup.utils import overlap_tag_pairs
from sdcdup.utils import generate_overlap_tag_slices
from sdcdup.utils import generate_tag_pair_lookup
from sdcdup.utils import channel_shift
from sdcdup.utils import load_duplicate_truth
from sdcdup.utils import ImgMod
from sdcdup.utils import WrappedDataLoader
from sdcdup.utils import EvalDataset as Dataset

from sdcdup.models.dupnet import load_checkpoint

%matplotlib inline
%reload_ext autoreload
%autoreload 2

RED = (244, 67, 54)  #F44336 
GREEN = (76, 175, 80)  #4CAF50 
LIGHT_BLUE = (3, 169, 244)  #03A9F4

SMALL_SIZE = 10
MEDIUM_SIZE = 12
BIGGER_SIZE = 16
BIGGEST_SIZE = 20
plt.rc('font', size=BIGGEST_SIZE)         # controls default text sizes
plt.rc('axes', titlesize=BIGGEST_SIZE)    # fontsize of the axes title
plt.rc('axes', labelsize=BIGGEST_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGEST_SIZE)  # fontsize of the figure title

# SENDTOENV
train_image_dir = 'data/raw/train_768/'

overlap_tag_slices = generate_overlap_tag_slices()
img_overlap_index_maps = generate_tag_pair_lookup()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
dup_truth = load_duplicate_truth()
print(len(dup_truth))

In [None]:
def preprocess(x):
    return x.view(-1, 6, 256, 256).to(device)

## Check the performance of the model

In [None]:
TilePairs = namedtuple('TilePairs', 'img1_id img2_id img1_overlap_tag overlap_idx idx1 idx2')

ytrue = []
tile_pairs = []
for (img1_id, img2_id, img1_overlap_tag), is_dup in tqdm_notebook(dup_truth.items()):
    for overlap_idx, (idx1, idx2) in enumerate(img_overlap_index_maps[img1_overlap_tag]):
        tile_pairs.append(TilePairs(img1_id, img2_id, img1_overlap_tag, overlap_idx, idx1, idx2))
        ytrue.append(is_dup)
print(len(tile_pairs), sum(ytrue), len(ytrue))

In [None]:
test_ds = Dataset(tile_pairs)
test_dl = data.DataLoader(test_ds, batch_size=256, num_workers=18)
test_dl = WrappedDataLoader(test_dl, preprocess)
print(len(test_dl))

In [None]:
model = load_checkpoint('models/dup_model.2019_0802_2209.best.pth')
model.cuda()
model.to(device)

In [None]:
model.eval()
with torch.no_grad():
    yprobs0 = [model(xb) for xb in tqdm_notebook(test_dl)]
    yprobs = np.vstack([l.cpu() for l in yprobs0]).reshape(-1)
    print(len(yprobs0), yprobs.shape, min(yprobs), max(yprobs))

In [None]:
overlap_cnn_tile_scores = {}
for tp, yprob in zip(tile_pairs, yprobs):
    
    if (tp.img1_id, tp.img2_id) not in overlap_cnn_tile_scores:
        overlap_cnn_tile_scores[(tp.img1_id, tp.img2_id)] = {}
    
    if tp.img1_overlap_tag not in overlap_cnn_tile_scores[(tp.img1_id, tp.img2_id)]:
        n_overlapping_tiles = len(img_overlap_index_maps[tp.img1_overlap_tag])
        cnn_scores = np.zeros(n_overlapping_tiles)
        overlap_cnn_tile_scores[(tp.img1_id, tp.img2_id)][tp.img1_overlap_tag] = cnn_scores
    
    overlap_cnn_tile_scores[(tp.img1_id, tp.img2_id)][tp.img1_overlap_tag][tp.overlap_idx] = yprob

DNN_Stats = namedtuple('dnn_stats', ['yprob', 'ypred', 'ytrue', 'loss', 'yconf'])

dup_dict = {}
for (img1_id, img2_id, img1_overlap_tag), ytrue in tqdm_notebook(dup_truth.items()):
    assert img1_id < img2_id

    dcnn_scores_raw = overlap_cnn_tile_scores[(img1_id, img2_id)][img1_overlap_tag]
    dcnn_conf_raw = np.abs((dcnn_scores_raw - 0.5) * 2) # confidence? (1: very, 0: not at all)
    yconf = np.min(dcnn_conf_raw)
    yprob = np.min(dcnn_scores_raw)
    ypred = (yprob > 0.5) * 1
    assert ypred <= 1
    
    if ytrue:
        bce = - ytrue * np.log(yprob)
    else:
        bce = - (1 - ytrue) * np.log(1 - yprob)
    
    dup_dict[(img1_id, img2_id, img1_overlap_tag)] = DNN_Stats(yprob, ypred, ytrue, bce, yconf)

In [None]:
DNN_Stats2 = namedtuple('dnn_stats', ['key', 'yprob', 'ypred', 'ytrue', 'loss', 'yconf'])
dup_dict_flat = []
for keys, dnns in tqdm_notebook(dup_dict.items()):
    dup_dict_flat.append(DNN_Stats2(keys, dnns.yprob, dnns.ypred, dnns.ytrue, dnns.loss, dnns.yconf))

In [None]:
n_confident = 0
n_correct = 0
id_tags = []
for dnns in tqdm_notebook(sorted(dup_dict_flat, key=operator.attrgetter('yconf'), reverse=False)):

    # Skip invalids, but print them out so we know which ones are.
    if dnns.loss == np.nan:
        print('nan ', dnns)
        continue
    if dnns.loss == np.inf:
        print('+inf', dnns)
        continue
    if dnns.loss == -np.inf:
        print('-inf', dnns)
        continue
        
#     Skip the ones the dnn was certain about.
#     if dnns.yconf > 0.02:
#         n_confident += 1
#         continue

    # Skip the ones the dnn got correct.
    if dnns.ypred == dnns.ytrue:
        n_correct += 1
        continue
        
    id_tags.append(dnns.key)
len(id_tags), n_confident, n_correct

In [None]:
tags_counter = Counter()
for img1_id, img2_id, img1_overlap_tag in id_tags:
    for overlap_idx, (idx1, idx2) in enumerate(img_overlap_index_maps[img1_overlap_tag]):
        tags_counter[(img1_id, idx1)] += 1
        tags_counter[(img2_id, idx2)] += 1
print(len(tags_counter))

for k, v in sorted(tags_counter.items(), key=operator.itemgetter(1), reverse=True):
    if v > 3:
        print(k, v)

In [None]:
aa = 0
n_samples = 10
use_median_shift = True
shift_brightness = False

test_files = id_tags[aa * n_samples: (aa + 1) * n_samples]#[::-1]
for f in test_files:
    print(f, '{:10.5} {} {} {:10.5} {:.5} {}'.format(*dup_dict[f]))

dtick = 256
n_ticks = 768 // dtick + 1
ticks = [i * dtick for i in range(n_ticks)]

fig, m_axs = plt.subplots(n_samples, 2, figsize = (16, 8 * n_samples))
for ii, (img1_id, img2_id, img1_overlap_tag) in enumerate(test_files):
    
    (ax1, ax2) = m_axs[ii]
    yprob, ypred, is_dup, loss, yconf = dup_dict[(img1_id, img2_id, img1_overlap_tag)]
    
    imgmod1 = ImgMod(os.path.join(train_image_dir, img1_id))
    imgmod2 = ImgMod(os.path.join(train_image_dir, img2_id))

    slice1 = overlap_tag_slices[img1_overlap_tag]
    slice2 = overlap_tag_slices[overlap_tag_pairs[img1_overlap_tag]]

    m12 = np.median(np.vstack([imgmod1.parent_rgb[slice1], imgmod2.parent_rgb[slice2]]), axis=(0, 1), keepdims=True).astype(np.uint8)
    
    if shift_brightness:
        brightness_level = -100 if np.sum(m12) >= 384 else 100
        img1 = imgmod1.brightness_shift('L', brightness_level)
        img2 = imgmod2.brightness_shift('L', brightness_level)
    else:
        img1 = imgmod1.parent_rgb
        img2 = imgmod2.parent_rgb

    if use_median_shift:
        img1_drop = imgmod1.parent_rgb - m12
        img2_drop = imgmod2.parent_rgb - m12
    else:        
        img1_drop = imgmod1.parent_rgb
        img2_drop = imgmod2.parent_rgb
    
    img1[slice1] = img1_drop[slice1]
    img2[slice2] = img2_drop[slice2]

    ax1.imshow(img1)
    ax1.set_title(f'{img1_id}')
    ax1.set_xticks(ticks)
    ax1.set_yticks(ticks)

    ax2.imshow(img2)
    ax2.set_title(f'{img2_id}')
    ax2.set_xticks(ticks)
    ax2.set_yticks(ticks)

plt.tight_layout()
# fig.savefig(os.path.join('temp', BASE_MODEL, f"{train_meta_filebase}_{score_str}_batch_{BATCH_NUM}_row_{aa+1}.jpg"))