## Create visualizations for local metrics (neighborhoods of data)

In [12]:
# General imports
import torch
import numpy as np
import os, sys
import json
from tqdm import tqdm
import matplotlib.pyplot as plt

In [2]:
# Local imports
sys.path.insert(0, 'src')
from utils import read_json, read_lists, list_to_dict, ensure_dir, load_image, get_image_id
from utils.model_utils import prepare_device, quick_predict
from utils.knn_utils import display_image_paths
from utils.visualizations import plot
from parse_config import ConfigParser
from data_loader import data_loaders
import model.model as module_arch

In [3]:
# Define constants, paths
class_list_path = os.path.join('metadata', 'cinic-10', 'class_names.txt')

config_path = 'configs/copies/cinic10_imagenet_segmentation_edit_trials.json'

In [4]:
# Load config file, models, and dataloader
class_list = read_lists(class_list_path)
class_idx_dict = list_to_dict(class_list)

config_json = read_json(config_path)
K = config_json['editor']['K']

config = ConfigParser(config_json)

device, device_ids = prepare_device(config['n_gpu'])

trial_paths_path_template = os.path.join('saved', 'edit', 'trials', 'CINIC10_ImageNet-VGG_16', '{}_{}', '{}', 'trial_paths.txt')


### Create visualizations for top 10 nearest neighbors

In [None]:
# Function to take in KNN object and display neighbors for all combos of key/val and logits/features/images
def show_nearest_neighbors(key_path,
                           value_path,
                           knn,
                           n_display,
                           image_id,
                           save_dir,
                           title_template,
                           show=True):
                           # title=None):
    data_modes = knn.keys()
    image_title_template = 'GT: {}\nPred: {}\nDist: {:.3f}'

    target_class = os.path.basename(os.path.dirname(key_path))
    predicted_key_class_idx = np.argmax(knn['logits']['anchor_data'][0])
    predicted_key_class = class_list[predicted_key_class_idx]
    
    predicted_value_class_idx = np.argmax(knn['logits']['anchor_data'][1])
    predicted_value_class = class_list[predicted_value_class_idx]
    
    if 'Pre' in title_template:
        prepost = 'pre'
    elif 'Post' in title_template:
        prepost = 'post'
    else: 
        prepost = ''
    # key_image = load_image(key_path)
    # value_image = load_image(value_path)
    
    # print(key_image.shape)
    for anchor_idx, anchor in enumerate(['key', 'val']):
        for data_mode in data_modes:
            knn_mode = knn[data_mode]
        
            image_paths = knn_mode['image_paths'][anchor_idx][:n_display]
            distances = knn_mode['distances'][anchor_idx][:n_display]
            predictions = knn_mode['predictions'][anchor_idx][:n_display]
            labels = knn_mode['labels'][anchor_idx][:n_display]
            
            image_titles = []
            
            # Form labels
            for image_path, distance, prediction, label in zip(image_paths, distances, predictions, labels):
                # images.append(load_images(image_path))
                image_title = image_title_template.format(class_list[label], class_list[prediction], distance)
                image_titles.append(image_title)
            
            # Prepend anchor image and label
            if anchor == 'key':
                image_paths.insert(0, key_path)
                image_titles.insert(0, 'GT: {}\nPred: {}\nKey Image'.format(target_class, predicted_key_class))
            else:
                image_paths.insert(0, value_path)
                image_titles.insert(0, 'GT: {}\nPred: {}\nValue Image'.format(target_class, predicted_value_class))
                
            title = title_template.format(n_display, anchor, image_id, data_mode)
            save_path = os.path.join(save_dir, '{}-edit_{}_nn_visual_{}.png'.format(prepost, anchor, data_mode))
            display_image_paths(
                image_paths=image_paths,
                labels=image_titles,
                figure_title=title,
                subplot_padding=2,
                save_path=save_path,
                show=show)
            

In [None]:
timestamp = '0127_103716'
n_display = 9
show = False

class_name = 'dog'
n_select = 100

trial_paths_path = trial_paths_path_template.format(class_name, n_select, timestamp)

# save_visualizations_dir = os.path.join(os.path.dirname(trial_paths_path), 'visualizations')
# ensure_dir(save_visualizations_dir)
trial_paths = read_lists(trial_paths_path)

for trial_idx, trial_dir in enumerate(tqdm(trial_paths)):
    config_path = os.path.join(trial_dir, 'models', 'config.json')
    config_dict = read_json(config_path)
    
    key_path = config_dict['editor']['key_paths_file']
    value_path = config_dict['editor']['value_paths_file']
    image_id = get_image_id(key_path)
    image_id += '-{}'.format(os.path.basename(os.path.dirname(value_path))) # append modification method
    
    # Load pre and post edit KNNs
    pre_edit_knn = torch.load(os.path.join(trial_dir, 'models', 'pre_edit_{}-nn.pth'.format(K)))
    post_edit_knn = torch.load(os.path.join(trial_dir, 'models', 'post_edit_{}-nn.pth'.format(K)))
    
    save_dir = os.path.join(trial_dir, 'models', 'knn_visualizations')
    ensure_dir(save_dir)
    
    # Save graphics for pre edit neighbors
    title_template = "Pre Edit {} NN for {} [{}] (based on '{}')"
    show_nearest_neighbors(
        key_path=key_path,
        value_path=value_path,
        knn=pre_edit_knn, 
        n_display=n_display,
        image_id=image_id,
        title_template=title_template,
        save_dir=save_dir,
        show=show)
    
    # Save graphics for post edit neighbors
    title_template = "Post Edit {} NN for {} [{}] (based on '{}')"
    show_nearest_neighbors(
        key_path=key_path,
        value_path=value_path,
        knn=post_edit_knn, 
        n_display=n_display,
        image_id=image_id,
        title_template=title_template,
        save_dir=save_dir,
        show=show)


### Obtain overlaps in the neighbors

In [20]:
def get_overlaps(data1,
                 data2,
                 intervals,
                 relative=True):
    '''
    Given the data and intervals, compute the amount of overlap (%) between the two
    Arg(s):
        data1 : list[any]
            sequence of data
        data2 : list[any]
            second sequence of data
        intervals : list[int]
            at which intervals to calculate overlaps
    Returns:
        list[float] : list of % overlaps at each interval
    '''
    max_len = len(data1) if len(data1) > len(data2) else len(data2)
    overlaps = []
    for interval in intervals:
        
        if interval > max_len:
            interval = max_len
        set1 = set(data1[:interval])
        set2 = set(data2[:interval])
        
        overlap = len(set1 & set2)
        if relative:
            overlap = round(100 * overlap/ interval, 2)

            
        overlaps.append(overlap)
        
        if interval == max_len:
            break
    return overlaps

In [34]:
def plot_pre_post_overlap(pre_edit_knn,
                          post_edit_knn,
                          intervals,
                          trial_dir,
                          trial_id,
                          relative=True,
                          save_plots=True,
                          show=False):
    
    overlaps = []
    legends = []
    
    # Get % Overlap for all 4 combinations
    for anchor_idx, anchor in enumerate(['key', 'val']):
        for data_mode in ['logits', 'features']:
            pre_edit_paths = pre_edit_knn[data_mode]['image_paths'][anchor_idx]
            post_edit_paths = post_edit_knn[data_mode]['image_paths'][anchor_idx]
            
            cur_overlaps = get_overlaps(
                data1=pre_edit_paths,
                data2=post_edit_paths,
                intervals=intervals,
                relative=relative)
            
            overlaps.append(cur_overlaps)
            legends.append('{} {}'.format(anchor, data_mode))
    
    # Add Upper Limit
    if not relative:
        all_intervals = [intervals for i in range(len(overlaps))]
        all_intervals.append((0, 100))
        overlaps.append((0, 100))
        legends.append('Upper Limit')
        
    if save_plots:
        save_plot_path = os.path.join(trial_dir, 'models', 'knn_visualizations', 'pre_post_neighbor_overlap.png')
    else:
        save_plot_path = None
        
    # Plot
    fig, ax = plot(
        xs=all_intervals,
        ys=overlaps,
        labels=legends,
        title='{} Overlap of Neighbors Pre and Post Edit\n[{}]'.format('Relative' if relative else 'Absolute', trial_id),
        xlabel='Neighborhood Size',
        ylabel='{} Overlap of Neighbors'.format('%' if relative else '#'),
        ylimits=(0, 100),
        scatter=False,
        line=True,
        save_path=save_plot_path,
        show=show)
    
            
    
def plot_logit_feature_overlap(pre_edit_knn,
                               post_edit_knn,
                               intervals,
                               trial_dir, 
                               trial_id,
                               relative=True,
                               save_plots=True,
                               show=False):
    overlaps = []
    legends = []
    
    for anchor_idx, anchor in enumerate(['key', 'val']):
        # Obtain neighbors for pre-edit 
        pre_edit_features_paths = pre_edit_knn['features']['image_paths'][anchor_idx]
        pre_edit_logits_paths = pre_edit_knn['logits']['image_paths'][anchor_idx]

        cur_overlaps = get_overlaps(
            data1=pre_edit_features_paths,
            data2=pre_edit_logits_paths,
            intervals=intervals,
            relative=relative)

        overlaps.append(cur_overlaps)
        legends.append('pre edit {}'.format(anchor))
        
        # Obtain neighbors for post-edit
        post_edit_features_paths = post_edit_knn['features']['image_paths'][anchor_idx]
        post_edit_logits_paths = post_edit_knn['logits']['image_paths'][anchor_idx]

        cur_overlaps = get_overlaps(
            data1=post_edit_features_paths,
            data2=post_edit_logits_paths,
            intervals=intervals,
            relative=relative)

        overlaps.append(cur_overlaps)
        legends.append('post edit {}'.format(anchor))
        
    all_intervals = [intervals for i in range(len(overlaps))]
    
    # Add Upper Limit
    if not relative:
        all_intervals = [intervals for i in range(len(overlaps))]
        all_intervals.append((0, 100))
        overlaps.append((0, 100))
        legends.append('Upper Limit')
        
    if save_plots:
        save_plots_path = os.path.join(trial_dir, 'models', 'knn_visualizations', 'logit_feature_neighbor_overlap.png')
    else:
        save_plots_path = None
    plot(
        xs=all_intervals,
        ys=overlaps,
        labels=legends,
        title='{} Overlap of Neighbors Between Logits and Features \n[{}]'.format('Relative' if relative else 'Absolute', trial_id),
        xlabel='Neighborhood Size',
        ylabel='{} Overlap of Neighbors'.format('%' if relative else '#'),
        ylimits=(0, 100),
        scatter=False,
        line=True,
        save_path=save_plots_path,
        show=show)

#### Obtain overlaps in neighbors pre/post edit and overlaps bw logits/features

In [35]:
n_step = 10
intervals = [i for i in range(5 , K, n_step)]
intervals.append(K)

timestamp = '0127_103716'
class_name = 'dog'
n_select = 100

save_plots = True
show = False
relative = False

trial_paths_path = trial_paths_path_template.format(class_name, n_select, timestamp)
trial_paths = read_lists(trial_paths_path)

for trial_idx, trial_dir in enumerate(tqdm(trial_paths)):
    # Get trial ID
    trial_id = os.path.join(*trial_dir.split('/')[-2:])
    try:
        trial_id = trial_id.split('_softmax')[0]
    except:
        continue
    
    # Get Pre and Post KNN Objects
    pre_edit_knn = torch.load(os.path.join(trial_dir, 'models', 'pre_edit_{}-nn.pth'.format(K)))
    post_edit_knn = torch.load(os.path.join(trial_dir, 'models', 'post_edit_{}-nn.pth'.format(K)))
    
    # Plot overlaps for pre vs post edit
    plot_pre_post_overlap(
        pre_edit_knn=pre_edit_knn,
        post_edit_knn=post_edit_knn,
        intervals=intervals,
        trial_dir=trial_dir,
        trial_id=trial_id,
        relative=relative,
        save_plots=save_plots,
        show=show)
    
    # Plot overlaps for features vs logits
    plot_logit_feature_overlap(
        pre_edit_knn=pre_edit_knn,
        post_edit_knn=post_edit_knn,
        intervals=intervals,
        trial_dir=trial_dir,
        relative=relative,
        trial_id=trial_id,
        save_plots=save_plots,
        show=show)
    
    plt.close('all')


        

100%|███████████████████████████████████████████████████████████████████████| 55/55 [00:21<00:00,  2.57it/s]


In [29]:
# Pedal to the metal!