In [1]:
import sys
sys.path.append('..')

import torch
import torch.nn as nn
from torchvision.models import vgg16
from torchvision.models import VGG16_Weights

from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler

from src.data_loader import FairFaceData, CelebData
from facenet_pytorch import InceptionResnetV1
import torch.utils.data as torchdata

use_gpu = torch.cuda.is_available()
if use_gpu:
    print("Using CUDA")

class Identity(torch.nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x


In [2]:
# create model without fully connected layers

# model with vgg16
model = vgg16(weights=VGG16_Weights.IMAGENET1K_V1)
# remove fully connected layers
model.classifier = Identity()

# since all our images are the same size we probably done need the avgpool layer
model.avgpool = Identity()

# print(model)
for param in model.parameters():
      param.requires_grad = False

In [8]:
filename = '../data/img_align_celeba/202599.jpg'

# sample execution 
from PIL import Image
from torchvision import transforms

input_image = Image.open(filename)
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model

# move the input and model to GPU for speed if available
if torch.cuda.is_available():
    input_batch = input_batch.to('cuda')
    model.to('cuda')

with torch.no_grad():
    output = model(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
print(probabilities)


tensor([0., 0., 0.,  ..., 0., 0., 0.])
tensor([1.0600e-07, 1.0600e-07, 1.0600e-07,  ..., 1.0600e-07, 1.0600e-07,
        1.0600e-07])
