## Create Edits that match post edit class distribution of target class to real edits

In [4]:
# General imports
import torch
import numpy as np
import os, sys
import json
from tqdm import tqdm
import matplotlib.pyplot as plt
import math
from datetime import datetime
import collections

import warnings
from argparse import Namespace
warnings.filterwarnings("ignore")

from torchvision import transforms
from PIL import Image

from helpers import classifier_helpers
import helpers.data_helpers as dh
import helpers.context_helpers as coh
import helpers.rewrite_helpers as rh
import helpers.vis_helpers as vh

%matplotlib inline

In [None]:
# # Local imports
# sys.path.insert(0, 'src')
# from utils import read_json, read_lists, list_to_dict, ensure_dir, informal_log, write_lists, get_common_dir_path
# from utils.df_utils import load_and_preprocess_csv
# from utils.visualizations import histogram
# from utils.model_utils import prepare_device, quick_predict
# from parse_config import ConfigParser
# from test import predict, predict_with_bump
# import datasets.datasets as module_data
# import model.model as module_arch
# import model.metric as module_metric
# import model.loss as module_loss
# from trainer.editor import Editor

In [2]:
DATASET_NAME = 'ImageNet'
LAYERNUM = 12
REWRITE_MODE = 'editing'
ARCH = 'vgg16'

## Load model

In [3]:
ret = classifier_helpers.get_default_paths(DATASET_NAME, arch=ARCH)
DATASET_PATH, MODEL_PATH, MODEL_CLASS, ARCH, CD = ret
ret = classifier_helpers.load_classifier(MODEL_PATH, MODEL_CLASS, ARCH,
                            DATASET_NAME, LAYERNUM) 
model, context_model, target_model = ret[:3]

### Load test data

In [None]:
train_data, test_data = dh.get_vehicles_on_snow_data(DATASET_NAME, CD)

In [24]:
# Load post edit counter
restore_dir = os.path.join('edited_checkpoints', 'vehicles_on_snow')
post_edit_counter = torch.load(os.path.join(restore_dir, 'post_edit_counter.pth'))
pre_edit_counter = torch.load(os.path.join(restore_dir, 'pre_edit_counter.pth'))

print(pre_edit_counter, post_edit_counter)
# Sort classes based on largest change
all_classes = list(set(list(pre_edit_counter.keys()) + list(post_edit_counter.keys())))
deltas = []
for class_idx in all_classes:
    pre_edit_count = pre_edit_counter[class_idx] if class_idx in pre_edit_counter else 0
    post_edit_count = post_edit_counter[class_idx] if class_idx in post_edit_counter else 0
    deltas.append(np.abs(post_edit_count - pre_edit_count))
print(deltas)
# Sorts in ascending order
sorted_idxs = np.argsort(deltas)
sorted_idxs = np.flip(sorted_idxs)

# Two lists with corresponding classes and amount they changed by
deltas_sorted = np.array(deltas)[sorted_idxs]
classes_sorted = np.array(all_classes)[sorted_idxs]

Counter({803: 23, 802: 21, 555: 18, 779: 14, 847: 8, 751: 8, 479: 8, 920: 6, 670: 5, 408: 4, 609: 2, 829: 1, 665: 1, 450: 1, 348: 1, 928: 1, 471: 1, 874: 1, 586: 1, 866: 1, 961: 1, 717: 1, 643: 1}) Counter({779: 17, 555: 17, 751: 17, 803: 15, 670: 9, 847: 9, 479: 9, 802: 8, 920: 7, 408: 3, 665: 2, 609: 2, 866: 2, 829: 1, 450: 1, 348: 1, 928: 1, 251: 1, 471: 1, 867: 1, 874: 1, 586: 1, 961: 1, 717: 1, 643: 1})
[0, 3, 1, 1, 1, 4, 0, 13, 8, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 9, 1]
int64
[802 751 803 670 779 251 555 867 866 408 479 920 665 847 928 450 829 961
 586 717 471 348 609 874 643]
[13  9  8  4  3  1  1  1  1  1  1  1  1  1  0  0  0  0  0  0  0  0  0  0
  0]


### Define functions

In [25]:
def predict_with_bump(bump_amounts,
                      data,
                      model):
    predictions = []
    
    if not torch.is_tensor(bump_amounts):
        bump_amounts = torch.tensor(bump_amounts)
        
    for c, x in data.items():
        with ch.no_grad():
            logits = model(x.cuda())
            assert logits.shape == bump_amounts.shape
            logits += bump_amounts
            pred = logits.argmax(axis=1)
            predictions.append(pred)
        correct = [p for p in pred if p == c]
        acc = 100 * len(correct) / len(x)
        print(f'Class: {c}/{CD[c]} | Accuracy: {acc:.2f}',) 
    predictions = ch.cat(predictions)
    
    return predictions
        

In [None]:
def match_bump(pre_edit_counter,
               post_edit_counter,
               n_target_predictions,
               target_class_idx,
               # bumps_preds_metrics,
               results_save_dir,
               # data_loader,
               test_data,
               model,
               loss_fn,
               metric_fns,
               device,
               cushion=5,
               n_stop=10,
               debug=True):
    '''
    Given a number of predictions for target class, obtain bump amount to match and save post edit metrics in results_save_dir
    
    Arg(s):
        n_target_predictions : int
            Number of predictions to obtain for target class
        target_class_idx : int
            index of target class
        bumps_preds_metrics : dict
            saved data from match_bump_edits()
        results_save_dir : str
            directory to save results to
        data_loader : torch.utils.data.DataLoader
            validation data loader to obtain metrics for
        model : torch.nn.Module
            model
        loss_fn : module
            loss function
        metric_fns : list[model.metric modules]
            list of metric functions
        device : torch.device
            GPU device to run model on
        cushion : int
            how far away cur_n_target_predictions can be from n_target_predictions on either side to break loop (buffer)
        n_stop : int
            how many iterations when stuck at the same cur_n_target_predictions until to break the loop
        debug : bool
            control verbosity
        
            
    Returns: 
        None
    '''
    # Unpack bumps_preds_metrics
    pre_edit_metrics = bumps_preds_metrics['pre_edit_metrics']
    target_class_predictions = bumps_preds_metrics['target_class_predictions']
    bump_amounts = bumps_preds_metrics['bump_amounts']
    
    # Obtain number of total predictions and assert n_target_predictions is less
    n_predictions_total = np.sum(pre_edit_metrics['predicted_class_distribution'])
    assert n_target_predictions <= n_predictions_total, \
        "n_target_predictions ({}) must be less than total number of data in dataloader ({})".format(n_target_predictions, n_predictions_total)
    
    # Find index above and below n_target_predictions
    bin_high_idx = -1
    for bin_idx, target_class_prediction in enumerate(target_class_predictions):
        if target_class_prediction > n_target_predictions:
            bin_high_idx = bin_idx
            break
            
    if bin_high_idx == -1: # Past upper end 
        # n_predictions_upper_bound
        bump_amount_upper_bound = bump_amounts[bin_high_idx] * 2
        bump_amount_lower_bound = bump_amounts[bin_high_idx]
    # Fall into a bin from the histogram (or lower)
    else: 
        n_predictions_upper_bound = target_class_predictions[bin_high_idx]
        bump_amount_upper_bound = bump_amounts[bin_high_idx]

        # Store lower bounds for bump_amount and n_predictions
        bin_low_idx = bin_high_idx - 1
        if bin_low_idx > -1: # First bin is already higher than n_target_predictions
            n_predictions_lower_bound = target_class_predictions[bin_low_idx]
            bump_amount_lower_bound = bump_amounts[bin_low_idx]
        else: # n_target_predictions is less than bump amount for the first bin
            n_predictions_lower_bound = 0
            bump_amount_lower_bound = -10 
        
    cur_n_target_predictions = 0
    if debug:
        print("target n_predictions: {}".format(n_target_predictions))
        print("Initial bounds for bump: ({}, {})".format(bump_amount_lower_bound, bump_amount_upper_bound))
    
    # Keep looping while the difference between current n_target_predictions and goal n_target_predictions is too large
    while abs(cur_n_target_predictions - n_target_predictions) > cushion:
        # Update bump amount
        cur_bump_amount = (bump_amount_lower_bound + bump_amount_upper_bound) / 2
        # Check before we undergo an infinite loop
        if cur_bump_amount == 0:
            print("cur_bump_amount is 0, exiting loop")
            break
            
        # predict using logit bump
        log = predict_with_bump(
                data_loader=data_loader,
                model=model,
                loss_fn=loss_fn,
                metric_fns=metric_fns,
                device=device,
                target_class_idx=target_class_idx,
                bump_amount=cur_bump_amount,
                output_save_path=os.path.join(results_save_dir, "post_edit_logits.pth"),
                log_save_path=os.path.join(results_save_dir, "post_edit_metrics.pth"))

        # Obtain num. predictions for target class and determine bin idx
        post_class_distribution = log['predicted_class_distribution']
        cur_n_target_predictions = post_class_distribution[target_class_idx]
        if debug:
            print("cur_bump_amount: {}, cur_n_target_predictions: {}".format(cur_bump_amount, cur_n_target_predictions))
        
        # Update bump bounds of binary search
        if cur_n_target_predictions > n_target_predictions:
            bump_amount_upper_bound = cur_bump_amount
            if debug:
                print("Updated upper bound to {}".format(bump_amount_upper_bound))
        elif cur_n_target_predictions < n_target_predictions:
            bump_amount_lower_bound = cur_bump_amount
            if debug:
                print("Updated lower bound to {}".format(bump_amount_lower_bound))
        
    if debug:
        print("final results: bump amount: {} n_target_predictions: {}".format(cur_bump_amount, cur_n_target_predictions))
    # if results_save_dir is not None:
    #     torch.save(log, os.path.join(results_save_dir, "post_edit_metrics.pth"))
        # torch.save(logits, os.path.join(results_save_dir, "post_edit_logits.pth"))
    return cur_bump_amount, cur_n_target_predictions, log

In [33]:
selected_idxs = np.where(deltas_sorted > 1)[0]
selected_deltas = deltas_sorted[selected_idxs]
selected_classes = classes_sorted[selected_idxs]

match_bump(

### Below was from CINIC-10 Setup

In [3]:
# Define constants, paths
class_list_path = os.path.join('metadata', 'cinic-10', 'class_names.txt')

# config_path = 'configs/copies/edit_experiments/cinic10_imagenet_bump_noise.json'
config_path = 'configs/copies/edit_experiments/cinic10_imagenet_bump_corresponding.json'

In [4]:
# Load config file, models, and dataloader
class_list = read_lists(class_list_path)
class_idx_dict = list_to_dict(class_list)

config_dict = read_json(config_path)
config = ConfigParser(config_dict)
device, device_ids = prepare_device(config_dict['n_gpu'])

In [5]:
# Load model
layernum = config.config['layernum']
model = config.init_obj('arch', module_arch, layernum=layernum, device=device)
model.eval()


ModelWrapperSanturkar(
  (model): VGG(
    (normalize): InputNormalize()
    (layer0): Sequential(
      (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (layer1): Sequential(
      (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
      (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (layer2): Sequential(
      (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU()
    )
    (layer3): Sequential(
      (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn): BatchNorm2d(128, eps=1e-05, momentum=0.

In [6]:
# Load datasets
# config = ConfigParser(config_dict)
data_loader_args = dict(config.config["data_loader"]["args"])
dataset_args = dict(config["dataset_args"])

val_image_paths = read_lists(config_dict['dataset_paths']['valid_images'])
val_labels = read_lists(config_dict['dataset_paths']['valid_labels'])
val_paths_data_loader = torch.utils.data.DataLoader(
    module_data.CINIC10Dataset(
        data_dir="",
        image_paths=val_image_paths,
        labels=val_labels,
        return_paths=True,
        **dataset_args
    ),
    **data_loader_args
)

# Obtain loss function and metric functions
loss_fn = getattr(module_loss, config['loss'])
metric_fns = [getattr(module_metric, met) for met in config['metrics']]

In [7]:
target_class_name = 'airplane'
target_class_idx = class_idx_dict[target_class_name]
n_select = 100
timestamp = '0214_112633'
bumps_timestamp = '0208_112555'
save_timestamp = '0213_143741'

root_dir = os.path.join('saved', 'edit', 'trials', 'CINIC10_ImageNet-VGG_16', '{}' + '_{}'.format(n_select), timestamp)
csv_path_template = os.path.join(root_dir, 'results_table.csv')


In [33]:
# def save_class_distribution(csv_path, 
#                             target_class_idx,
#                             target_class_name,
#                             show=False,
#                             histogram_save_path=None,
#                             data_save_path=None):
    
#     df = load_and_preprocess_csv(
#         csv_path=csv_path,
#         drop_duplicates=['ID']
#     )
    
#     # Obtain number of predictions for target class pre edit
#     pre_edit_class_distribution = df['Pre Class Dist'].to_numpy()
#     pre_edit_class_distribution = np.stack(pre_edit_class_distribution)
#     target_pre_edit_class_predictions = np.mean(pre_edit_class_distribution[:, target_class_idx])
    
#     # Obtain number of predictions for target class post edit for each trial
#     class_distribution = df['Post Class Dist'].to_numpy()
#     class_distribution = np.stack(class_distribution, axis=0)
#     target_class_distribution = class_distribution[:, target_class_idx]
#     # target_class_bins = np.bincount(target_class_distribution)
#     if histogram_save_path is None:
#         histogram_save_path = os.path.join(os.path.dirname(csv_path), 'graphs', 'summary', 'target_class_distribution.png')
#     title = 'Post Edit {} Class Distribution for {} Edits'.format(target_class_name, target_class_name)
#     xlabel = 'Num. {} Predictions Post Edit'.format(target_class_name)
#     ylabel = 'Num. Edits'
    
#     bin_values, bins, _ = histogram(
#         data=target_class_distribution,
#         n_bins=50,
#         title=title,
#         xlabel=xlabel,
#         ylabel=ylabel,
#         marker=target_pre_edit_class_predictions,
#         show=show,
#         save_path=histogram_save_path)
    
#     bin_tuples = []
#     for bin_idx in range(len(bins) - 1):
#         bin_tuples.append((bins[bin_idx], bins[bin_idx+1]))

#     save_data = {
#         "n_target_predictions": target_class_distribution,
#         "histogram_bin_values": bin_values,
#         "histogram_bins": bins
#     }
    
#     if data_save_path is None:
#         data_save_path = os.path.join(os.path.dirname(csv_path), 'target_class_distribution.pth')
#         torch.save(save_data, data_save_path)
    
#     print("Saved target class distribution & histogram data to {}".format(data_save_path))
    
#     plt.close('all')
#     return save_data

#### Loop over all classes and save histograms and distributions

In [None]:
for class_name in class_list:
    csv_path = csv_path_template.format(class_name)
    save_class_distribution(
        csv_path=csv_path,
        target_class_idx=class_idx_dict[class_name],
        target_class_name=class_name,
        show=False)
    

In [None]:
# target_class_distribution_path = os.path.join(root_dir, 'target_class_distribution.pth').format(
#     target_class_name)
# target_class_distribution = torch.load(target_class_distribution_path)
# bin_lows = target_class_distribution['histogram_bins']

# bump_save_dir = os.path.join(os.path.dirname(os.path.dirname(config.save_dir)), save_timestamp, '{}_{}'.format(target_class_name, n_select))
# ensure_dir(bump_save_dir)
# # Save a copy of histogram info to save_dir
# torch.save(target_class_distribution, os.path.join(bump_save_dir, 'target_class_distribution.pth'))
# # Run bump experiments
# print("Obtaining class distribution for {} from {}".format(target_class_name, target_class_distribution_path))
# print("Saving results to {}".format(bump_save_dir))
# match_bump_edits(
#     data_loader=val_paths_data_loader,
#     model=model,
#     loss_fn=loss_fn,
#     metric_fns=metric_fns,
#     device=device,
#     bin_lows=bin_lows,
#     target_class_idx=target_class_idx,
#     save_dir=bump_save_dir)

In [None]:
# bump_save_dir = os.path.join(os.path.dirname(os.path.dirname(config.save_dir)), save_timestamp, '{}_{}'.format(target_class_name, n_select))
# metrics_save_path = os.path.join(bump_save_dir, 'bumps_preds_metrics.pth')

# bumped_target_class_dist = torch.load(metrics_save_path)

# bumped_hist_data = []


# target_class_dist_dict = torch.load(target_class_distribution_path)

# for n_target_predictions, bucket_value in zip(
#     bumped_target_class_dist['target_class_predictions'], 
#     target_class_distribution['histogram_bin_values']):
#     cur_data = [n_target_predictions for i in range(int(bucket_value))]
#     bumped_hist_data += cur_data
    
# bins = target_class_distribution['histogram_bins']
# bin_values = target_class_distribution['histogram_bin_values']

# histogram_save_path = os.path.join(
#     bump_save_dir, 
#     'graphs',
#     'summary',
#     'bumped_target_class_distribution.png')
# bump_bin_values, bump_bins, _= histogram(
#         data=bumped_hist_data,
#         n_bins=bins, #50,
#         title='Bumped Post Edit {} Class Distribution to Match {} Edits'.format(target_class_name, target_class_name),
#         xlabel='Num. {} Predictions Post Bump'.format(target_class_name),
#         ylabel='Num. Edits',
#         save_path=histogram_save_path,
#         show=True)

# assert (bin_values == bump_bin_values).all()

### Perform Bump for a Corresponding Edit

In [19]:
def match_bump(n_target_predictions,
               target_class_idx,
               bumps_preds_metrics,
               results_save_dir,
               data_loader,
               model,
               loss_fn,
               metric_fns,
               device,
               cushion=5,
               n_stop=10,
               debug=True):
    '''
    Given a number of predictions for target class, obtain bump amount to match and save post edit metrics in results_save_dir
    
    Arg(s):
        n_target_predictions : int
            Number of predictions to obtain for target class
        target_class_idx : int
            index of target class
        bumps_preds_metrics : dict
            saved data from match_bump_edits()
        results_save_dir : str
            directory to save results to
        data_loader : torch.utils.data.DataLoader
            validation data loader to obtain metrics for
        model : torch.nn.Module
            model
        loss_fn : module
            loss function
        metric_fns : list[model.metric modules]
            list of metric functions
        device : torch.device
            GPU device to run model on
        cushion : int
            how far away cur_n_target_predictions can be from n_target_predictions on either side to break loop (buffer)
        n_stop : int
            how many iterations when stuck at the same cur_n_target_predictions until to break the loop
        debug : bool
            control verbosity
        
            
    Returns: 
        None
    '''
    # Unpack bumps_preds_metrics
    pre_edit_metrics = bumps_preds_metrics['pre_edit_metrics']
    target_class_predictions = bumps_preds_metrics['target_class_predictions']
    bump_amounts = bumps_preds_metrics['bump_amounts']
    
    # Obtain number of total predictions and assert n_target_predictions is less
    n_predictions_total = np.sum(pre_edit_metrics['predicted_class_distribution'])
    assert n_target_predictions <= n_predictions_total, \
        "n_target_predictions ({}) must be less than total number of data in dataloader ({})".format(n_target_predictions, n_predictions_total)
    
    # Find index above and below n_target_predictions
    bin_high_idx = -1
    for bin_idx, target_class_prediction in enumerate(target_class_predictions):
        if target_class_prediction > n_target_predictions:
            bin_high_idx = bin_idx
            break
            
    if bin_high_idx == -1: # Past upper end 
        # n_predictions_upper_bound
        bump_amount_upper_bound = bump_amounts[bin_high_idx] * 2
        bump_amount_lower_bound = bump_amounts[bin_high_idx]
    # Fall into a bin from the histogram (or lower)
    else: 
        n_predictions_upper_bound = target_class_predictions[bin_high_idx]
        bump_amount_upper_bound = bump_amounts[bin_high_idx]

        # Store lower bounds for bump_amount and n_predictions
        bin_low_idx = bin_high_idx - 1
        if bin_low_idx > -1: # First bin is already higher than n_target_predictions
            n_predictions_lower_bound = target_class_predictions[bin_low_idx]
            bump_amount_lower_bound = bump_amounts[bin_low_idx]
        else: # n_target_predictions is less than bump amount for the first bin
            n_predictions_lower_bound = 0
            bump_amount_lower_bound = -10 
        
    cur_n_target_predictions = 0
    if debug:
        print("target n_predictions: {}".format(n_target_predictions))
        print("Initial bounds for bump: ({}, {})".format(bump_amount_lower_bound, bump_amount_upper_bound))
    
    # Keep looping while the difference between current n_target_predictions and goal n_target_predictions is too large
    while abs(cur_n_target_predictions - n_target_predictions) > cushion:
        # Update bump amount
        cur_bump_amount = (bump_amount_lower_bound + bump_amount_upper_bound) / 2
        # Check before we undergo an infinite loop
        if cur_bump_amount == 0:
            print("cur_bump_amount is 0, exiting loop")
            break
            
        # predict using logit bump
        log = predict_with_bump(
                data_loader=data_loader,
                model=model,
                loss_fn=loss_fn,
                metric_fns=metric_fns,
                device=device,
                target_class_idx=target_class_idx,
                bump_amount=cur_bump_amount,
                output_save_path=os.path.join(results_save_dir, "post_edit_logits.pth"),
                log_save_path=os.path.join(results_save_dir, "post_edit_metrics.pth"))

        # Obtain num. predictions for target class and determine bin idx
        post_class_distribution = log['predicted_class_distribution']
        cur_n_target_predictions = post_class_distribution[target_class_idx]
        if debug:
            print("cur_bump_amount: {}, cur_n_target_predictions: {}".format(cur_bump_amount, cur_n_target_predictions))
        
        # Update bump bounds of binary search
        if cur_n_target_predictions > n_target_predictions:
            bump_amount_upper_bound = cur_bump_amount
            if debug:
                print("Updated upper bound to {}".format(bump_amount_upper_bound))
        elif cur_n_target_predictions < n_target_predictions:
            bump_amount_lower_bound = cur_bump_amount
            if debug:
                print("Updated lower bound to {}".format(bump_amount_lower_bound))
        
    if debug:
        print("final results: bump amount: {} n_target_predictions: {}".format(cur_bump_amount, cur_n_target_predictions))
    # if results_save_dir is not None:
    #     torch.save(log, os.path.join(results_save_dir, "post_edit_metrics.pth"))
        # torch.save(logits, os.path.join(results_save_dir, "post_edit_logits.pth"))
    return cur_bump_amount, cur_n_target_predictions, log
    

In [9]:
# # testing

# bumps_preds_metrics_path = os.path.join(
#     "saved/edit/experiments/bump_edits",
#     "bumps_preds_metrics",
#     "{}_{}_bumps_preds_metrics.pth".format(target_class_name, n_select))
# bumps_preds_metrics = torch.load(bumps_preds_metrics_path)

# n_target_predictions = 65000
# results_save_dir = os.path.join('temp')
# match_bump(n_target_predictions=n_target_predictions,
#            target_class_idx=target_class_idx,
#            bumps_preds_metrics=bumps_preds_metrics,
#            results_save_dir=results_save_dir,
#            data_loader=val_paths_data_loader,
#            model=model,
#            loss_fn=loss_fn,
#            metric_fns=metric_fns,
#            device=device,
#            cushion=5,
#            debug=True)

In [20]:
n_edits = 5
cushion = 10
original_trials_trial_paths_path = os.path.join(root_dir.format(target_class_name), 'trial_paths.txt')
original_trial_paths = read_lists(original_trials_trial_paths_path)
common_path = get_common_dir_path(original_trial_paths)
print("Length of trial paths: {}".format(len(original_trial_paths)))

# Data structure to store how much to bump for each n_predictions in target class
bump_amount_dictionary_path = os.path.join(
    'metadata',
    config_dict['name'], # CINIC10_ImageNet-VGG16
    'bump_amounts',
    '{}_{}'.format(target_class_name, n_select),
    'logit_bump_buffer_{}_dict.pth'.format(cushion))

if os.path.isfile(bump_amount_dictionary_path):
    bump_amount_dictionary = torch.load(bump_amount_dictionary_path)
else:
    bump_amount_dictionary = {}
ensure_dir(os.path.dirname(bump_amount_dictionary_path))

    
# Store histogram information
bumps_preds_metrics_path = os.path.join(
    "saved/edit/experiments/bump_edits",
    "bumps_preds_metrics",
    "{}_{}_bumps_preds_metrics.pth".format(target_class_name, n_select))
bumps_preds_metrics = torch.load(bumps_preds_metrics_path)

# Create directories and paths
result_root = os.path.join(
        os.path.dirname(os.path.dirname(config.save_dir)), 
        save_timestamp, 
        '{}_{}'.format(target_class_name, n_select))
ensure_dir(result_root)
progress_report_path = os.path.join(result_root, 'progress_reports.txt')
# Create file to store paths
logit_bump_trial_paths_path = os.path.join(result_root, 'trial_paths.txt')
if os.path.isfile(logit_bump_trial_paths_path):
    os.remove(logit_bump_trial_paths_path)

# Iterate through all trial paths
for trial_idx, trial_path in enumerate(original_trial_paths):
    if trial_idx <= 151:
        continue
    # Obtain trial ID and create save directory for logit bump results
    trial_id = trial_path[len(common_path)+1:]
    logit_bump_trial_save_dir = os.path.join(
        result_root,
        'results',
        trial_id,
        'models')
    ensure_dir(logit_bump_trial_save_dir)
    
    # Obtain desired n_target_predictions
    trial_post_edit_metrics_path = os.path.join(
        trial_path,
        'models',
        'post_edit_metrics.pth')
    trial_post_edit_metrics = torch.load(trial_post_edit_metrics_path)
    n_target_predictions = trial_post_edit_metrics['predicted_class_distribution'][target_class_idx]
    # If bump amount is in dictionary, obtain that value
    if n_target_predictions in bump_amount_dictionary:
        
        informal_log("[{}] Found corresponding logit bump for {} in dictionary ({}/{})".format(
            datetime.now().strftime(r'%m%d_%H%M%S'),
            trial_id,
            trial_idx + 1,
            len(original_trial_paths)), progress_report_path)
        bump_amount = bump_amount_dictionary[n_target_predictions]
        
        log = predict_with_bump(
            data_loader=val_paths_data_loader,
            model=model,
            loss_fn=loss_fn,
            metric_fns=metric_fns,
            device=device,
            target_class_idx=target_class_idx,
            bump_amount=bump_amount,
            output_save_path=os.path.join(logit_bump_trial_save_dir, "post_edit_logits.pth"),
            log_save_path=os.path.join(logit_bump_trial_save_dir, "post_edit_metrics.pth"))
        # Sanity check that the n_target_predictions with bump is within buffer of target
        bump_n_target_predictions = log['predicted_class_distribution'][target_class_idx]
        assert abs(n_target_predictions - bump_n_target_predictions) <= cushion
        
        
    else:
        informal_log("[{}] Creating corresponding logit bump for {} ({}/{})".format(
            datetime.now().strftime(r'%m%d_%H%M%S'),
            trial_id,
            trial_idx + 1,
            len(original_trial_paths)), progress_report_path)
        # Find corresponding bump amount
        bump_amount, n_target_predictions, metrics = \
            match_bump(
                n_target_predictions=n_target_predictions,
                target_class_idx=target_class_idx,
                bumps_preds_metrics=bumps_preds_metrics,
                results_save_dir=logit_bump_trial_save_dir,
                data_loader=val_paths_data_loader,
                model=model,
                loss_fn=loss_fn,
                metric_fns=metric_fns,
                device=device,
                cushion=cushion,
                debug=True)
        bump_amount_dictionary[n_target_predictions] = bump_amount

    # Write trial path to list
    informal_log(os.path.dirname(logit_bump_trial_save_dir), logit_bump_trial_paths_path)

    # Update the bump dictionary on file
    if trial_idx % 10 == 0:
        torch.save(bump_amount_dictionary, bump_amount_dictionary_path)

Length of trial paths: 158
[0217_140839] Creating corresponding logit bump for airplane-train-n02704645_17657/felzenszwalb_gaussian_softmax (153/158)
target n_predictions: 8733
Initial bounds for bump: (-10, 0.2)


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:30<00:00,  8.98it/s]


cur_bump_amount: -4.9, cur_n_target_predictions: 5558
Updated lower bound to -4.9


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:26<00:00, 10.23it/s]


cur_bump_amount: -2.35, cur_n_target_predictions: 7231
Updated lower bound to -2.35


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:07<00:00, 39.05it/s]


cur_bump_amount: -1.075, cur_n_target_predictions: 8080
Updated lower bound to -1.075


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:07<00:00, 38.34it/s]


cur_bump_amount: -0.4375, cur_n_target_predictions: 8595
Updated lower bound to -0.4375


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:07<00:00, 36.87it/s]


cur_bump_amount: -0.11875, cur_n_target_predictions: 8906
Updated upper bound to -0.11875


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:07<00:00, 35.26it/s]


cur_bump_amount: -0.278125, cur_n_target_predictions: 8734
Updated upper bound to -0.278125
final results: bump amount: -0.278125 n_target_predictions: 8734
saved/edit/experiments/corresponding_bump_edits/CINIC10_ImageNet-VGG_16/0213_143741/airplane_100/results/airplane-train-n02704645_17657/felzenszwalb_gaussian_softmax
[0217_141006] Creating corresponding logit bump for airplane-train-n03595860_736/felzenszwalb_masked_softmax (154/158)
target n_predictions: 18439
Initial bounds for bump: (3.637978807091713, 3.751665644813329)


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:07<00:00, 37.47it/s]


cur_bump_amount: 3.694822225952521, cur_n_target_predictions: 18554
Updated upper bound to 3.694822225952521


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 39.52it/s]


cur_bump_amount: 3.666400516522117, cur_n_target_predictions: 18395
Updated lower bound to 3.666400516522117


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:07<00:00, 37.35it/s]


cur_bump_amount: 3.680611371237319, cur_n_target_predictions: 18475
Updated upper bound to 3.680611371237319


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:07<00:00, 36.99it/s]


cur_bump_amount: 3.673505943879718, cur_n_target_predictions: 18431
Updated lower bound to 3.673505943879718
final results: bump amount: 3.673505943879718 n_target_predictions: 18431
saved/edit/experiments/corresponding_bump_edits/CINIC10_ImageNet-VGG_16/0213_143741/airplane_100/results/airplane-train-n03595860_736/felzenszwalb_masked_softmax
[0217_141036] Creating corresponding logit bump for airplane-train-n03595860_736/felzenszwalb_gaussian_softmax (155/158)
target n_predictions: 17151
Initial bounds for bump: (3.2741809263825417, 3.4788172342814505)


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:07<00:00, 37.68it/s]


cur_bump_amount: 3.376499080331996, cur_n_target_predictions: 16785
Updated lower bound to 3.376499080331996


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:07<00:00, 37.15it/s]


cur_bump_amount: 3.4276581573067233, cur_n_target_predictions: 17049
Updated lower bound to 3.4276581573067233


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:07<00:00, 35.36it/s]


cur_bump_amount: 3.453237695794087, cur_n_target_predictions: 17184
Updated upper bound to 3.453237695794087


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:08<00:00, 34.17it/s]


cur_bump_amount: 3.440447926550405, cur_n_target_predictions: 17120
Updated lower bound to 3.440447926550405


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 39.60it/s]


cur_bump_amount: 3.446842811172246, cur_n_target_predictions: 17153
Updated upper bound to 3.446842811172246
final results: bump amount: 3.446842811172246 n_target_predictions: 17153
saved/edit/experiments/corresponding_bump_edits/CINIC10_ImageNet-VGG_16/0213_143741/airplane_100/results/airplane-train-n03595860_736/felzenszwalb_gaussian_softmax
[0217_141114] Creating corresponding logit bump for airplane-train-n04160586_8239/felzenszwalb_masked_softmax (156/158)
target n_predictions: 25238
Initial bounds for bump: (4.604316927725449, 4.7482018317168695)


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 39.95it/s]


cur_bump_amount: 4.676259379721159, cur_n_target_predictions: 25292
Updated upper bound to 4.676259379721159


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:06<00:00, 40.07it/s]


cur_bump_amount: 4.640288153723304, cur_n_target_predictions: 25024
Updated lower bound to 4.640288153723304


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:07<00:00, 39.01it/s]


cur_bump_amount: 4.658273766722232, cur_n_target_predictions: 25145
Updated lower bound to 4.658273766722232


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:07<00:00, 35.34it/s]


cur_bump_amount: 4.667266573221696, cur_n_target_predictions: 25230
Updated lower bound to 4.667266573221696
final results: bump amount: 4.667266573221696 n_target_predictions: 25230
saved/edit/experiments/corresponding_bump_edits/CINIC10_ImageNet-VGG_16/0213_143741/airplane_100/results/airplane-train-n04160586_8239/felzenszwalb_masked_softmax
[0217_141143] Creating corresponding logit bump for airplane-train-n02691156_6453/felzenszwalb_masked_softmax (157/158)
target n_predictions: 17487
Initial bounds for bump: (3.4788172342814505, 3.637978807091713)


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:07<00:00, 37.93it/s]


cur_bump_amount: 3.5583980206865817, cur_n_target_predictions: 17773
Updated upper bound to 3.5583980206865817


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:08<00:00, 33.65it/s]


cur_bump_amount: 3.518607627484016, cur_n_target_predictions: 17533
Updated upper bound to 3.518607627484016


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 25.72it/s]


cur_bump_amount: 3.4987124308827333, cur_n_target_predictions: 17418
Updated lower bound to 3.4987124308827333


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:10<00:00, 26.66it/s]


cur_bump_amount: 3.5086600291833747, cur_n_target_predictions: 17481
Updated lower bound to 3.5086600291833747
final results: bump amount: 3.5086600291833747 n_target_predictions: 17481
saved/edit/experiments/corresponding_bump_edits/CINIC10_ImageNet-VGG_16/0213_143741/airplane_100/results/airplane-train-n02691156_6453/felzenszwalb_masked_softmax
[0217_141220] Found corresponding logit bump for airplane-train-n02691156_6453/felzenszwalb_gaussian_softmax in dictionary (158/158)


100%|█████████████████████████████████████████████████████████████████████| 274/274 [00:09<00:00, 28.63it/s]

saved/edit/experiments/corresponding_bump_edits/CINIC10_ImageNet-VGG_16/0213_143741/airplane_100/results/airplane-train-n02691156_6453/felzenszwalb_gaussian_softmax





In [14]:
ensure_dir(os.path.dirname(bump_amount_dictionary_path))
torch.save(bump_amount_dictionary, bump_amount_dictionary_path)

### Replicate bump results on a coarse bin level (repeat bins)