# Pretrained CNN classification

We use a pretrained CNN to classify images into ImageNet classes.

In [1]:
import numpy as np
from pathlib import Path
from PIL import Image
import torch

from lib.imagenet_classes import idx2label
from lib.plots import plot_transformed_val_input
from lib.cnn_classifiers import val_transform, get_cnn

In [2]:
from ipywidgets import interact


soda_imgs = sorted(
    str(img_path)
    for img_path in Path('data/sodas/query/').glob('**/*.jpg')
)

@interact(
    model_name=['alexnet', 'vgg19', 'inception_v3', 'resnet50'],
    img_path=soda_imgs,
    pretrained=True
)
def classify_img(model_name, img_path, pretrained=True):
    model = get_cnn(model_name, pretrained)
    model.eval()

    im = Image.open(img_path)
    print('Original shape:', np.array(im).shape)

    # Transform the image
    x = val_transform(im)
    print('Shape after transform:', x.shape)

    plot_transformed_val_input(im)

    # Add batch dimension
    x = x.unsqueeze(0)
    print('Input shape:', x.shape)

    # Pass through model
    with torch.no_grad():
        y = model(x)
    print('Output shape:', y.shape)

    # Remove batch dimension
    y = y.squeeze()
    print('Squeezed output shape:', y.shape)

    # Softmax to get procentual scores
    y = torch.nn.functional.softmax(y, dim=0)

    # Get top 10 matches
    print()
    top_10_idxs = y.argsort(descending=True)[:10].numpy()
    top_10 = [f'{idx2label[idx]} ({y[idx]*100:.1f}%)'
              for idx in top_10_idxs]
    
    print('Top 10 matches:\n', '\n'.join([f'\t{i + 1}) {label}' for i, label in enumerate(top_10)]))

interactive(children=(Dropdown(description='model_name', options=('alexnet', 'vgg19', 'inception_v3', 'resnet5…