In [5]:
from torchvision.models import  ResNet101_Weights
from torchvision.datasets import ImageFolder
import torch
from torchmetrics import Accuracy
from sklearn.metrics import confusion_matrix
import sys
sys.path.append('../')
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
from models import two_conv_network
import matplotlib.pyplot as plt
import seaborn as sn
import pandas as pd
import numpy as np
import torch.nn as nn

In [6]:
data = ImageFolder('../data_retrieval/data_other', ResNet101_Weights.IMAGENET1K_V2.transforms())
targets = data.targets
train_indices, test_val_indices = train_test_split(np.arange(len(targets)), stratify=targets, train_size=0.65, random_state=21)
train_data = Subset(data, indices=train_indices)
val_test_data = Subset(data, indices=test_val_indices)
targets = np.array(targets)
test_val_targets = targets[test_val_indices]
test_indices, val_indices = train_test_split(test_val_indices, stratify=test_val_targets, train_size=0.57, random_state=21)
val_data, test_data = Subset(data, indices=val_indices), Subset(data, indices=test_indices)

In [8]:
model = two_conv_network.TwoLayerConvNet(len(test_data.dataset.classes))

In [9]:
model.load_state_dict(torch.load('../models/models_weights/two_conv_network_non_diagnostic.pth'))

<All keys matched successfully>

In [10]:
model.eval()

TwoLayerConvNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1): ReLU()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2): ReLU()
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=100352, out_features=16, bias=True)
  (relu3): ReLU()
  (fc2): Linear(in_features=16, out_features=28, bias=True)
)

In [11]:
test_loader = torch.utils.data.DataLoader(test_data,
                                             batch_size=16, shuffle=True,
                                             num_workers=4)

In [12]:
print("Size of test dataset:", len(test_data))

Size of test dataset: 2547


In [13]:
idx_to_class = {v: k for k, v in test_data.dataset.class_to_idx.items()}

In [14]:
idx_to_class

{0: 'AN_Neo-Assyrian',
 1: 'AN_Neo-Babylonian',
 2: 'A_Neo-Assyrian',
 3: 'A_Neo-Babylonian',
 4: 'AŠ_Neo-Assyrian',
 5: 'AŠ_Neo-Babylonian',
 6: 'BAD_Neo-Assyrian',
 7: 'BAD_Neo-Babylonian',
 8: 'DIŠ_Neo-Assyrian',
 9: 'DIŠ_Neo-Babylonian',
 10: 'GIŠ_Neo-Assyrian',
 11: 'GIŠ_Neo-Babylonian',
 12: 'IGI_Neo-Assyrian',
 13: 'IGI_Neo-Babylonian',
 14: 'I_Neo-Assyrian',
 15: 'I_Neo-Babylonian',
 16: 'MA_Neo-Assyrian',
 17: 'MA_Neo-Babylonian',
 18: 'MU_Neo-Assyrian',
 19: 'MU_Neo-Babylonian',
 20: 'NA_Neo-Assyrian',
 21: 'NA_Neo-Babylonian',
 22: 'NU_Neo-Assyrian',
 23: 'NU_Neo-Babylonian',
 24: 'UD_Neo-Assyrian',
 25: 'UD_Neo-Babylonian',
 26: 'ŠU₂_Neo-Assyrian',
 27: 'ŠU₂_Neo-Babylonian'}

In [15]:
correct_pred = {classname: 0 for classname in idx_to_class}
total_pred = {classname: 0 for classname in idx_to_class}
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        outputs = model(images)
        _, predictions = torch.max(outputs, 1)
        for label, prediction in zip(labels, predictions):
            if idx_to_class[int(label)] == idx_to_class[(int(prediction))]:
                correct_pred[int(label)] += 1
            total_pred[int(label)] += 1

In [16]:
for classname, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[classname]
    print(f'Accuracy for class: {idx_to_class[classname]:5s} is {accuracy:.1f} %')

Accuracy for class: AN_Neo-Assyrian is 60.6 %
Accuracy for class: AN_Neo-Babylonian is 32.1 %
Accuracy for class: A_Neo-Assyrian is 0.0 %
Accuracy for class: A_Neo-Babylonian is 30.4 %
Accuracy for class: AŠ_Neo-Assyrian is 0.0 %
Accuracy for class: AŠ_Neo-Babylonian is 54.4 %
Accuracy for class: BAD_Neo-Assyrian is 0.0 %
Accuracy for class: BAD_Neo-Babylonian is 0.0 %
Accuracy for class: DIŠ_Neo-Assyrian is 0.0 %
Accuracy for class: DIŠ_Neo-Babylonian is 79.6 %
Accuracy for class: GIŠ_Neo-Assyrian is 0.0 %
Accuracy for class: GIŠ_Neo-Babylonian is 0.0 %
Accuracy for class: IGI_Neo-Assyrian is 0.0 %
Accuracy for class: IGI_Neo-Babylonian is 0.0 %
Accuracy for class: I_Neo-Assyrian is 0.0 %
Accuracy for class: I_Neo-Babylonian is 0.0 %
Accuracy for class: MA_Neo-Assyrian is 0.0 %
Accuracy for class: MA_Neo-Babylonian is 0.0 %
Accuracy for class: MU_Neo-Assyrian is 0.0 %
Accuracy for class: MU_Neo-Babylonian is 0.0 %
Accuracy for class: NA_Neo-Assyrian is 0.0 %
Accuracy for class: NA_Neo

In [17]:
print("Overall accuracy: ",sum(correct_pred.values()) / sum(total_pred.values()))

Overall accuracy:  0.27326266195524146
