## Models initialization

In [None]:
import torch
import torchvision.models as models

In [None]:
model_dict = {
      'AlexNet':    models.alexnet(pretrained=True),
          'VGG':    models.vgg19_bn(pretrained=True),
    'GoogleNet':    models.googlenet(pretrained=True),
       'ResNet':    models.resnet152(pretrained=True),
      'ResNeXt':    models.resnext101_32x8d(pretrained=True)
}

## Images

In [None]:
import matplotlib.pyplot as plt
import math
from PIL import Image
import os

directory = os.path.join('..', 'Images', 'Myimages')
images = {}

def display_images(images, columns=5, width=20, row_height=3.5, font_size=20, title=""):
    rows = math.ceil(len(images) / columns)
    fig = plt.figure(figsize=(width, row_height * rows))
    fig.suptitle(title, fontsize=font_size, x=0.14)

    for i, img in enumerate(images):
        plt.subplot(len(images) // columns + 1, columns, i + 1)
        plt.imshow(img)
        plt.axis("off")

def display_all_images():
    global images
    for key in images:
        display_images(images[key], title=key)

def resize_all_images():
    global images
    for key in images:
        for i, img in enumerate(images[key]):
            images[key][i] = img.resize((256,256))

def get_all_images():
    global directory, images
    all_subdirectories = os.listdir(directory)
    for x in all_subdirectories:
        x_path = os.path.join(directory, x)
        if os.path.isdir(x_path):
            images_from_files = []
            all_files = os.listdir(x_path)
            for f in all_files:
                images_from_files.append(Image.open(os.path.join(x_path, f)))
            images[x] = images_from_files
        else:
            images[x[:-4]] = [Image.open(x_path)] # x[:-4] - key of dir is filename without .jpg postfix
    resize_all_images()



In [None]:
get_all_images()

## Image transformations for models inputs

In [None]:
from torchvision import transforms

transform = transforms.Compose([
    transforms.CenterCrop(227),         # Crop the image to 227x227 pixels from center
    transforms.ToTensor(),              # Convert image to PyTorch Tensor data type
    transforms.Normalize(               # Normalizing image
        mean=[0.485, 0.456, 0.406],     # Mean and std same as used on training data
        std=[0.229, 0.224, 0.225]
    )
])

## Getting class names from file

In [None]:
with open('imagenet_classes.txt') as f:
    classes = [line.split(", ")[1].strip() for line in f.readlines()]

## Recognition and visualizing results

In [57]:
with open('results.txt', 'w') as f:
    for img in images:
        text = img + ":\n"
        imageset = images[img]

        # Recognition and writing to file
        for i in imageset:
            img_t = transform(i)
            batch_t = torch.unsqueeze(img_t, 0)
            for model in model_dict:
                text += "\t" + model + ":\n"
                model_dict[model].eval()
                out = model_dict[model](batch_t)
                _, indices = torch.sort(out, descending=True)
                percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
                for idx in indices[0][:5]:
                    text += "\t\t" + classes[idx] + ": " + str(percentage[idx].item()) + "\n"
            text += "\n"
        f.writelines(text)
        f.write("\n")
'DONE'

'DONE'