### Analyze the data from `setup/setup_datasets.py`

In [37]:
import os, sys
import torch
import numpy as np
from PIL import Image

sys.path.insert(0, 'src')
from utils.visualizations import show_image_rows, make_grid
import model.metric as module_metric

In [None]:
dataset_type = '2_Spurious_MNIST'
train_path = os.path.join(data_root_dir, dataset_type, 'training.pt')
train_data = torch.load(train_path)

test_path = os.path.join(data_root_dir, dataset_type, 'test.pt')
test_data = torch.load(test_path)

### Load colors and labels for test set

In [None]:
# train_imgs = train_data['images']
# train_labels = train_data['labels']

# test_imgs = test_data['images']
# test_labels = test_data['labels']

n_show = 20
for idx, data in enumerate([train_data, test_data]):
    imgs = data['images']
    labels = data['labels']
    print(imgs[0].shape)
    print(np.amax(imgs[0]))
    show_imgs = imgs[:n_show]
    show_labels = labels[:n_show]
    show_imgs = make_grid(show_imgs, items_per_row=5)
    show_labels = make_grid(show_labels, items_per_row=5)
    show_image_rows(
        images=show_imgs,
        image_titles=show_labels,
        image_size=(1.5, 1.5),
        figure_title='{} {}'.format(dataset_type, 'Train' if idx == 0 else 'Test'))


In [40]:
dataset_type = '2_Spurious_MNIST'

data_dir = os.path.join('data', dataset_type)
test_data_path = os.path.join(data_dir, 'test.pt')

test_data = torch.load(test_data_path)
test_labels = np.array(test_data['labels'])
test_colors = np.array(test_data['colors'])

congruent_idxs_path = os.path.join(data_dir, 'test_congruent_idxs.pt')
incongruent_idxs_path = os.path.join(data_dir, 'test_incongruent_idxs.pt')

congruent_idxs = torch.load(congruent_idxs_path)
incongruent_idxs = torch.load(incongruent_idxs_path)

In [38]:
trial_timestamp = '0316_093312'
model_arch = 'VGG_16'
trial_dir = os.path.join('saved', 'test', '{}-{}'.format(dataset_type, model_arch), trial_timestamp)

trial_logits_path = os.path.join(trial_dir, 'log', 'logits.pth')
trial_logits = torch.load(trial_logits_path).cpu().numpy()
trial_predictions = np.argmax(trial_logits, axis=1)


In [None]:
### Print test set metrics for overall, congruent, and incongruent test set samples

In [41]:
partition_labels = ['congruent', 'incongruent']
metric_names = [
    "accuracy",
    "per_class_accuracy",
    "precision",
    "recall",
    "f1",
    "predicted_class_distribution"]
metric_fns = [getattr(module_metric, metric_name) for metric_name in metric_names]

print("Overall test set performance")
metrics = module_metric.compute_metrics(
        metric_fns=metric_fns,
        prediction=trial_predictions,
        target=test_labels,
        unique_labels=[i for i in range(10)],
        save_mean=True)
for metric_name, metric in metrics.items():
    print("{}: {}".format(metric_name, metric))
    
print("")
for i, idxs in enumerate([congruent_idxs, incongruent_idxs]):
    partitioned_labels = test_labels[idxs]
    partitioned_predictions = trial_predictions[idxs]
    
    metrics = module_metric.compute_metrics(
        metric_fns=metric_fns,
        prediction=partitioned_predictions,
        target=partitioned_labels,
        unique_labels=[i for i in range(10)],
        save_mean=True)
    print("Metrics for {} indices in test set".format(partition_labels[i]))
    
    for metric_name, metric in metrics.items():
        print("{}: {}".format(metric_name, metric))
    print("")


Overall test set performance
TP: [448 537 520 498 452 422 437 473 454 454]
TN: [8329 8374 8213 7556 7783 8245 8366 7921 7886 8022]
FPs: [241  41 305 984 785 413 226 601 690 519]
FNs: [482 548 462 462 480 420 471 505 470 505]
accuracy: 0.4942105263157895
per_class_accuracy: [0.92389474 0.938      0.91926316 0.84778947 0.86684211 0.91231579
 0.92663158 0.88357895 0.87789474 0.89221053]
per_class_accuracy_mean: 0.8988421052631578
precision: [0.65021771 0.92906574 0.63030303 0.33603239 0.36540016 0.50538922
 0.65912519 0.44040968 0.39685315 0.46659815]
precision_mean: 0.537939442183942
recall: [0.48172043 0.49493088 0.52953157 0.51875    0.48497854 0.50118765
 0.48127753 0.48364008 0.49134199 0.4734098 ]
recall_mean: 0.4940768471198444
predicted_class_distribution: [ 689  578  825 1482 1237  835  663 1074 1144  973]
f1: [0.55342804 0.64582081 0.57553957 0.40786241 0.41678193 0.50327967
 0.55633355 0.46101365 0.43907157 0.4699793 ]
f1_mean: 0.5029110470741379

Metrics for congruent indices 

  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))
