*Partially adopted and modified from* [github.com/modestyachts/CIFAR-10.1/](https://github.com/modestyachts/CIFAR-10.1)

In [1]:
from ipywidgets import widgets
import tqdm
import numpy as np
from matplotlib import image
import os
import json
import operator
from PIL import Image
import io

- **json_dict**: the dictionary loaded from the json file
- **test_imgs**: list of all the keys loaded from the json file. 
- **cur_test_img**: the filename of the test image that is to be inspected. Use **test_ind** to keep track of the index of the test images.
- **cur_matches**: the dictionary of topk items that has filename from tiny images as key and l2 distance as value
- **cur_keys**: list of tuples of keys and values of cur_matches sorted by increasing l2 distance such that the images are shown in from closest to farthest.

In [2]:
def np_to_png(a, fmt='png', scale=1):
    a = np.uint8(a)
    f = io.BytesIO()
    tmp_img = Image.fromarray(a)
    tmp_img = tmp_img.resize((scale * 32, scale * 32), Image.NEAREST)
    tmp_img.save(f, fmt)
    return f.getvalue()

In [3]:
json_dir = os.path.join(os.getcwd(), 'sim_data')
dist_file = os.path.join(json_dir, 'dist_to_candidate.json')
similars_file = os.path.join(json_dir, 'similar_imgs.json')
idx_file = os.path.join(json_dir, 'fname_to_idx.json')
idx_sim_file = os.path.join(json_dir, 'sim_to_idx.json')

with open(dist_file) as f:
    dist_dict = json.load(f)
    
with open(similars_file) as f:
    sim_dict = json.load(f)
    
with open(idx_file) as f:
    idx_dict = json.load(f)
    
with open(idx_sim_file) as f:
    idx_sim_dict = json.load(f)
        
test_imgs = list(dist_dict.keys())

In [4]:
data = np.load('candidate.npz')
X_train, X_test = data['X_train'], data['X_test']

data = np.load('sims.npz')
X_sim = data['X_sim']

#### Check similar images in the training set per test image
- `test_ind`: Index of the test set for inspection.

In [5]:
test_ind = 54
test_ind -= 1

In [6]:
test_ind += 1
print(test_ind)
cur_test_img = test_imgs[test_ind]
cur_matches = sorted(dist_dict[cur_test_img].items(), key=operator.itemgetter(1))
cur_similars = set()
if cur_test_img in sim_dict:
    cur_similars = set(sim_dict[cur_test_img])

54


In [7]:
num_images_to_show = len(cur_matches)+1            # len(cur_good_indices)
img_offset = -1

assert img_offset >= -1 # and img_offset < len(new_imgs)
# num_images_to_show = min(num_images_to_show, len(new_imgs) - img_offset)
num_cols = 8
num_rows = 13
num_per_tab = num_cols * num_rows
num_tabs = 1             # int(math.ceil(num_images_to_show / num_per_tab))
scale=3

checkboxes = {}

tab_contents = []
for kk in tqdm.tqdm(range(num_tabs), desc='Setting up similar images'):
    rows = []
    cur_num_rows = num_rows
    for ii in range(cur_num_rows):
        cur_row = []
        cur_num_cols = num_cols
        if kk == num_tabs - 1 and ii == cur_num_rows - 1:
            cur_num_cols = num_images_to_show - (num_tabs - 1) * num_per_tab - (cur_num_rows - 1) * num_cols
        for jj in range(cur_num_cols):
            cur_index = img_offset + kk * num_per_tab + ii * num_cols + jj
            if cur_index < 0:
                cur_img = widgets.Image(value=np_to_png(X_test[idx_dict['test'][cur_test_img]]), width=96, height=96)
                cur_checkbox = widgets.Checkbox(value=False, description='original', indent=False, layout=widgets.Layout(width='100px', height='28')) #, description=str(ii * num_cols + jj))
                cur_checkbox.width = '90px'
                checkboxes[cur_test_img] = cur_checkbox
            elif cur_matches[cur_index][0] in cur_similars:
                cur_img = widgets.Image(value=np_to_png(X_sim[idx_sim_dict[cur_matches[cur_index][0]]]), width=96, height=96)
                cur_checkbox = widgets.Checkbox(value=True, description=str(cur_matches[cur_index][1]), indent=False, layout=widgets.Layout(width='100px', height='28')) #, description=str(ii * num_cols + jj))
                cur_checkbox.width = '90px'
                checkboxes[cur_matches[cur_index][0]] = cur_checkbox
            else:
                cur_img = widgets.Image(value=np_to_png(X_train[idx_dict['train'][cur_matches[cur_index][0]]]), width=96, height=96)
                cur_checkbox = widgets.Checkbox(value=False, description=str(cur_matches[cur_index][1]), indent=False, layout=widgets.Layout(width='100px', height='28')) #, description=str(ii * num_cols + jj))
                cur_checkbox.width = '90px'
                checkboxes[cur_matches[cur_index][0]] = cur_checkbox
            cur_box = widgets.VBox([cur_img, cur_checkbox])
            cur_box.layout.align_items = 'center'
            cur_box.layout.padding = '6px'
            cur_row.append(cur_box)
        cur_hbox = widgets.HBox(cur_row)
        rows.append(cur_hbox)
    tab_contents.append(widgets.VBox(rows))

tab = widgets.Tab()
tab.children = tab_contents
for i in range(len(tab.children)):
    tab.set_title(i, str(i))
display(tab)

Setting up similar images: 100%|██████████| 1/1 [00:01<00:00,  1.45s/it]


Tab(children=(VBox(children=(HBox(children=(VBox(children=(Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x…