### Analyze the data from `setup/setup_datasets.py`

In [33]:
import os, sys
import torch
import numpy as np
from PIL import Image

sys.path.insert(0, 'src')
from utils.visualizations import show_image_rows, make_grid
import model.metric as module_metric

In [None]:
dataset_type = '2_Spurious_MNIST'
train_path = os.path.join(data_root_dir, dataset_type, 'training.pt')
train_data = torch.load(train_path)

test_path = os.path.join(data_root_dir, dataset_type, 'test.pt')
test_data = torch.load(test_path)

### Load colors and labels for test set

In [None]:
# train_imgs = train_data['images']
# train_labels = train_data['labels']

# test_imgs = test_data['images']
# test_labels = test_data['labels']

n_show = 20
for idx, data in enumerate([train_data, test_data]):
    imgs = data['images']
    labels = data['labels']
    print(imgs[0].shape)
    print(np.amax(imgs[0]))
    show_imgs = imgs[:n_show]
    show_labels = labels[:n_show]
    show_imgs = make_grid(show_imgs, items_per_row=5)
    show_labels = make_grid(show_labels, items_per_row=5)
    show_image_rows(
        images=show_imgs,
        image_titles=show_labels,
        image_size=(1.5, 1.5),
        figure_title='{} {}'.format(dataset_type, 'Train' if idx == 0 else 'Test'))


In [24]:
dataset_type = '2_Spurious_MNIST'

data_dir = os.path.join('data', dataset_type)
test_data_path = os.path.join(data_dir, 'test.pt')

test_data = torch.load(test_data_path)
test_labels = np.array(test_data['labels'])
test_colors = np.array(test_data['color_idxs'])

congruent_idxs_path = os.path.join(data_dir, 'test_congruent_idxs.pt')
incongruent_idxs_path = os.path.join(data_dir, 'test_incongruent_idxs.pt')

congruent_idxs = torch.load(congruent_idxs_path)
incongruent_idxs = torch.load(incongruent_idxs_path)

In [32]:
trial_timestamp = '0314_114855'
model_arch = 'VGG_16'
trial_dir = os.path.join('saved', 'test', '{}-{}'.format(dataset_type, model_arch), trial_timestamp)

trial_logits_path = os.path.join(trial_dir, 'log', 'logits.pth')
trial_logits = torch.load(trial_logits_path).cpu().numpy()
trial_predictions = np.argmax(trial_logits, axis=1)


In [None]:
### Print test set metrics for overall, congruent, and incongruent test set samples

In [36]:
partition_labels = ['congruent', 'incongruent']
metric_names = [
    "accuracy",
    "per_class_accuracy",
    "precision",
    "recall",
    "f1",
    "predicted_class_distribution"]
metric_fns = [getattr(module_metric, metric_name) for metric_name in metric_names]

print("Overall test set performance")
metrics = module_metric.compute_metrics(
        metric_fns=metric_fns,
        prediction=trial_predictions,
        target=test_labels,
        unique_labels=[i for i in range(10)],
        save_mean=True)
for metric_name, metric in metrics.items():
    print("{}: {}".format(metric_name, metric))
    
print("")
for i, idxs in enumerate([congruent_idxs, incongruent_idxs]):
    partitioned_labels = test_labels[idxs]
    partitioned_predictions = trial_predictions[idxs]
    
    metrics = module_metric.compute_metrics(
        metric_fns=metric_fns,
        prediction=partitioned_predictions,
        target=partitioned_labels,
        unique_labels=[i for i in range(10)],
        save_mean=True)
    print("Metrics for {} indices in test set".format(partition_labels[i]))
    
    for metric_name, metric in metrics.items():
        print("{}: {}".format(metric_name, metric))
    print("")


Overall test set performance
TP: [471 562 546 530 480 447 467 495 484 477]
TN: [8765 8822 8644 7953 8202 8675 8798 8343 8305 8452]
FPs: [ 255   43  324 1037  816  433  244  629  721  539]
FNs: [509 573 486 480 502 445 491 533 490 532]
accuracy: 0.4959
per_class_accuracy: [0.9236 0.9384 0.919  0.8483 0.8682 0.9122 0.9265 0.8838 0.8789 0.8929]
per_class_accuracy_mean: 0.8991800000000001
precision: [0.64876033 0.92892562 0.62758621 0.33822591 0.37037037 0.50795455
 0.65682138 0.44039146 0.40165975 0.46948819]
precision_mean: 0.5390183759944495
recall: [0.48061224 0.49515419 0.52906977 0.52475248 0.48879837 0.50112108
 0.4874739  0.48151751 0.49691992 0.47274529]
recall_mean: 0.4958164743442034
predicted_class_distribution: [ 726  605  870 1567 1296  880  711 1124 1205 1016]
f1: [0.55216882 0.64597701 0.57413249 0.41133101 0.4214223  0.50451467
 0.55961654 0.46003717 0.44424048 0.47111111]
f1_mean: 0.5044551597509661

Metrics for congruent indices in test set
TP: [471 562 546 530 480 447 4

  return np.nan_to_num(2 * precisions * recalls / (precisions + recalls))
