In [24]:
import os
import pickle
import torch
import torch.nn as nn
import torchvision.models as models
from src.NeuralModels import ConceptNetwork
from src.DataLoader import CaltechBirdsDataset 

device = 'cuda' if torch.cuda.is_available() else 'cpu'
results_dir = os.path.join(os.getcwd(), os.pardir, 'results')

In [25]:
phases = ['Train', 'Test']
image_datasets = {x:CaltechBirdsDataset(train=(x=='Train'), bounding=True) for x in phases}
dataloaders    = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=8, shuffle=False, num_workers=1) for x in phases}
dataset_sizes  = {x: len(image_datasets[x]) for x in phases}

In [16]:
concept_names = {'one_hot': 200}
for name in image_datasets['Train'].concept_names:
    concept_names[name] =  len(image_datasets['Train'].attributes.loc[(image_datasets['Train'].attributes['concept_name']==name)])

In [37]:
results_dict = dict()
for results_file in os.listdir(results_dir):
    file_type     = results_file[-11:-4]
    concept_name  = results_file[8:-11]
    if file_type == 'results':
        with open(os.path.join(results_dir,results_file), 'rb') as file:
            results = pickle.load(file)
            results_dict[concept_name] = results


In [52]:
threshold = 0.50
concept_list = []
for concept_name, results in results_dict.items():
    labels, preds = results['labels'], results['predictions']
    count, correct = 0, 0
    for i, label in enumerate(labels):
        if concept_name == 'one_hot' or label != 0:
            count +=1
            if preds[i] == label: correct += 1   
    accuracy = round(correct / count,2)
    print(concept_name, accuracy)
    if accuracy >= threshold:
        concept_list.append(concept_name)
print(concept_list)


has_back_color 0.51
has_back_pattern 0.49
has_belly_color 0.57
has_belly_pattern 0.67
has_bill_color 0.46
has_bill_length 0.73
has_bill_shape 0.58
has_breast_color 0.55
has_breast_pattern 0.64
has_crown_color 0.6
has_eye_color 0.85
has_forehead_color 0.58
has_head_pattern 0.33
has_leg_color 0.41
has_nape_color 0.57
has_primary_color 0.56
has_shape 0.59
has_size 0.6
has_tail_pattern 0.44
has_tail_shape 0.32
has_throat_color 0.59
has_underparts_color 0.55
has_under_tail_color 0.45
has_upperparts_color 0.51
has_upper_tail_color 0.4
has_wing_color 0.55
has_wing_pattern 0.52
has_wing_shape 0.14
one_hot 0.74
['has_back_color', 'has_belly_color', 'has_belly_pattern', 'has_bill_length', 'has_bill_shape', 'has_breast_color', 'has_breast_pattern', 'has_crown_color', 'has_eye_color', 'has_forehead_color', 'has_nape_color', 'has_primary_color', 'has_shape', 'has_size', 'has_throat_color', 'has_underparts_color', 'has_upperparts_color', 'has_wing_color', 'has_wing_pattern', 'one_hot']


In [30]:
        labels, preds = results['labels'], results['predictions']
        

{'labels': [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  2,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  3,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  4,
  5,
  5,
  5,
  5,
  5,
  5,
  5,
  5,
  5,
  5,
  5,
  6,
  6,
  6,
  6,
  6,
  6,
  6,
  6,
  6,
  6,
  6,
  6,
  6,
  6,
  6,
  6,
  6,
  6,
  6,
  6,
  6,
  6,
  6,
  7,
  7,
  7,
  7,
  7,
  7,
  7,
  7,
  7,
  7,
  7,
  7,
  7,
  7,
  7,
  7,
  7,
  7,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,
  8,


In [34]:
all_labels, all_preds = [], []
correct, count = 0,0
for data_dict, inputs in dataloaders['Test']:
    inputs = inputs.to(device)
    labels = data_dict[concept_name]
    all_labels += [label.item() for label in labels]
    labels = labels.to(device)
    
    with torch.set_grad_enabled(False):
        outputs = model_ft(inputs)[concept_name]
        values, preds = torch.max(outputs, 1)
        all_preds += [pred.item() for pred in preds]
    count += len(labels)
    correct +=  torch.sum(preds == labels).item()
    
    


In [35]:
all_labels[:15]

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

In [36]:
all_preds[:15]

[185, 89, 91, 31, 27, 141, 75, 24, 184, 13, 115, 15, 12, 184, 91]