## Given edit and corresponding logit bump, examine the differences

In [None]:
# General imports
import torch
import numpy as np
import os, sys
import json
from tqdm import tqdm
from difflib import SequenceMatcher
import matplotlib.pyplot as plt

In [None]:
# Local imports
sys.path.insert(0, 'src')
from utils import read_json, read_lists, list_to_dict, ensure_dir
from utils.model_utils import prepare_device, quick_predict
from utils.visualizations import plot
from parse_config import ConfigParser
from data_loader import data_loaders
import model.model as module_arch

In [None]:
# Define constants, paths
class_list_path = os.path.join('metadata', 'cinic-10', 'class_names.txt')
class_list = read_lists(class_list_path)
class_idx_dict = list_to_dict(class_list)

target_class_name = 'airplane'
target_class_idx = class_idx_dict[target_class_name]
n_select = 100

# Get paths to post edit trial paths
edit_trial_paths_path = 'saved/edit/trials/CINIC10_ImageNet-VGG_16/airplane_100/0214_112633/trial_paths.txt'
assert '{}_{}'.format(target_class_name, n_select) in edit_trial_paths_path
logit_bump_trial_paths_path = 'saved/edit/experiments/corresponding_bump_edits/CINIC10_ImageNet-VGG_16/0213_143741/airplane_100/trial_paths_mini.txt'
assert '{}_{}'.format(target_class_name, n_select) in logit_bump_trial_paths_path

edit_trial_paths = read_lists(edit_trial_paths_path)
logit_bump_trial_paths = read_lists(logit_bump_trial_paths_path)

In [None]:
# Load pre edit logits
pre_edit_logits_path = 'metadata/CINIC10_ImageNet-VGG_16/pre_edit_validation_set/pre_edit_logits.pth'
pre_edit_logits = torch.load(pre_edit_logits_path)
print(pre_edit_logits.shape)

### Define functions to 1) plot pre edit logits

In [None]:
def plot_pre_post(pre_edit_logits,
                  post_edit_logits,
                  logit_bump_logits,
                  target_class_idx,
                  image_id,
                  subset='all',
                  n_classes=10,
                  plot_softmax=False,
                  show=True,
                  save_path=None):
    '''
    Plot graph of post edit vs pre edit logits/softmax for each data point in the validation set
    
    Arg(s):
        pre_edit_logits : N x C (70K x 10) torch.tensor
            logits of the pre edit model 
        post_edit_logits : N x C (70K x 10) torch.tensor
            logits of the post edit model 
        logit_bump_logits : N x C (70K x 10) torch.tensor
            logits of the corresponding logit_bump to the post edit model
        target_class_idx : int
            index of target class
        target_class_name : str
            name of target class
        plot_softmax : bool
            if True, plot the softmax instead of pure logits values
        show : bool
            whether or not to show graph
        save_path : str or None
            if not None, save plot to designated path
    '''
    # Obtain x, y of data points
    if plot_softmax:
        # Get softmax and convert back to numpy
        if not torch.is_tensor(pre_edit_logits):
            pre_edit_logits = torch.from_numpy(pre_edit_logits)
        if not torch.is_tensor(post_edit_logits):
            post_edit_logits = torch.from_numpy(post_edit_logits)
        if not torch.is_tensor(logit_bump_logits):
            logit_bump_logits = torch.from_numpy(logit_bump_logits)
        
        pre_edit_softmax = torch.softmax(pre_edit_logits, dim=1).cpu().numpy()
        post_edit_softmax = torch.softmax(post_edit_logits, dim=1).cpu().numpy()
        logit_bump_softmax = torch.softmax(logit_bump_logits, dim=1).cpu().numpy()
        
        pre_edit_target_data = pre_edit_softmax
        post_edit_target_data = post_edit_softmax
        logit_bump_target_data = logit_bump_softmax
        
        # xs = [pre_edit_target_softmax, pre_edit_target_softmax]
        # ys = [post_edit_target_softmax, logit_bump_target_softmax]
        # legends = ['Edited Softmax', 'Logit Bump Softmax']
        
    else:
        # Convert all to numpy
        if torch.is_tensor(pre_edit_logits):
            pre_edit_logits = pre_edit_logits.cpu().numpy()
        if torch.is_tensor(post_edit_logits):
            post_edit_logits = post_edit_logits.cpu().numpy()
        if torch.is_tensor(logit_bump_logits):
            logit_bump_logits = logit_bump_logits.cpu().numpy()
            
        pre_edit_target_data = pre_edit_logits
        post_edit_target_data = post_edit_logits
        logit_bump_target_data = logit_bump_logits
    
    # Choose appropriate subset of data
    if subset == 'all': # all data points
        pre_edit_target_data = pre_edit_target_data[:, target_class_idx]
        post_edit_target_data = post_edit_target_data[:, target_class_idx]
        logit_bump_target_data = logit_bump_target_data[:, target_class_idx]
    else:
        post_edit_predictions = np.argmax(post_edit_target_data, axis=1)
        logit_bump_predictions = np.argmax(logit_bump_target_data, axis=1)
        
        post_edit_subset_idxs = np.where(post_edit_predictions == target_class_idx) # where the max = target_class_idx
        logit_bump_subset_idxs = np.where(logit_bump_predictions == target_class_idx)
        overlap_subset_idxs = np.intersect1d(post_edit_subset_idxs, logit_bump_subset_idxs)
        
        # Sanity check that these are corresponding
        assert abs(len(post_edit_subset_idxs) - len(logit_bump_subset_idxs)) < 10
        if subset == 'pred_target': # data points where edited prediction is target
            pre_edit_target_data = pre_edit_target_data[:, target_class_idx]
            pre_edit_target_data = pre_edit_target_data[post_edit_subset_idxs]

            post_edit_target_data = post_edit_target_data[:, target_class_idx]
            post_edit_target_data = post_edit_target_data[post_edit_subset_idxs]

            logit_bump_target_data = logit_bump_target_data[:, target_class_idx]
            logit_bump_target_data = logit_bump_target_data[post_edit_subset_idxs]
        elif subset == 'logit_bump_target':  # data points where logit bump is target
            pre_edit_target_data = pre_edit_target_data[:, target_class_idx]
            pre_edit_target_data = pre_edit_target_data[logit_bump_subset_idxs]

            post_edit_target_data = post_edit_target_data[:, target_class_idx]
            post_edit_target_data = post_edit_target_data[logit_bump_subset_idxs]

            logit_bump_target_data = logit_bump_target_data[:, target_class_idx]
            logit_bump_target_data = logit_bump_target_data[logit_bump_subset_idxs]
            
        elif subset == 'pred_logit_bump_target': # data where both edited and logit bump are target
            pre_edit_target_data = pre_edit_target_data[:, target_class_idx]
            pre_edit_target_data = pre_edit_target_data[overlap_subset_idxs]

            post_edit_target_data = post_edit_target_data[:, target_class_idx]
            post_edit_target_data = post_edit_target_data[overlap_subset_idxs]

            logit_bump_target_data = logit_bump_target_data[:, target_class_idx]
            logit_bump_target_data = logit_bump_target_data[overlap_subset_idxs]
        elif subset == 'true_target': # data where the true class is target (has underlying assumptions about how data is ordered)
            n_per_class = int(pre_edit_target_data.shape[0] / n_classes)
            print(n_per_class)
            start_idx = target_class_idx * n_per_class
            pre_edit_target_data = pre_edit_target_data[start_idx:start_idx+n_per_class, target_class_idx]
            post_edit_target_data = post_edit_target_data[start_idx:start_idx+n_per_class, target_class_idx]
            logit_bump_target_data = logit_bump_target_data[start_idx:start_idx+n_per_class, target_class_idx]
            
    xs = [pre_edit_target_data, pre_edit_target_data]
    ys = [post_edit_target_data, logit_bump_target_data]
    legends = [
        'Post Edit {}'.format('Softmax' if plot_softmax else 'Logits'), 
        'Logit Bump {}'.format('Softmax' if plot_softmax else 'Logits')
    ]
        
    title = 'Pre vs Post Edit {} for \n{}\nSubset: {}'.format(
        'Softmax' if plot_softmax else 'Logits',
        image_id,
        subset)
    xlabel = 'Pre Edit {}'.format('Softmax' if plot_softmax else 'Logits')
    ylabel = 'Post Edit {}'.format('Softmax' if plot_softmax else 'Logits')
    # endpoints of y=x line
    
    if plot_softmax:
        highlight = None
        highlight_label = None
    else:
        pre_edit_data_min = np.amin(pre_edit_target_data)
        pre_edit_data_max = np.amax(pre_edit_target_data)
        highlight = [(pre_edit_data_min, pre_edit_data_max), (pre_edit_data_min, pre_edit_data_max)]
        highlight_label = 'y=x'
    fig, axs = plot(
        xs=xs,
        ys=ys,
        labels=legends,
        title=title,
        xlabel=xlabel,
        ylabel=ylabel,
        highlight=highlight,
        highlight_label=highlight_label,
        alpha=0.1,
        marker_size=2,
        scatter=True,
        line=False,
        show=show,
        save_path=save_path)
    

    return fig, axs


In [None]:
# n_data_points = 7000
# Create directory to save graphs
save_root = os.path.dirname(logit_bump_trial_paths_path)

# Define possible subset of data points to look at
subset = 'pred_logit_bump_target' # ['all', 'pred_target', 'logit_bump_target', 'pred_logit_bump_target', 'true_target']
subsets = ['all', 'pred_target', 'logit_bump_target', 'pred_logit_bump_target', 'true_target']
show = False
save_graphs_dir = os.path.join(save_root, 'graphs', 'pre_post_logits')
trial_paths_path = os.path.join(save_root, 'trial_paths.txt')

for subset in subsets:
    for trial_idx, (edit_trial_path, logit_bump_trial_path) in enumerate(zip(edit_trial_paths, logit_bump_trial_paths)):
        if trial_idx == 10:
            break
        # Extract image_id
        s = SequenceMatcher(None, edit_trial_path, logit_bump_trial_path)
        match = s.find_longest_match(alo=0, ahi=len(edit_trial_path), blo=0, bhi=len(logit_bump_trial_path))
        image_id = edit_trial_path[match.a+1:match.a+match.size-len('_softmax')]

        # Create directory for this trial
        trial_dir = os.path.join(save_graphs_dir, image_id)

        edit_trial_logits_path = os.path.join(edit_trial_path, 'models', 'post_edit_logits.pth')
        logit_bump_logits_path = os.path.join(logit_bump_trial_path, 'models', 'post_edit_logits.pth')


        edit_trial_logits = torch.load(edit_trial_logits_path)
        logit_bump_logits = torch.load(logit_bump_logits_path)
        plot_pre_edit_logits = pre_edit_logits

        save_logits_path = os.path.join(
            trial_dir,
            '{}_pre_post_logits.png'.format(subset))
        save_softmax_path = os.path.join(
            trial_dir,
            '{}_pre_post_softmax.png'.format(subset))
        ensure_dir(os.path.dirname(save_softmax_path))

        logits_fig, logits_axis = plot_pre_post(
            pre_edit_logits=plot_pre_edit_logits,
            post_edit_logits=edit_trial_logits,
            logit_bump_logits=logit_bump_logits,
            target_class_idx=target_class_idx,
            image_id=image_id,
            subset=subset,
            show=show,
            save_path=save_logits_path)

        softmax_fig, softmax_axis = plot_pre_post(
            pre_edit_logits=plot_pre_edit_logits,
            post_edit_logits=edit_trial_logits,
            logit_bump_logits=logit_bump_logits,
            target_class_idx=target_class_idx,
            image_id=image_id,
            plot_softmax=True,
            subset=subset,
            show=show,
            save_path=save_softmax_path)
    

In [None]:
print("saved/edit/experiments/corresponding_bump_edits/CINIC10_ImageNet-VGG_16/0213_143741/airplane_100/results_mini/airplane-train-n03365231_4635/felzenszwalb_masked_softmax/models/post_edit_logits.pth")

In [None]:
# Pedal to the metal!