In [1]:
# Try image classification using AlexNet on one of the jigsaw images
from torchvision import models
import torch

In [2]:
# Load pre-trained AlexNet
alexnet = models.alexnet(pretrained=True)

In [3]:
# Define a transform to prepare the image for AlexNet
from torchvision import transforms
transform = transforms.Compose([            #[1]
 transforms.Resize(256),                    #[2]
 transforms.CenterCrop(224),                #[3]
 transforms.ToTensor(),                     #[4]
 transforms.Normalize(                      #[5]
 mean=[0.485, 0.456, 0.406],                #[6]
 std=[0.229, 0.224, 0.225]                  #[7]
 )])

In [20]:
# Import Pillow and load the image
from PIL import Image
import glob
ims = glob.glob('storage_theme/food/*.jpg')
im = ims[1]
img = Image.open(im)

In [21]:
# Prepare the image
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0)

In [22]:
# Run AlexNet!
alexnet.eval()
out = alexnet(batch_t)
test = img_t.view(1, 3, 224, 224)

In [23]:
# Check which class it is
with open('imagenet_classes.txt') as f:
    labels = [line.strip() for line in f.readlines()]
    
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100

_, indices = torch.sort(out, descending=True)
print([(labels[idx], percentage[idx].item()) for idx in indices[0][:5]])

[("509: 'confectionery, confectionary, candy store',", 74.10552215576172), ("582: 'grocery store, grocery, food market, market',", 11.73461627960205), ("692: 'packet',", 5.971068382263184), ("800: 'slot, one-armed bandit',", 5.235652446746826), ("917: 'comic book',", 1.4645171165466309)]


In [22]:
# Try resnet too
# First, load the model
resnet = models.resnet101(pretrained=True)
 
# Second, put the network in eval mode
resnet.eval()
 
# Third, carry out model inference
out = resnet(batch_t)
 
# Forth, print the top 5 classes predicted by the model
_, indices = torch.sort(out, descending=True)
percentage = torch.nn.functional.softmax(out, dim=1)[0] * 100
[(labels[idx], percentage[idx].item()) for idx in indices[0][:5]]

[("454: 'bookshop, bookstore, bookstall',", 72.2619400024414),
 ("860: 'tobacco shop, tobacconist shop, tobacconist',", 23.93340492248535),
 ("865: 'toyshop',", 0.7846974730491638),
 ("509: 'confectionery, confectionary, candy store',", 0.7568069696426392),
 ("624: 'library',", 0.5270415544509888)]