## Create Edits that match post edit class distribution of target class to real edits

In [62]:
# General imports
import torch
import numpy as np
import os, sys
import json
from tqdm import tqdm
import matplotlib.pyplot as plt
import math
from sklearn import metrics
from scipy import stats
from datetime import datetime
import collections

import warnings
from argparse import Namespace
warnings.filterwarnings("ignore")

from torchvision import transforms
from PIL import Image

from helpers import classifier_helpers
import helpers.data_helpers as dh
import helpers.context_helpers as coh
import helpers.rewrite_helpers as rh
import helpers.vis_helpers as vh

%matplotlib inline

In [2]:
sys.path.insert(0, '/n/fs/ac-editing/model-editing/src')
print(sys.path)
import model.loss as module_loss
import model.metric as module_metric
from utils.model_utils import prepare_device

['/n/fs/ac-editing/model-editing/src', '/n/fs/ac-editing/model-editing/external_code/EditingClassifiersRepo', '/n/fs/ac-project/anaconda3/envs/editing/lib/python38.zip', '/n/fs/ac-project/anaconda3/envs/editing/lib/python3.8', '/n/fs/ac-project/anaconda3/envs/editing/lib/python3.8/lib-dynload', '', '/u/ac4802/.local/lib/python3.8/site-packages', '/n/fs/ac-project/anaconda3/envs/editing/lib/python3.8/site-packages']


In [3]:
DATASET_NAME = 'ImageNet'
LAYERNUM = 12
REWRITE_MODE = 'editing'
ARCH = 'vgg16'

### Load imagenet dataset

In [152]:
base_dataset, train_loader, val_loader = dh.get_dataset(DATASET_NAME, DATASET_PATH,
                                                        batch_size=32, workers=8)

==> Preparing dataset imagenet..


## Load model

In [5]:
ret = classifier_helpers.get_default_paths(DATASET_NAME, arch=ARCH)
DATASET_PATH, MODEL_PATH, MODEL_CLASS, ARCH, CD = ret
ret = classifier_helpers.load_classifier(MODEL_PATH, MODEL_CLASS, ARCH,
                            DATASET_NAME, LAYERNUM) 
model, context_model, target_model = ret[:3]
n_classes = len(CD) - 1

### Load test data

In [6]:
train_data, test_data = dh.get_vehicles_on_snow_data(DATASET_NAME, CD)

Test data stats...
ImageNet class: school bus; # Images: 22 

ImageNet class: motor scooter, scooter; # Images: 21 

ImageNet class: traffic light, traffic signal, stoplight; # Images: 9 

ImageNet class: fire engine, fire truck; # Images: 20 

ImageNet class: tank, army tank, armored combat vehicle, armoured combat vehicle; # Images: 17 

ImageNet class: racer, race car, racing car; # Images: 20 

ImageNet class: car wheel; # Images: 20 



In [12]:
# Load post edit counter
restore_dir = os.path.join('edited_checkpoints', 'vehicles_on_snow')
post_edit_counter = torch.load(os.path.join(restore_dir, 'post_edit_counter.pth'))
pre_edit_counter = torch.load(os.path.join(restore_dir, 'pre_edit_counter.pth'))

print(pre_edit_counter, post_edit_counter)
# Sort classes based on largest change
all_classes = list(set(list(pre_edit_counter.keys()) + list(post_edit_counter.keys())))
deltas = []
for class_idx in all_classes:
    pre_edit_count = pre_edit_counter[class_idx] if class_idx in pre_edit_counter else 0
    post_edit_count = post_edit_counter[class_idx] if class_idx in post_edit_counter else 0
    deltas.append(np.abs(post_edit_count - pre_edit_count))
print(deltas)
# Sorts in ascending order
sorted_idxs = np.argsort(deltas)
sorted_idxs = np.flip(sorted_idxs)

# Two lists with corresponding classes and amount they changed by
deltas_sorted = np.array(deltas)[sorted_idxs]
classes_sorted = np.array(all_classes)[sorted_idxs]

Counter({803: 23, 802: 21, 555: 18, 779: 14, 847: 8, 751: 8, 479: 8, 920: 6, 670: 5, 408: 4, 609: 2, 829: 1, 665: 1, 450: 1, 348: 1, 928: 1, 471: 1, 874: 1, 586: 1, 866: 1, 961: 1, 717: 1, 643: 1}) Counter({779: 17, 555: 17, 751: 17, 803: 15, 670: 9, 847: 9, 479: 9, 802: 8, 920: 7, 408: 3, 665: 2, 609: 2, 866: 2, 829: 1, 450: 1, 348: 1, 928: 1, 251: 1, 471: 1, 867: 1, 874: 1, 586: 1, 961: 1, 717: 1, 643: 1})
[0, 3, 1, 1, 1, 4, 0, 13, 8, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 9, 1]


### Get device and metric functions

In [7]:
metrics = ['accuracy', 'per_class_accuracy', 'recall', 'precision', 'f1', 'predicted_class_distribution']
metric_fns = [getattr(module_metric, met) for met in metrics]

device, device_ids = prepare_device(1)

### Define functions

In [8]:
def predict_with_bump(data_loader,
                      model,
                      loss_fn,
                      metric_fns,
                      device,
                      # target_class_idx=0,
                      bump_amounts,
                      output_save_path=None,
                      log_save_path=None):
    '''
    Run the model on the data_loader, calculate metrics, and log

    Arg(s):
        data_loader : torch Dataloader
            data to test on
        model : torch.nn.Module
            model to run
        loss_fn : module
            loss function
        metric_fns : list[model.metric modules]
            list of metric functions
        device : torch.device
            device to move data to
        bump_amounts :  np.array(float)
            array of bump amounts 
        output_save_path : str or None
            if not None, save model_outputs to save_path
        log_save_path : str or None
            if not None, save metrics to save_path

    Returns :
        log : dict{} of metrics
    '''

    # Hold data for calculating metrics
    outputs = []
    targets = []

    # Ensure model is in eval mode
    if model.training:
        model.eval()

    with torch.no_grad():
        for idx, item in enumerate(tqdm(data_loader)):
            if len(item) == 3:
                data, target, path = item
            else:
                data, target = item
            data, target = data.to(device), target.to(device)
            output = model(data)
            assert output.shape[1] == bump_amounts.shape[0], \
                "Logits shape: {} bump_amounts shape: {}".format(output.shape, bump_amounts.shape)
            # Store outputs and targets
            outputs.append(output)
            targets.append(target)
            # if idx == 1:
            #     break

    # Concatenate predictions and targets
    outputs = torch.cat(outputs, dim=0)
    targets = torch.cat(targets, dim=0)

    # Adjust output softmax by bump amount
    # if is instance(bump_amount, np.array):
    outputs += torch.tensor(bump_amounts).to(device)
    # else: 
    #     outputs[:, target_class_idx] += bump_amount

    # Calculate loss
    if loss_fn is not None:
        loss = loss_fn(outputs, targets).item()
        log.update({'loss': loss})
    n_samples = len(data_loader.sampler)

    # Calculate predictions based on argmax
    predictions = torch.argmax(outputs, dim=1)

    # Move predictions and target to cpu and convert to numpy to calculate metrics
    predictions = predictions.cpu().numpy()
    targets = targets.cpu().numpy()

    # Calculate metrics
    log = module_metric.compute_metrics(
        metric_fns=metric_fns,
        prediction=predictions,
        target=targets)

    # Add bump amount to log
    
    log.update({'bump_amounts': bump_amounts})

    if output_save_path is not None:
        ensure_dir(os.path.dirname(output_save_path))
        torch.save(outputs, output_save_path)

    if log_save_path is not None:
        ensure_dir(os.path.dirname(log_save_path))
        torch.save(log, log_save_path)

    return {
        'metrics': log,
        'logits': outputs
    }


In [50]:
def quick_predict_with_bump(bump_amounts,
                      data,
                      model,
                      debug=True):
    
    '''
    Arg(s):
        bump_amounts : C length np.array
            bump amount for each class
        data : (list(int), B x Ch x H x W torch.tensor)
            labels and image data tensors
        model : torch.nn.Module
            model to make prediction
    '''
    predictions = []
    accuracies = []
    all_logits = []
    total_correct = 0
    if not torch.is_tensor(bump_amounts):
        bump_amounts = torch.tensor(bump_amounts)
    bump_amounts = bump_amounts.cuda()
        
    for c, x in data.items():
        with torch.no_grad():
            logits = model(x.cuda())
            assert logits.shape[1] == bump_amounts.shape[0], "Logits shape: {} bump_amounts shape: {}".format(logits.shape, bump_amounts.shape)
            pred = logits.argmax(axis=1)
            predictions.append(pred)
            all_logits.append(logits)
        correct = [p for p in pred if p == c]
        total_correct += len(correct)
        acc = 100 * len(correct) / len(x)
        if debug:
            print(f'Class: {c}/{CD[c]} | Accuracy: {acc:.2f}',) 
        accuracies.append(acc)
    predictions = torch.cat(predictions, dim=0)
    all_logits = torch.cat(all_logits, dim=0)
    counter = collections.Counter(predictions.tolist())
    results = {
        'predictions': predictions,
        'logits': all_logits,
        'counter': counter,
        'accuracies': accuracies,
        'n_correct': total_correct
    }
    return results
        
def find_bump_bounds(n_target_predictions,
                    target_class_idx,
                    n_classes,
                    test_data,
                    model,
                    delta=2,
                    debug=True):
    # Get baseline predictions
    bump_amounts = np.zeros(n_classes)
    baseline_predictions, baseline_counter, _ = quick_predict_with_bump(
        bump_amounts=bump_amounts,
        data=test_data,
        model=model,
        debug=debug)
    
    cur_target_predictions = baseline_counter[target_class_idx]
    if debug:
        print("Baseline target predictions: {} Target n predictions: {}".format(
            cur_target_predictions, n_target_predictions))
    
    # If no need to bump, return now!
    if cur_target_predictions == n_target_predictions:
        return 0, 0
    elif cur_target_predictions < n_target_predictions:
        
        bump_lower_bound = 0
        target_predictions_lower = cur_target_predictions
        # Find upper bound
        if debug:
            print("Searching for upper bound...")
        while cur_target_predictions < n_target_predictions:
            bump_amounts[target_class_idx] += delta
            if debug:
                print("cur bump amount: {}".format(bump_amounts[target_class_idx]))
            cur_predictions, cur_counter, _ = quick_predict_with_bump(
                bump_amounts=bump_amounts,
                data=test_data,
                model=model,
                debug=debug)
            cur_target_predictions = cur_counter[target_class_idx]
            if debug:
                print("cur target_predictions: {} target n predictions: {}".format(
                    cur_target_predictions, n_target_predictions))
        bump_upper_bound = bump_amounts[target_class_idx]
        target_predictions_upper = cur_target_predictions
        
    elif cur_target_predictions > n_target_predictions:
        bump_upper_bound = 0
        target_predictions_upper = cur_target_predictions
        # Find lower bound
        
        if debug:
            print("Searching for lower bound...")
        while cur_target_predictions > n_target_predictions:
            bump_amounts[target_class_idx] -= delta
            if debug:
                print("cur bump amount: {}".format(bump_amounts[target_class_idx]))
            cur_predictions, cur_counter, _ = quick_predict_with_bump(
                bump_amounts=bump_amounts,
                data=test_data,
                model=model,
                debug=debug)
            cur_target_predictions = cur_counter[target_class_idx]
            if debug:
                print("cur target_predictions: {} target n predictions: {}".format(
                    cur_target_predictions, n_target_predictions))
        bump_lower_bound = bump_amounts[target_class_idx]
        target_predictions_lower = cur_target_predictions
    
    print("lower bound: {} target class predictions: {}".format(
        bump_lower_bound, target_predictions_lower))
    print("upper bound: {} target class predictions: {}".format(
        bump_upper_bound, target_predictions_upper))
    return bump_lower_bound, bump_upper_bound
    
def bump_amount_binary_search(n_target_predictions,
                              target_class_idx,
                              n_classes,
                              bump_lower_bound,
                              bump_upper_bound,
                              cushion,
                              test_data,
                              model,
                              debug=True):
    '''
    Perform binary search to find the right bump amount
    '''
    bump_amounts = np.zeros(n_classes)
    baseline_predictions, baseline_counter, _ = quick_predict_with_bump(
        bump_amounts=bump_amounts,
        data=test_data,
        model=model,
        debug=debug)
    
    cur_target_predictions = baseline_counter[target_class_idx]
    
    while abs(cur_target_predictions - n_target_predictions) > cushion:
        # Update bump amount
        cur_bump_amount = (bump_lower_bound + bump_upper_bound) / 2
        # Check before we undergo an infinite loop
        if cur_bump_amount == 0:
            print("cur_bump_amount is 0, exiting loop")
            break
            
        # predict using logit bump
        bump_amounts[target_class_idx] = cur_bump_amount
        cur_predictions, cur_counter, _ = quick_predict_with_bump(
            bump_amounts=bump_amounts,
            data=test_data,
            model=model,
            debug=debug)

        # Obtain num. predictions for target class and determine bin idx
        # post_class_distribution = log['predicted_class_distribution']
        cur_target_predictions = cur_counter[target_class_idx]
        if debug:
            print("cur_bump_amount: {}, cur_target_predictions: {}".format(cur_bump_amount, cur_target_predictions))
        
        # Update bump bounds of binary search
        if cur_target_predictions > n_target_predictions:
            bump_upper_bound = cur_bump_amount
            if debug:
                print("Updated upper bound to {}".format(bump_upper_bound))
        elif cur_target_predictions < n_target_predictions:
            bump_lower_bound = cur_bump_amount
            if debug:
                print("Updated lower bound to {}".format(bump_lower_bound))
    return bump_amounts

In [10]:
def match_bump(#pre_edit_counter,
               #post_edit_counter,
               n_classes,
               n_target_predictions,
               target_class_idx,
               # bump_amount_lower_bound,
               # bump_amount_upper_bound,
               # bumps_preds_metrics,
               results_save_dir,
               # data_loader,
               test_data,
               model,
               # loss_fn,
               # metric_fns,
               # device,
               cushion=5,
               # n_stop=10,
               debug=True):
    '''
    Given a number of predictions for target class, obtain bump amount to match and save post edit metrics in results_save_dir
    
    Arg(s):
        n_target_predictions : int
            Number of predictions to obtain for target class
        target_class_idx : int
            index of target class
        bumps_preds_metrics : dict
            saved data from match_bump_edits()
        results_save_dir : str
            directory to save results to
        data_loader : torch.utils.data.DataLoader
            validation data loader to obtain metrics for
        model : torch.nn.Module
            model
        loss_fn : module
            loss function
        metric_fns : list[model.metric modules]
            list of metric functions
        device : torch.device
            GPU device to run model on
        cushion : int
            how far away cur_n_target_predictions can be from n_target_predictions on either side to break loop (buffer)
        n_stop : int
            how many iterations when stuck at the same cur_n_target_predictions until to break the loop
        debug : bool
            control verbosity
        
            
    Returns: 
        None
    '''
    if debug:
        print("Target class idx: {}".format(target_class_idx))
    empty_bumps = np.zeros(n_classes)

    bump_lower_bound, bump_upper_bound = find_bump_bounds(
        n_target_predictions=n_target_predictions,
        target_class_idx=target_class_idx,
        n_classes=n_classes,
        test_data=test_data,
        model=model,
        debug=debug)
        
    bump_amounts = bump_amount_binary_search(
        n_target_predictions=n_target_predictions,
        target_class_idx=target_class_idx,
        n_classes=n_classes,
        bump_lower_bound=bump_lower_bound,
        bump_upper_bound=bump_upper_bound,
        cushion=cushion,
        test_data=test_data,
        model=model,
        debug=debug)
    
    return bump_amounts

### For each class with significant change on the test set, figure out how much to bump to achieve same # of predictions on test set

In [181]:
selected_idxs = np.where(deltas_sorted > 0)[0]
selected_deltas = deltas_sorted[selected_idxs]
selected_classes = classes_sorted[selected_idxs]

accumulated_bump_amounts = np.zeros(n_classes)
for selected_class in selected_classes:
    print("Finding bump for {} class ({})".format(CD[selected_class], selected_class))
    bump_amounts = match_bump(
        n_classes=n_classes,
        n_target_predictions=post_edit_counter[selected_class],
        target_class_idx=selected_class,
        # bumps_preds_metrics,
        results_save_dir=None,
        # data_loader,
        test_data=test_data,
        cushion=0,
        model=model,
        debug=False)

    pre_predictions, pre_counter, pre_accuracies = quick_predict_with_bump(
        bump_amounts=np.zeros(n_classes),
        data=test_data,
        model=model,
        debug=False)

    post_predictions, post_counter, post_accuracies = quick_predict_with_bump(
        bump_amounts=bump_amounts,
        data=test_data,
        model=model,
        debug=False)
    print("Target class: {} ({})".format(CD[selected_class], selected_class))
    print("Pre prediction bumps: {} Post bump # predictions: {}".format(
        pre_counter[selected_class],
        post_counter[selected_class]))
    print("Bump amount: {}".format(bump_amounts[selected_class]))

    for idx, c in enumerate(test_data):
        print(f'Class: {c}/{CD[c]} | Accuracy: {pre_accuracies[idx]:.2f} -> {post_accuracies[idx]:.2f}',) 
    print("---***---\n")
    
    accumulated_bump_amounts += bump_amounts


Finding bump for snowmobile class (802)
lower bound: -2.0 target class predictions: 6
upper bound: 0 target class predictions: 21
Target class: snowmobile (802)
Pre prediction bumps: 21 Post bump # predictions: 8
Bump amount: -1.5
Class: 779/school bus | Accuracy: 63.64 -> 63.64
Class: 670/motor scooter, scooter | Accuracy: 23.81 -> 38.10
Class: 920/traffic light, traffic signal, stoplight | Accuracy: 66.67 -> 66.67
Class: 555/fire engine, fire truck | Accuracy: 90.00 -> 90.00
Class: 847/tank, army tank, armored combat vehicle, armoured combat vehicle | Accuracy: 47.06 -> 47.06
Class: 751/racer, race car, racing car | Accuracy: 40.00 -> 65.00
Class: 479/car wheel | Accuracy: 40.00 -> 40.00
---***---

Finding bump for racer, race car, racing car class (751)
lower bound: 0 target class predictions: 8
upper bound: 2.0 target class predictions: 18
Target class: racer, race car, racing car (751)
Pre prediction bumps: 8 Post bump # predictions: 17
Bump amount: 1.875
Class: 779/school bus | A

### Compare the accumulated logit bumping results to the edited model

In [11]:
acc_predictions, acc_counter, acc_accuracies = quick_predict_with_bump(
    bump_amounts=accumulated_bump_amounts,
    data=test_data,
    model=model,
    debug=False)
for idx, c in enumerate(test_data):
        print(f'Class: {c}/{CD[c]} \n Accuracy change: {pre_accuracies[idx]:.2f} -> {acc_accuracies[idx]:.2f}\n',) 

NameError: name 'accumulated_bump_amounts' is not defined

## Load accumulated bump amounts and edited model to run on test set

In [17]:
logit_calibration_path = os.path.join(restore_dir, 'logit_calibration.pth')
accumulated_bump_amounts = torch.load(logit_calibration_path)
print(accumulated_bump_amounts.shape)

edited_model_path = os.path.join(restore_dir, 'edited_imagenet_vgg.pt.best')
edited_ret = classifier_helpers.get_default_paths(DATASET_NAME, arch=ARCH)
DATASET_PATH, MODEL_PATH, MODEL_CLASS, ARCH, CD = edited_ret
edited_ret = classifier_helpers.load_classifier(MODEL_PATH, MODEL_CLASS, ARCH,
                            DATASET_NAME, LAYERNUM) 

edited_model, edited_context_model, edited_target_model = edited_ret[:3]
edited_state_dict = torch.load(edited_model_path)
edited_model.load_state_dict(edited_state_dict)

(1000,)


<All keys matched successfully>

In [51]:
print("Original model: ")
original_results = quick_predict_with_bump(
    bump_amounts=np.zeros_like(accumulated_bump_amounts),
    data=test_data,
    model=model,
    debug=True)

print("Calibrated model: ")
calibrated_results = quick_predict_with_bump(
    bump_amounts=accumulated_bump_amounts,
    data=test_data,
    model=model,
    debug=True)

print("Edited model: ")
edited_results = quick_predict_with_bump(
    bump_amounts=np.zeros_like(accumulated_bump_amounts),
    data=test_data,
    model=edited_model,
    debug=True)



Class: 779/school bus | Accuracy: 63.64
Class: 670/motor scooter, scooter | Accuracy: 23.81
Class: 920/traffic light, traffic signal, stoplight | Accuracy: 66.67
Class: 555/fire engine, fire truck | Accuracy: 90.00
Class: 847/tank, army tank, armored combat vehicle, armoured combat vehicle | Accuracy: 47.06
Class: 751/racer, race car, racing car | Accuracy: 40.00
Class: 479/car wheel | Accuracy: 40.00
Class: 779/school bus | Accuracy: 63.64
Class: 670/motor scooter, scooter | Accuracy: 23.81
Class: 920/traffic light, traffic signal, stoplight | Accuracy: 66.67
Class: 555/fire engine, fire truck | Accuracy: 90.00
Class: 847/tank, army tank, armored combat vehicle, armoured combat vehicle | Accuracy: 47.06
Class: 751/racer, race car, racing car | Accuracy: 40.00
Class: 479/car wheel | Accuracy: 40.00
Class: 779/school bus | Accuracy: 77.27
Class: 670/motor scooter, scooter | Accuracy: 42.86
Class: 920/traffic light, traffic signal, stoplight | Accuracy: 77.78
Class: 555/fire engine, fire

In [54]:
# Save all results
original_results_save_path = os.path.join(restore_dir, 'vos_original_results.pth')
calibrated_results_save_path = os.path.join(restore_dir, 'vos_calibrated_results.pth')
edited_results_save_path = os.path.join(restore_dir, 'vos_edited_results.pth')

torch.save(original_results, original_results_save_path)
torch.save(calibrated_results, calibrated_results_save_path)
torch.save(edited_results, edited_results_save_path)



## Load results from original, edited, and calibrated model to analyze

In [55]:
# Load all results
original_results_save_path = os.path.join(restore_dir, 'vos_original_results.pth')
calibrated_results_save_path = os.path.join(restore_dir, 'vos_calibrated_results.pth')
edited_results_save_path = os.path.join(restore_dir, 'vos_edited_results.pth')

original_results = torch.load(original_results_save_path)
calibrated_results = torch.load(calibrated_results_save_path)
edited_results = torch.load(edited_results_save_path)

In [67]:
def get_IOU(predictions_a, 
            predictions_b, 
            target_class_idx, 
            modes=['binary']):
    
    IOUs = []
    
    if torch.is_tensor(predictions_a):
        predictions_a = predictions_a.cpu().numpy()
    if torch.is_tensor(predictions_b):
        predictions_b = predictions_b.cpu().numpy()


    for mode in modes:
        try:
            if mode == 'binary':
                # Binarize predictions based on target class
                binary_predictions_a = np.where(
                    predictions_a == target_class_idx,
                    1, 0)
                binary_predictions_b = np.where(
                    predictions_b == target_class_idx,
                    1, 0)
                
                IOU = metrics.jaccard_score(
                    y_true=binary_predictions_a,
                    y_pred=binary_predictions_b,
                    average=mode)
            else:
                IOU = metrics.jaccard_score(
                    y_true=predictions_a,
                    y_pred=predictions_b,
                    average=mode)
            IOUs.append(IOU)
        except Exception as e:
            print(e)
            continue
    return IOUs

def get_ranking(logits, target_class_idx):
    if not torch.is_tensor(logits):
        logits = torch.from_numpy(logits)
    
    softmax = torch.softmax(logits, dim=1)
    target_softmax = softmax[:, target_class_idx]
    ranking = target_softmax.argsort().argsort()
    
    return ranking.cpu().numpy()

def get_spearman(logits_a, logits_b, target_class_idx):
    ranking_a = get_ranking(
        logits=logits_a,
        target_class_idx=target_class_idx)
    ranking_b = get_ranking(
        logits=logits_b,
        target_class_idx=target_class_idx)
    
    if torch.is_tensor(ranking_a):
        ranking_a = ranking_a.cpu().numpy()
    if torch.is_tensor(ranking_b):
        ranking_b = ranking_b.cpu().numpy()
    spearman = stats.spearmanr(
        a=ranking_a,
        b=ranking_b)
    return spearman

#### Compare test set accuracy, IOU of calibrated and edited model on test set

In [73]:
edited_predictions = edited_results['predictions']
calibrated_predictions = calibrated_results['predictions']
IOUs = get_IOU(
    predictions_a=edited_predictions,
    predictions_b=calibrated_predictions,
    target_class_idx=None,
    modes=['weighted'])

n_samples = edited_predictions.shape[0]
print("IOU of edited and calibrated: {}".format(IOUs[0]))
print("Original overall accuracy: {}\nEdited overall accuracy: {}\nCalibrated overall accuracy: {}".format(
    original_results['n_correct'] / n_samples,
    edited_results['n_correct'] / n_samples,
    calibrated_results['n_correct'] / n_samples))

IOU of edited and calibrated: 0.7297347542544712
Original overall accuracy: 0.5193798449612403
Edited overall accuracy: 0.6589147286821705
Calibrated overall accuracy: 0.5193798449612403


In [74]:
#### Calculate Spearman's correlation coefficient for each class in 
calibrated_logits = calibrated_results['logits']
edited_logits = edited_results['logits']

correlations = []
for class_idx, _ in test_data.items():
    spearman = get_spearman(
        logits_a=calibrated_logits,
        logits_b=edited_logits,
        target_class_idx=class_idx)
    correlations.append(spearman.correlation)
    
correlations = np.array(correlations)
print("Mean Spearman's correlation: {} ({})".format(np.mean(correlations), np.std(correlations)))

Mean Spearman's correlation: 0.9770875926399185 (0.006029611771177532)


### Run original model on ImageNet validation set

In [160]:
pre_log = predict_with_bump(
    data_loader=val_loader,
    model=model,
    loss_fn=None,
    metric_fns=metric_fns,
    device=device,
    # target_class_idx=0,
    bump_amounts=np.zeros(n_classes),
    output_save_path=None,
    log_save_path=None)

print("Pre edit results:")
print("Overall Accuracy: {}".format(pre_log['metrics']['accuracy']))

100%|████████████████████████████████████████████████████| 1563/1563 [02:31<00:00, 10.31it/s]


Pre edit results:
{'TP': array([45, 44, 40, 41, 43, 39, 43, 37, 44, 50, 45, 47, 47, 48, 46, 46, 43,
       45, 48, 44, 42, 38, 48, 48, 47, 48, 32, 36, 45, 45, 40, 43, 19, 38,
       30, 31, 23, 44, 42, 42, 31, 46, 40, 43, 34, 39, 26, 39, 46, 33, 43,
       45, 37, 45, 30, 28, 42, 44, 38, 31, 19, 40, 29, 42, 33, 39, 27, 40,
       18, 49, 45, 47, 48, 21, 32, 42, 46, 41, 36, 41, 46, 41, 42, 47, 47,
       46, 38, 47, 48, 47, 49, 46, 48, 48, 46, 47, 45, 46, 45, 40, 49, 31,
       49, 37, 44, 44, 36, 44, 34, 45, 40, 44, 37, 41, 40, 41, 40, 40, 42,
       33, 32, 43, 40, 41, 27, 44, 36, 44, 42, 50, 48, 45, 45, 45, 35, 49,
       48, 45, 46, 49, 43, 46, 46, 47, 46, 46, 46, 43, 45, 43, 40, 36, 37,
       39, 43, 39, 45, 48, 30, 41, 45, 40, 40, 22, 44, 27, 34, 15, 37, 44,
       34, 40, 35, 35, 46, 30, 42, 40, 45, 35, 31, 45, 43, 40, 34, 35, 32,
       36, 31, 36, 39, 39, 39, 29, 41, 48, 37, 38, 43, 38, 36, 31, 32, 46,
       34, 38, 39, 44, 42, 40, 40, 41, 38, 43, 45, 44, 46, 46, 41, 37, 38,


In [184]:
post_log = predict_with_bump(
    data_loader=val_loader,
    model=model,
    loss_fn=None,
    metric_fns=metric_fns,
    device=device,
    # target_class_idx=0,
    bump_amounts=accumulated_bump_amounts,
    output_save_path=None,
    log_save_path=None)



100%|████████████████████████████████████████████████████| 1563/1563 [02:32<00:00, 10.27it/s]


In [185]:
pre_metrics = pre_log['metrics']
post_metrics = post_log['metrics']

print("Change in metrics:")
print("Overall accuracy: {:.4f} -> {:.4f}".format(pre_metrics['accuracy'], post_metrics['accuracy']))
for selected_class in selected_classes:
    print("Examining per class metrics for {} ({})".format(CD[selected_class], selected_class))
    print("accuracy: {:.4f} -> {:.4f}".format(
        pre_metrics['per_class_accuracy'][selected_class],
        post_metrics['per_class_accuracy'][selected_class]))
    print("recall: {:.4f} -> {:.4f}".format(
        pre_metrics['recall'][selected_class],
        post_metrics['recall'][selected_class]))
    print("precision: {:.4f} -> {:.4f}".format(
        pre_metrics['precision'][selected_class],
        post_metrics['precision'][selected_class]))
    print("---***---")

14
14
Change in metrics:
Overall accuracy: 0.7373 -> 0.7365
Examining per class metrics for snowmobile (802)
accuracy: 0.9997 -> 0.9999
recall: 0.9400 -> 0.9200
precision: 0.8246 -> 0.9388
---***---
Examining per class metrics for racer, race car, racing car (751)
accuracy: 0.9992 -> 0.9981
recall: 0.8200 -> 0.8600
precision: 0.5616 -> 0.3258
---***---
Examining per class metrics for snowplow, snowplough (803)
accuracy: 0.9996 -> 0.9997
recall: 0.9200 -> 0.8400
precision: 0.7667 -> 0.8936
---***---
Examining per class metrics for motor scooter, scooter (670)
accuracy: 0.9997 -> 0.9997
recall: 0.8400 -> 0.7600
precision: 0.8571 -> 0.9048
---***---
Examining per class metrics for school bus (779)
accuracy: 0.9997 -> 0.9997
recall: 0.9000 -> 0.9200
precision: 0.8491 -> 0.7931
---***---
Examining per class metrics for dalmatian, coach dog, carriage dog (251)
accuracy: 0.9999 -> 0.9998
recall: 0.9800 -> 0.9800
precision: 0.9074 -> 0.8305
---***---
Examining per class metrics for fire engine

### Analyze edited model (using images) in ImageNet validation set

In [179]:
ret = classifier_helpers.get_default_paths(DATASET_NAME, arch=ARCH)
DATASET_PATH, MODEL_PATH, MODEL_CLASS, ARCH, CD = ret
ret = classifier_helpers.load_classifier(MODEL_PATH, MODEL_CLASS, ARCH,
                            DATASET_NAME, LAYERNUM) 
edited_model, edited_context_model, edited_target_model = ret[:3]
edited_state_dict = torch.load('edited_checkpoints/vehicles_on_snow/imagenet_vgg.pt.best')
edited_model.load_state_dict(edited_state_dict)

<All keys matched successfully>

In [173]:
# Run on test set first to verify #'s
quick_predict_with_bump(
    bump_amounts=np.zeros(n_classes),
    data=test_data,
    model=edited_model,
    debug=True)
# Run edited model on the ImageNet validation set
edited_log = predict_with_bump(
    data_loader=val_loader,
    model=edited_model,
    loss_fn=None,
    metric_fns=metric_fns,
    device=device,
    # target_class_idx=0,
    bump_amounts=np.zeros(n_classes),
    output_save_path=None,
    log_save_path=None)

Class: 779/school bus | Accuracy: 77.27
Class: 670/motor scooter, scooter | Accuracy: 42.86
Class: 920/traffic light, traffic signal, stoplight | Accuracy: 77.78
Class: 555/fire engine, fire truck | Accuracy: 85.00
Class: 847/tank, army tank, armored combat vehicle, armoured combat vehicle | Accuracy: 52.94
Class: 751/racer, race car, racing car | Accuracy: 85.00
Class: 479/car wheel | Accuracy: 45.00


100%|████████████████████████████████████████████████████| 1563/1563 [02:31<00:00, 10.32it/s]


In [189]:
edited_metrics = edited_log['metrics']
print("Metric: [Pre] -> [Logit Bumped] / [Edited]")
print("Overall accuracy: {:.4f} -> {:.4f} / {:.4f}".format(
    pre_metrics['accuracy'], 
    post_metrics['accuracy'],
    edited_metrics['accuracy']))

for selected_class in selected_classes:
    print("Examining per class metrics for {} ({})".format(CD[selected_class], selected_class))
    print("accuracy: {:.4f} -> {:.4f} / {:.4f}".format(
        pre_metrics['per_class_accuracy'][selected_class],
        post_metrics['per_class_accuracy'][selected_class],
        edited_metrics['per_class_accuracy'][selected_class]))
    # print("recall: {:.4f} -> {:.4f} / {:.4f}".format(
    #     pre_metrics['recall'][selected_class],
    #     post_metrics['recall'][selected_class],
    #     edited_metrics['recall'][selected_class]))
    print("precision: {:.4f} -> {:.4f} / {:.4f}".format(
        pre_metrics['precision'][selected_class],
        post_metrics['precision'][selected_class],
        edited_metrics['precision'][selected_class]))
    print("---***---")

torch.Size([50000, 1000])
Metric: [Pre] -> [Logit Bumped] / [Edited]
Overall accuracy: 0.7373 -> 0.7365 / 0.7348
Examining per class metrics for snowmobile (802)
accuracy: 0.9997 -> 0.9999 / 0.9998
precision: 0.8246 -> 0.9388 / 0.8704
---***---
Examining per class metrics for racer, race car, racing car (751)
accuracy: 0.9992 -> 0.9981 / 0.9992
precision: 0.5616 -> 0.3258 / 0.5634
---***---
Examining per class metrics for snowplow, snowplough (803)
accuracy: 0.9996 -> 0.9997 / 0.9997
precision: 0.7667 -> 0.8936 / 0.8036
---***---
Examining per class metrics for motor scooter, scooter (670)
accuracy: 0.9997 -> 0.9997 / 0.9997
precision: 0.8571 -> 0.9048 / 0.8400
---***---
Examining per class metrics for school bus (779)
accuracy: 0.9997 -> 0.9997 / 0.9997
precision: 0.8491 -> 0.7931 / 0.8491
---***---
Examining per class metrics for dalmatian, coach dog, carriage dog (251)
accuracy: 0.9999 -> 0.9998 / 0.9998
precision: 0.9074 -> 0.8305 / 0.8596
---***---
Examining per class metrics for 

In [190]:
torch.save(model.state_dict(), os.path.join(restore_dir, 'original_imagenet_vgg.pt.best'))
torch.save(accumulated_bump_amounts, os.path.join(restore_dir, 'logit_calibration.pth'))

In [192]:
imagenet_val_results = {
    'original': pre_log,
    'edit': edited_log,
    'calibrated': post_log
}

torch.save(imagenet_val_results, os.path.join(restore_dir, 'imagenet_val_results.pth'))