In [2]:
import os
import torch
import torch.nn as nn
import torchvision.models as models
from src.NeuralModels import ConceptNetwork
from src.DataLoader import CaltechBirdsDataset 

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [31]:
phases = ['Train', 'Test']
image_datasets = {x:CaltechBirdsDataset(train=(x=='Train'), bounding=True) for x in phases}
dataloaders    = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=8, shuffle=False, num_workers=1) for x in phases}
dataset_sizes  = {x: len(image_datasets[x]) for x in phases}

In [32]:
concept_names = {'one_hot': 200}
for name in image_datasets['Train'].concept_names:
    concept_names[name] =  len(image_datasets['Train'].attributes.loc[(image_datasets['Train'].attributes['concept_name']==name)])

In [33]:
for model_dict in os.listdir(os.path.join(os.getcwd(), os.pardir, 'results')): 
    concept_name = model_dict[7:-4]
    if concept_name == 'one_hot':
        concept_set= {concept_name: concept_names[concept_name]}
model_ft = ConceptNetwork(concept_names=concept_set, freeze=False)
model_ft.load_state_dict(torch.load(os.path.join(os.getcwd(), os.pardir, 'results',model_dict), weights_only=True))
model_ft.to(device)
       
   

ConceptNetwork(
  (feature_detector): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine

In [34]:
all_labels, all_preds = [], []
correct, count = 0,0
for data_dict, inputs in dataloaders['Test']:
    inputs = inputs.to(device)
    labels = data_dict[concept_name]
    all_labels += [label.item() for label in labels]
    labels = labels.to(device)
    
    with torch.set_grad_enabled(False):
        outputs = model_ft(inputs)[concept_name]
        values, preds = torch.max(outputs, 1)
        all_preds += [pred.item() for pred in preds]
    count += len(labels)
    correct +=  torch.sum(preds == labels).item()
    
    


In [35]:
all_labels[:15]

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

In [36]:
all_preds[:15]

[185, 89, 91, 31, 27, 141, 75, 24, 184, 13, 115, 15, 12, 184, 91]