## Plot n chosen segment vs edit quality metric

In [None]:
# General imports
import torch
import numpy as np
import os, sys
import json
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# Local imports
sys.path.insert(0, 'src')
from utils import read_json, read_lists, list_to_dict, ensure_dir
from utils.model_utils import prepare_device, quick_predict
from utils.df_utils import load_and_preprocess_csv
from utils.visualizations import histogram, bar_graph, plot, boxplot
from parse_config import ConfigParser
from data_loader import data_loaders
import model.model as module_arch

In [None]:
# Define constants, paths
class_list_path = os.path.join('metadata', 'cinic-10', 'class_names.txt')

config_path = 'configs/copies/cinic10_imagenet_segmentation_edit_trials.json'

class_name = 'airplane'
n_select = 100
timestamp = '0127_103716'
paths_timestamp = '0126_161209'


In [None]:
# Load config file, models, and dataloader
class_list = read_lists(class_list_path)
class_idx_dict = list_to_dict(class_list)

config_json = read_json(config_path)
K = config_json['editor']['K']

device, device_ids = prepare_device(config_json['n_gpu'])

root_dir = os.path.join('saved', 'edit', 'trials', 'CINIC10_ImageNet-VGG_16', '{}_{}', timestamp)
root_dir = root_dir.format(class_name, n_select)
save_paths_dir = os.path.join('paths', 'edits', 'semantics', '{}_{}'.format(class_name, n_select), paths_timestamp)

graph_save_dir = os.path.join(root_dir, 'graphs')

trial_paths_path = os.path.join(root_dir, 'trial_paths.txt')
csv_path = os.path.join(root_dir, 'results_table.csv')
val_paths_path = os.path.join(save_paths_dir, 'value_images_softmax.txt')

show = False


In [None]:
# Load objects
df = load_and_preprocess_csv(
    csv_path=csv_path,
    drop_duplicates=['ID'])

trial_paths = read_lists(trial_paths_path)
val_paths = read_lists(val_paths_path)

print("Restoring trial_paths from {}".format(trial_paths_path))
print("Restoring results csv from {}".format(csv_path))
print("Restoring segmentatation paths from {}".format(val_paths_path))

### Obtain number of segments modified for each edit and make a barchart

In [None]:
def match_idx(element, test_elements):
    n_repeat = test_elements.shape[0]
    chosen_idx = -1
    
    for n in range(n_repeat):
        if np.array_equal(element, test_elements[n]):
            chosen_idx = n
            break
    return chosen_idx
    
def get_segment_number(segmentation_paths):
    '''
    Given list of paths to saved segmentation dictionaries (output of segment_semantically.ipynb),
    Return list of numbers showing segment number chosen
    
    Arg(s):
        segmentation_paths : list[str]
            list of paths to segmentation objects
    
    Returns:
        chosen_idxs : list[int]
            list of indices of chosen segments
    '''
    chosen_idxs = []
    
    for segmentation_path in tqdm(segmentation_paths):
        # Load segmentation dictionary -> selected image and all cumulative images
        segmentation_dict = torch.load(segmentation_path)
        chosen_image = segmentation_dict['softmax_most_change_image']
        cumulative_images = segmentation_dict['softmax_cum_modifications']
        if torch.is_tensor(cumulative_images):
            cumulative_images = cumulative_images.cpu().numpy()
            
        # Find index that selected image matches cumulative images
        chosen_idxs.append(match_idx(chosen_image, cumulative_images))

    return chosen_idxs


In [None]:
segmentation_paths = [os.path.join(os.path.dirname(val_path), 'cumulative_segment_results.pth') for val_path in val_paths]
segment_idxs = get_segment_number(segmentation_paths)

In [None]:


bins = np.expand_dims(np.bincount(segment_idxs), axis=0)
# bins = np.stack([np.bincount(segment_idxs), np.bincount(segment_idxs)], axis=0)
bin_labels = [i for i in range(bins.shape[-1])]
bar_graph_save_path = os.path.join(graph_save_dir, 'summary', 'n_segment_bar_graph.png')
bar_graph(
    data=bins,
    xlabel='Number of Segments Modified',
    ylabel='Counts of Edits',
    title='Distribution of Number Segments Modified for Value Image in {} Class'.format(class_name),
    labels=bin_labels,
    xlabel_rotation=0,
    save_path=bar_graph_save_path,
    show=False)
    

### Obtain corresponding edit quality and make a plot

In [None]:
# def histogram(data,
#               multi_method='side',
#               n_bins=10,
#               labels=None,
#               data_range=None,
#               alpha=1.0,
#               colors=None,
#               title=None,
#               xlabel=None,
#               ylabel=None,
#               marker=None,
#               save_path=None,
#               show=True):
#     '''
#     Plot histogram of data provided

#     Arg(s):
#         data : np.array or sequence of np.array
#             Data for histogram
#         n_bins : int
#             number of bins for histogram
#         labels : list[str]
#             label for each type of histogram (should be same number of sequences as data)
#         data_range : (float, float)
#             upper and lower range of bins (default is max and min)
#     '''
    
#     assert multi_method in ['side', 'overlap'], "Unrecognized multi_method: {}".format(multi_method)
    
#     if type(data) == np.ndarray and len(data.shape) == 2:
#         data = data.tolist()
#     n_data = len(data)
            
#     if labels is None:
#         labels = [None for i in range(n_data)]
#     if colors is None:
#         colors = [None for i in range(n_data)]
            
#     if type(data) == np.ndarray and len(data.shape) == 1:
#         plt.hist(data,
#                 bins=n_bins,
#                 label=labels[0],
#                 range=data_range,
#                 color=colors,
#                 edgecolor='black',
#                 alpha=alpha)
#     else:
#         # Overlapping histograms
#         if multi_method == 'overlap':
#             for cur_idx, cur_data in enumerate(data):
#                 plt.hist(cur_data,
#                      bins=n_bins,
#                      label=labels[cur_idx],
#                      range=data_range,
#                      color=colors[cur_idx],
#                      edgecolor='black',
#                     alpha=alpha)
#         # Side by side histogram
#         else:
#             plt.hist(data,
#                  bins=n_bins,
#                  label=labels,
#                  range=data_range,
#                  color=None,
#                  edgecolor='black',
#                  alpha=alpha)

#     # Marker is a vertical line marking original
#     if marker is not None:
#         plt.axvline(x=marker, color='r')

#     # Make legend
#     if labels is not None:
#         plt.legend()
#     # Set title and axes labels
#     if title is not None:
#         plt.title(title)
#     if xlabel is not None:
#         plt.xlabel(xlabel)
#     if ylabel is not None:
#         plt.ylabel(ylabel)

#     if save_path is not None:
#         ensure_dir(os.path.dirname(save_path))
#         plt.savefig(save_path)
#     if show:
#         plt.show()
#     plt.clf()

In [None]:
quality_measurement = 'Post Target Accuracy'

# Sanity check the rows of DF are same as segmentation_paths
for edit_id, segmentation_path in zip(df['ID'], segmentation_paths):
    edit_id = edit_id.split('_softmax')[0]
    assert edit_id in segmentation_path
    
edit_qualities = df[quality_measurement]
pre_edit_quality = df[quality_measurement.replace('Post', 'Pre')].mean()

xlabel = 'Number of Segments Modified'
ylabel = quality_measurement
title = '{} vs {}'.format(ylabel, xlabel)

plot(
    xs=[segment_idxs],
    ys=[edit_qualities],
    line=False,
    xlabel=xlabel,
    ylabel=ylabel,
    title=title,
    show=True)

box_whisker_dict = {}
for edit_quality, segment_idx in zip(edit_qualities, segment_idxs): 
    if segment_idx in box_whisker_dict:
        box_whisker_dict[segment_idx].append(edit_quality)
    else:
        box_whisker_dict[segment_idx] = [edit_quality]
data = [np.array(box_whisker_plot) for box_whisker_plot in box_whisker_dict.values()]


labels = [segment_idx for segment_idx in box_whisker_dict.keys()]

# highlight_x =
boxplot_save_path = os.path.join(graph_save_dir, 'summary', 'boxplot_segment_{}.png'.format(quality_measurement.lower().replace(' ', '_')))
boxplot(
    data=data,
    labels=labels,
    xlabel=xlabel,
    ylabel=ylabel,
    highlight=pre_edit_quality,
    title=title,
    save_path=boxplot_save_path,
    show=True)
# plt.close('all')
# for multi_method in ['side', 'overlap']:
#     histogram(
#         data=data,
#         multi_method=multi_method,
#         labels=labels,
#         n_bins=10,
#         data_range=(0, 1),
#         xlabel=quality_measurement,
#         ylabel='Number of Edits',
#         title='Distribution of Edit Quality Separated By Num. Segments Modified')

plt.close('all')

