In [1]:
import torch
import pprint
import torchvision
from torchvision import models, transforms
from PIL import Image

In [2]:
image_names = ["candy4.jpeg", "candy5.jpeg"]
images = [Image.open(img) for img in image_names]

In [3]:
resnet = models.resnet101(weights=torchvision.models.ResNet101_Weights.IMAGENET1K_V1)

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

In [4]:
img_tensors = [preprocess(img) for img in images]
batch_tensors = [torch.unsqueeze(img_t, 0) for img_t in img_tensors]

In [5]:
resnet.eval()
outputs = [resnet(batch_t) for batch_t in batch_tensors]

In [6]:
for i in outputs:
    _, index = torch.max(i, 1)

indices = [torch.max(out, 1).indices for out in outputs]
print(indices)

[tensor([259]), tensor([259])]


In [7]:
with open("imagenet1000.txt", 'r') as f:
    lines = [line.strip() for line in f.readlines()]
    
    for i, val in enumerate(images):
        print(image_names[i])
        _, indices = torch.sort(outputs[i], descending=True)
        percentage = torch.nn.functional.softmax(outputs[i], dim=1)[0] * 100
        temp = [(lines[idx], percentage[idx].item()) for idx in indices[0][:10]]
        pprint.pprint(temp)


candy4.jpeg
[("259: 'Pomeranian',", 99.87897491455078),
 ("151: 'Chihuahua',", 0.048462312668561935),
 ("261: 'keeshond',", 0.034530431032180786),
 ("154: 'Pekinese, Pekingese, Peke',", 0.01686878316104412),
 ("157: 'papillon',", 0.010804271325469017),
 ("152: 'Japanese spaniel',", 0.002236943459138274),
 ("265: 'toy poodle',", 0.0018744270782917738),
 ("155: 'Shih-Tzu',", 0.00035706613562069833),
 ("192: 'cairn, cairn terrier',", 0.00035551906330510974),
 ("223: 'schipperke',", 0.0003492470714263618)]
candy5.jpeg
[("259: 'Pomeranian',", 99.90487670898438),
 ("154: 'Pekinese, Pekingese, Peke',", 0.062284424901008606),
 ("151: 'Chihuahua',", 0.009642297402024269),
 ("263: 'Pembroke, Pembroke Welsh corgi',", 0.006276996340602636),
 ("157: 'papillon',", 0.00566014414653182),
 ("261: 'keeshond',", 0.004148251377046108),
 ("260: 'chow, chow chow',", 0.0019690582994371653),
 ("258: 'Samoyed, Samoyede',", 0.0015442997682839632),
 ("152: 'Japanese spaniel',", 0.0009459663415327668),
 ("778: 's