In [12]:
import pathlib
import json
import numpy as np

In [13]:
# load labels of the test set
candidate_file = 'candidate.npz'
test_labels = np.load(candidate_file)['y_test']

In [14]:
# load the indices of the correctly classified test images across different runs
correct_dirs = pathlib.Path.cwd() / 'results' / 'resnet'
cndt_corr_list, cifar_corr_list = list(), list()

for json_file in correct_dirs.glob('*.json'):
    with open(json_file) as fn:
        json_dict = json.load(fn)
    cifar_corr_list.append(json_dict) if 'cifar' in str(json_file) else cndt_corr_list.append(json_dict)

In [25]:
image_per_cls = 200
num_cls = 10

cndt_corr_dict, cifar_corr_dict = dict(), dict()

for i in range(num_cls):
    cndt_corr_dict[i] = [len([idx for idx in crlist['correct'] if test_labels[int(idx)] == i]) 
                         for crlist in cndt_corr_list]
    cifar_corr_dict[i] = [len([idx for idx in crlist['correct'] if test_labels[int(idx)] == i]) 
                         for crlist in cifar_corr_list]

In [26]:
cndt_corr_dict

{0: [170, 166, 172, 171, 172],
 1: [184, 185, 183, 188, 186],
 2: [185, 184, 180, 183, 184],
 3: [172, 168, 170, 176, 175],
 4: [190, 188, 191, 188, 190],
 5: [172, 173, 177, 168, 174],
 6: [192, 190, 191, 189, 194],
 7: [194, 189, 193, 194, 193],
 8: [187, 189, 187, 185, 186],
 9: [188, 190, 188, 192, 188]}

In [27]:
cifar_corr_dict

{0: [163, 164, 164, 158, 167],
 1: [163, 164, 160, 163, 165],
 2: [171, 171, 176, 166, 168],
 3: [155, 150, 158, 153, 160],
 4: [180, 179, 177, 181, 181],
 5: [155, 161, 156, 163, 161],
 6: [180, 180, 183, 177, 177],
 7: [178, 171, 173, 178, 180],
 8: [173, 174, 181, 175, 176],
 9: [188, 184, 187, 184, 187]}

In [65]:
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
str_classes = ''
for lbl in cifar10_classes:
    str_classes += lbl + ' '
str_classes = str_classes[:-1]

csv_file = 'cifar_accuracy_per_class.csv'
csv_iter = cifar_corr_dict.values() if csv_file[:5] == 'cifar' else cndt_corr_dict.values()

mean_list = [np.mean(corr) for corr in csv_iter]
# std_list = [np.std(corr, ddof=0) for corr in cifar_corr_dict.values()]
mean_list = [np.mean(np.array(corr)/200) for corr in csv_iter]
std_list = [np.std(np.array(corr)/200, ddof=0) for corr in csv_iter]

np.savetxt(csv_file, 
           np.column_stack((mean_list, std_list)), 
           fmt='%.4f', 
           header='class mean std', 
           comments='',
           # delimiter=',', 
           newline='\n')

with open(csv_file, 'r') as fn:
    csv_lines = fn.readlines()
    
with open(csv_file, 'w') as fn:
    for i, ln in enumerate(csv_lines):
        if i == 0:
            fn.write(ln)
        elif i < 11:
            fn.write(cifar10_classes[i-1] + ' ' + ln)
        else:
            print(ln)