In [26]:
import torch
from torchvision import transforms
from PIL import Image
from torch import argmax

from model import FaceModel

In [27]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FaceModel(14).to(device)
model.load_state_dict(torch.load('model.pth'))

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

classes = {
    'Ali_Khamenei': 0, 'Angelina_Jolie': 1, 'Barak_Obama': 2,
    'Behnam_Bani': 3, 'Donald_Trump': 4, 'Emma_Watson': 5,
    'Han_Hye_Jin': 6, 'Kim_Jong_Un': 7, 'Leyla_Hatami': 8,
    'Lionel_Messi': 9, 'Michelle_Obama': 10, 'Morgan_Freeman': 11,
    'Queen_Elizabeth': 12, 'Scarlett_Johansson': 13
}

In [31]:
def predict(image_path):
    image = Image.open(image_path)
    image = transform(image)
    image = image.unsqueeze(0)
    image = image.to(device)
    outputs = model(image)
    predict = argmax(outputs)
    return [k for k, v in classes.items() if v == predict.item()][0]

In [36]:
predict('dataset/Behnam_Bani/Behnam-Bani-11_01.jpg')

'Behnam_Bani'

In [37]:
predict('dataset/Han_Hye_Jin/Han-Hye-Jin-08_01.jpg')

'Han_Hye_Jin'

In [41]:
predict('dataset/Michelle_Obama/Michelle-Obama-10_01.jpg')

'Scarlett_Johansson'