In [None]:
import os
import operator
from collections import Counter
from collections import namedtuple
from collections import defaultdict

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook
import torch
from torch.utils import data

from sdcdup.utils import generate_overlap_tag_slices
from sdcdup.utils import generate_tag_pair_lookup
from sdcdup.utils import get_project_root
from sdcdup.utils import bce_loss
from sdcdup.utils import load_duplicate_truth
from sdcdup.utils import ImgMod
from sdcdup.data import EvalDataset as Dataset
from sdcdup.data import WrappedDataLoader
from sdcdup.models import load_checkpoint
from sdcdup.visualization import ChannelShift
from sdcdup.visualization import get_ticks
from sdcdup.visualization import show_image_pair

%load_ext dotenv
%dotenv
%matplotlib inline
%reload_ext autoreload
%autoreload 2

RED = (244, 67, 54)  #F44336
GREEN = (76, 175, 80)  #4CAF50
BLUE = (3, 169, 244)  #03A9F4

SMALL_SIZE = 10
MEDIUM_SIZE = 12
BIGGER_SIZE = 16
BIGGEST_SIZE = 20
plt.rc('font', size=BIGGEST_SIZE)         # controls default text sizes
plt.rc('axes', titlesize=BIGGEST_SIZE)    # fontsize of the axes title
plt.rc('axes', labelsize=BIGGEST_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=BIGGER_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGEST_SIZE)  # fontsize of the figure title

project_root = get_project_root()
models_dir = os.path.join(project_root, 'models')
train_image_dir = os.path.join(project_root, os.getenv('RAW_DATA_DIR'), 'train_768')
img_overlap_index_maps = generate_tag_pair_lookup()
overlap_tag_slices = generate_overlap_tag_slices()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
dup_truth = load_duplicate_truth(chunk_type='truth')
print(len(dup_truth))

In [None]:
def preprocess(x):
    return x.view(-1, 6, 256, 256).to(device)

## Check the performance of the model

In [None]:
TilePairs = namedtuple('TilePairs', 'img1_id img2_id img1_overlap_tag overlap_idx idx1 idx2')

tile_pairs = []
for img1_id, img2_id, img1_overlap_tag in tqdm_notebook(dup_truth):
    for overlap_idx, (idx1, idx2) in enumerate(img_overlap_index_maps[img1_overlap_tag]):
        tile_pairs.append(TilePairs(img1_id, img2_id, img1_overlap_tag, overlap_idx, idx1, idx2))

print(len(tile_pairs))

In [None]:
test_ds = Dataset(tile_pairs)
test_dl = data.DataLoader(test_ds, batch_size=256, num_workers=18)
test_dl = WrappedDataLoader(test_dl, preprocess)
print('Total number of batches to evaluate: ', len(test_dl))

In [None]:
model = load_checkpoint(os.path.join(models_dir, 'dup_model.2019_0802_2209.best.pth'))
model.cuda()
model.to(device)

In [None]:
model.eval()
with torch.no_grad():
    yprobs0 = [model(xb) for xb in tqdm_notebook(test_dl)]
    yprobs = np.vstack([l.cpu() for l in yprobs0]).reshape(-1)
print(len(yprobs0), yprobs.shape, min(yprobs), max(yprobs))

In [None]:
overlap_cnn_tile_scores = defaultdict(dict)
for tp, yprob in zip(tile_pairs, yprobs):
    
    if tp.img1_overlap_tag not in overlap_cnn_tile_scores[(tp.img1_id, tp.img2_id)]:
        n_overlapping_tiles = len(img_overlap_index_maps[tp.img1_overlap_tag])
        overlap_cnn_tile_scores[(tp.img1_id, tp.img2_id)][tp.img1_overlap_tag] = np.zeros(n_overlapping_tiles)
    
    overlap_cnn_tile_scores[(tp.img1_id, tp.img2_id)][tp.img1_overlap_tag][tp.overlap_idx] = yprob
print(len(overlap_cnn_tile_scores))

In [None]:
DNN_Stats = namedtuple('dnn_stats', ['yprob', 'ypred', 'ytrue', 'loss', 'yconf'])

dup_dict = {}
for (img1_id, img2_id, img1_overlap_tag), ytrue in tqdm_notebook(dup_truth.items()):
    assert img1_id < img2_id

    dcnn_scores_raw = overlap_cnn_tile_scores[(img1_id, img2_id)][img1_overlap_tag]
    dcnn_conf_raw = np.abs((dcnn_scores_raw - 0.5) * 2) # confidence? (1: high, 0: low)
    yconf = np.min(dcnn_conf_raw)
    yprob = np.min(dcnn_scores_raw)
    ypred = (yprob > 0.5) * 1
    assert ypred <= 1
    loss = bce_loss(ytrue, yprob)
    
    dup_dict[(img1_id, img2_id, img1_overlap_tag)] = DNN_Stats(yprob, ypred, ytrue, loss, yconf)

In [None]:
n_confident = 0
n_correct = 0
id_tags = []
for key, dnns in tqdm_notebook(sorted(dup_dict.items(), key=lambda x: x[1].yconf, reverse=False)):

    # Skip invalids, but print them out so we know which ones are.
    if dnns.loss == np.nan:
        print('nan ', dnns)
        continue
    if dnns.loss == np.inf:
        print('+inf', dnns)
        continue
    if dnns.loss == -np.inf:
        print('-inf', dnns)
        continue
        
#     Skip the ones with high confidence.
#     if dnns.yconf > 0.02:
#         n_confident += 1
#         continue

#     Skip the ones the dnn got correct.
    if dnns.ypred == dnns.ytrue:
        n_correct += 1
        continue
        
    id_tags.append(key)
len(id_tags), n_confident, n_correct

In [None]:
tags_counter = Counter()
for img1_id, img2_id, img1_overlap_tag in id_tags:
    for overlap_idx, (idx1, idx2) in enumerate(img_overlap_index_maps[img1_overlap_tag]):
        tags_counter[(img1_id, idx1)] += 1
        tags_counter[(img2_id, idx2)] += 1
print(len(tags_counter))

for k, v in sorted(tags_counter.items(), key=operator.itemgetter(1), reverse=True):
    if v > 3:
        print(k, v)

In [None]:
aa = 0
n_samples = 10

test_files = id_tags[aa * n_samples: (aa + 1) * n_samples]#[::-1]
for f in test_files:
    print(f, '{:10.5} {} {} {:10.5} {}'.format(*dup_dict[f]))

draw_bboxes = False
bbox_thickness = 4
ticks = get_ticks()
median_color_shift = True

fig, m_axs = plt.subplots(n_samples, 2, figsize = (16, 8 * n_samples))
for ii, (img1_id, img2_id, img1_overlap_tag) in enumerate(test_files):
    
    if (img1_id, img2_id, img1_overlap_tag) in dup_truth:
        bbox_color = GREEN if dup_truth[(img1_id, img2_id, img1_overlap_tag)] else RED
    else:
        bbox_color = BLUE
    
    (ax1, ax2) = m_axs[ii]
    
    imgmod1 = ImgMod(os.path.join(train_image_dir, img1_id))
    imgmod2 = ImgMod(os.path.join(train_image_dir, img2_id))

    show_image_pair(ax1, ax2, imgmod1, imgmod2, img1_overlap_tag, draw_bboxes, bbox_thickness, bbox_color, img1_id, img2_id, ticks, shift=ChannelShift('median', True))
    
plt.tight_layout()
# fig.savefig(os.path.join('temp', BASE_MODEL, f"{train_meta_filebase}_{score_str}_batch_{BATCH_NUM}_row_{aa+1}.jpg"))