## Create visualizations for local metrics (neighborhoods of data)

In [1]:
# General imports
import torch
import numpy as np
import os, sys
import json
# from tqdm import tqdm

In [3]:
# Local imports
sys.path.insert(0, 'src')
from utils import read_json, read_lists, list_to_dict, ensure_dir, load_image
from utils.model_utils import prepare_device, quick_predict
from utils.knn_utils import display_image_paths
from parse_config import ConfigParser
from data_loader import data_loaders
import model.model as module_arch

In [4]:
# Define constants, paths
class_list_path = os.path.join('metadata', 'cinic-10', 'class_names.txt')

config_path = 'configs/copies/cinic10_imagenet_segmentation_edit_trials.json'

In [5]:
# Load config file, models, and dataloader
class_list = read_lists(class_list_path)
class_idx_dict = list_to_dict(class_list)

config_json = read_json(config_path)
K = config_json['editor']['K']

config = ConfigParser(config_json)

device, device_ids = prepare_device(config['n_gpu'])

trial_paths_path_template = os.path.join('saved', 'edit', 'trials', 'CINIC10_ImageNet-VGG_16', '{}_{}', '{}', 'trial_paths.txt')


### Create visualizations for top 10 nearest neighbors

In [19]:
# Function to take in KNN object and display neighbors
def show_nearest_neighbors(knn,
                           n_display,
                           save_dir,
                           title=None):
    data_modes = knn.keys()
    image_title_template = 'GT: {}\nPred: {}\nDist: {}'
    key_image, val_image = knn['features']['anchor_data']
    print(key_image.shape)
    for anchor_idx, anchor in enumerate(['key', 'val']):
        anchor_image = knn[data_modes[0]]['anchor_data']
        # print(anchor_image.keys()
        for data_mode in data_modes:
            knn_mode = knn[data_mode]
            # Extract image_paths, distances, predictions
            # key_image_paths, val_image_paths = knn_mode['image_paths'][:n_display]
            # key_distances, val_distances = knn_mode['distances'][:n_display]
            # key_predictions, val_predictions = knn_mode['predictions'][:n_display]

            # image_paths = knn_mode['image_paths'][:n_display]
            # distances = knn_mode['distances'][:n_display]
            # predictions = knn_mode['predictions'][:n_display]
        
            image_paths = knn_mode['image_paths'][anchor_idx][:n_display]
            distances = knn_mode['distances'][anchor_idx][:n_display]
            predictions = knn_mode['predictions'][anchor_idx][:n_display]
            labels = knn_mode['labels'][anchor_idx][:n_display]
            print(labels[0])
            
            images = []
            image_titles = []
            
            # Form labels
            for image_path, distance, prediction, label in zip(image_paths, distances, predictions, labels):
                # images.append(load_images(image_path))
                image_title = image_title_template.format(class_list[label], class_list[prediction], distance)
                image_titles.append(image_title)
            
            save_path = os.path.join(save_dir, '{}_nn_visual_{}.png'.format(anchor, data_mode))
            display_image_paths(
                image_paths=image_paths,
                labels=image_titles,
                figure_title=title,
                save_path=save_path)
            
            print(len(image_paths))

In [20]:
timestamp = '0127_103716'
class_name = 'airplane'
n_select = 100

trial_paths_path = trial_paths_path_template.format(class_name, n_select, timestamp)

# save_visualizations_dir = os.path.join(os.path.dirname(trial_paths_path), 'visualizations')
# ensure_dir(save_visualizations_dir)
trial_paths = read_lists(trial_paths_path)

for trial_idx, trial_dir in enumerate(trial_paths):
    config_path = os.path.join(trial_dir, 'models', 'config.json')
    config_dict = read_json(config_path)
    
    key_path = config_dict['editor']['key_paths_file']
    value_path = config_dict['editor']['value_paths_file']
    # Load pre and post edit KNNs
    pre_edit_knn = torch.load(os.path.join(trial_dir, 'models', 'pre_edit_{}-nn.pth'.format(K)))
    post_edit_knn = torch.load(os.path.join(trial_dir, 'models', 'post_edit_{}-nn.pth'.format(K)))
    print(pre_edit_knn['features'].keys())
    # print(pre_edit_knn['features']['anchor_data'].keys())
    # assert (pre_edit_knn['features']['anchor_data'] == pre_edit_knn['logits']['anchor_data'])
    show_nearest_neighbors(
        pre_edit_knn, 
        n_display=10,
        title='Pre Edit Neighbors',
        save_dir=os.path.join(trial_dir, 'models', 'knn_visualizations', 'pre_edit'))
    # Visualize each 
    if trial_idx == 0:
        break

dict_keys(['indices', 'distances', 'image_paths', 'labels', 'predictions', 'anchor_data', 'neighbor_data'])
(512,)


TypeError: 'dict_keys' object is not subscriptable

In [None]:
# Load datasets
# data_loader_args = dict(config.config["data_loader"]["args"])

In [None]:
# # Load model
# layernum = config.config['layernum']
# model = config.init_obj('arch', module_arch, layernum=layernum)

In [None]:
# Function definitions

In [None]:
# Pedal to the metal!