Based on this vid: https://www.youtube.com/watch?v=15zlr2vJqKc

Models initialization

In [105]:
import torch
import torchvision.models as models

In [106]:
modelDict = {
    'VGG': models.vgg19_bn(pretrained=True),
    'AlexNet': models.alexnet(pretrained=True),
    'ResNet': models.resnet152(pretrained=True),
    'GoogleNet': models.googlenet(pretrained=True),
}

Images

In [107]:
from PIL import Image
from skimage import io, transform
import matplotlib.pyplot as plt

images = {
    'Cat': Image.open('../Images/cat.jfif'),
    'Bear': Image.open('../Images/bear_fish.jfif'),
    'Skunk': Image.open('../Images/skunk.jfif'),
    'Elephant': Image.open('../Images/slonind.jfif'),
    'Dugong': Image.open('../Images/mkrava.jfif'),
    'Agama': Image.open('../Images/agama.jfif'),
    'Ambulance': Image.open('../Images/ambulance.jfif')
}

Image transformations

In [108]:
from torchvision import transforms
transform = transforms.Compose([
    transforms.Resize(256),             # Resizing image to 256x256 
    transforms.CenterCrop(224),         # Crop the image to 224x224 pixels from center
    transforms.ToTensor(),              # Convert image to PyTorch Tensor data type
    transforms.Normalize(               # Normalizing image
        mean=[0.485, 0.456, 0.406],     # Mean and std same as used on training data
        std=[0.229, 0.224, 0.225]
    )
])

Class names

In [109]:
with open('imagenet_classes.txt') as f:
    classes = [line.split(", ")[1].strip() for line in f.readlines()]

Testing NNs and creating .csv file with result

In [110]:
import csv

result = open('result.csv', 'w')
filewriter = csv.writer(result, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
filewriter.writerow(['Image', 'VGG', 'AlexNet', 'ResNet', 'GoogleNet'])

36

In [111]:
for img in images:
    row = []
    row.append(img)
    img_t = transform(images[img])
    batch_t = torch.unsqueeze(img_t, 0)
    for model in modelDict:
        modelDict[model].eval()
        out = modelDict[model](batch_t)
        _, indices = torch.sort(out, descending=True)
        percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
        found = 0
        for idx in indices[0][:5]:
            if classes[idx].lower().find(img.lower()) >= 0:
                found = percentage[idx].item()
                break
        row.append(found)
    filewriter.writerow(row)

result.close()
"DONE"

'DONE'

Visualize result

In [112]:
import pandas as pd

df = pd.read_csv('result.csv')

In [113]:
%matplotlib

Using matplotlib backend: Qt5Agg


In [115]:
df.plot(x="Image", y=["VGG", "AlexNet", "ResNet", "GoogleNet"], kind="bar")

<AxesSubplot:xlabel='Image'>