### See how a model performs on the congruent and incongruent test samples

In [1]:
import os, sys
import torch
import numpy as np
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt
from collections import OrderedDict

sys.path.insert(0, 'src')
from utils.visualizations import show_image_rows, make_grid, plot
from utils import read_lists

import model.metric as module_metric

In [2]:
dataset_type = '2_Spurious_MNIST'
data_root_dir = os.path.join('data')
train_path = os.path.join(data_root_dir, dataset_type, 'training.pt')
train_data = torch.load(train_path)

test_path = os.path.join(data_root_dir, dataset_type, 'test.pt')
test_data = torch.load(test_path)

### Load colors and labels for test set

In [None]:
# train_imgs = train_data['images']
# train_labels = train_data['labels']

# test_imgs = test_data['images']
# test_labels = test_data['labels']

n_show = 20
for idx, data in enumerate([train_data, test_data]):
    imgs = data['images']
    labels = data['labels']
    print(imgs[0].shape)
    print(np.amax(imgs[0]))
    show_imgs = imgs[:n_show]
    show_labels = labels[:n_show]
    show_imgs = make_grid(show_imgs, items_per_row=5)
    show_labels = make_grid(show_labels, items_per_row=5)
    show_image_rows(
        images=show_imgs,
        image_titles=show_labels,
        image_size=(1.5, 1.5),
        figure_title='{} {}'.format(dataset_type, 'Train' if idx == 0 else 'Test'))


In [6]:
dataset_type = '2_Spurious_MNIST'

data_dir = os.path.join('data', dataset_type)
test_data_path = os.path.join(data_dir, 'test.pt')

test_data = torch.load(test_data_path)
test_labels = np.array(test_data['labels'])
test_colors = np.array(test_data['colors'])

congruent_idxs_path = os.path.join(data_dir, 'test_congruent_idxs.pt')
incongruent_idxs_path = os.path.join(data_dir, 'test_incongruent_idxs.pt')

congruent_idxs = torch.load(congruent_idxs_path)
incongruent_idxs = torch.load(incongruent_idxs_path)

In [3]:
trial_paths_path = 'saved/edit_2_Spurious_MNIST/method_eac/debug/LeNet/0323_110318/trial_paths.txt'
trial_paths = read_lists(trial_paths_path)
trial_dir = os.path.dirname(trial_paths_path)


In [None]:
### Print test set metrics for overall, congruent, and incongruent test set samples

In [10]:
def print_and_save_partitioned_results(pre_edit_predictions: np.array,
                                       post_edit_predictions: np.array,
                                       labels: np.array,
                                       row_data: dict,
                                       partition_name: str,
                                       metric_fns: list,
                                       mean_only: bool
                                       ):
    print("Calculating {} test set performance".format(partition_name))
    
    metrics = {}
    
    metrics['pre'] = module_metric.compute_metrics(
        metric_fns=metric_fns,
        prediction=pre_edit_predictions,
        target=labels,
        unique_labels=[l for l in range(10)],
        save_mean=True)
    
    metrics['post'] = module_metric.compute_metrics(
        metric_fns=metric_fns,
        prediction=post_edit_predictions,
        target=labels,
        unique_labels=[l for l in range(10)],
        save_mean=True)
    
    
    
    for metric_name in metrics['pre'].keys():
        if mean_only and not isinstance(metrics['pre'][metric_name], np.float64):
            continue
        metric_str = "{}: ".format(metric_name)
        # for i in range(n_trials):
        for status in ['pre', 'post']:
            metric_value = metrics[status][metric_name]
            if np.isscalar(metric_value):
                metric_str +="{:.4f} ".format(metric_value)
                row_data['{} {} {}'.format(partition_name, status, metric_name)] = metric_value
            else:
                metric_str +="{} ".format(metric_value)
                row_data[metric_name] = metric_value
            if status == 'pre':
                metric_str += "-> "
        print(metric_str)
    print("")
    
    return row_data
    
def print_summary(congruent_idxs: np.array,
                  incongruent_idxs: np.array,
                  pre_edit_predictions: np.array,
                  post_edit_predictions: np.array, 
                  test_labels: np.array,
                  mean_only=True):
    n_trials = len(pre_edit_predictions)
    assert len(post_edit_predictions) == n_trials
    
    row_data = OrderedDict()
    
    partition_labels = ['congruent', 'incongruent']
    metric_names = [
        "accuracy",
        "per_class_accuracy",
        "precision",
        "recall",
        "f1",
        "predicted_class_distribution"]
    metric_fns = [getattr(module_metric, metric_name) for metric_name in metric_names]

    print("Overall test set performance")
    
    row_data = print_and_save_partitioned_results(
        pre_edit_predictions=pre_edit_predictions,
        post_edit_predictions=post_edit_predictions,
        labels=test_labels,
        row_data=row_data,
        partition_name='overall',
        metric_fns=metric_fns,
        mean_only=mean_only)
    
    # Do the same but for congruent/incongruent subsets
    for label_idx, idxs in enumerate([congruent_idxs, incongruent_idxs]):
        congruency_str = 'congruent' if label_idx == 0 else 'incongruent'
        metrics = []
        
        partitioned_labels = test_labels[idxs]
        
        partitioned_pre_edit_predictions = pre_edit_predictions[idxs]
        partitioned_post_edit_predictions = post_edit_predictions[idxs]
        
        row_data = print_and_save_partitioned_results(
            pre_edit_predictions=partitioned_pre_edit_predictions,
            post_edit_predictions=partitioned_post_edit_predictions,
            labels=partitioned_labels,
            row_data=row_data,
            partition_name=congruency_str,
            metric_fns=metric_fns,
            mean_only=mean_only)
    
    return row_data

### Compare pre vs post edit on each partition

In [13]:
df = pd.DataFrame()
csv_save_path = os.path.join(trial_dir, 'results.csv')
for trial_path in trial_paths:
    # Load pre edit logits & get predictions
    pre_edit_trial_logits_path = os.path.join(trial_path, 'models', 'pre_edit_logits.pth')
    pre_edit_trial_logits = torch.load(pre_edit_trial_logits_path).cpu().numpy()
    pre_edit_trial_predictions = np.argmax(pre_edit_trial_logits, axis=1)

    # Load post edit logits & get predictions
    post_edit_trial_logits_path = os.path.join(trial_path, 'models', 'post_edit_logits.pth')
    post_edit_trial_logits = torch.load(post_edit_trial_logits_path).cpu().numpy()
    post_edit_trial_predictions = np.argmax(post_edit_trial_logits, axis=1)
    
    row_data = OrderedDict()
    row_data['path'] = trial_path
    row_data.update(print_summary(
        congruent_idxs=congruent_idxs,
        incongruent_idxs=incongruent_idxs,
        pre_edit_predictions = pre_edit_trial_predictions,
        post_edit_predictions = post_edit_trial_predictions,
        test_labels=test_labels,
        mean_only=False))
    
    df = df.append(pd.Series(row_data, name=trial_path))

df.set_index('path')

Overall test set performance
Calculating overall test set performance
TP: [450 538 520 497 452 422 436 472 456 453] -> [ 924 1080  928  934  913   71    0    3    0    0] 
TN: [8250 8296 8225 7702 7784 8342 8397 8265 7467 7968] -> [8017 8066 7754 7433 6695 8658 8591 8522 8576 8541] 
FPs: [ 320  119  293  838  784  316  195  257 1109  573] -> [ 553  349  764 1107 1873    0    1    0    0    0] 
FNs: [480 547 462 463 480 420 472 506 468 506] -> [  6   5  54  26  19 771 908 975 924 959] 
accuracy: 0.4943 -> 0.5108 
per_class_accuracy: [0.91578947 0.92989474 0.92052632 0.86305263 0.86694737 0.92252632
 0.92978947 0.91968421 0.834      0.88642105] -> [0.94115789 0.96273684 0.91389474 0.88073684 0.80084211 0.91884211
 0.90431579 0.89736842 0.90273684 0.89905263] 
per_class_accuracy_mean: 0.8989 -> 0.9022 
precision: [0.58441558 0.81887367 0.6396064  0.37228464 0.36569579 0.57181572
 0.69096672 0.64746228 0.2913738  0.44152047] -> [0.62559242 0.75577327 0.54846336 0.45761881 0.32770998 1.
 0.

  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))


Unnamed: 0_level_0,FNs,FPs,TN,TP,congruent post accuracy,congruent post f1_mean,congruent post per_class_accuracy_mean,congruent post precision_mean,congruent post recall_mean,congruent pre accuracy,...,overall post recall_mean,overall pre accuracy,overall pre f1_mean,overall pre per_class_accuracy_mean,overall pre precision_mean,overall pre recall_mean,per_class_accuracy,precision,predicted_class_distribution,recall
path,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
saved/edit_2_Spurious_MNIST/method_eac/debug/LeNet/0323_110318/results/edit_idx_0,"[5, 5, 40, 21, 16, 416, 466, 497, 466, 495]","[336, 149, 398, 631, 912, 0, 1, 0, 0, 0]","[3947, 4066, 3908, 3671, 3372, 4345, 4295, 426...","[474, 542, 416, 439, 462, 1, 0, 1, 0, 0]",0.531448,0.382514,0.90629,0.481599,0.512343,0.991136,...,0.497387,0.494316,0.50532,0.898863,0.542402,0.49418,"[0.9283914321713566, 0.9676606467870642, 0.908...","[0.5851851851851851, 0.784370477568741, 0.5110...","[810, 691, 814, 1070, 1374, 1, 1, 1, 0, 0]","[0.9895615866388309, 0.9908592321755028, 0.912..."
saved/edit_2_Spurious_MNIST/method_eac/debug/LeNet/0323_110318/results/edit_idx_1,"[465, 424, 328, 449, 455, 410, 459, 476, 465, ...","[350, 164, 310, 733, 749, 545, 246, 343, 557, ...","[3933, 4051, 3996, 3569, 3535, 3800, 4050, 392...","[14, 123, 128, 11, 23, 7, 7, 22, 1, 1]",0.907978,0.901745,0.981596,0.919174,0.905886,0.991136,...,0.485835,0.494316,0.50532,0.898863,0.542402,0.49418,"[0.8288534229315414, 0.876522469550609, 0.8660...","[0.038461538461538464, 0.42857142857142855, 0....","[364, 287, 438, 744, 772, 552, 253, 365, 558, ...","[0.029227557411273485, 0.22486288848263253, 0...."
saved/edit_2_Spurious_MNIST/method_eac/debug/LeNet/0323_110318/results/edit_idx_2,"[1, 2, 39, 44, 18, 417, 466, 498, 466, 495]","[305, 137, 347, 800, 854, 3, 0, 0, 0, 0]","[3978, 4078, 3959, 3502, 3430, 4342, 4296, 426...","[478, 545, 417, 416, 460, 0, 0, 0, 0, 0]",0.544111,0.404273,0.908822,0.463482,0.525927,0.991136,...,0.502242,0.494316,0.50532,0.898863,0.542402,0.49418,"[0.9357412851742966, 0.9708105837883242, 0.918...","[0.6104725415070242, 0.7991202346041055, 0.545...","[783, 682, 764, 1216, 1314, 3, 0, 0, 0, 0]","[0.9979123173277662, 0.9963436928702011, 0.914..."


In [12]:
df.to_csv(csv_save_path)
print("Saved csv to {}".format(csv_save_path))

Saved csv to saved/edit_2_Spurious_MNIST/method_eac/debug/LeNet/0323_110318/results.csv


In [None]:
# Print layer number vs congruent post accuracy, incongruent post accuracy, and overall post accuracy
plot_save_path = os.path.join(trial_dir, 'layer_v_accuracy.pdf')
labels = ['congruent post accuracy', 'incongruent post accuracy', 'overall post accuracy']
xs = [[i for i in range(1, 13)] for n in range(3)]
ys = [df[label] for label in labels]

plot(
    xs=xs,
    ys=ys,
    labels=labels,
    title='Editing Layer vs Accuracy',
    xlabel='Layer Edited',
    ylabel='Accuracy on Subset',
    save_path=plot_save_path)
