## Example

In this simple example, we load an image, pre-process it, and classify it with a pretrained EfficientNet.

In [1]:
import json
from PIL import Image

import torch
from torchvision import transforms

from efficientnet_pytorch import EfficientNet

In [2]:
model_name = 'efficientnet-b0'
image_size = 128  # EfficientNet.get_image_size(model_name) # 224

In [None]:
# Open image
img = Image.open('verde.jpg')
img

In [4]:
# Preprocess image
tfms = transforms.Compose([transforms.Resize(image_size), transforms.CenterCrop(image_size),
                           transforms.ToTensor(),
                           transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),])
img = tfms(img).unsqueeze(0)

In [5]:
# Load class names
labels_map = json.load(open('labels_map.txt'))
labels_map = [labels_map[str(i)] for i in range(len(labels_map))]

In [None]:
num_classes = 2
weights_path="C:/git/EfficientNet-PyTorch/results/model_Best.pth.tar"

state_dict = torch.load(weights_path)
fc_weight = state_dict['state_dict']['_fc.weight']
print("Number of output classes of the weights loaded:", fc_weight.shape[0])


In [None]:
override_params = {"image_size":128}
model = EfficientNet.from_pretrained(
            model_name,
            weights_path=weights_path,
            in_channels=3,
            num_classes=2
            )

In [None]:
model = EfficientNet.from_name(model_name, num_classes=2)

ret = model.load_state_dict(state_dict, strict=False)
assert set(ret.missing_keys) == set(
    ['_fc.weight', '_fc.bias']), 'Missing keys when loading pretrained weights: {}'.format(ret.missing_keys)
assert not ret.unexpected_keys, 'Missing keys when loading pretrained weights: {}'.format(ret.unexpected_keys)

print('Loaded pretrained weights for {}'.format(model_name))

In [None]:
# Classify with EfficientNet
override_params = {"image_size":128}
model = EfficientNet.from_pretrained(
            model_name,
            weights_path="C:/git/EfficientNet-PyTorch/examples/imagenet/model_best.pth.tar",
            in_channels=3,
            num_classes=2
            )
model.eval()
with torch.no_grad():
    logits = model(img)
preds = torch.topk(logits, k=5).indices.squeeze(0).tolist()

print('-----')
for idx in preds:
    label = labels_map[idx]
    prob = torch.softmax(logits, dim=1)[0, idx].item()
    print('{:<75} ({:.2f}%)'.format(label, prob*100))