In [1]:
from torchvision.models import resnet101, ResNet101_Weights, resnet50, ResNet50_Weights
import torch
from torchvision.datasets import ImageFolder
from torchmetrics import Accuracy

In [2]:
model = resnet101()

In [3]:
model.load_state_dict(torch.load('../models/models_weights/resnet101.pth'))
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [4]:
CDP_data = ImageFolder('../data_retrieval/CDP_data/', ResNet101_Weights.IMAGENET1K_V2.transforms())

In [5]:
test_loader = torch.utils.data.DataLoader(CDP_data,
                                             batch_size=16, shuffle=True,
                                             num_workers=4)

In [6]:
idx_to_class = {v: k for k, v in CDP_data.class_to_idx.items()}

In [7]:
correct_pred = {classname: 0 for classname in idx_to_class}
total_pred = {classname: 0 for classname in idx_to_class}
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = model(images)
        _, predictions = torch.max(outputs, 1)
        for label, prediction in zip(labels, predictions):
            if idx_to_class[int(label)] == idx_to_class[(int(prediction))]:
                correct_pred[int(label)] += 1
            total_pred[int(label)] += 1

In [8]:
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {idx_to_class[classname]:5s} is {accuracy:.1f} %')

Accuracy for class: E_Neo-Assyrian is 83.3 %
Accuracy for class: E_Neo-Babylonian is 73.9 %
Accuracy for class: GAR_Neo-Assyrian is 90.7 %
Accuracy for class: GAR_Neo-Babylonian is 83.8 %
Accuracy for class: KA_Neo-Assyrian is 96.9 %
Accuracy for class: KA_Neo-Babylonian is 50.0 %
Accuracy for class: KI_Neo-Assyrian is 92.3 %
Accuracy for class: KI_Neo-Babylonian is 53.7 %
Accuracy for class: MEŠ_Neo-Assyrian is 63.0 %
Accuracy for class: MEŠ_Neo-Babylonian is 88.9 %
Accuracy for class: NI_Neo-Assyrian is 91.2 %
Accuracy for class: NI_Neo-Babylonian is 80.0 %
Accuracy for class: RU_Neo-Assyrian is 96.3 %
Accuracy for class: RU_Neo-Babylonian is 55.6 %
Accuracy for class: TA_Neo-Assyrian is 70.0 %
Accuracy for class: TA_Neo-Babylonian is 84.6 %
Accuracy for class: TI_Neo-Assyrian is 81.8 %
Accuracy for class: TI_Neo-Babylonian is 60.0 %
Accuracy for class: U₂_Neo-Assyrian is 61.5 %
Accuracy for class: U₂_Neo-Babylonian is 83.9 %
Accuracy for class: ŠU_Neo-Assyrian is 84.6 %
Accuracy for

In [9]:
print("Overall accuracy: ",sum(correct_pred.values()) / sum(total_pred.values()))

Overall accuracy:  0.7637540453074434


In [10]:
true_values = torch.zeros(len(CDP_data), dtype=torch.int64)
preds = torch.zeros(len(CDP_data), 1000)
counter = 0
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = model(images)
        for pred,label  in zip(outputs, labels):
            preds[counter] = pred
            true_values[counter] = int(label)
            counter += 1
            

In [11]:
accuracy_2 = Accuracy('multiclass', num_classes=1000, top_k=2)
accuracy_3 = Accuracy('multiclass', num_classes=1000, top_k=3)

In [12]:
top_2 = accuracy_2(preds, true_values)
top_3 = accuracy_3(preds, true_values)

In [13]:
print("Overall Top-2 accuracy: ",float(top_2))
print("Overall Top-3 accuracy: ",float(top_3))

Overall Top-2 accuracy:  0.8883495330810547
Overall Top-3 accuracy:  0.9288026094436646
