In [1]:
from torch.utils.data import DataLoader
from concept_model.dataset import CUBImageToClass

batch_size = 32
num_workers = 2

test_data = CUBImageToClass(train=False)
test_dataloader = DataLoader(test_data, batch_size=batch_size, num_workers=num_workers)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break


Shape of X [N, C, H, W]: torch.Size([32, 3, 299, 299])
Shape of y: torch.Size([32]) torch.int64


In [2]:
import torch
from concept_model.inference import ImageToAttributesModel, AttributesToClassModel

image_to_attributes_model = ImageToAttributesModel(
    "independent_image_to_attributes.pth"
).model
attributes_to_class_model = AttributesToClassModel(
    "independent_attributes_to_class.pth"
).model

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))


Using cache found in /home/joanna/.cache/torch/hub/pytorch_vision_v0.10.0


Using cuda device


In [3]:
def test(dataloader, image_to_attributes_model, attributes_to_class_model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    image_to_attributes_model.eval()
    attributes_to_class_model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = image_to_attributes_model(X.to(device)), y.to(device)
            pred = attributes_to_class_model(torch.nn.Sigmoid()(X))
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    return test_loss, correct


In [4]:
test_loss, correct = test(
    test_dataloader,
    image_to_attributes_model,
    attributes_to_class_model,
    torch.nn.CrossEntropyLoss(),
)
print(f"Test loss for independent model: {test_loss:>8f}")
print(f"Test accuracy for independent model: {100 * correct:>0.2f}%")


Test loss for independent model: 3.474282
Test accuracy for independent model: 25.73%
