### See how a model performs on the congruent and incongruent test samples

In [None]:
import os, sys
import torch
import numpy as np
from PIL import Image
import pandas as pd
import matplotlib.pyplot as plt
from collections import OrderedDict

sys.path.insert(0, 'src')
from utils.visualizations import show_image_rows, make_grid, plot
from utils import read_lists

import model.metric as module_metric

In [None]:
dataset_type = '2_Spurious_MNIST'
data_root_dir = os.path.join('data')
train_path = os.path.join(data_root_dir, dataset_type, 'training.pt')
train_data = torch.load(train_path)

test_path = os.path.join(data_root_dir, dataset_type, 'test.pt')
test_data = torch.load(test_path)

### Load colors and labels for test set

In [None]:
# train_imgs = train_data['images']
# train_labels = train_data['labels']

# test_imgs = test_data['images']
# test_labels = test_data['labels']

n_show = 20
for idx, data in enumerate([train_data, test_data]):
    imgs = data['images']
    labels = data['labels']
    print(imgs[0].shape)
    print(np.amax(imgs[0]))
    show_imgs = imgs[:n_show]
    show_labels = labels[:n_show]
    show_imgs = make_grid(show_imgs, items_per_row=5)
    show_labels = make_grid(show_labels, items_per_row=5)
    show_image_rows(
        images=show_imgs,
        image_titles=show_labels,
        image_size=(1.5, 1.5),
        figure_title='{} {}'.format(dataset_type, 'Train' if idx == 0 else 'Test'))


In [None]:
dataset_type = '2_Spurious_MNIST'

data_dir = os.path.join('data', dataset_type)
test_data_path = os.path.join(data_dir, 'test.pt')

test_data = torch.load(test_data_path)
test_labels = np.array(test_data['labels'])
test_colors = np.array(test_data['colors'])

congruent_idxs_path = os.path.join(data_dir, 'test_congruent_idxs.pt')
incongruent_idxs_path = os.path.join(data_dir, 'test_incongruent_idxs.pt')

congruent_idxs = torch.load(congruent_idxs_path)
incongruent_idxs = torch.load(incongruent_idxs_path)

In [171]:
# trial_timestamp = '0317_154335'
# model_arch = 'VGG_16'
# trial_dir = os.path.join('saved', 'edit_{}'.format(dataset_type), '{}-{}'.format(dataset_type, model_arch), trial_timestamp, 'results', 'edit_idx_6')

trial_paths_path = 'saved/edit_2_Spurious_MNIST/method_eac/debug/VGG_16-layernum/0321_140122/trial_paths copy.txt'
trial_paths = read_lists(trial_paths_path)
trial_dir = os.path.dirname(trial_paths_path)

# trial_logits_path = os.path.join(trial_dir, 'log', 'logits.pth')



# trial_logits_path = os.path.join(trial_dir, 'models', 'post_edit_logits.pth')

# trial_logits = torch.load(trial_logits_path).cpu().numpy()
# trial_predictions = np.argmax(trial_logits, axis=1)


In [None]:
### Print test set metrics for overall, congruent, and incongruent test set samples

In [172]:
def print_and_save_partitioned_results(pre_edit_predictions: np.array,
                                       post_edit_predictions: np.array,
                                       labels: np.array,
                                       row_data: dict,
                                       partition_name: str,
                                       metric_fns: list,
                                       mean_only: bool
                                       ):
    print("Calculating {} test set performance".format(partition_name))
    
    metrics['pre'] = module_metric.compute_metrics(
        metric_fns=metric_fns,
        prediction=pre_edit_predictions,
        target=labels,
        unique_labels=[l for l in range(10)],
        save_mean=True)
    
    metrics['post'] = module_metric.compute_metrics(
        metric_fns=metric_fns,
        prediction=post_edit_predictions,
        target=labels,
        unique_labels=[l for l in range(10)],
        save_mean=True)
    
    
    
    for metric_name in metrics['pre'].keys():
        if mean_only and not isinstance(metrics['pre'][metric_name], np.float64):
            continue
        metric_str = "{}: ".format(metric_name)
        # for i in range(n_trials):
        for status in ['pre', 'post']:
            metric_value = metrics[status][metric_name]
            if np.isscalar(metric_value):
                metric_str +="{:.4f} ".format(metric_value)
                row_data['{} {} {}'.format(partition_name, status, metric_name)] = metric_value
            else:
                metric_str +="{} ".format(metric_value)
                row_data[metric_name] = metric_value
            if status == 'pre':
                metric_str += "-> "
        print(metric_str)
    print("")
    
    return row_data
    
def print_summary(congruent_idxs: np.array,
                  incongruent_idxs: np.array,
                  pre_edit_predictions: np.array,
                  post_edit_predictions: np.array, 
                  test_labels: np.array,
                  mean_only=True):
    n_trials = len(trial_predictions)
    row_data = OrderedDict()
    
    partition_labels = ['congruent', 'incongruent']
    metric_names = [
        "accuracy",
        "per_class_accuracy",
        "precision",
        "recall",
        "f1",
        "predicted_class_distribution"]
    metric_fns = [getattr(module_metric, metric_name) for metric_name in metric_names]

    print("Overall test set performance")
    
    row_data = print_and_save_partitioned_results(
        pre_edit_predictions=pre_edit_predictions,
        post_edit_predictions=post_edit_predictions,
        labels=test_labels,
        row_data=row_data,
        partition_name='overall',
        metric_fns=metric_fns,
        mean_only=mean_only)
    
    # Do the same but for congruent/incongruent subsets
    for label_idx, idxs in enumerate([congruent_idxs, incongruent_idxs]):
        congruency_str = 'congruent' if label_idx == 0 else 'incongruent'
        metrics = []
        
        partitioned_labels = test_labels[idxs]
        
        partitioned_pre_edit_predictions = pre_edit_predictions[idxs]
        partitioned_post_edit_predictions = post_edit_predictions[idxs]
        
        row_data = print_and_save_partitioned_results(
            pre_edit_predictions=partitioned_pre_edit_predictions,
            post_edit_predictions=partitioned_post_edit_predictions,
            labels=partitioned_labels,
            row_data=row_data,
            partition_name=congruency_str,
            metric_fns=metric_fns,
            mean_only=mean_only)
    
    return row_data

### Compare pre vs post edit on each partition

In [173]:
df = pd.DataFrame()
csv_save_path = os.path.join(trial_dir, 'results.csv')
for trial_path in trial_paths:
    # Load pre edit logits & get predictions
    pre_edit_trial_logits_path = os.path.join(trial_path, 'models', 'pre_edit_logits.pth')
    pre_edit_trial_logits = torch.load(pre_edit_trial_logits_path).cpu().numpy()
    pre_edit_trial_predictions = np.argmax(pre_edit_trial_logits, axis=1)

    # Load post edit logits & get predictions
    post_edit_trial_logits_path = os.path.join(trial_path, 'models', 'post_edit_logits.pth')
    post_edit_trial_logits = torch.load(post_edit_trial_logits_path).cpu().numpy()
    post_edit_trial_predictions = np.argmax(post_edit_trial_logits, axis=1)
    
    row_data = OrderedDict()
    row_data['path'] = trial_path
    row_data.update(print_summary(
        congruent_idxs=congruent_idxs,
        incongruent_idxs=incongruent_idxs,
        pre_edit_predictions = pre_edit_trial_predictions,
        post_edit_predictions = post_edit_trial_predictions,
        test_labels=test_labels))
    
    df = df.append(pd.Series(row_data, name=trial_path))

df.set_index('path')

Overall test set performance
Calculating overall test set performance
accuracy: 0.4943 -> 0.4205 
per_class_accuracy_mean: 0.8989 -> 0.8841 
precision_mean: 0.5381 -> 0.7252 
recall_mean: 0.4942 -> 0.4134 
f1_mean: 0.5030 -> 0.4370 

Calculating congruent test set performance
accuracy: 0.9911 -> 0.7423 
per_class_accuracy_mean: 0.9982 -> 0.9485 
precision_mean: 0.9910 -> 0.8254 
recall_mean: 0.9911 -> 0.7247 
f1_mean: 0.9910 -> 0.6998 

Calculating incongruent test set performance
accuracy: 0.0000 -> 0.1004 
per_class_accuracy_mean: 0.8000 -> 0.8201 
precision_mean: 0.0000 -> 0.0113 
recall_mean: 0.0000 -> 0.0998 
f1_mean: 0.0000 -> 0.0203 

Overall test set performance
Calculating overall test set performance
accuracy: 0.4943 -> 0.4193 
per_class_accuracy_mean: 0.8989 -> 0.8839 
precision_mean: 0.5381 -> 0.7590 
recall_mean: 0.4942 -> 0.4184 
f1_mean: 0.5030 -> 0.4517 

Calculating congruent test set performance
accuracy: 0.9911 -> 0.7398 
per_class_accuracy_mean: 0.9982 -> 0.9480 
pr

  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))
  return np.nan_to_num(TPs / (TPs + FPs))
  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))
  return np.nan_to_num(TPs / (TPs + FPs)

Unnamed: 0_level_0,congruent post accuracy,congruent post f1_mean,congruent post per_class_accuracy_mean,congruent post precision_mean,congruent post recall_mean,congruent pre accuracy,congruent pre f1_mean,congruent pre per_class_accuracy_mean,congruent pre precision_mean,congruent pre recall_mean,...,overall post accuracy,overall post f1_mean,overall post per_class_accuracy_mean,overall post precision_mean,overall post recall_mean,overall pre accuracy,overall pre f1_mean,overall pre per_class_accuracy_mean,overall pre precision_mean,overall pre recall_mean
path,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
saved/edit_2_Spurious_MNIST/method_eac/debug/VGG_16-layernum/0321_140122/debug/results/edit_idx_0,0.742296,0.699775,0.948459,0.825356,0.724719,0.991136,0.99103,0.998227,0.991,0.991079,...,0.420526,0.436957,0.884105,0.725191,0.413414,0.494316,0.503034,0.898863,0.53812,0.494196
saved/edit_2_Spurious_MNIST/method_eac/debug/VGG_16-layernum/0321_140122/debug/results/edit_idx_1,0.739764,0.712135,0.947953,0.826337,0.733105,0.991136,0.99103,0.998227,0.991,0.991079,...,0.419263,0.451658,0.883853,0.759031,0.418427,0.494316,0.503034,0.898863,0.53812,0.494196
saved/edit_2_Spurious_MNIST/method_eac/debug/VGG_16-layernum/0321_140122/debug/results/edit_idx_2,0.79886,0.775528,0.959772,0.834393,0.788443,0.991136,0.99103,0.998227,0.991,0.991079,...,0.444105,0.465186,0.888821,0.676984,0.438741,0.494316,0.503034,0.898863,0.53812,0.494196


In [170]:
df.to_csv(csv_save_path)
print("Saved csv to {}".format(csv_save_path))

Saved csv to saved/edit_2_Spurious_MNIST/method_eac/VGG_16-layernum/0320_160517/results.csv


In [None]:
# Print layer number vs congruent post accuracy, incongruent post accuracy, and overall post accuracy
plot_save_path = os.path.join(trial_dir, 'layer_v_accuracy.pdf')
labels = ['congruent post accuracy', 'incongruent post accuracy', 'overall post accuracy']
xs = [[i for i in range(1, 13)] for n in range(3)]
ys = [df[label] for label in labels]

plot(
    xs=xs,
    ys=ys,
    labels=labels,
    title='Editing Layer vs Accuracy',
    xlabel='Layer Edited',
    ylabel='Accuracy on Subset',
    save_path=plot_save_path)
