In [28]:
import os
import pickle
import torch
import itertools
import torch.nn as nn
import torchvision.models as models
from src.NeuralModels import ConceptNetwork
from src.DataLoader import CaltechBirdsDataset 

device = 'cuda' if torch.cuda.is_available() else 'cpu'
results_dir = os.path.join(os.getcwd(), os.pardir, 'results')

In [2]:
phases = ['Train', 'Test']
image_datasets = {x:CaltechBirdsDataset(train=(x=='Train'), bounding=True) for x in phases}
dataloaders    = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=8, shuffle=False, num_workers=1) for x in phases}
dataset_sizes  = {x: len(image_datasets[x]) for x in phases}

In [3]:
concept_names = {'one_hot': 200}
for name in image_datasets['Train'].concept_names:
    concept_names[name] =  len(image_datasets['Train'].attributes.loc[(image_datasets['Train'].attributes['concept_name']==name)])

In [4]:
results_dict = dict()
for results_file in os.listdir(results_dir):
    file_type     = results_file[-11:-4]
    concept_name  = results_file[8:-11]
    if file_type == 'results':
        with open(os.path.join(results_dir,results_file), 'rb') as file:
            results = pickle.load(file)
            results_dict[concept_name] = results


In [26]:
incorrect_ids = dict()
threshold     = 0.55
concept_list  = []
for concept_name, results in results_dict.items():
    labels, preds = results['labels'], results['predictions']
    count, correct, bad_ids = 0, 0, set()
    for i, label in enumerate(labels):
        if concept_name == 'one_hot' or label != 0 :
            count +=1
            if preds[i] == label: correct += 1
            else: bad_ids.add(i)
    incorrect_ids[concept_name] = bad_ids
    accuracy = round(correct / count,2)
    if accuracy > threshold:
        concept_list.append(concept_name)

In [40]:
for concepts in itertools.combinations(concept_list[:-1],3):
    id1, id2, id3 = incorrect_ids[concepts[0]], incorrect_ids[concepts[1]], incorrect_ids[concepts[2]]
    print(concepts, len(id1.intersection(id2, id3)))

('has_belly_color', 'has_belly_pattern', 'has_bill_length') 306
('has_belly_color', 'has_belly_pattern', 'has_bill_shape') 438
('has_belly_color', 'has_belly_pattern', 'has_breast_pattern') 688
('has_belly_color', 'has_belly_pattern', 'has_crown_color') 512
('has_belly_color', 'has_belly_pattern', 'has_eye_color') 219
('has_belly_color', 'has_belly_pattern', 'has_forehead_color') 540
('has_belly_color', 'has_belly_pattern', 'has_nape_color') 549
('has_belly_color', 'has_belly_pattern', 'has_primary_color') 568
('has_belly_color', 'has_belly_pattern', 'has_shape') 480
('has_belly_color', 'has_belly_pattern', 'has_size') 442
('has_belly_color', 'has_belly_pattern', 'has_throat_color') 614
('has_belly_color', 'has_bill_length', 'has_bill_shape') 334
('has_belly_color', 'has_bill_length', 'has_breast_pattern') 306
('has_belly_color', 'has_bill_length', 'has_crown_color') 320
('has_belly_color', 'has_bill_length', 'has_eye_color') 163
('has_belly_color', 'has_bill_length', 'has_forehead_col

In [34]:
all_labels, all_preds = [], []
correct, count = 0,0
for data_dict, inputs in dataloaders['Test']:
    inputs = inputs.to(device)
    labels = data_dict[concept_name]
    all_labels += [label.item() for label in labels]
    labels = labels.to(device)
    
    with torch.set_grad_enabled(False):
        outputs = model_ft(inputs)[concept_name]
        values, preds = torch.max(outputs, 1)
        all_preds += [pred.item() for pred in preds]
    count += len(labels)
    correct +=  torch.sum(preds == labels).item()
    
    


In [35]:
all_labels[:15]

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

In [36]:
all_preds[:15]

[185, 89, 91, 31, 27, 141, 75, 24, 184, 13, 115, 15, 12, 184, 91]