## Inspection of training and test set

In order make sure that the test set does not include images that are similar to images in the training set, we have removed all images that are similar to test images through manual inspection. This notebook allows to visualize which images are similar to a given image in the test set, and to see which images we have removed.

Choices for the similarity metric and the set/subset of the images for inspection are provided with selection widgets in the [following section](#Check-similar-images-in-the-training-set-per-test-image).

The notebook is partially adopted from: [github.com/modestyachts/CIFAR-10.1/](https://github.com/modestyachts/CIFAR-10.1)

In [None]:
from IPython.core.display import HTML
from IPython.display import Javascript, display
from ipywidgets import widgets
import tqdm
import numpy as np
from matplotlib import image
import os
import json
import operator
from PIL import Image
import io
from IPython.core.debugger import set_trace

In [None]:
# get the png value for Image widget
def np_to_png(a, fmt='png', scale=1):
    a = np.uint8(a)
    f = io.BytesIO()
    tmp_img = Image.fromarray(a)
    tmp_img = tmp_img.resize((scale * 32, scale * 32), Image.NEAREST)
    tmp_img.save(f, fmt)
    return f.getvalue()

The variables of the notebook are as follows:

- **json_dir**: Directory containing the json files for l2 distance and sub directories for structural similarity.
- **res_dir**: Directory containing the json files for the indices of the correctly classified test images.
- **test_imgs**: List of all the keys loaded from the json file. 
- **cur_test_img**: The filename of the test image that is to be inspected. Use **test_ind** to keep track of the index of the test images.
- **cur_matches**: The dictionary of topk items that has filename from tiny images as key and l2 distance as value
- **cur_keys**: List of tuples of keys and values of cur_matches sorted by increasing l2 distance such that the images are shown in from closest to farthest.

In [None]:
def load_sim_files(json_dir, sim_metric):
    """
    Loads the json files that contains the pre-recorded values for one of the following distances metrics:
        l2_distance: Euclidean distance between two images.
        structural_similarity: Structural similarity between two images as implemented in ssim from scikit image.
    """
    assert sim_metric in ['l2_distance', 'structural_similarity']
    dist_file = os.path.join(json_dir, 'dist_to_candidate.json')  # l2 distance of top-100 training images
    dist_dir = os.path.join(json_dir, 'ssim_dist/')
    similars_file = os.path.join(json_dir, 'similar_imgs.json')   # manually selected similar images
    idx_file = os.path.join(json_dir, 'fname_to_idx.json')        # filename to index dict
    idx_sim_file = os.path.join(json_dir, 'sim_to_idx.json')      # filename to index for eliminated images
    
    with open(similars_file) as f:
        sim_dict = json.load(f)
        
    with open(idx_file) as f:
        idx_dict = json.load(f)
        
    with open(idx_sim_file) as f:
        idx_sim_dict = json.load(f)
    
    if sim_metric == 'l2_distance':
        with open(dist_file) as f:
            dist_dict = json.load(f)
        return dist_dict, sim_dict, idx_dict, idx_sim_dict
    
    else:
        idx_dict_reverse = {'train':{str(v):k for k,v in idx_dict['train'].items()}, 
                            'test':{str(v):k for k,v in idx_dict['test'].items()}}
        idx_sim_dict_reverse = {str(v):k for k,v in idx_sim_dict.items()}
    
        # load ssim values for training images
        dist_dict = dict()
        dist_list = [os.path.join(dist_dir, fname) for fname in os.listdir(dist_dir) if fname.endswith('.json')]
        for dist_file in dist_list:
            with open(dist_file) as f:
                dist_temp = json.load(f)
            for cur_img, cur_dict in dist_temp.items():
                dist_dict[idx_dict_reverse['test'][cur_img]] = {idx_dict_reverse['train'][k]:v for k,v in cur_dict.items()}
        
        # load ssim values for excluded images
        dist_list = [os.path.join(dist_dir, 'excluded/' + fname) for fname in os.listdir(os.path.join(dist_dir, 'excluded')) if fname.endswith('.json')]
        for dist_file in dist_list:
            with open(dist_file) as f:
                dist_temp = json.load(f)
            for cur_img, cur_dict in dist_temp.items():
                dist_dict[idx_dict_reverse['test'][cur_img]].update({idx_sim_dict_reverse[k]:v for k,v in cur_dict.items()})
        return dist_dict, sim_dict, idx_dict, idx_sim_dict
    
def load_correct_classified(res_dir, test_dict):
    """
    Load subset of test images that are correctly classified with candidate training set,
    but wrongly classified with CIFAR-10 training set.
    """
    cndt_file = os.path.join(res_dir, 'train_logs_corr.json')
    cifar_file = os.path.join(res_dir, 'train_logs_cifar_corr.json')
    
    reverse_dict = {str(v):k for k,v in test_dict.items()}
    
    with open(cndt_file) as f:
        cndt_dict = json.load(f)
        
    with open(cifar_file) as f:
        cifar_dict = json.load(f)
        
    return sorted([reverse_dict[str(idx)] for idx in cndt_dict['correct'] if idx not in cifar_dict['correct']])

def get_next_image(ev):
    display(Javascript('IPython.notebook.execute_cell_range(IPython.notebook.get_selected_index()-1, IPython.notebook.ncells())'))


In [None]:
# load training set that excludes images similar to test images
data = np.load('candidate.npz')
X_train, X_test = data['X_train'], data['X_test']

# load the images that were eliminated from the candidate set
data = np.load('sims.npz')
X_sim = data['X_sim']

#### Check similar images in the training set per test image
- `test_ind`: Index of the test set for inspection.

In [None]:
'''Select metric for similarity from the toggle button
from l2_distance or structural_similarity(coming soon)
Click 'Show Similars' to check similar images.
'''
sim_metric = widgets.ToggleButtons(
    options=['l2_distance', 'structural_similarity'], # coming soon: 'structural_similarity' 
    description='Metric:',
    disabled=False,
    button_style='',
    tooltips=['l2 distance between images', 'ssim from scikit image'], # 'ssim from scikit image'
)
display(sim_metric)

idx_select = widgets.ToggleButtons(
    options=['all', 'diff'], # coming soon: 'structural_similarity' 
    description='Images:',
    disabled=False,
    button_style='',
    tooltips=['all test images and similars', 
              'non-overlaps between candidate and cifar training sets'], # 'ssim from scikit image'
)
display(idx_select)

In [None]:
HTML('''<script> </script> <form action="javascript:IPython.notebook.execute_cells_below()"><input type="submit" id="toggleButton" value="Show Similars"></form>''')

In [None]:
# load files that contain l2 distance and manually eliminated image lists
json_dir = os.path.join(os.getcwd(), 'sim_data')  # directory of json files
assert sim_metric.value in ['l2_distance', 'structural_similarity']
assert idx_select.value in ['all', 'diff']

if sim_metric.value == 'l2_distance':
    dist_dict, sim_dict, idx_dict, idx_sim_dict = load_sim_files(json_dir, sim_metric.value)
else:
    dist_dict, sim_dict, idx_dict, idx_sim_dict = load_sim_files(json_dir, sim_metric.value)
    
if idx_select == 'all':
    test_imgs = sorted(list(dist_dict.keys()))
else:
    res_dir = os.path.join(os.getcwd(), 'results/resnet/')
    test_imgs = load_correct_classified(res_dir, idx_dict['test'])

In [None]:
test_ind = 0   # Test image index for inspection
test_ind -= 1   # offset for automatic index increase

In [None]:
# check next test image
test_ind += 1
print('Inspecting test image #' + str(test_ind))
cur_test_img = test_imgs[test_ind]
cur_matches = sorted(dist_dict[cur_test_img].items(), key=operator.itemgetter(1))[:100]
cur_similars = set()
if cur_test_img in sim_dict:
    cur_similars = set(sim_dict[cur_test_img])

In [None]:
num_images_to_show = len(cur_matches)+1            
img_offset = -1

assert img_offset >= -1 

num_cols = 8
num_rows = 13
num_per_tab = num_cols * num_rows
num_tabs = 1             
scale=3

checkboxes = {}

tab_contents = []
for kk in tqdm.tqdm(range(num_tabs), desc='#{} - Setting up similar images by {}'.format(str(test_ind), sim_metric.value)):
    rows = []
    cur_num_rows = num_rows
    for ii in range(cur_num_rows):
        cur_row = []
        cur_num_cols = num_cols
        if kk == num_tabs - 1 and ii == cur_num_rows - 1:
            cur_num_cols = num_images_to_show - (num_tabs - 1) * num_per_tab - (cur_num_rows - 1) * num_cols
        for jj in range(cur_num_cols):
            cur_index = img_offset + kk * num_per_tab + ii * num_cols + jj
            if cur_index < 0:
                cur_img = widgets.Image(value=np_to_png(X_test[idx_dict['test'][cur_test_img]]), width=96, height=96)
                cur_checkbox = widgets.Checkbox(value=False, description='original', disabled=True, indent=False, layout=widgets.Layout(width='100px', height='28'))
                cur_checkbox.width = '90px'
                checkboxes[cur_test_img] = cur_checkbox
            elif cur_matches[cur_index][0] in cur_similars or cur_matches[cur_index][0] in idx_sim_dict:
                cur_img = widgets.Image(value=np_to_png(X_sim[idx_sim_dict[cur_matches[cur_index][0]]]), width=96, height=96)
                # cur_checkbox = widgets.Checkbox(value=True, description=str(np.around(cur_matches[cur_index][1],2)) + 'NOT included', indent=False, layout=widgets.Layout(width='100px', height='28')) 
                cur_checkbox = widgets.HTML(value = str(np.around(cur_matches[cur_index][1],2)) + " - <b><font color='red'>Excluded</b>", layout=widgets.Layout(width='100px', height='28'))
                cur_checkbox.width = '90px'
                checkboxes[cur_matches[cur_index][0]] = cur_checkbox
            else:
                cur_img = widgets.Image(value=np_to_png(X_train[idx_dict['train'][cur_matches[cur_index][0]]]), width=96, height=96)
                cur_checkbox = widgets.Checkbox(value=False, description=str(cur_matches[cur_index][1]), indent=False, layout=widgets.Layout(width='100px', height='28'))
                cur_checkbox.width = '90px'
                checkboxes[cur_matches[cur_index][0]] = cur_checkbox
            cur_box = widgets.VBox([cur_img, cur_checkbox])
            cur_box.layout.align_items = 'center'
            cur_box.layout.padding = '6px'
            cur_row.append(cur_box)
        cur_hbox = widgets.HBox(cur_row)
        rows.append(cur_hbox)
    tab_contents.append(widgets.VBox(rows))

tab = widgets.Tab()
tab.children = tab_contents
for i in range(len(tab.children)):
    tab.set_title(i, str(i))

next_img_button = widgets.Button(description="Next image")
next_img_button.on_click(get_next_image)
display(next_img_button)

display(tab)