In [1]:
from torchvision.models import resnet101, ResNet101_Weights
from torchvision.datasets import ImageFolder
import torch
from torchmetrics import Accuracy, F1Score, Precision, Recall
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset, ConcatDataset
from torchvision.transforms import Grayscale, Compose
import matplotlib.pyplot as plt
import seaborn as sn
import pandas as pd
import numpy as np
import torch.nn as nn

In [2]:
data = ImageFolder('../data_retrieval/data_other', Compose([ResNet101_Weights.IMAGENET1K_V2.transforms(), Grayscale()]))
targets = data.targets
train_indices, test_val_indices = train_test_split(np.arange(len(targets)), stratify=targets, train_size=0.65, random_state=21)
train_data = Subset(data, indices=train_indices)
val_test_data = Subset(data, indices=test_val_indices)
targets = np.array(targets)
test_val_targets = targets[test_val_indices]
test_indices, val_indices = train_test_split(test_val_indices, stratify=test_val_targets, train_size=0.57, random_state=21)
val_data, test_data = Subset(data, indices=val_indices), Subset(data, indices=test_indices)

In [3]:
model = resnet101()

In [4]:
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Linear(2048, len(test_data.dataset.classes), bias=True)

In [5]:
model.load_state_dict(torch.load('../models/models_weights/resnet101_grayscale_non_diagnostic.pth'))

<All keys matched successfully>

In [6]:
model.eval()

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [7]:
test_loader = torch.utils.data.DataLoader(test_data,
                                             batch_size=16, shuffle=True,
                                             num_workers=4)

In [8]:
print("Size of test dataset:", len(test_data))

Size of test dataset: 2547


In [9]:
idx_to_class = {v: k for k, v in test_data.dataset.class_to_idx.items()}

In [10]:
idx_to_class

{0: 'AN_Neo-Assyrian',
 1: 'AN_Neo-Babylonian',
 2: 'A_Neo-Assyrian',
 3: 'A_Neo-Babylonian',
 4: 'AŠ_Neo-Assyrian',
 5: 'AŠ_Neo-Babylonian',
 6: 'BAD_Neo-Assyrian',
 7: 'BAD_Neo-Babylonian',
 8: 'DIŠ_Neo-Assyrian',
 9: 'DIŠ_Neo-Babylonian',
 10: 'GIŠ_Neo-Assyrian',
 11: 'GIŠ_Neo-Babylonian',
 12: 'IGI_Neo-Assyrian',
 13: 'IGI_Neo-Babylonian',
 14: 'I_Neo-Assyrian',
 15: 'I_Neo-Babylonian',
 16: 'MA_Neo-Assyrian',
 17: 'MA_Neo-Babylonian',
 18: 'MU_Neo-Assyrian',
 19: 'MU_Neo-Babylonian',
 20: 'NA_Neo-Assyrian',
 21: 'NA_Neo-Babylonian',
 22: 'NU_Neo-Assyrian',
 23: 'NU_Neo-Babylonian',
 24: 'UD_Neo-Assyrian',
 25: 'UD_Neo-Babylonian',
 26: 'ŠU₂_Neo-Assyrian',
 27: 'ŠU₂_Neo-Babylonian'}

In [11]:
correct_pred = {classname: 0 for classname in idx_to_class}
total_pred = {classname: 0 for classname in idx_to_class}
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = model(images)
        _, predictions = torch.max(outputs, 1)
        for label, prediction in zip(labels, predictions):
            if idx_to_class[int(label)] == idx_to_class[(int(prediction))]:
                correct_pred[int(label)] += 1
            total_pred[int(label)] += 1

In [14]:
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {idx_to_class[classname]:5s} is {accuracy:.1f} %')

Accuracy for class: AN_Neo-Assyrian is 84.5 %
Accuracy for class: AN_Neo-Babylonian is 86.2 %
Accuracy for class: A_Neo-Assyrian is 65.9 %
Accuracy for class: A_Neo-Babylonian is 89.9 %
Accuracy for class: AŠ_Neo-Assyrian is 68.4 %
Accuracy for class: AŠ_Neo-Babylonian is 81.6 %
Accuracy for class: BAD_Neo-Assyrian is 80.4 %
Accuracy for class: BAD_Neo-Babylonian is 79.2 %
Accuracy for class: DIŠ_Neo-Assyrian is 74.6 %
Accuracy for class: DIŠ_Neo-Babylonian is 91.0 %
Accuracy for class: GIŠ_Neo-Assyrian is 75.8 %
Accuracy for class: GIŠ_Neo-Babylonian is 73.3 %
Accuracy for class: IGI_Neo-Assyrian is 88.4 %
Accuracy for class: IGI_Neo-Babylonian is 89.2 %
Accuracy for class: I_Neo-Assyrian is 85.7 %
Accuracy for class: I_Neo-Babylonian is 94.7 %
Accuracy for class: MA_Neo-Assyrian is 91.5 %
Accuracy for class: MA_Neo-Babylonian is 77.8 %
Accuracy for class: MU_Neo-Assyrian is 83.3 %
Accuracy for class: MU_Neo-Babylonian is 85.3 %
Accuracy for class: NA_Neo-Assyrian is 88.5 %
Accuracy f

In [15]:
print("Overall accuracy: ",sum(correct_pred.values()) / sum(total_pred.values()))

Overall accuracy:  0.8578720062819003
